diff options
Diffstat (limited to 'tensorflow/core')
532 files changed, 23159 insertions, 6262 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 47f16ac747..d575604a56 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -85,11 +85,12 @@ load( "tf_cc_tests", "tf_copts", "tf_cuda_library", + "tf_features_nomodules_if_android", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "tf_features_nomodules_if_android", + "transitive_hdrs", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -120,16 +121,16 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", - "tf_additional_proto_hdrs", "tf_additional_proto_compiler_hdrs", + "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", "tf_additional_verbs_lib_defines", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", - "tf_lib_proto_parsing_deps", "tf_lib_proto_compiler_deps", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", "tf_platform_hdrs", "tf_platform_srcs", @@ -143,6 +144,7 @@ load( ) load( "//tensorflow/core:platform/default/build_config_root.bzl", + "if_dynamic_kernels", "if_static", "tf_cuda_tests_tags", ) @@ -168,6 +170,7 @@ COMMON_PROTO_SRCS = [ "example/example.proto", "example/feature.proto", "framework/allocation_description.proto", + "framework/api_def.proto", "framework/attr_value.proto", "framework/cost_graph.proto", "framework/device_attributes.proto", @@ -179,7 +182,6 @@ COMMON_PROTO_SRCS = [ "framework/log_memory.proto", "framework/node_def.proto", "framework/op_def.proto", - "framework/api_def.proto", "framework/reader_base.proto", "framework/remote_fused_graph_execute_info.proto", "framework/resource_handle.proto", @@ -299,6 +301,7 @@ filegroup( name = "platform_base_hdrs", srcs = [ "platform/byte_order.h", + "platform/cord.h", "platform/env_time.h", "platform/logging.h", "platform/macros.h", @@ -695,15 +698,29 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":lib_internal", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:optional", ], ) cc_library( + name = "feature_util", + srcs = ["example/feature_util.cc"], + hdrs = ["example/feature_util.h"], + visibility = ["//visibility:public"], + deps = [ + ":core_stringpiece", + ":lib_proto_parsing", + ":protos_all_cc", + ], +) + +cc_library( name = "abi", srcs = ["platform/abi.cc"], hdrs = ["platform/abi.h"], + deps = [":platform_base"], ) cc_library( @@ -858,7 +875,6 @@ tf_cuda_library( "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", - "util/env_var.h", "util/events_writer.h", "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", @@ -1038,7 +1054,6 @@ tf_gen_op_libs( "random_grad", "random_ops", "remote_fused_graph_ops", - "resource_variable_ops", "rpc_ops", "scoped_allocator_ops", "sdca_ops", @@ -1049,7 +1064,6 @@ tf_gen_op_libs( "spectral_ops", "state_ops", "stateless_random_ops", - "string_ops", "summary_ops", "training_ops", ], @@ -1057,6 +1071,13 @@ tf_gen_op_libs( tf_gen_op_libs( op_lib_names = [ + "string_ops", + ], + deps = ["@com_google_absl//absl/strings"], +) + +tf_gen_op_libs( + op_lib_names = [ "array_ops", ], deps = [":protos_all_cc"], @@ -1074,6 +1095,14 @@ tf_gen_op_libs( deps = ["//tensorflow/core/kernels:debug_ops"], ) +tf_gen_op_libs( + is_external = False, + op_lib_names = [ + "resource_variable_ops", + ], + deps = [":lib"], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1268,8 +1297,8 @@ cc_library( # This includes implementations of all kernels built into TensorFlow. cc_library( - name = "all_kernels", - visibility = ["//visibility:public"], + name = "all_kernels_statically_linked", + visibility = ["//visibility:private"], deps = [ "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:audio", @@ -1312,6 +1341,7 @@ cc_library( "//tensorflow/core/kernels:rpc_op", "//tensorflow/core/kernels:scoped_allocator_ops", "//tensorflow/core/kernels:sdca_ops", + "//tensorflow/core/kernels:searchsorted_op", "//tensorflow/core/kernels:set_kernels", "//tensorflow/core/kernels:sparse", "//tensorflow/core/kernels:state", @@ -1339,6 +1369,7 @@ cc_library( "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_slice_op", "//tensorflow/core/kernels:mkl_softmax_op", + "//tensorflow/core/kernels:mkl_transpose_op", "//tensorflow/core/kernels:mkl_tfconv_op", "//tensorflow/core/kernels:mkl_aggregate_ops", ]) + if_cuda([ @@ -1347,6 +1378,15 @@ cc_library( ]), ) +cc_library( + name = "all_kernels", + visibility = ["//visibility:public"], + deps = if_dynamic_kernels( + [], + otherwise = [":all_kernels_statically_linked"], + ), +) + tf_cuda_library( name = "tensorflow_opensource", copts = tf_copts(), @@ -1410,9 +1450,11 @@ cc_library( ":test", ":testlib_ops", "//tensorflow/cc:scope", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:random_ops", ], ) @@ -1902,6 +1944,13 @@ tf_pyclif_proto_library( ) tf_pyclif_proto_library( + name = "protobuf/config_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/config.proto", + visibility = ["//visibility:public"], +) + +tf_pyclif_proto_library( name = "protobuf/device_properties_pyclif", proto_lib = ":protos_all_cc", proto_srcfile = "protobuf/device_properties.proto", @@ -2040,6 +2089,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/snappy.h", "platform/tensor_coding.h", "platform/tracing.h", + "util/env_var.h", ] # Replicated for lib_internal and lib_internal_impl. @@ -2067,6 +2117,7 @@ cc_library( deps = tf_additional_lib_deps() + [ "@com_google_absl//absl/strings", "//third_party/eigen3", + "@com_google_absl//absl/base:core_headers", "//tensorflow/core/platform/default/build_config:platformlib", ] + if_static([":lib_internal_impl"]), ) @@ -2079,6 +2130,7 @@ cc_library( "platform/*.cc", "platform/profile_utils/**/*.cc", "framework/resource_handle.cc", + "util/env_var.cc", ], exclude = [ "**/*test*", @@ -2258,6 +2310,7 @@ cc_library( deps = [ "//tensorflow/core/platform/default/build_config:jpeg", "//tensorflow/core/platform/default/build_config:logging", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) @@ -2290,6 +2343,7 @@ cc_library( deps = [ "//tensorflow/core/platform/default/build_config:gif", "//tensorflow/core/platform/default/build_config:logging", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/strings", ], ) @@ -2434,7 +2488,6 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/unique_tensor_references.h", "framework/variant.h", "util/command_line_flags.h", - "util/env_var.h", "util/equal_graph_def.h", "util/presized_cuckoo_map.h", "util/tensor_slice_set.h", @@ -2463,7 +2516,12 @@ tf_cuda_library( cc_header_only_library( name = "framework_internal_headers_lib", - includes = ["../../external/com_google_absl"], + # Fully depend on external repositories, because identifying the headers + # is fragile. + extra_deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], deps = [ ":lib", ":lib_internal", @@ -2501,6 +2559,7 @@ tf_cuda_library( "**/*test*", "**/*main.cc", "example/example_parser_configuration.*", + "example/feature_util.cc", "util/reporter.cc", "framework/fake_input.*", "framework/op_gen_lib.*", @@ -2510,6 +2569,7 @@ tf_cuda_library( "util/memmapped_file_system_writer.*", "util/stats_calculator.*", "util/version_info.cc", + "util/env_var.cc", ], ) + select({ "//tensorflow:windows": [], @@ -2529,6 +2589,7 @@ tf_cuda_library( ], }), deps = [ + ":feature_util", ":lib", ":lib_internal", ":protos_all_proto_text", @@ -2548,11 +2609,12 @@ tf_cuda_library( cc_header_only_library( name = "framework_headers_lib", + # Fully depend on external repositories, because identifying the headers + # is fragile. extra_deps = [ - # ABSL headers get dropped, so we add them back here. "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", ], - includes = ["../../external/com_google_absl"], visibility = ["//visibility:public"], deps = [ ":framework", @@ -2562,7 +2624,12 @@ cc_header_only_library( cc_header_only_library( name = "stream_executor_headers_lib", - includes = ["../../external/com_google_absl"], + # Fully depend on external repositories, because identifying the headers + # is fragile. + extra_deps = [ + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + ], visibility = ["//visibility:public"], deps = [ ":stream_executor", @@ -2753,8 +2820,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/stats_publisher_interface.h", "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", - "common_runtime/tracing_device.h", - "common_runtime/visitable_allocator.h", "common_runtime/process_state.h", "common_runtime/pool_allocator.h", "graph/gradients.h", @@ -2950,12 +3015,16 @@ tf_cuda_library( ] + tf_additional_device_tracer_deps(), ) -cc_library( - name = "session_ref", - srcs = ["common_runtime/session_ref.cc"], - hdrs = ["common_runtime/session_ref.h"], - copts = tf_copts(), - deps = [":core_cpu_base"], +tf_proto_library_cc( + name = "replay_log_proto", + srcs = ["protobuf/replay_log.proto"], + cc_api_version = 2, + protodeps = [ + ":master_proto", + ] + tf_additional_all_protos(), + visibility = [ + "//tensorflow:internal", + ], ) cc_library( @@ -3221,7 +3290,6 @@ tf_cc_tests( "lib/gtl/edit_distance_test.cc", "lib/gtl/flatmap_test.cc", "lib/gtl/flatset_test.cc", - "lib/gtl/inlined_vector_test.cc", "lib/gtl/int_type_test.cc", "lib/gtl/iterator_range_test.cc", "lib/gtl/manual_constructor_test.cc", @@ -3713,6 +3781,7 @@ tf_cc_test_mkl( ":core_cpu_internal", ":framework", ":framework_internal", + ":lib", ":test", ":test_main", ":testlib", @@ -4693,6 +4762,18 @@ cc_library( ] + tf_additional_libdevice_deps(), ) +transitive_hdrs( + name = "headers", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets go here (must be at the end). diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt new file mode 100644 index 0000000000..cdaeb5091c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt @@ -0,0 +1,34 @@ +op { + graph_op_name: "BoostedTreesBucketize" + visibility: HIDDEN + in_arg { + name: "float_values" + description: <<END +float; List of Rank 2 Tensor each containing float values for a single feature. +END + } + in_arg { + name: "bucket_boundaries" + description: <<END +float; List of Rank 1 Tensors each containing the bucket boundaries for a single +feature. +END + } + out_arg { + name: "buckets" + description: <<END +int; List of Rank 2 Tensors each containing the bucketized values for a single feature. +END + } + attr { + name: "num_features" + description: <<END +inferred int; number of features. +END + } + summary: "Bucketize each feature based on bucket boundaries." + description: <<END +An op that returns a list of float tensors, where each tensor represents the +bucketized values for a single feature. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt new file mode 100644 index 0000000000..20da1295f6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "BoostedTreesCreateQuantileStreamResource" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource; Handle to quantile stream resource. +END + } + in_arg { + name: "epsilon" + description: <<END +float; The required approximation error of the stream resource. +END + } + in_arg { + name: "num_streams" + description: <<END +int; The number of streams managed by the resource that shares the same epsilon. +END + } + attr { + name: "max_elements" + description : <<END +int; The maximum number of data points that can be fed to the stream. +END + } + summary: "Create the Resource for Quantile Streams." +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt new file mode 100644 index 0000000000..ca111af312 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt @@ -0,0 +1,40 @@ +op { + graph_op_name: "BoostedTreesMakeQuantileSummaries" + visibility: HIDDEN + in_arg { + name: "float_values" + description: <<END +float; List of Rank 2 Tensors each containing values for a single feature. +END + } + in_arg { + name: "example_weights" + description: <<END +float; Rank 1 Tensor with weights per instance. +END + } + in_arg { + name: "epsilon" + description: <<END +float; The required maximum approximation error. +END + } + out_arg { + name: "summaries" + description: <<END +float; List of Rank 2 Tensors each containing the quantile summary (value, weight, +min_rank, max_rank) of a single feature. +END + } + attr { + name: "num_features" + description: <<END +int; Inferred from the size of float_values. +The number of float features. +END + } + summary: "Makes the summary of quantiles for the batch." + description: <<END +An op that takes a list of tensors and outputs the quantile summaries for each tensor. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt new file mode 100644 index 0000000000..bbeecbf32b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt @@ -0,0 +1,22 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource handle referring to a QuantileStreamResource. +END + } + in_arg { + name: "summaries" + description: <<END +string; List of Rank 2 Tensor each containing the summaries for a single feature. +END + } + summary: "Add the quantile summaries to each quantile stream resource." + description: <<END +An op that adds a list of quantile summaries to a quantile stream resource. Each +summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank) +for a single feature. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt new file mode 100644 index 0000000000..2fd94efa10 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt @@ -0,0 +1,31 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceFlush" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource handle referring to a QuantileStreamResource. +END + } + in_arg { + name: "num_buckets", + description: <<END +int; approximate number of buckets unless using generate_quantiles. +END + } + attr { + name: "generate_quantiles" + description: <<END +bool; If True, the output will be the num_quantiles for each stream where the ith +entry is the ith quantile of the input with an approximation error of epsilon. +Duplicate values may be present. +If False, the output will be the points in the histogram that we got which roughly +translates to 1/epsilon boundaries and without any duplicates. +Default to False. +END + } + summary: "Flush the summaries for a quantile stream resource." + description: <<END +An op that flushes the summaries for a quantile stream resource. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt new file mode 100644 index 0000000000..206672802f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt @@ -0,0 +1,27 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource handle referring to a QuantileStreamResource. +END + } + out_arg { + name: "bucket_boundaries" + description: <<END +float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. +END + } + attr { + name: "num_features" + description: <<END +inferred int; number of features to get bucket boundaries for. +END + } + summary: "Generate the bucket boundaries for each feature based on accumulated summaries." + description: <<END +An op that returns a list of float tensors for a quantile stream resource. Each +tensor is Rank 1 containing bucket boundaries for a single feature. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt new file mode 100644 index 0000000000..cb7786c051 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp" + visibility: HIDDEN + summary: "Creates a handle to a BoostedTreesQuantileStreamResource." +} diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt index e39213cbc7..440800704e 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt @@ -11,7 +11,8 @@ END name: "record_defaults" description: <<END One tensor per column of the input record, with either a -scalar default value for that column or empty if the column is required. +scalar default value for that column or an empty vector if the column is +required. END } out_arg { diff --git a/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt new file mode 100644 index 0000000000..3c8a455983 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt @@ -0,0 +1,49 @@ +op { + graph_op_name: "ExtractVolumePatches" + in_arg { + name: "input" + description: <<END +5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`. +END + } + out_arg { + name: "patches" + description: <<END +5-D Tensor with shape `[batch, out_planes, out_rows, out_cols, +ksize_planes * ksize_rows * ksize_cols * depth]` containing patches +with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized +in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols` +are the dimensions of the output patches. +END + } + attr { + name: "ksizes" + description: <<END +The size of the sliding window for each dimension of `input`. +END + } + attr { + name: "strides" + description: <<END +1-D of length 5. How far the centers of two consecutive patches are in +`input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`. +END + } + attr { + name: "padding" + description: <<END +The type of padding algorithm to use. + +We specify the size-related attributes as: + +```python + ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1] + strides = [1, stride_planes, strides_rows, strides_cols, 1] +``` +END + } + summary: <<END +Extract `patches` from `input` and put them in the "depth" output +dimension. 3D extension of `extract_image_patches`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt new file mode 100644 index 0000000000..758eeb96f0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource; The reference to quantile stream resource handle. +END + } + out_arg { + name: "is_initialized" + description: <<END +bool; True if the resource is initialized, False otherwise. +END + } + summary: "Checks whether a quantile stream has been initialized." + description: <<END +An Op that checks if quantile stream resource is initialized. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt new file mode 100644 index 0000000000..5ce825ae04 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt @@ -0,0 +1,45 @@ +op { + graph_op_name: "LowerBound" + visibility: HIDDEN + in_arg { + name: "sorted_inputs" + description: <<END +2-D Tensor where each row is ordered. +END + } + in_arg { + name: "values" + description: <<END +2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains +the values that will be searched for in `sorted_search_values`. +END + } + out_arg { + name: "output" + description: <<END +A `Tensor` with the same shape as `values`. It contains the first scalar index +into the last dimension where values can be inserted without changing the +ordered property. +END + } + summary: "Applies lower_bound(sorted_search_values, values) along each row." + description: <<END +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='left')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = LowerBound(sorted_sequence, values) + + result == [[1, 2, 2], + [0, 1, 5]] +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt new file mode 100644 index 0000000000..171add16d4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt @@ -0,0 +1,14 @@ +op { + graph_op_name: "ModelDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + summary: "Identity transformation that models performance." + description: <<END +Identity transformation that models performance. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt new file mode 100644 index 0000000000..4b0a5d8f65 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt @@ -0,0 +1,43 @@ +op { + graph_op_name: "MultiDeviceIterator" + out_arg { + name: "handle" + description: <<END +Handle to the resource created. +END + } + attr { + name: "devices" + description: <<END +A list of devices the iterator works across. +END + } + attr { + name: "shared_name" + description: <<END +If non-empty, this resource will be shared under the given name +across multiple sessions. +END + } + attr { + name: "container" + description: <<END +If non-empty, this resource is placed in the given container. +Otherwise, a default container is used. +END + } + attr { + name: "output_types" + description: <<END +The type list for the return values. +END + } + attr { + name: "output_shapes" + description: <<END +The list of shapes being produced. +END + } + summary: "Creates a MultiDeviceIterator resource." + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt new file mode 100644 index 0000000000..adaacd8ab7 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "MultiDeviceIteratorFromStringHandle" + in_arg { + name: "string_handle" + description: <<END +String representing the resource. +END + } + out_arg { + name: "multi_device_iterator" + description: <<END +A MultiDeviceIterator resource. +END + } + attr { + name: "output_types" + description: <<END +The type list for the return values. +END + } + attr { + name: "output_shapes" + description: <<END +The list of shapes being produced. +END + } + summary: "Generates a MultiDeviceIterator resource from its provided string handle." + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt new file mode 100644 index 0000000000..f9be9188cc --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt @@ -0,0 +1,41 @@ +op { + graph_op_name: "MultiDeviceIteratorGetNextFromShard" + in_arg { + name: "multi_device_iterator" + description: <<END +A MultiDeviceIterator resource. +END + } + in_arg { + name: "shard_num" + description: <<END +Integer representing which shard to fetch data for. +END + } + in_arg { + name: "incarnation_id" + description: <<END +Which incarnation of the MultiDeviceIterator is running. +END + } + out_arg { + name: "components" + description: <<END +Result of the get_next on the dataset. +END + } + attr { + name: "output_types" + description: <<END +The type list for the return values. +END + } + attr { + name: "output_shapes" + description: <<END +The list of shapes being produced. +END + } + summary: "Gets next element for the provided shard number." + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt new file mode 100644 index 0000000000..6b54fa1307 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt @@ -0,0 +1,30 @@ +op { + graph_op_name: "MultiDeviceIteratorInit" + in_arg { + name: "dataset" + description: <<END +Dataset to be iterated upon. +END + } + in_arg { + name: "multi_device_iterator" + description: <<END +A MultiDeviceIteratorResource. +END + } + in_arg { + name: "max_buffer_size" + description: <<END +The maximum size of the host side per device buffer to keep. +END + } + out_arg { + name: "incarnation_id" + description: <<END +An int64 indicating which incarnation of the MultiDeviceIterator +is running. +END + } + summary: "Initializes the multi device iterator with the given dataset." + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt new file mode 100644 index 0000000000..1f1fdf99b4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt @@ -0,0 +1,17 @@ +op { + graph_op_name: "MultiDeviceIteratorToStringHandle" + in_arg { + name: "multi_device_iterator" + description: <<END +A MultiDeviceIterator resource. +END + } + out_arg { + name: "string_handle" + description: <<END +A string representing the resource. +END + } + summary: "Produces a string handle for the given MultiDeviceIterator." + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt new file mode 100644 index 0000000000..27bc4013c3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "ParallelInterleaveDatasetV2" + visibility: HIDDEN + attr { + name: "f" + description: <<END +A function mapping elements of `input_dataset`, concatenated with +`other_arguments`, to a Dataset variant that contains elements matching +`output_types` and `output_shapes`. +END + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt new file mode 100644 index 0000000000..4cb8955dcb --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt @@ -0,0 +1,19 @@ +op { + graph_op_name: "PrintV2" + in_arg { + name: "input" + description: <<END +The string scalar to print. +END + } + attr { + name: "output_stream" + description: <<END +A string specifying the output stream or logging level to print to. +END + } + summary: "Prints a string scalar." + description: <<END +Prints a string scalar to the desired output_stream. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt index 8cef243aee..30fd97a0d7 100644 --- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt @@ -9,7 +9,7 @@ END in_arg { name: "pattern" description: <<END -A 1-D string tensor of the regular expression to match the input. +A scalar string tensor containing the regular expression to match the input. END } out_arg { diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt index 35f55fe106..d33a36ce06 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt index 70a07d9b4c..afdc39da96 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt index b2e3eece38..026b5b3991 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt index 7bac02e23d..a168eed87f 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt index a73306a892..876b860824 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt new file mode 100644 index 0000000000..6d9d9908ca --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "StaticRegexFullMatch" + in_arg { + name: "input" + description: <<END +A string tensor of the text to be processed. +END + } + out_arg { + name: "output" + description: <<END +A bool tensor with the same shape as `input`. +END + } + attr { + name: "pattern" + description: "The regular expression to match the input." + } + summary: "Check if the input matches the regex pattern." + description: <<END +The input is a string tensor of any shape. The pattern is the +regular expression to be matched with every element of the input tensor. +The boolean values (True or False) of the output tensor indicate +if the input matches the regex pattern provided. + +The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt new file mode 100644 index 0000000000..a82dae9e48 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt @@ -0,0 +1,38 @@ +op { + graph_op_name: "StringFormat" + in_arg { + name: "inputs" + description: <<END +The list of tensors to format into the placeholder string. +END + } + + out_arg { + name: "output" + description: <<END += The resulting string scalar. +END + } + attr { + name: "template" + description: <<END +A string, the template to format tensor summaries into. +END + } + attr { + name: "placeholder" + description: <<END +A string, at each placeholder in the template a subsequent tensor summary will be inserted. +END + } + attr { + name: "summarize" + description: <<END +When formatting the tensor summaries print the first and last summarize entries of each tensor dimension. +END + } + summary: "Formats a string template using a list of tensors." + description: <<END +Formats a string template using a list of tensors, pretty-printing tensor summaries. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt index cc21ddc815..7d2fbcd00b 100644 --- a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt @@ -1,5 +1,15 @@ op { graph_op_name: "StringLength" + attr { + name: "unit" + description: <<END +The unit that is counted to compute string length. One of: `"BYTE"` (for +the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8 +encoded Unicode code points in each string). Results are undefined +if `unit=UTF8_CHAR` and the `input` strings do not contain structurally +valid UTF-8. +END + } in_arg { name: "input" description: <<END diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt index 8fc1e5cba3..5246090ab3 100644 --- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt @@ -32,8 +32,10 @@ For each string in the input `Tensor`, creates a substring starting at index If `len` defines a substring that would extend beyond the length of the input string, then as many characters as possible are used. -If `pos` is negative or specifies a character index larger than any of the input -strings, then an `InvalidArgumentError` is thrown. +A negative `pos` indicates distance within the string backwards from the end. + +If `pos` specifies an index which is out of range for any of the input strings, +then an `InvalidArgumentError` is thrown. `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on Op creation. diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt index 907c6d2022..7a60e4387a 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt @@ -3,15 +3,14 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. -END +A tensor whose shape is a prefix of `data.shape`.END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the maximum along segments of a tensor." @@ -24,13 +23,16 @@ This operator is similar to the unsorted segment sum operator found [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). Instead of computing the sum over segments, it computes the maximum such that: -\\(output_i = \max_j data_j\\) where max is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such +that `segment_ids[j...] == i`. If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for the specific numeric type, `output[i] = numeric_limits<T>::lowest()`. +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. + <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt> </div> diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt index 37dd973b23..7e139ddf4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. +A tensor whose shape is a prefix of `data.shape`. END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the minimum along segments of a tensor." @@ -24,11 +24,14 @@ This operator is similar to the unsorted segment sum operator found [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). Instead of computing the sum over segments, it computes the minimum such that: -\\(output_i = \min_j data_j\\) where min is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such +that `segment_ids[j...] == i`. If the minimum is empty for a given segment ID `i`, it outputs the largest possible value for the specific numeric type, `output[i] = numeric_limits<T>::max()`. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt index efbc023705..9c8ea3b620 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. +A tensor whose shape is a prefix of `data.shape`. END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the product along segments of a tensor." @@ -25,9 +25,12 @@ This operator is similar to the unsorted segment sum operator found Instead of computing the sum over segments, it computes the product of all entries belonging to a segment such that: -\\(output_i = \prod_j data_j\\) where the product is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples +`j...` such that `segment_ids[j...] == i`. If there is no entry for a given segment ID `i`, it outputs 1. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt index a8874950eb..7e5d9265c2 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt @@ -21,7 +21,7 @@ Read for an explanation of segments. Computes a tensor such that -\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such +\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` need not be sorted and need not cover all values in the full range of valid values. diff --git a/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt new file mode 100644 index 0000000000..0630f6e82a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt @@ -0,0 +1,45 @@ +op { + graph_op_name: "UpperBound" + visibility: HIDDEN + in_arg { + name: "sorted_inputs" + description: <<END +2-D Tensor where each row is ordered. +END + } + in_arg { + name: "values" + description: <<END +2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains +the values that will be searched for in `sorted_search_values`. +END + } + out_arg { + name: "output" + description: <<END +A `Tensor` with the same shape as `values`. It contains the last scalar index +into the last dimension where values can be inserted without changing the +ordered property. +END + } + summary: "Applies upper_bound(sorted_search_values, values) along each row." + description: <<END +Each set of rows with the same index in (sorted_inputs, values) is treated +independently. The resulting row is the equivalent of calling +`np.searchsorted(sorted_inputs, values, side='right')`. + +The result is not a global index to the entire +`Tensor`, but rather just the index in the last dimension. + +A 2-D example: + sorted_sequence = [[0, 3, 9, 9, 10], + [1, 2, 3, 4, 5]] + values = [[2, 4, 9], + [0, 2, 6]] + + result = UpperBound(sorted_sequence, values) + + result == [[1, 2, 4], + [0, 2, 5]] +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt index 1bc3660479..01387b7527 100644 --- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt @@ -2,10 +2,31 @@ op { visibility: HIDDEN graph_op_name: "WindowDataset" in_arg { - name: "window_size" + name: "size" description: <<END A scalar representing the number of elements to accumulate in a window. END } + in_arg { + name: "shift" + description: <<END +A scalar representing the steps moving the sliding window forward in one +iteration. It must be positive. +END + } + in_arg { + name: "stride" + description: <<END +A scalar representing the stride of the input elements of the sliding window. +It must be positive. +END + } + in_arg { + name: "drop_remainder" + description: <<END +A scalar representing whether a window should be dropped in case its size is +smaller than desired. +END + } summary: "A dataset that creates window datasets from the input dataset." } diff --git a/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt new file mode 100644 index 0000000000..ca107abc6b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Xdivy" + summary: "Returns 0 if x == 0, and x / y otherwise, elementwise." +} diff --git a/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt new file mode 100644 index 0000000000..da625f7836 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Xlogy" + summary: "Returns 0 if x == 0, and x * log(y) otherwise, elementwise." +} diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt index 9552fc92e3..e395e333bf 100644 --- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "BatchToSpaceND" endpoint { - name: "manip.batch_to_space_nd" + name: "batch_to_space_nd" } endpoint { - name: "batch_to_space_nd" + name: "manip.batch_to_space_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt index 71257c8855..598f23bde3 100644 --- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "GatherNd" endpoint { - name: "manip.gather_nd" + name: "gather_nd" } endpoint { - name: "gather_nd" + name: "manip.gather_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt new file mode 100644 index 0000000000..e22d980424 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "PrintV2" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt index c469665b66..b3d596de7a 100644 --- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "Reshape" endpoint { - name: "manip.reshape" + name: "reshape" } endpoint { - name: "reshape" + name: "manip.reshape" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt index 77f595927b..51478b7c34 100644 --- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "ReverseV2" endpoint { - name: "manip.reverse" + name: "reverse" } endpoint { - name: "reverse" + name: "manip.reverse" deprecated: true } endpoint { diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt index a65a19b542..85888da45a 100644 --- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "ScatterNd" endpoint { - name: "manip.scatter_nd" + name: "scatter_nd" } endpoint { - name: "scatter_nd" + name: "manip.scatter_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt index af323a6cf3..146b97f444 100644 --- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "SpaceToBatchND" endpoint { - name: "manip.space_to_batch_nd" + name: "space_to_batch_nd" } endpoint { - name: "space_to_batch_nd" + name: "manip.space_to_batch_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt new file mode 100644 index 0000000000..8f0b1db45d --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StringFormat" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt index 01c02e1f70..df012414e3 100644 --- a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt @@ -1,6 +1,4 @@ op { graph_op_name: "StringLength" - endpoint { - name: "strings.length" - } + visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt index c34061c941..1d8695f1fd 100644 --- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "Tile" endpoint { - name: "manip.tile" + name: "tile" } endpoint { - name: "tile" + name: "manip.tile" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt new file mode 100644 index 0000000000..984442ba2b --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt @@ -0,0 +1,6 @@ +op { + graph_op_name: "Xdivy" + endpoint { + name: "math.xdivy" + } +} diff --git a/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt new file mode 100644 index 0000000000..b4a5299256 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt @@ -0,0 +1,6 @@ +op { + graph_op_name: "Xlogy" + endpoint { + name: "math.xlogy" + } +} diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 3bf0532491..3843ea9e60 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -31,7 +31,7 @@ namespace tensorflow { BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, bool allow_growth, const string& name) - : suballocator_(sub_allocator), + : sub_allocator_(sub_allocator), name_(name), free_chunks_list_(kInvalidChunkHandle), next_allocation_id_(1) { @@ -72,7 +72,7 @@ BFCAllocator::~BFCAllocator() { VLOG(2) << "Number of regions allocated: " << region_manager_.regions().size(); for (const auto& region : region_manager_.regions()) { - suballocator_->Free(region.ptr(), region.memory_size()); + sub_allocator_->Free(region.ptr(), region.memory_size()); } for (BinNum b = 0; b < kNumBins; b++) { @@ -108,7 +108,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { // Try allocating. size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes); - void* mem_addr = suballocator_->Alloc(alignment, bytes); + void* mem_addr = sub_allocator_->Alloc(alignment, bytes); if (mem_addr == nullptr && !started_backpedal_) { // Only backpedal once. started_backpedal_ = true; @@ -119,7 +119,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { while (mem_addr == nullptr) { bytes = RoundedBytes(bytes * kBackpedalFactor); if (bytes < rounded_bytes) break; - mem_addr = suballocator_->Alloc(alignment, bytes); + mem_addr = sub_allocator_->Alloc(alignment, bytes); } } @@ -158,10 +158,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { // Insert the chunk into the right bin. InsertFreeChunkIntoBin(h); - // Invoke visitors on newly allocated region. - for (const auto& visitor : region_visitors_) { - visitor(mem_addr, bytes); - } return true; } @@ -490,15 +486,6 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) { InsertFreeChunkIntoBin(coalesced_chunk); } -void BFCAllocator::AddAllocVisitor(Visitor visitor) { - VLOG(1) << "AddVisitor"; - mutex_lock l(lock_); - region_visitors_.push_back(visitor); - for (const auto& region : region_manager_.regions()) { - visitor(region.ptr(), region.memory_size()); - } -} - bool BFCAllocator::TracksAllocationSizes() { return true; } size_t BFCAllocator::RequestedSize(const void* ptr) { @@ -596,7 +583,7 @@ string BFCAllocator::RenderOccupancy() { region_offset += region.memory_size(); } - return std::string(rendered, resolution); + return string(rendered, resolution); } void BFCAllocator::DumpMemoryLog(size_t num_bytes) { diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index 20e1dab1d5..2d74bf2b28 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -23,7 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/allocator_retry.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" @@ -42,7 +42,7 @@ namespace tensorflow { // coalescing. One assumption we make is that the process using this // allocator owns pretty much all of the memory, and that nearly // all requests to allocate memory go through this interface. -class BFCAllocator : public VisitableAllocator { +class BFCAllocator : public Allocator { public: // Takes ownership of sub_allocator. BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, @@ -55,11 +55,6 @@ class BFCAllocator : public VisitableAllocator { const AllocationAttributes& allocation_attr) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - - // Does nothing, because memory is never freed. - void AddFreeVisitor(Visitor visitor) override {} - bool TracksAllocationSizes() override; size_t RequestedSize(const void* ptr) override; @@ -309,7 +304,7 @@ class BFCAllocator : public VisitableAllocator { }; // Returns 'bytes' rounded up to the next highest kMinAllocationSize. - size_t RoundedBytes(size_t bytes); + static size_t RoundedBytes(size_t bytes); // Try to add a new memory region that can satisfy an allocation of // 'rounded_bytes' bytes. Returns true on success and false on @@ -423,7 +418,7 @@ class BFCAllocator : public VisitableAllocator { // of the available memory. bool started_backpedal_ = false; - std::unique_ptr<SubAllocator> suballocator_; + std::unique_ptr<SubAllocator> sub_allocator_; string name_; // Structures mutable after construction @@ -435,9 +430,6 @@ class BFCAllocator : public VisitableAllocator { // Pointer to head of linked list of free Chunks ChunkHandle free_chunks_list_ GUARDED_BY(lock_); - // Called once on each region, ASAP. - std::vector<Visitor> region_visitors_ GUARDED_BY(lock_); - // Counter containing the next unique identifier to assign to a // newly-created chunk. int64 next_allocation_id_ GUARDED_BY(lock_); diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 97b6971c5b..419867ff58 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -61,6 +61,7 @@ bool ReadPartialShapesFromShapeMap( shape_map, std::vector<PartialTensorShape>* input_shapes) { CHECK(shape_map != nullptr); + input_shapes->resize(n->num_inputs()); for (const Edge* in : n->in_edges()) { // Don't need to check if incoming control edges have known shapes. if (in->IsControlEdge()) continue; @@ -71,7 +72,9 @@ bool ReadPartialShapesFromShapeMap( } const auto& known_shape = known_shape_iter->second; CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first; - input_shapes->push_back(known_shape[in->src_output()]); + DCHECK_GE(in->dst_input(), 0); + DCHECK_LT(in->dst_input(), input_shapes->size()); + (*input_shapes)[in->dst_input()] = known_shape[in->src_output()]; } return true; } @@ -467,19 +470,19 @@ bool ReplaceTensorWithConstant( const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. - // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY + // 1) Do not replace another constant. + // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // constraint, do not replace it. - // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY - // constraint, do not replace it. - // 3) If the constant op created does not have a kernel implementation - // for the device, do not use it. - // 4) If the size of the constant in bytes is too large (> + // 3) If the size of the constant in bytes is too large (> // max_constant_in_bytes), do not replace it. This prevents the size of the // Graph from growing too large. + // 4) If the constant op created does not have a kernel implementation + // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. - // 5) Do not replace another constant. + // 5) If the constant op for the device has different output memory type + // from the original op output memory type, do not replace it. if (tensor.first->IsConstant()) { return false; } @@ -494,8 +497,7 @@ bool ReplaceTensorWithConstant( return false; } bool is_int32 = tensor.first->output_type(tensor.second) == DT_INT32; - if ((memory_type == HOST_MEMORY && !is_int32) || - (memory_type == DEVICE_MEMORY && is_int32)) { + if (memory_type == HOST_MEMORY && !is_int32) { return false; } } @@ -533,6 +535,23 @@ bool ReplaceTensorWithConstant( if (!NodeBuilder(builder).Finalize(graph, &constant_node).ok()) { return false; } + if (partition_device && device_type != DEVICE_CPU) { + MemoryType original_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, tensor.first, tensor.second, + &original_output_memory_type) + .ok()) { + return false; + } + MemoryType const_output_memory_type; + if (!MemoryTypeForOutput(device_type, graph, constant_node, 0, + &const_output_memory_type) + .ok()) { + return false; + } + if (original_output_memory_type != const_output_memory_type) { + return false; + } + } for (auto edge : edges_to_remove) { graph->AddEdge(constant_node, 0, edge->dst(), edge->dst_input()); graph->RemoveEdge(edge); diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index f8cb854b52..d800a86199 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -347,7 +347,12 @@ namespace { static Status WrappedTensorDeviceCopy( const Tensor& from, Tensor* to, const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) { - if (DMAHelper::CanUseDMA(&from)) { + if (from.dtype() == DT_VARIANT) { + // TODO(b/116349787): Implement support for nested variants. + return errors::Unimplemented( + "Support for copying nested variants to device has not yet been " + "implemented."); + } else if (DMAHelper::CanUseDMA(&from)) { TF_RETURN_IF_ERROR(copy(from, to)); } else { *to = from; @@ -358,7 +363,7 @@ static Status WrappedTensorDeviceCopy( #define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy) + Tensor, DIRECTION, WrappedTensorDeviceCopy) REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h index 81d68e3be4..2ef1547cd9 100644 --- a/tensorflow/core/common_runtime/device.h +++ b/tensorflow/core/common_runtime/device.h @@ -101,11 +101,21 @@ class Device : public DeviceBase { } } + // If true, and tracing is enabled, the `tracing::ScopedAnnotation()` tracing + // mechanism will be used instead of `tracing::ScopedActivity()`. Some devices + // may override this method to use annotations, which enable child activities + // (such as GPU kernel launches) to be related to the OpKernel invocation. + virtual bool TraceUsingAnnotations() const { return false; } + // Blocks until all operations queued on the device at the time of // the call have completed. Returns any error pending on the device // at completion. virtual Status Sync() = 0; + // Override this to return true for devices that require a Sync() call before + // session completion. + virtual bool RequiresSyncOnCompletion() const { return false; } + // Optionally modify the device's GraphDef before execution. // // This method should be considered experimental and is supplied to enable diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index eb388202fa..af5d5b17e7 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -1202,14 +1202,11 @@ Status DirectSession::CreateExecutors( auto opseg = device->op_segment(); params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { - // We do not share the kernel via the OpSegment if the node is - // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. - if (!lib->IsStateful(ndef.op()) || - lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { + if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { @@ -1222,13 +1219,11 @@ Status DirectSession::CreateExecutors( create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { - // If the node is stateful, opseg owns it. Otherwise, delete it. - if (kernel && !lib->IsStateful(kernel->type_string())) { + if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; - } }; - optimizer.Optimize(lib, options_.env, device, &iter->second, + optimizer.Optimize(lib, options_.env, device, &partition_graph, /*shape_map=*/nullptr); // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 3f2355e530..65e816c202 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -1255,7 +1255,7 @@ TEST(DirectSessionTest, RunHandleTest) { ASSERT_TRUE(s.ok()); ASSERT_EQ(1, outputs.size()); - ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()(); + const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()(); Tensor string_handle(DT_STRING, {}); string_handle.flat<string>().setConstant(resource_handle.name()); @@ -1308,7 +1308,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) { ASSERT_TRUE(s.ok()); ASSERT_EQ(1, outputs.size()); - ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()(); + const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()(); Tensor string_handle(DT_STRING, {}); string_handle.flat<string>().setConstant(resource_handle.name()); diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 0b096a14a3..2c63b8704e 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { options.config.mutable_graph_options() ->mutable_rewrite_options() ->set_min_graph_nodes(-1); + options.config.mutable_graph_options() + ->mutable_rewrite_options() + ->set_pin_to_host_optimization(RewriterConfig::OFF); std::unique_ptr<Session> session(NewSession(options)); TF_ASSERT_OK(session->Create(def)); std::vector<std::pair<string, Tensor>> inputs; @@ -105,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); if (node->name() == y->name()) { -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) // if MKL is used, it goes through various additional // graph rewrite pass. In TF, everytime a graph pass // happens, "constant" nodes are allocated @@ -114,16 +117,16 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { // which increments the value of AllocationId. // Thus AllocationId becomes more than TF if MKL // is used. Now IDs for MKL are 8 more than TF. - EXPECT_EQ(29, cm->AllocationId(node, 0)); -#else EXPECT_EQ(21, cm->AllocationId(node, 0)); -#endif - } else { -#ifdef INTEL_MKL - EXPECT_EQ(30, cm->AllocationId(node, 0)); #else + EXPECT_EQ(13, cm->AllocationId(node, 0)); +#endif // INTEL_MKL && ENABLE_MKL + } else { +#if defined(INTEL_MKL) && defined(ENABLE_MKL) EXPECT_EQ(22, cm->AllocationId(node, 0)); -#endif +#else + EXPECT_EQ(14, cm->AllocationId(node, 0)); +#endif // INTEL_MKL && ENABLE_MKL } } EXPECT_LE(0, cm->MaxExecutionTime(node)); diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 39a3b49cd1..18420b60fd 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -32,34 +32,55 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { return default_val; } +std::unique_ptr<thread::ThreadPool> EagerThreadPool( + const SessionOptions& opts) { + SessionOptions opts_copy(opts); + if (opts_copy.config.inter_op_parallelism_threads() == 0) { + // Eager defaults to a single thread when no threads are specified. + opts_copy.config.set_inter_op_parallelism_threads(1); + } + + return std::unique_ptr<thread::ThreadPool>( + NewThreadPoolFromSessionOptions(opts_copy)); +} + } // namespace EagerContext::EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, - bool async, std::unique_ptr<DeviceMgr> device_mgr, + bool async, + std::unique_ptr<const DeviceMgr> device_mgr, Rendezvous* rendezvous) + : EagerContext(opts, default_policy, async, device_mgr.release(), + /*device_mgr_owned*/ true, rendezvous) {} + +EagerContext::EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, + bool async, const DeviceMgr* device_mgr, + bool device_mgr_owned, Rendezvous* rendezvous) : policy_(default_policy), - local_device_manager_(std::move(device_mgr)), - local_unowned_device_manager_(nullptr), - devices_(local_device_manager_->ListDevices()), + devices_(device_mgr->ListDevices()), rendezvous_(rendezvous), - thread_pool_(NewThreadPoolFromSessionOptions(opts)), + thread_pool_(EagerThreadPool(opts)), pflr_(new ProcessFunctionLibraryRuntime( - local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, - &func_lib_def_, {}, thread_pool_.get())), + device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, + thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), num_active_steps_(0), async_default_(async), + log_memory_(LogMemory::IsEnabled()), env_(opts.env), use_send_tensor_rpc_(false) { - InitDeviceMapAndAsync(); - if (opts.config.inter_op_parallelism_threads() > 0) { - runner_ = [this](std::function<void()> closure) { - this->thread_pool_->Schedule(closure); - }; + if (device_mgr_owned) { + local_device_manager_.reset(device_mgr); + local_unowned_device_manager_ = nullptr; } else { - runner_ = [](std::function<void()> closure) { closure(); }; + local_unowned_device_manager_ = device_mgr; } + InitDeviceMapAndAsync(); + runner_ = [this](std::function<void()> closure) { + this->thread_pool_->Schedule(std::move(closure)); + }; } void EagerContext::InitDeviceMapAndAsync() { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3c95ac590d..5ed6057ec6 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #endif +#include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -65,10 +66,17 @@ enum ContextDevicePlacementPolicy { class EagerContext { public: - explicit EagerContext(const SessionOptions& opts, - ContextDevicePlacementPolicy default_policy, bool async, - std::unique_ptr<DeviceMgr> device_mgr, - Rendezvous* rendezvous); + // TODO: remove this constructor once we migrate all callers to the next one. + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + std::unique_ptr<const DeviceMgr> device_mgr, + Rendezvous* rendezvous); + + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + const DeviceMgr* device_mgr, bool device_mgr_owned, + Rendezvous* rendezvous); + ~EagerContext(); // Returns the function library runtime for the given device. @@ -134,6 +142,7 @@ class EagerContext { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() { return log_device_placement_; } + bool LogMemory() { return log_memory_; } Rendezvous* GetRendezvous() { return rendezvous_; } @@ -207,8 +216,8 @@ class EagerContext { thread_local_policies_ GUARDED_BY(policy_map_mu_); // Only one of the below is set. - std::unique_ptr<DeviceMgr> local_device_manager_; - DeviceMgr* local_unowned_device_manager_; + std::unique_ptr<const DeviceMgr> local_device_manager_; + const DeviceMgr* local_unowned_device_manager_; std::unique_ptr<DeviceMgr> remote_device_manager_; // Devices owned by device_manager @@ -254,6 +263,8 @@ class EagerContext { std::unordered_map<std::thread::id, bool> thread_local_async_ GUARDED_BY(async_map_mu_); + const bool log_memory_; + Env* const env_; #ifndef __ANDROID__ diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 5b3a64ba98..1bc63616d0 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -251,26 +251,6 @@ Status EagerLocalExecute(EagerOperation* op, EagerContext* ctx = op->EagerContext(); auto status = ctx->GetStatus(); if (!status.ok()) return status; - // Ensure all resource-touching ops run in the device the resource is, - // regardless of anything else that has been specified. This is identical to - // the graph mode behavior. - for (int i = 0; i < op->Inputs().size(); ++i) { - Device* input_op_device = nullptr; - status = op->Inputs()[i]->OpDevice(&input_op_device); - if (!status.ok()) return status; - VLOG(2) << "for op " << op->Name() << " input " << i << " " - << DataTypeString(op->Inputs()[i]->dtype) << " " - << (input_op_device == nullptr ? "cpu" : input_op_device->name()) - << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name()); - if (op->Inputs()[i]->dtype == DT_RESOURCE && - (input_op_device != op->Device() || input_op_device == nullptr)) { - Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device; - VLOG(1) << "Changing device of operation " << op->Name() << " to " - << d->name() << " because input #" << i - << " is a resource in this device."; - op->SetDevice(d); - } - } Device* device = op->Device(); Fprint128 cache_key = op->MutableAttrs()->CacheKey( @@ -296,7 +276,7 @@ Status EagerLocalExecute(EagerOperation* op, LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new KernelAndDevice(ctx->GetRendezvous()); + kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); auto* flr = ctx->func_lib(device); if (flr == nullptr) { @@ -604,6 +584,27 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, Status EagerExecute(EagerOperation* op, gtl::InlinedVector<TensorHandle*, 2>* retvals, int* num_retvals) { + // Ensure all resource-touching ops run in the device the resource is, + // regardless of anything else that has been specified. This is identical to + // the graph mode behavior. + EagerContext* ctx = op->EagerContext(); + for (int i = 0; i < op->Inputs().size(); ++i) { + Device* input_op_device = nullptr; + auto status = op->Inputs()[i]->OpDevice(&input_op_device); + if (!status.ok()) return status; + VLOG(2) << "for op " << op->Name() << " input " << i << " " + << DataTypeString(op->Inputs()[i]->dtype) << " " + << (input_op_device == nullptr ? "cpu" : input_op_device->name()) + << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name()); + if (op->Inputs()[i]->dtype == DT_RESOURCE && + (input_op_device != op->Device() || input_op_device == nullptr)) { + Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device; + VLOG(1) << "Changing device of operation " << op->Name() << " to " + << d->name() << " because input #" << i + << " is a resource in this device."; + op->SetDevice(d); + } + } bool op_is_local = IsLocal(op->EagerContext(), op->Device()); if (op_is_local) { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 3d61ff4dc2..83d8425477 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -32,21 +32,6 @@ limitations under the License. namespace tensorflow { // static -Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = CreateOpKernel(device->device_type().c_str(), device, - device->GetAllocator(AllocatorAttributes()), - nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); - out->device_ = device; - out->kernel_.reset(k); - out->flib_ = nullptr; - out->runner_ = nullptr; - out->default_runner_ = [](std::function<void()> f) { f(); }; - return s; -} - -// static Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, std::function<void(std::function<void()>)>* runner, KernelAndDevice* out) { @@ -95,6 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; params.cancellation_manager = &cm_; + params.log_memory = log_memory_; if (stats != nullptr) { params.track_allocations = true; } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 0ef419cbaa..04151a1171 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -52,12 +52,12 @@ class KernelAndDevice { static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, std::function<void(std::function<void()>)>* runner, KernelAndDevice* out); - // TODO(ashankar): Remove this - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - KernelAndDevice(tensorflow::Rendezvous* rendez) - : device_(nullptr), flib_(nullptr), rendez_(rendez) {} + KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory) + : device_(nullptr), + flib_(nullptr), + rendez_(rendez), + log_memory_(log_memory) {} // TODO(ashankar): Handle list-valued inputs. Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, @@ -87,6 +87,7 @@ class KernelAndDevice { DataTypeVector output_dtypes_; std::function<void(std::function<void()>)>* runner_; std::function<void(std::function<void()>)> default_runner_; + const bool log_memory_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index 6abe98f53c..da280b2317 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) { .NumInputs(2) .BuildNodeDef()); TestEnv env; - KernelAndDevice k(nullptr); + KernelAndDevice k(nullptr, false); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), @@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) { .NumInputs(inputs.size()) .BuildNodeDef()); TestEnv env; - KernelAndDevice kernel(nullptr); + KernelAndDevice kernel(nullptr, false); TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), nullptr, &kernel)); tensorflow::testing::StartTiming(); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index b912f7d37b..d58724cbfa 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); - CHECK(remote_shape_ != nullptr); *num_dims = remote_shape_->dims(); } else { TF_RETURN_IF_ERROR(WaitReady()); @@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) { return Status::OK(); } +Status TensorHandle::NumElements(int64* num_elements) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + *num_elements = remote_shape_->num_elements(); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(num_elements != nullptr); + + *num_elements = tensor_.NumElements(); + } + + return Status::OK(); +} + Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) { if (!IsRemote()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 1bc9c6531a..e55f1a0338 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted { Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); + Status NumElements(int64* num_elements); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(int64* op_id, int32* output_num); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 84865397bc..2c48084cab 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -76,56 +76,47 @@ bool IsInitializationOp(const Node* node) { namespace nodestats { inline int64 NowInNsec() { return Env::Default()->NowNanos(); } -void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) { +void SetScheduled(NodeExecStatsInterface* stats, int64 micros) { if (!stats) return; stats->SetScheduled(micros * EnvTime::kMicrosToNanos); } -void SetAllStart(NodeExecStatsWrapper* stats) { +void SetAllStart(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordExecutorStarted(); } -void SetOpStart(NodeExecStatsWrapper* stats) { +void SetOpStart(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordComputeStarted(); } -void SetOpEnd(NodeExecStatsWrapper* stats) { +void SetOpEnd(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordComputeEnded(); } -void SetAllEnd(NodeExecStatsWrapper* stats) { +void SetAllEnd(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordExecutorEnded(); } -void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) { +void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) { if (!stats) return; stats->SetOutput(slot, v); } -void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) { +void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { if (!stats) return; stats->SetMemory(ctx); } -void SetReferencedTensors(NodeExecStatsWrapper* stats, +void SetReferencedTensors(NodeExecStatsInterface* stats, const TensorReferenceVector& tensors) { if (!stats) return; stats->SetReferencedTensors(tensors); } -// Sets the timeline_label field of *stats, using data from *node. -// Returns true iff the node is a transfer node. -bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) { - if (!stats) { - return false; - } - return stats->SetTimelineLabel(node); -} - } // namespace nodestats class ExecutorImpl; @@ -152,6 +143,8 @@ struct NodeItem { bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr bool is_merge : 1; // True iff IsMerge(node) bool is_enter : 1; // True iff IsEnter(node) + bool is_constant_enter : 1; // True iff IsEnter(node) and + // node->GetAttr("is_constant") == true. bool is_exit : 1; // True iff IsExit(node) bool is_control_trigger : 1; // True iff IsControlTrigger(node) bool is_sink : 1; // True iff IsSink(node) @@ -635,6 +628,14 @@ Status ExecutorImpl::Initialize() { item->kernel_is_async = (item->kernel->AsAsync() != nullptr); item->is_merge = IsMerge(n); item->is_enter = IsEnter(n); + if (item->is_enter) { + bool is_constant_enter; + TF_RETURN_IF_ERROR( + GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter)); + item->is_constant_enter = is_constant_enter; + } else { + item->is_constant_enter = false; + } item->is_exit = IsExit(n); item->is_control_trigger = IsControlTrigger(n); item->is_sink = IsSink(n); @@ -1237,6 +1238,9 @@ class ExecutorState { // Step-local container. ScopedStepContainer* step_container_; StepStatsCollectorInterface* const stats_collector_; + const tracing::TraceCollector* const trace_collector_; + const tracing::EventCollector* const event_collector_; + // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_; @@ -1245,6 +1249,7 @@ class ExecutorState { CancellationManager* cancellation_manager_; Executor::Args::Runner runner_; bool sync_on_finish_; + const bool trace_using_annotations_; // Owned. @@ -1301,7 +1306,7 @@ class ExecutorState { // After item->kernel computation is done, processes its outputs. Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, - EntryVector* outputs, NodeExecStatsWrapper* stats); + EntryVector* outputs, NodeExecStatsInterface* stats); // After processing the outputs, propagates the outputs to their dsts. // Contents of *outputs are left in an indeterminate state after @@ -1312,7 +1317,7 @@ class ExecutorState { // "node" just finishes. Takes ownership of "stats". Returns true if // execution has completed. bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStatsWrapper* stats, + NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive @@ -1359,12 +1364,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) tensor_store_(args.tensor_store), step_container_(args.step_container), stats_collector_(args.stats_collector), + trace_collector_(tracing::GetTraceCollector()), + event_collector_( + tracing::GetEventCollector(tracing::EventCategory::kCompute)), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), cancellation_manager_(args.cancellation_manager), runner_(args.runner), sync_on_finish_(args.sync_on_finish), + trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()), num_outstanding_ops_(0) { // We start the entire execution in iteration 0 of the root frame // so let us create the root frame and the state for iteration 0. @@ -1513,7 +1522,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { struct ExecutorState::AsyncState { AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, const NodeItem* _item, Entry* _first_input, - NodeExecStatsWrapper* _stats) + NodeExecStatsInterface* _stats) : saved_inputs(*p.inputs), saved_input_device_contexts(*p.input_device_contexts), saved_input_alloc_attrs(*p.input_alloc_attrs), @@ -1538,7 +1547,7 @@ struct ExecutorState::AsyncState { const NodeItem* item; Entry* first_input; OpKernelContext ctx; - NodeExecStatsWrapper* stats; + NodeExecStatsInterface* stats; private: OpKernelContext::Params* ParamsButClearingEigenGPUDevice( @@ -1550,6 +1559,32 @@ struct ExecutorState::AsyncState { } }; +// Returns true if `item` might be traced by the given trace and event +// collectors. Returns false only if `item` definitely will not be traced. +bool MightTrace(const NodeItem& item, + const tracing::TraceCollector* trace_collector, + const tracing::EventCollector* event_collector, + bool using_annotations) { + // Tracing will only be enabled if either `event_collector` is non null, + // or `trace_collector` is non-null and enabled for this particular kernel. + // Although `tracing::ScopedActivity`, + // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of + // these properties internally in their constructors, the cost of passing the + // necessary arguments to them can be significant, so we avoid constructing + // them in the common case (when we know they will not be used). + if (event_collector != nullptr) { + return true; + } + if (trace_collector) { + if (using_annotations) { + return trace_collector->IsEnabledForAnnotations(); + } else { + return trace_collector->IsEnabledForActivities(item.kernel_is_expensive); + } + } + return false; +} + void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { const GraphView& gview = impl_->gview_; TaggedNodeSeq ready; @@ -1583,7 +1618,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.stats_collector = stats_collector_; Status s; - NodeExecStatsWrapper* stats = nullptr; + NodeExecStatsInterface* stats = nullptr; + EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); @@ -1613,7 +1649,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (stats_collector_ && !tagged_node.is_dead) { // track allocations if and only if we are collecting statistics params.track_allocations = true; - stats = new NodeExecStatsWrapper(node->name()); + stats = stats_collector_->CreateNodeExecStats(node); nodestats::SetScheduled(stats, scheduled_nsec); nodestats::SetAllStart(stats); } @@ -1671,7 +1707,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { auto done = [this, state]() { Device* device = impl_->params_.device; - NodeExecStatsWrapper* stats = state->stats; // Shorthand + NodeExecStatsInterface* stats = state->stats; // Shorthand Entry* first_input = state->first_input; // Shorthand nodestats::SetOpEnd(stats); @@ -1720,7 +1756,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { // Synchronous computes. OpKernelContext ctx(¶ms, item.num_outputs); nodestats::SetOpStart(stats); - device->Compute(CHECK_NOTNULL(op_kernel), &ctx); + + if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_, + event_collector_, + trace_using_annotations_))) { + const string& op_name = op_kernel->name(); + tracing::ScopedRegion region(tracing::EventCategory::kCompute, + op_name); + if (trace_using_annotations_) { + // The OpKernel may create child activities (such as GPU kernel + // launches), so use a `ScopedAnnotation` to relate these activities + // in the trace. + tracing::ScopedAnnotation activity(op_name, + op_kernel->type_string()); + device->Compute(op_kernel, &ctx); + } else { + // Use the cheaper `ScopedActivity` to trace just the OpKernel + // execution. + tracing::ScopedActivity activity(op_name, op_kernel->type_string(), + item.kernel_is_expensive); + device->Compute(op_kernel, &ctx); + } + } else { + // In the common case, avoid creating any tracing objects. + device->Compute(op_kernel, &ctx); + } + nodestats::SetOpEnd(stats); s = ProcessOutputs(item, &ctx, &outputs, stats); if (s.ok() && impl_->device_record_tensor_accesses_) { @@ -1862,7 +1923,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, - NodeExecStatsWrapper* stats) { + NodeExecStatsInterface* stats) { const Node* node = item.node; DCHECK_EQ(0, outputs->size()); outputs->resize(item.num_outputs); @@ -1997,15 +2058,12 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, is_frame_done = input_frame->DecrementOutstandingOpsLocked( &impl_->gview_, input_iter, ready); } else if (item->is_enter) { - bool is_constant; - const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant); - DCHECK(s.ok()) << s; FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame); output_iter = 0; { const NodeItem* item = impl_->gview_.node(node->id()); mutex_lock l(output_frame->mu); - if (is_constant) { + if (item->is_constant_enter) { // Propagate to all active iterations if this is a loop invariant. output_frame->AddLoopInv(item, (*outputs)[0], ready); } else { @@ -2080,16 +2138,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStatsWrapper* stats, + NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready) { nodestats::SetAllEnd(stats); - if (stats_collector_ != nullptr && - !nodestats::SetTimelineLabel(node, stats)) { - // Only record non-transfer nodes. - // Transfers 'stats' ownership to 'stats_collector_'. - stats_collector_->Save(impl_->params_.device->name(), stats); - } else if (stats) { - delete stats; + if (stats) { + if (stats_collector_) { + stats->Done(impl_->params_.device->name()); + } else { + delete stats; + } } bool abort_run = false; @@ -2311,13 +2368,15 @@ void ExecutorState::Finish() { auto done_cb = std::move(done_cb_); auto runner = std::move(runner_); mu_.unlock(); - if (sync_on_finish_ && status.ok()) { + Device* device = impl_->params_.device; + if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) { // Block until the device has finished all queued operations. For // devices like GPUs that continue to execute Ops after their Compute // methods have completed, this ensures that control is not returned to // the user until the step (and its side-effects) has actually completed. - status = impl_->params_.device->Sync(); + status.Update(device->Sync()); } + delete this; CHECK(done_cb != nullptr); runner([=]() { done_cb(status); }); diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 46bb8d92f8..472865ca43 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -414,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel( device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, graph_def_version_, &s); - *kernel = new CallOp(handle, &construction); - if (!s.ok()) { - delete *kernel; + if (s.ok()) { + *kernel = new CallOp(handle, &construction); } return s; } @@ -615,11 +614,14 @@ void PruneFunctionBody(Graph* g) { std::unordered_set<const Node*> nodes; for (auto n : g->nodes()) { // NOTE(mrry): "_Retval" nodes are stateful, and so will be added - // to the seed set of `nodes`. + // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we + // specifically exclude them as seeds, to avoid unconditionally executing + // unused argument nodes (e.g. in a function like `lambda x, y: y`). // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is // still needed. It would be preferable to prune entire loops and/or // conditionals if they are not used in the graph. - if (n->IsControlFlow() || n->op_def().is_stateful()) { + if (n->IsControlFlow() || + (n->op_def().is_stateful() && n->type_string() != kArgOp)) { nodes.insert(n); } } @@ -925,29 +927,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } DCHECK(run_opts.runner != nullptr); - Executor::Args* exec_args = new Executor::Args; + Executor::Args exec_args; // Inherit the step_id from the caller. - exec_args->step_id = run_opts.step_id; - exec_args->rendezvous = run_opts.rendezvous; - exec_args->stats_collector = run_opts.stats_collector; - exec_args->cancellation_manager = run_opts.cancellation_manager; - exec_args->collective_executor = run_opts.collective_executor; - exec_args->step_container = run_opts.step_container; - exec_args->runner = *run_opts.runner; - exec_args->call_frame = frame; - - item->exec->RunAsync( - // Executor args - *exec_args, - // Done callback. - std::bind( - [item, frame, exec_args](DoneCallback done, - // Start unbound arguments. - const Status& status) { - delete exec_args; - done(status); - }, - std::move(done), std::placeholders::_1)); + exec_args.step_id = run_opts.step_id; + exec_args.rendezvous = run_opts.rendezvous; + exec_args.stats_collector = run_opts.stats_collector; + exec_args.cancellation_manager = run_opts.cancellation_manager; + exec_args.collective_executor = run_opts.collective_executor; + exec_args.step_container = run_opts.step_container; + exec_args.runner = *run_opts.runner; + exec_args.call_frame = frame; + + item->exec->RunAsync(exec_args, std::move(done)); } bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 120f480198..7bab9be9a6 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { // Name "SquareAndAddOneWithStatefulNodes", // Args - {"x: int32"}, + {"x: int32", "y: float32"}, // Return values - {"y: int32"}, + {"z: int32"}, // Attrs {}, // Nodes @@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { "RandomUniform", {"shape"}, {{"T", T}, {"dtype", DT_FLOAT}}}, - // y = Add<T>(a, o) - {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); + // z = Add<T>(a, o) + {{"z"}, "Add", {"a", "o"}, {{"T", T}}}}); Init({stateful_func}); auto x = test::AsTensor<int32>({1, 2, 3, 4}); - Tensor y; + auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0}); + Tensor z; FunctionLibraryRuntime::Handle handle; TF_CHECK_OK( @@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { StepStatsCollector stats_collector(&stats); FunctionLibraryRuntime::Options opts; opts.stats_collector = &stats_collector; - TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y})); + TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z})); TF_CHECK_OK(flr0_->ReleaseHandle(handle)); TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {}, - {x}, {&y})); - test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17})); + {x, y}, {&z})); + test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17})); stats_collector.FinalizeAndSwap(&stats); - // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute. + // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to + // execute. std::set<string> expected_node_names( - {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"}); + {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"}); std::set<string> executed_node_names; for (const auto& node_stats : stats.dev_stats()[0].node_stats()) { executed_node_names.insert(node_stats.node_name()); diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h index 636cd43575..6bd29ef775 100644 --- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h +++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h @@ -26,8 +26,12 @@ namespace tensorflow { class CUDAHostAllocator : public SubAllocator { public: // Note: stream_exec cannot be null. - explicit CUDAHostAllocator(se::StreamExecutor* stream_exec) - : stream_exec_(stream_exec) { + explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node, + const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), + stream_exec_(stream_exec), + numa_node_(numa_node) { CHECK(stream_exec_ != nullptr); } ~CUDAHostAllocator() override {} @@ -39,19 +43,23 @@ class CUDAHostAllocator : public SubAllocator { if (ptr == nullptr) { LOG(WARNING) << "could not allocate pinned host memory of size: " << num_bytes; + return ptr; } + VisitAlloc(ptr, numa_node_, num_bytes); } return ptr; } void Free(void* ptr, size_t num_bytes) override { if (ptr != nullptr) { + VisitFree(ptr, numa_node_, num_bytes); stream_exec_->HostMemoryDeallocate(ptr); } } private: se::StreamExecutor* stream_exec_; // not owned, non-null + const int numa_node_; TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator); }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc index 2d4c8d0201..42021e51f3 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc @@ -22,18 +22,48 @@ limitations under the License. namespace tensorflow { -GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory, - const string& name) - : GPUBFCAllocator(cuda_gpu_id, total_memory, GPUOptions(), name) {} +bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) { + const char* force_allow_growth_string = + std::getenv("TF_FORCE_GPU_ALLOW_GROWTH"); + if (force_allow_growth_string == nullptr) { + return gpu_options.allow_growth(); + } + + if (strcmp("false", force_allow_growth_string) == 0) { + if (gpu_options.allow_growth()) { + LOG(WARNING) + << "Overriding allow_growth setting because the" + << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original" + << " config value was " << gpu_options.allow_growth() << "."; + } + return false; + } else if (strcmp("true", force_allow_growth_string) == 0) { + if (!gpu_options.allow_growth()) { + LOG(WARNING) + << "Overriding allow_growth setting because the" + << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original" + << " config value was " << gpu_options.allow_growth() << "."; + } + return true; + } + + LOG(ERROR) + << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could" + << " not be parsed: \"" << force_allow_growth_string << "\". Valid" + << " values are \"true\" or \"false\". Using original config value" + << " of " << gpu_options.allow_growth() << "."; + return gpu_options.allow_growth(); +} + +GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator, + size_t total_memory, const string& name) + : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {} -GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory, +GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator, + size_t total_memory, const GPUOptions& gpu_options, const string& name) - : BFCAllocator( - new GPUMemAllocator( - GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(), - gpu_options.per_process_gpu_memory_fraction() > 1.0 || - gpu_options.experimental().use_unified_memory()), - total_memory, gpu_options.allow_growth(), name) {} + : BFCAllocator(sub_allocator, total_memory, + GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {} } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h index f1cc2eace1..d4c9cee89a 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h @@ -31,28 +31,20 @@ limitations under the License. namespace tensorflow { -// A GPU memory allocator that implements a 'best-fit with coalescing' -// algorithm. -class GPUBFCAllocator : public BFCAllocator { - public: - // 'cuda_gpu_id' refers to the ID of the GPU device within - // the process and must reference a valid ID in the process. - GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory, - const string& name); - GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory, - const GPUOptions& gpu_options, const string& name); - virtual ~GPUBFCAllocator() {} - - TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator); -}; - // Suballocator for GPU memory. class GPUMemAllocator : public SubAllocator { public: + // 'platform_gpu_id' refers to the ID of the GPU device within + // the process and must reference a valid ID in the process. // Note: stream_exec cannot be null. explicit GPUMemAllocator(se::StreamExecutor* stream_exec, - bool use_unified_memory) - : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) { + PlatformGpuId gpu_id, bool use_unified_memory, + const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), + stream_exec_(stream_exec), + gpu_id_(gpu_id), + use_unified_memory_(use_unified_memory) { CHECK(stream_exec_ != nullptr); } ~GPUMemAllocator() override {} @@ -65,12 +57,14 @@ class GPUMemAllocator : public SubAllocator { } else { ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque(); } + VisitAlloc(ptr, gpu_id_.value(), num_bytes); } return ptr; } void Free(void* ptr, size_t num_bytes) override { if (ptr != nullptr) { + VisitFree(ptr, gpu_id_.value(), num_bytes); if (use_unified_memory_) { stream_exec_->UnifiedMemoryDeallocate(ptr); } else { @@ -82,11 +76,28 @@ class GPUMemAllocator : public SubAllocator { private: se::StreamExecutor* stream_exec_; // not owned, non-null + const PlatformGpuId gpu_id_; const bool use_unified_memory_ = false; TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator); }; +// A GPU memory allocator that implements a 'best-fit with coalescing' +// algorithm. +class GPUBFCAllocator : public BFCAllocator { + public: + GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory, + const string& name); + GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory, + const GPUOptions& gpu_options, const string& name); + ~GPUBFCAllocator() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator); + + private: + static bool GetAllowGrowthValue(const GPUOptions& gpu_options); +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 67caeb3495..60e82ed13b 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -46,7 +47,11 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use, } TEST(GPUBFCAllocatorTest, NoDups) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); CheckStats(&a, 0, 0, 0, 0); // Allocate a lot of raw pointers @@ -75,7 +80,11 @@ TEST(GPUBFCAllocatorTest, NoDups) { } TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); // Allocate 256 raw pointers of sizes between 100 bytes and about // a meg random::PhiloxRandom philox(123, 17); @@ -133,7 +142,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { } TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); CheckStats(&a, 0, 0, 0, 0); float* first_ptr = a.Allocate<float>(1024); @@ -168,18 +181,30 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { } TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); float* ptr = a.Allocate<float>(0); EXPECT_EQ(nullptr, ptr); } TEST(GPUBFCAllocatorTest, TracksSizes) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); EXPECT_EQ(true, a.TracksAllocationSizes()); } TEST(GPUBFCAllocatorTest, AllocatedVsRequested) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); float* t1 = a.Allocate<float>(1); EXPECT_EQ(4, a.RequestedSize(t1)); EXPECT_EQ(256, a.AllocatedSize(t1)); @@ -187,8 +212,12 @@ TEST(GPUBFCAllocatorTest, AllocatedVsRequested) { } TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); // Configure a 1MiB byte limit - GPUBFCAllocator a(CudaGpuId(0), 1 << 20, "GPU_0_bfc"); + GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc"); float* first_ptr = a.Allocate<float>(1 << 6); float* second_ptr = a.Allocate<float>(1 << 20); @@ -203,7 +232,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) { options.set_allow_growth(true); // Max of 2GiB, but starts out small. - GPUBFCAllocator a(CudaGpuId(0), 1LL << 31, options, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc"); // Allocate 10 raw pointers of sizes between 100 bytes and about // 64 megs. @@ -264,8 +297,15 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) { } TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) { - GPUBFCAllocator a(CudaGpuId(0), 1UL << 60, "GPU_0_bfc"); - GPUBFCAllocator b(CudaGpuId(0), 1UL << 60, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc"); + sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc"); void* amem = a.AllocateRaw(1, 1); void* bmem = b.AllocateRaw(1, 1 << 30); a.DeallocateRaw(amem); @@ -273,7 +313,11 @@ TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) { } static void BM_Allocation(int iters) { - GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc"); // Exercise a few different allocation sizes std::vector<size_t> sizes = {256, 4096, 16384, 524288, 512, 1048576, 10485760, 104857600, @@ -289,7 +333,11 @@ static void BM_Allocation(int iters) { BENCHMARK(BM_Allocation); static void BM_AllocationThreaded(int iters, int num_threads) { - GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc"); thread::ThreadPool pool(Env::Default(), "test", num_threads); std::atomic_int_fast32_t count(iters); mutex done_lock; @@ -325,7 +373,11 @@ BENCHMARK(BM_AllocationThreaded)->Arg(1)->Arg(4)->Arg(16); // A more complex benchmark that defers deallocation of an object for // "delay" allocations. static void BM_AllocationDelayed(int iters, int delay) { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); // Exercise a few different allocation sizes std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024}; int size_index = 0; @@ -358,12 +410,18 @@ BENCHMARK(BM_AllocationDelayed)->Arg(1)->Arg(10)->Arg(100)->Arg(1000); class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test { protected: + void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); } + // The following test methods are called from tests. The reason for this is // that this class is a friend class to BFCAllocator, but tests are not, so // only methods inside this class can access private members of BFCAllocator. void TestBinDebugInfo() { - GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); std::vector<void*> initial_ptrs; std::vector<size_t> initial_ptrs_allocated_sizes; @@ -441,7 +499,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test { } void TestLog2FloorNonZeroSlow() { - GPUBFCAllocator a(CudaGpuId(0), 1 /* total_memory */, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc"); EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0)); EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1)); EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2)); @@ -450,6 +512,56 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test { EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024)); EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025)); } + + void TestForceAllowGrowth() { + PlatformGpuId platform_gpu_id(0); + GPUOptions options; + // Unset flag value uses provided option. + unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"); + options.set_allow_growth(true); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options, + "GPU_0_bfc"); + EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}), + unset_flag_allocator.curr_region_allocation_bytes_); + + // Unparseable flag value uses provided option. + setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1); + options.set_allow_growth(true); + sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options, + "GPU_1_bfc"); + EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}), + unparsable_flag_allocator.curr_region_allocation_bytes_); + + // Max of 2GiB total memory. Env variable set forces allow_growth, which + // does an initial allocation of 1MiB. + setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1); + options.set_allow_growth(false); + sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31, + options, "GPU_2_bfc"); + EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}), + force_allow_growth_allocator.curr_region_allocation_bytes_); + + // If env variable forces allow_growth disabled, all available memory is + // allocated. + setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1); + options.set_allow_growth(true); + sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31, + options, "GPU_3_bfc"); + EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31), + force_no_allow_growth_allocator.curr_region_allocation_bytes_); + } }; TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); } @@ -458,6 +570,10 @@ TEST_F(GPUBFCAllocatorPrivateMethodsTest, Log2FloorNonZeroSlow) { TestLog2FloorNonZeroSlow(); } +TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) { + TestForceAllowGrowth(); +} + } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc index 934a57a5fb..d85ca8892f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc @@ -27,10 +27,11 @@ limitations under the License. namespace tensorflow { -GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator, - CudaGpuId cuda_gpu_id) +GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator, + PlatformGpuId platform_gpu_id) : base_allocator_(allocator) { - stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + stream_exec_ = + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); } GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; } @@ -60,14 +61,6 @@ void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) { #endif // GOOGLE_CUDA } -void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) { - return base_allocator_->AddAllocVisitor(visitor); -} - -void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) { - return base_allocator_->AddFreeVisitor(visitor); -} - bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h index 856fdc34b4..8df3724bc4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h @@ -19,7 +19,7 @@ limitations under the License. #include <memory> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -29,20 +29,18 @@ namespace tensorflow { // An allocator that wraps a GPU allocator and adds debugging // functionality that verifies that users do not write outside their // allocated memory. -class GPUcudaMallocAllocator : public VisitableAllocator { +class GPUcudaMallocAllocator : public Allocator { public: - explicit GPUcudaMallocAllocator(VisitableAllocator* allocator, - CudaGpuId cuda_gpu_id); + explicit GPUcudaMallocAllocator(Allocator* allocator, + PlatformGpuId platform_gpu_id); ~GPUcudaMallocAllocator() override; string Name() override { return "gpu_debug"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - void AddFreeVisitor(Visitor visitor) override; bool TracksAllocationSizes() override; private: - VisitableAllocator* base_allocator_ = nullptr; // owned + Allocator* base_allocator_ = nullptr; // owned se::StreamExecutor* stream_exec_; // Not owned. diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc index e4c834b30d..989ddbe4af 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc @@ -73,10 +73,11 @@ void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) { // ----------------------------------------------------------------------------- // GPUDebugAllocator // ----------------------------------------------------------------------------- -GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator, - CudaGpuId cuda_gpu_id) +GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator, + PlatformGpuId platform_gpu_id) : base_allocator_(allocator) { - stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + stream_exec_ = + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); } GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; } @@ -111,14 +112,6 @@ void GPUDebugAllocator::DeallocateRaw(void* ptr) { base_allocator_->DeallocateRaw(ptr); } -void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) { - return base_allocator_->AddAllocVisitor(visitor); -} - -void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) { - return base_allocator_->AddFreeVisitor(visitor); -} - bool GPUDebugAllocator::TracksAllocationSizes() { return true; } size_t GPUDebugAllocator::RequestedSize(const void* ptr) { @@ -158,10 +151,11 @@ bool GPUDebugAllocator::CheckFooter(void* ptr) { // ----------------------------------------------------------------------------- // GPUNanResetAllocator // ----------------------------------------------------------------------------- -GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator, - CudaGpuId cuda_gpu_id) +GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator, + PlatformGpuId platform_gpu_id) : base_allocator_(allocator) { - stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + stream_exec_ = + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); } GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; } @@ -200,14 +194,6 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) { base_allocator_->DeallocateRaw(ptr); } -void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) { - return base_allocator_->AddAllocVisitor(visitor); -} - -void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) { - return base_allocator_->AddFreeVisitor(visitor); -} - size_t GPUNanResetAllocator::RequestedSize(const void* ptr) { return base_allocator_->RequestedSize(ptr); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h index 0f9b72040c..17757a106c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h @@ -21,7 +21,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -31,16 +31,14 @@ namespace tensorflow { // An allocator that wraps a GPU allocator and adds debugging // functionality that verifies that users do not write outside their // allocated memory. -class GPUDebugAllocator : public VisitableAllocator { +class GPUDebugAllocator : public Allocator { public: - explicit GPUDebugAllocator(VisitableAllocator* allocator, - CudaGpuId cuda_gpu_id); + explicit GPUDebugAllocator(Allocator* allocator, + PlatformGpuId platform_gpu_id); ~GPUDebugAllocator() override; string Name() override { return "gpu_debug"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - void AddFreeVisitor(Visitor visitor) override; bool TracksAllocationSizes() override; size_t RequestedSize(const void* ptr) override; size_t AllocatedSize(const void* ptr) override; @@ -53,7 +51,7 @@ class GPUDebugAllocator : public VisitableAllocator { bool CheckFooter(void* ptr); private: - VisitableAllocator* base_allocator_ = nullptr; // owned + Allocator* base_allocator_ = nullptr; // owned se::StreamExecutor* stream_exec_; // Not owned. @@ -63,23 +61,21 @@ class GPUDebugAllocator : public VisitableAllocator { // An allocator that wraps a GPU allocator and resets the memory on // allocation and free to 'NaN', helping to identify cases where the // user forgets to initialize the memory. -class GPUNanResetAllocator : public VisitableAllocator { +class GPUNanResetAllocator : public Allocator { public: - explicit GPUNanResetAllocator(VisitableAllocator* allocator, - CudaGpuId cuda_gpu_id); + explicit GPUNanResetAllocator(Allocator* allocator, + PlatformGpuId platform_gpu_id); ~GPUNanResetAllocator() override; string Name() override { return "gpu_nan_reset"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - void AddFreeVisitor(Visitor visitor) override; size_t RequestedSize(const void* ptr) override; size_t AllocatedSize(const void* ptr) override; void GetStats(AllocatorStats* stats) override; void ClearStats() override; private: - VisitableAllocator* base_allocator_ = nullptr; // owned + Allocator* base_allocator_ = nullptr; // owned se::StreamExecutor* stream_exec_; // Not owned. diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc index 236a0afa0b..aca08a7e33 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -34,10 +34,14 @@ namespace tensorflow { namespace { TEST(GPUDebugAllocatorTest, OverwriteDetection_None) { - const CudaGpuId cuda_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id); - auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id); + auto stream_exec = + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); for (int s : {8}) { std::vector<int64> cpu_array(s); @@ -58,11 +62,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) { for (int s : {8, 211}) { EXPECT_DEATH( { - const CudaGpuId cuda_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id); + const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id); auto stream_exec = - GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); std::vector<int64> cpu_array(s); memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64)); @@ -91,11 +98,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) { for (int s : {8, 22}) { EXPECT_DEATH( { - const CudaGpuId cuda_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id); + const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id); auto stream_exec = - GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); std::vector<int64> cpu_array(s); memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64)); @@ -121,10 +131,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) { } TEST(GPUDebugAllocatorTest, ResetToNan) { - const CudaGpuId cuda_gpu_id(0); - GPUNanResetAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id); - auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id); + auto stream_exec = + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); std::vector<float> cpu_array(1024); std::vector<float> cpu_array_result(1024); @@ -161,13 +175,17 @@ TEST(GPUDebugAllocatorTest, ResetToNan) { } TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) { - const CudaGpuId cuda_gpu_id(0); + const PlatformGpuId platform_gpu_id(0); // NaN reset must be the outer-most allocator. + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); GPUNanResetAllocator a( - new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id), - cuda_gpu_id); - auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id), + platform_gpu_id); + auto stream_exec = + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); std::vector<float> cpu_array(1024); std::vector<float> cpu_array_result(1024); @@ -204,18 +222,24 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) { } TEST(GPUDebugAllocatorTest, TracksSizes) { - const CudaGpuId cuda_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id); + const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id); EXPECT_EQ(true, a.TracksAllocationSizes()); } TEST(GPUDebugAllocatorTest, AllocatedVsRequested) { - const CudaGpuId cuda_gpu_id(0); + const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); GPUNanResetAllocator a( - new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""), - cuda_gpu_id), - cuda_gpu_id); + new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), + platform_gpu_id), + platform_gpu_id); float* t1 = a.Allocate<float>(1); EXPECT_EQ(4, a.RequestedSize(t1)); EXPECT_EQ(256, a.AllocatedSize(t1)); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 2763ac0d4a..d8ebdeff5d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -41,7 +41,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/common_runtime/local_device.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -105,9 +104,9 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface { reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize); stream_ = cuda_stream; allocator_ = alloc; - CudaGpuId cuda_gpu_id; - TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id)); - device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()]; + PlatformGpuId platform_gpu_id; + TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); + device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()]; } const cudaStream_t& stream() const override { return *stream_; } @@ -285,6 +284,38 @@ BaseGPUDevice::~BaseGPUDevice() { for (auto ctx : device_contexts_) ctx->Unref(); } +// This should be idempotent if already initialized. +Status BaseGPUDevice::InitScratchBuffers() { + mutex_lock l(scratch_init_mutex_); + if (scratch_.size() < max_streams_) { + for (int i = 0; i < max_streams_; i++) { + DCHECK(streams_[i]); + if (scratch_.size() > i && scratch_[i]) continue; + size_t scratch_buffer_size = + Eigen::kCudaScratchSize + sizeof(unsigned int); + void* scratch_buffer = gpu_allocator_->AllocateRaw( + Allocator::kAllocatorAlignment, scratch_buffer_size); + if (scratch_buffer == nullptr) { + return errors::FailedPrecondition( + "Failed to allocate scratch buffer for device ", + tf_gpu_id_.value()); + } + se::DeviceMemory<char> mem( + se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); + + bool ok = executor_->SynchronousMemZero( + &mem, Eigen::kCudaScratchSize + sizeof(unsigned int)); + if (!ok) { + return errors::FailedPrecondition( + "Failed to memcopy into scratch buffer for device ", + tf_gpu_id_.value()); + } + scratch_.push_back(static_cast<char*>(scratch_buffer)); + } + } + return Status::OK(); +} + Status BaseGPUDevice::Init(const SessionOptions& options) { auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_); if (!executor_status.status().ok()) { @@ -303,27 +334,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { for (int i = 0; i < max_streams_; i++) { streams_.push_back(StreamGroupFactory::Global().GetOrCreate( tf_gpu_id_, i, executor_, options.config.gpu_options())); - - size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int); - void* scratch_buffer = gpu_allocator_->AllocateRaw( - Allocator::kAllocatorAlignment, scratch_buffer_size); - if (scratch_buffer == nullptr) { - return errors::FailedPrecondition( - "Failed to allocate scratch buffer for device ", tf_gpu_id_.value()); - } - scratch_.push_back(static_cast<char*>(scratch_buffer)); - - se::DeviceMemory<char> mem( - se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); - - bool ok = executor_->SynchronousMemZero( - &mem, Eigen::kCudaScratchSize + sizeof(unsigned int)); - if (!ok) { - return errors::FailedPrecondition( - "Failed to memcopy into scratch buffer for device ", - tf_gpu_id_.value()); - } - device_contexts_.push_back(new GPUDeviceContext( i, streams_.back()->compute, streams_.back()->host_to_device, streams_.back()->device_to_host, streams_.back()->device_to_device)); @@ -332,9 +342,10 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { gpu_device_info_->stream = streams_[0]->compute; gpu_device_info_->default_context = device_contexts_[0]; gpu_device_info_->event_mgr = em_.get(); - CudaGpuId cuda_gpu_id; - TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id)); - gpu_device_info_->gpu_id = cuda_gpu_id.value(); + PlatformGpuId platform_gpu_id; + TF_RETURN_IF_ERROR( + GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id)); + gpu_device_info_->gpu_id = platform_gpu_id.value(); set_tensorflow_gpu_device_info(gpu_device_info_); // Whether and how the GPU device uses its own threadpool. @@ -423,9 +434,6 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph, } void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { - tracing::ScopedRegion region(tracing::EventCategory::kCompute, - op_kernel->name()); - // NOTE(tucker): We need to discriminate between Eigen GPU // operations and all others. If an operation is Eigen // implemented (or otherwise tries to launch a cuda kernel @@ -439,8 +447,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { context->SetStatus(errors::Internal( "Invalid synchronous 'Compute' on GPU for '_Recv' op")); } else { - tracing::ScopedAnnotation annotation(op_kernel->name(), - op_kernel->type_string()); ComputeHelper(op_kernel, context); } } @@ -690,9 +696,9 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice { Eigen::GpuDevice device_; }; -// Parse 'visible_device_list' into a list of CUDA GPU ids. +// Parse 'visible_device_list' into a list of platform GPU ids. Status ParseVisibleDeviceList(const string& visible_device_list, - std::vector<CudaGpuId>* visible_gpu_order) { + std::vector<PlatformGpuId>* visible_gpu_order) { visible_gpu_order->clear(); se::Platform* gpu_manager = GPUMachineManager(); @@ -707,26 +713,28 @@ Status ParseVisibleDeviceList(const string& visible_device_list, } else { const std::vector<string> order_str = str_util::Split(visible_device_list, ','); - for (const string& cuda_gpu_id_str : order_str) { - int32 cuda_gpu_id; - if (!strings::safe_strto32(cuda_gpu_id_str, &cuda_gpu_id)) { + for (const string& platform_gpu_id_str : order_str) { + int32 platform_gpu_id; + if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) { return errors::InvalidArgument( "Could not parse entry in 'visible_device_list': '", - cuda_gpu_id_str, "'. visible_device_list = ", visible_device_list); + platform_gpu_id_str, "'. visible_device_list = ", + visible_device_list); } - if (cuda_gpu_id < 0 || cuda_gpu_id >= gpu_manager->VisibleDeviceCount()) { + if (platform_gpu_id < 0 || + platform_gpu_id >= gpu_manager->VisibleDeviceCount()) { return errors::InvalidArgument( - "'visible_device_list' listed an invalid GPU id '", cuda_gpu_id, + "'visible_device_list' listed an invalid GPU id '", platform_gpu_id, "' but visible device count is ", gpu_manager->VisibleDeviceCount()); } - visible_gpu_order->push_back(CudaGpuId(cuda_gpu_id)); + visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id)); } } // Validate no repeats. - std::set<CudaGpuId> visible_device_set(visible_gpu_order->begin(), - visible_gpu_order->end()); + std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(), + visible_gpu_order->end()); if (visible_device_set.size() != visible_gpu_order->size()) { return errors::InvalidArgument( "visible_device_list contained a duplicate entry: ", @@ -737,8 +745,8 @@ Status ParseVisibleDeviceList(const string& visible_device_list, Status VerifyVirtualDeviceSettings( const size_t num_gpus_to_use, const GPUOptions& gpu_options, - const std::vector<CudaGpuId>& visible_gpu_order, - const std::vector<CudaGpuId>& valid_cuda_gpu_ids) { + const std::vector<PlatformGpuId>& visible_gpu_order, + const std::vector<PlatformGpuId>& valid_platform_gpu_ids) { const auto& virtual_devices = gpu_options.experimental().virtual_devices(); CHECK(!virtual_devices.empty()); if (gpu_options.per_process_gpu_memory_fraction() > 0) { @@ -760,11 +768,11 @@ Status VerifyVirtualDeviceSettings( " #GPUs in visible_device_list: ", visible_gpu_order.size(), " virtual_devices.size(): ", virtual_devices.size()); } - if (valid_cuda_gpu_ids.size() != virtual_devices.size()) { + if (valid_platform_gpu_ids.size() != virtual_devices.size()) { return errors::Unknown( "The number of valid GPUs doesn't match the number of elements in " "the virtual_devices list.", - " #valid GPUs: ", valid_cuda_gpu_ids.size(), + " #valid GPUs: ", valid_platform_gpu_ids.size(), " virtual_devices.size(): ", virtual_devices.size()); } return Status::OK(); @@ -806,18 +814,18 @@ int64 MinSystemMemory(int64 available_memory) { } // Get the memory limit for the virtual device being created on GPU with -// 'cuda_gpu_id', when that virtual device is the only virtual device being +// 'platform_gpu_id', when that virtual device is the only virtual device being // created on that GPU. Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options, - CudaGpuId cuda_gpu_id, + PlatformGpuId platform_gpu_id, int64* memory_limit) { int64 total_memory = 0; int64 available_memory = 0; se::StreamExecutor* se = - GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) { return errors::Unknown("Failed to query available memory for GPU ", - cuda_gpu_id.value()); + platform_gpu_id.value()); } int64 allocated_memory = 0; @@ -867,10 +875,11 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() { return new ConcretePerOpGpuDevice(); } -void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, - PerOpGpuDevice* device, - DeviceContext* dc, - Allocator* allocator) { +Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, + PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) { + TF_RETURN_IF_ERROR(InitScratchBuffers()); if (dc) { const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc); const int stream_id = gpu_dc->stream_id(); @@ -881,6 +890,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, } else { ReinitializeDevice(context, device, 0, allocator); } + return Status::OK(); } Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr, @@ -916,8 +926,8 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, num_gpus_to_use = iter->second; } const auto& gpu_options = options.config.gpu_options(); - std::vector<CudaGpuId> visible_gpu_order; - std::vector<CudaGpuId> valid_cuda_gpu_ids; + std::vector<PlatformGpuId> visible_gpu_order; + std::vector<PlatformGpuId> valid_platform_gpu_ids; // If we aren't going to use any GPUs, don't initialize them. // We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0, // because it treats an empty gpu_options.visible_device_list as 'all GPUs are @@ -926,12 +936,12 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(), &visible_gpu_order)); TF_RETURN_IF_ERROR( - GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids)); + GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids)); } - if (num_gpus_to_use > valid_cuda_gpu_ids.size()) { - num_gpus_to_use = valid_cuda_gpu_ids.size(); + if (num_gpus_to_use > valid_platform_gpu_ids.size()) { + num_gpus_to_use = valid_platform_gpu_ids.size(); } - if (!valid_cuda_gpu_ids.empty()) { + if (!valid_platform_gpu_ids.empty()) { // Save the original device. int original_device = 0; cudaError_t err = cudaGetDevice(&original_device); @@ -941,17 +951,18 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, } // Force to implicitly initialize CUDA runtime on each valid GPU before // CreateGPUDevice(). - for (CudaGpuId cuda_gpu_id : valid_cuda_gpu_ids) { - err = cudaSetDevice(cuda_gpu_id.value()); + for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) { + err = cudaSetDevice(platform_gpu_id.value()); if (err != cudaSuccess) { - return errors::Internal("cudaSetDevice() on GPU:", cuda_gpu_id.value(), - " failed. Status: ", cudaGetErrorString(err)); + return errors::Internal("cudaSetDevice() on GPU:", + platform_gpu_id.value(), " failed. Status: ", + cudaGetErrorString(err)); } err = cudaFree(nullptr); if (err != cudaSuccess) { - return errors::Internal( - "CUDA runtime implicit initialization on GPU:", cuda_gpu_id.value(), - " failed. Status: ", cudaGetErrorString(err)); + return errors::Internal("CUDA runtime implicit initialization on GPU:", + platform_gpu_id.value(), " failed. Status: ", + cudaGetErrorString(err)); } } // Reset to the original device. @@ -977,10 +988,10 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, LOG(INFO) << line_buf; for (int i = 0; i < visible_gpu_order.size(); ++i) { line_buf = strings::StrCat(visible_gpu_order[i].value(), ": "); - CudaGpuId cuda_id_i = visible_gpu_order[i]; + PlatformGpuId gpu_id_i = visible_gpu_order[i]; for (int j = 0; j < visible_gpu_order.size(); ++j) { - CudaGpuId cuda_id_j = visible_gpu_order[j]; - if (im.directed_links.find({cuda_id_i, cuda_id_j}) != + PlatformGpuId gpu_id_j = visible_gpu_order[j]; + if (im.directed_links.find({gpu_id_i, gpu_id_j}) != im.directed_links.end()) { line_buf.append("Y "); } else { @@ -993,22 +1004,23 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, const auto& virtual_devices = gpu_options.experimental().virtual_devices(); if (!virtual_devices.empty()) { - TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings( - num_gpus_to_use, gpu_options, visible_gpu_order, valid_cuda_gpu_ids)); + TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(num_gpus_to_use, gpu_options, + visible_gpu_order, + valid_platform_gpu_ids)); // We've verified that num_gpus_to_use >= virtual_devices.size(). num_gpus_to_use = virtual_devices.size(); CHECK(gpu_options.visible_device_list().empty() || - valid_cuda_gpu_ids == visible_gpu_order); + valid_platform_gpu_ids == visible_gpu_order); } int next_tf_gpu_id = 0; std::vector<int64> memory_limit_bytes; for (int i = 0; i < num_gpus_to_use; ++i) { - const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i]; + const PlatformGpuId platform_gpu_id = valid_platform_gpu_ids[i]; if (virtual_devices.empty() || virtual_devices.Get(i).memory_limit_mb_size() == 0) { int64 single_virtual_device_memory_limit = 0; TF_RETURN_IF_ERROR(SingleVirtualDeviceMemoryLimit( - gpu_options, cuda_gpu_id, &single_virtual_device_memory_limit)); + gpu_options, platform_gpu_id, &single_virtual_device_memory_limit)); memory_limit_bytes.push_back(single_virtual_device_memory_limit); } else { const auto& memory_limit_mb = virtual_devices.Get(i).memory_limit_mb(); @@ -1021,7 +1033,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, TfGpuId tf_gpu_id(next_tf_gpu_id); ++next_tf_gpu_id; TF_RETURN_IF_ERROR( - GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id)); + GpuIdManager::InsertTfPlatformGpuIdPair(tf_gpu_id, platform_gpu_id)); } } const int num_tf_gpus = next_tf_gpu_id; @@ -1046,7 +1058,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options, return Status::OK(); } -static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id, +static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id, const se::DeviceDescription& desc) { int cc_major; int cc_minor; @@ -1055,9 +1067,8 @@ static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id, cc_minor = 0; } // LINT.IfChange - return strings::StrCat("device: ", cuda_gpu_id.value(), - ", name: ", desc.name(), - ", pci bus id: ", desc.pci_bus_id(), + return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ", + desc.name(), ", pci bus id: ", desc.pci_bus_id(), ", compute capability: ", cc_major, ".", cc_minor); // LINT.ThenChange(//tensorflow/python/platform/test.py) } @@ -1072,12 +1083,13 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, const string device_name = strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value()); GpuIdUtil::CheckValidTfGpuId(tf_gpu_id); - CudaGpuId cuda_gpu_id; - TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id)); + PlatformGpuId platform_gpu_id; + TF_RETURN_IF_ERROR( + GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); int numa_node = dev_locality.numa_node(); se::StreamExecutor* se = - GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); const se::DeviceDescription& desc = se->GetDeviceDescription(); GPUProcessState* process_state = GPUProcessState::singleton(); Allocator* gpu_allocator = process_state->GetGPUAllocator( @@ -1098,11 +1110,11 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, // TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit. BaseGPUDevice* gpu_device = CreateGPUDevice( options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality, - tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator, - ProcessState::singleton()->GetCPUAllocator(numa_node)); + tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc), + gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node)); LOG(INFO) << "Created TensorFlow device (" << device_name << " with " << (stats.bytes_limit >> 20) << " MB memory) -> physical GPU (" - << GetShortDeviceDescription(cuda_gpu_id, desc) << ")"; + << GetShortDeviceDescription(platform_gpu_id, desc) << ")"; TF_RETURN_IF_ERROR(gpu_device->Init(options)); devices->push_back(gpu_device); @@ -1110,18 +1122,21 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options, } namespace { -std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> +std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> GetPeerAccessMap(se::Platform* platform, - const std::vector<CudaGpuId>& visible_gpu_order) { - std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map( - new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>); - for (CudaGpuId cuda_gpu_i : visible_gpu_order) { - for (CudaGpuId cuda_gpu_j : visible_gpu_order) { + const std::vector<PlatformGpuId>& visible_gpu_order) { + std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> map( + new std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>); + for (PlatformGpuId platform_gpu_i : visible_gpu_order) { + for (PlatformGpuId platform_gpu_j : visible_gpu_order) { se::StreamExecutor* from = - GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i) + .ValueOrDie(); se::StreamExecutor* to = - GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie(); - (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to); + GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j) + .ValueOrDie(); + (*map)[{platform_gpu_i, platform_gpu_j}] = + from->CanEnablePeerAccessTo(to); } } @@ -1131,19 +1146,19 @@ GetPeerAccessMap(se::Platform* platform, } // namespace Status BaseGPUDeviceFactory::GetInterconnectMaps( - const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager, - std::vector<InterconnectMap>* maps) { + const std::vector<PlatformGpuId>& visible_gpu_order, + se::Platform* gpu_manager, std::vector<InterconnectMap>* maps) { // The default interconnect map is obtained from the StreamExecutor. auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order); maps->resize(1); InterconnectMap& imap = maps->at(0); imap.name = "StreamExecutor"; imap.strength = InterconnectMap::kStreamExecutorStrength; - for (CudaGpuId cuda_id_i : visible_gpu_order) { - for (CudaGpuId cuda_id_j : visible_gpu_order) { - if (cuda_id_i == cuda_id_j) continue; - if ((*access_map)[{cuda_id_i, cuda_id_j}]) { - imap.directed_links.insert({cuda_id_i, cuda_id_j}); + for (PlatformGpuId gpu_id_i : visible_gpu_order) { + for (PlatformGpuId gpu_id_j : visible_gpu_order) { + if (gpu_id_i == gpu_id_j) continue; + if ((*access_map)[{gpu_id_i, gpu_id_j}]) { + imap.directed_links.insert({gpu_id_i, gpu_id_j}); } } } @@ -1158,13 +1173,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities( all_tf_gpu_ids.push_back(TfGpuId(i)); } for (TfGpuId tf_gpu_id : all_tf_gpu_ids) { - CudaGpuId cuda_gpu_id; - TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id)); + PlatformGpuId platform_gpu_id; + TF_RETURN_IF_ERROR( + GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); // Get GPU bus_id from its reported NUMA affinity. Because GPUs are // virtualized in some environments, we can't just use the GPU id. // NUMA locales are indexed from 0, buses are indexed from 1. se::StreamExecutor* se = - GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); const se::DeviceDescription& desc = se->GetDeviceDescription(); int numa_node = desc.numa_node(); if (numa_node < 0) { @@ -1174,7 +1190,8 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities( // may run into trouble later with data transfer operations. The // trouble may manifest as slower than expected performance, or // outright failures. - LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id + LOG(INFO) << "Could not identify NUMA node of platform GPU id " + << platform_gpu_id << ", defaulting to 0. Your kernel may not have been built " << "with NUMA support."; numa_node = 0; @@ -1187,10 +1204,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities( LocalLinks* links = dev_locality.mutable_links(); for (const InterconnectMap& imap : interconnects) { for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) { - CudaGpuId cuda_gpu_dst; + PlatformGpuId platform_gpu_dst; TF_RETURN_IF_ERROR( - GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst)); - if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) != + GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst)); + if (imap.directed_links.find({platform_gpu_id, platform_gpu_dst}) != imap.directed_links.end()) { InterconnectLink* ilink = links->add_link(); ilink->set_device_id(tf_gpu_dst.value()); @@ -1204,10 +1221,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities( // add high strength links to the others. for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) { if (tf_gpu_id == tf_gpu_dst) continue; - CudaGpuId cuda_gpu_dst; + PlatformGpuId platform_gpu_dst; TF_RETURN_IF_ERROR( - GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst)); - if (cuda_gpu_id == cuda_gpu_dst) { + GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst)); + if (platform_gpu_id == platform_gpu_dst) { InterconnectLink* ilink = links->add_link(); ilink->set_device_id(tf_gpu_dst.value()); ilink->set_type("SAME_DEVICE"); @@ -1216,9 +1233,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities( } (*localities)[tf_gpu_id] = dev_locality; - VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id - << " on bus " << dev_locality.bus_id() << " numa: " << numa_node - << " pci: " << desc.pci_bus_id() + VLOG(1) << "GPUDevice PlatformGpuId " << platform_gpu_id << " TfGpuId " + << tf_gpu_id << " on bus " << dev_locality.bus_id() + << " numa: " << numa_node << " pci: " << desc.pci_bus_id() << " DeviceLocality: " << dev_locality.DebugString(); } return Status::OK(); @@ -1226,14 +1243,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities( static int GetDefaultMinGPUMultiprocessorCount( se::Platform* gpu_manager, - const std::vector<CudaGpuId>& visible_gpu_order) { + const std::vector<PlatformGpuId>& visible_gpu_order) { static const int kDefaultMinGPUMultiprocessorCount = 8; // Find the highest multi-processor count across all visible GPUs. int max_count = -1; for (int i = 0; i < visible_gpu_order.size(); ++i) { auto exec_status = - GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_order[i]); + GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_order[i]); if (!exec_status.ok()) { continue; } @@ -1252,7 +1269,7 @@ static int GetDefaultMinGPUMultiprocessorCount( static int GetMinGPUMultiprocessorCount( se::Platform* gpu_manager, - const std::vector<CudaGpuId>& visible_gpu_order) { + const std::vector<PlatformGpuId>& visible_gpu_order) { const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT"); if (tf_min_gpu_core_count == nullptr || @@ -1330,18 +1347,20 @@ std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() { } Status EnablePeerAccess(se::Platform* platform, - const std::vector<CudaGpuId>& visible_gpu_order) { + const std::vector<PlatformGpuId>& visible_gpu_order) { int possible_peer_count = 0; int enabled_peer_count = 0; for (int i = 0; i < visible_gpu_order.size(); ++i) { - const CudaGpuId cuda_gpu_i = visible_gpu_order[i]; + const PlatformGpuId platform_gpu_i = visible_gpu_order[i]; for (int j = 0; j < visible_gpu_order.size(); ++j) { - const CudaGpuId cuda_gpu_j = visible_gpu_order[j]; + const PlatformGpuId platform_gpu_j = visible_gpu_order[j]; // We have already validated that ExecutorForDevice() calls return OK. se::StreamExecutor* from = - GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i) + .ValueOrDie(); se::StreamExecutor* to = - GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j) + .ValueOrDie(); if (from->CanEnablePeerAccessTo(to)) { ++possible_peer_count; @@ -1349,7 +1368,8 @@ Status EnablePeerAccess(se::Platform* platform, if (!status.ok()) { LOG(WARNING) << "Unable to enable peer access between device ordinals " - << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status; + << platform_gpu_i << " and " << platform_gpu_j + << ", status: " << status; } else { ++enabled_peer_count; } @@ -1372,22 +1392,23 @@ Status EnablePeerAccess(se::Platform* platform, } // namespace Status BaseGPUDeviceFactory::GetValidDeviceIds( - const std::vector<CudaGpuId>& visible_gpu_order, - std::vector<CudaGpuId>* ids) { + const std::vector<PlatformGpuId>& visible_gpu_order, + std::vector<PlatformGpuId>* ids) { se::Platform* gpu_manager = GPUMachineManager(); bool new_gpu_found = false; for (int i = 0; i < visible_gpu_order.size(); ++i) { - const CudaGpuId cuda_gpu_id = visible_gpu_order[i]; + const PlatformGpuId visible_gpu_id = visible_gpu_order[i]; - // Only perform this once per visible cuda gpu id. - if (visible_gpu_initialized_[cuda_gpu_id.value()]) { + // Only perform this once per visible platform gpu id. + if (visible_gpu_initialized_[visible_gpu_id.value()]) { continue; } - visible_gpu_initialized_[cuda_gpu_id.value()] = true; + visible_gpu_initialized_[visible_gpu_id.value()] = true; new_gpu_found = true; - auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id); + auto executor = + GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id); if (!executor.ok()) { return executor.status(); } @@ -1435,9 +1456,9 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds( // Filter out devices that don't have the right capability or power. for (int i = 0; i < visible_gpu_order.size(); ++i) { - const CudaGpuId visible_gpu_id = visible_gpu_order[i]; + const PlatformGpuId visible_gpu_id = visible_gpu_order[i]; auto exec_status = - GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_id); + GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id); if (!exec_status.ok()) { LOG(INFO) << "Ignoring visible gpu device " << visible_gpu_id << " whose executor is in invalid state: " @@ -1486,7 +1507,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds( if (!ids->empty()) { std::vector<int> raw_ids(ids->size()); std::transform(ids->begin(), ids->end(), raw_ids.begin(), - [](CudaGpuId id) -> int { return id.value(); }); + [](PlatformGpuId id) -> int { return id.value(); }); LOG(INFO) << "Adding visible gpu devices: " << str_util::Join(raw_ids, ", "); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index 56d03d7a8c..674e8384d5 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -65,6 +65,11 @@ class BaseGPUDevice : public LocalDevice { // completes. bool RequiresRecordingAccessedTensors() const override; + // GPU kernel execution requires us to use `tracing::ScopedAnnotation()` + // rather than `tracing::ScopedActivity()`, in order to relate asynchronously + // launched GPU kernels to the OpKernel. + bool TraceUsingAnnotations() const { return true; } + void ConsumeListOfAccessedTensors( DeviceContext* device_context, const TensorReferenceVector& tensor_refs) override; @@ -86,15 +91,16 @@ class BaseGPUDevice : public LocalDevice { // The caller owns the returned device. PerOpGpuDevice* MakeGpuDevice() override; - void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, - DeviceContext* dc, Allocator* allocator) override; + Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) override; - // Returns the CUDA GPU id of this device within the native driver system; + // Returns the platform GPU id of this device within the native driver system; // e.g., for CUDA this is the ordinal of the GPU within the system. int gpu_id() const { - CudaGpuId cuda_gpu_id; - TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id)); - return cuda_gpu_id.value(); + PlatformGpuId platform_gpu_id; + TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id)); + return platform_gpu_id.value(); } // The executor that provides control for the device; e.g., for CUDA this @@ -125,6 +131,7 @@ class BaseGPUDevice : public LocalDevice { class StreamGroupFactory; gtl::InlinedVector<StreamGroup*, 4> streams_; + mutex scratch_init_mutex_; gtl::InlinedVector<char*, 4> scratch_; std::vector<GPUDeviceContext*> device_contexts_; GpuDeviceInfo* gpu_device_info_ = nullptr; @@ -135,6 +142,9 @@ class BaseGPUDevice : public LocalDevice { std::unique_ptr<EventMgr> em_; std::unique_ptr<thread::ThreadPool> thread_pool_; + // Initialize scractch buffers used by Eigen. + Status InitScratchBuffers(); + void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device, int stream_id, Allocator* allocator); @@ -168,14 +178,14 @@ class BaseGPUDeviceFactory : public DeviceFactory { int32 strength; static const int kSameDeviceStrength; static const int kStreamExecutorStrength; - std::set<std::pair<CudaGpuId, CudaGpuId>> directed_links; + std::set<std::pair<PlatformGpuId, PlatformGpuId>> directed_links; }; protected: // Populates *maps with interconnect maps for all local direct access // pathways between GPUs. virtual Status GetInterconnectMaps( - const std::vector<CudaGpuId>& visible_gpu_order, + const std::vector<PlatformGpuId>& visible_gpu_order, se::Platform* gpu_manager, std::vector<InterconnectMap>* maps); struct TfGpuIdHash { @@ -207,16 +217,16 @@ class BaseGPUDeviceFactory : public DeviceFactory { Allocator* gpu_allocator, Allocator* cpu_allocator) = 0; - // Returns into 'ids' the list of valid CUDA GPU ids, in the order that + // Returns into 'ids' the list of valid platform GPU ids, in the order that // they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc, // based upon 'visible_gpu_order' which was generated by parsing // GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU // ids. - Status GetValidDeviceIds(const std::vector<CudaGpuId>& visible_gpu_order, - std::vector<CudaGpuId>* ids); + Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order, + std::vector<PlatformGpuId>* ids); - // visible_gpu_initialized_[cuda_gpu_id] is true if visible GPU cuda_gpu_id - // has been initialized by the process. + // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU + // platform_gpu_id has been initialized by the process. std::unordered_map<int, bool> visible_gpu_initialized_; }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc index daf59f0560..36294094e9 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc @@ -30,18 +30,21 @@ namespace tensorflow { namespace { const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0"; -int64 GetTotalGPUMemory(CudaGpuId gpu_id) { +int64 GetTotalGPUMemory(PlatformGpuId gpu_id) { se::StreamExecutor* se = - GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id) + .ValueOrDie(); int64 total_memory, available_memory; CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory)); return total_memory; } -Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) { +Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major, + int* cc_minor) { se::StreamExecutor* se = - GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie(); + GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id) + .ValueOrDie(); if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) { *cc_major = 0; *cc_minor = 0; @@ -223,7 +226,7 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) { // error. TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) { int cc_major, cc_minor; - TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor)); + TF_ASSERT_OK(GetComputeCapability(PlatformGpuId(0), &cc_major, &cc_minor)); // Exit early while running on Pascal or later GPUs. if (cc_major >= 6) { return; @@ -244,10 +247,10 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) { // more memory than what is available on the device. TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) { static constexpr double kGpuMemoryFraction = 1.2; - static constexpr CudaGpuId kCudaGpuId(0); + static constexpr PlatformGpuId kPlatformGpuId(0); int cc_major, cc_minor; - TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor)); + TF_ASSERT_OK(GetComputeCapability(kPlatformGpuId, &cc_major, &cc_minor)); // Exit early if running on pre-Pascal GPUs. if (cc_major < 6) { LOG(INFO) @@ -262,7 +265,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) { ASSERT_EQ(1, devices.size()); int64 memory_limit = devices[0]->attributes().memory_limit(); - ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) * + ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kPlatformGpuId) * kGpuMemoryFraction)); AllocatorAttributes allocator_attributes = AllocatorAttributes(); diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h index 2a6caea296..f0d9022821 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id.h @@ -25,10 +25,10 @@ namespace tensorflow { // physical machine, it can be filtered by CUDA environment variable // CUDA_VISIBLE_DEVICES. Note that this id is not visible to Tensorflow, but // result after filtering by CUDA_VISIBLE_DEVICES is visible to TF and is -// called CUDA GPU id as below. See +// called platform GPU id as below. See // http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars // for more details. -// - CUDA GPU id (also called *visible* GPU id in +// - *platform* GPU id (also called *visible* GPU id in // third_party/tensorflow/core/protobuf/config.proto): this is the id that is // visible to Tensorflow after filtering by CUDA_VISIBLE_DEVICES, and is // generated by the CUDA GPU driver. It starts from 0 and is used for CUDA API @@ -39,14 +39,14 @@ namespace tensorflow { // field of the device name "/device:GPU:<id>", and is also the identifier of // a BaseGPUDevice. Note that the configuration allows us to create multiple // BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the -// hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1 +// hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1 // mapping, see the example below. // // For example, assuming that in the machine we have GPU device with index 0, 1, // 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create -// the following mapping between CUDA GPU id and physical GPU id: +// the following mapping between platform GPU id and physical GPU id: // -// CUDA GPU id -> physical GPU id +// platform GPU id -> physical GPU id // 0 -> 1 // 1 -> 2 // 2 -> 3 @@ -56,32 +56,32 @@ namespace tensorflow { // // Assuming we configure the Session to create one BaseGPUDevice per GPU // hardware, then setting GPUOptions::visible_device_list to "2,0" will create -// the following mappting between TF GPU id and CUDA GPU id: +// the following mappting between TF GPU id and platform GPU id: // -// TF GPU id -> CUDA GPU ID +// TF GPU id -> platform GPU ID // 0 (i.e. /device:GPU:0) -> 2 // 1 (i.e. /device:GPU:1) -> 0 // -// Note that CUDA GPU id 1 is filtered out by GPUOptions::visible_device_list, -// so it won't be used by the TF process. +// Note that platform GPU id 1 is filtered out by +// GPUOptions::visible_device_list, so it won't be used by the TF process. // // On the other hand, if we configure it to create 2 BaseGPUDevice per GPU // hardware, then setting GPUOptions::visible_device_list to "2,0" will create -// the following mappting between TF GPU id and CUDA GPU id: +// the following mappting between TF GPU id and platform GPU id: // -// TF GPU id -> CUDA GPU ID +// TF GPU id -> platform GPU ID // 0 (i.e. /device:GPU:0) -> 2 // 1 (i.e. /device:GPU:1) -> 2 // 2 (i.e. /device:GPU:2) -> 0 // 3 (i.e. /device:GPU:3) -> 0 // -// We create strong-typed integer classes for both TF GPU id and CUDA GPU id to -// minimize programming errors and improve code readability. Except for the +// We create strong-typed integer classes for both TF GPU id and platform GPU id +// to minimize programming errors and improve code readability. Except for the // StreamExecutor interface (as we don't change its API), whenever we need a -// TF GPU id (or CUDA GPU id) we should use TfGpuId (or CudaGpuId) instead of a -// raw integer. +// TF GPU id (or platform GPU id) we should use TfGpuId (or PlatformGpuId) +// instead of a raw integer. TF_LIB_GTL_DEFINE_INT_TYPE(TfGpuId, int32); -TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32); +TF_LIB_GTL_DEFINE_INT_TYPE(PlatformGpuId, int32); } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc index b5099dc8ef..2b40730119 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc @@ -26,26 +26,27 @@ limitations under the License. namespace tensorflow { namespace { -// Manages the map between TfGpuId and CUDA GPU id. -class TfToCudaGpuIdMap { +// Manages the map between TfGpuId and platform GPU id. +class TfToPlatformGpuIdMap { public: - static TfToCudaGpuIdMap* singleton() { - static auto* id_map = new TfToCudaGpuIdMap; + static TfToPlatformGpuIdMap* singleton() { + static auto* id_map = new TfToPlatformGpuIdMap; return id_map; } - Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) { + Status Insert(TfGpuId tf_gpu_id, PlatformGpuId platform_gpu_id) + LOCKS_EXCLUDED(mu_) { std::pair<IdMapType::iterator, bool> result; { mutex_lock lock(mu_); - result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()}); + result = id_map_.insert({tf_gpu_id.value(), platform_gpu_id.value()}); } - if (!result.second && cuda_gpu_id.value() != result.first->second) { + if (!result.second && platform_gpu_id.value() != result.first->second) { return errors::AlreadyExists( "TensorFlow device (GPU:", tf_gpu_id.value(), ") is being mapped to " "multiple CUDA devices (", - cuda_gpu_id.value(), " now, and ", result.first->second, + platform_gpu_id.value(), " now, and ", result.first->second, " previously), which is not supported. " "This may be the result of providing different GPU configurations " "(ConfigProto.gpu_options, for example different visible_device_list)" @@ -56,17 +57,17 @@ class TfToCudaGpuIdMap { return Status::OK(); } - bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const + bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); auto result = id_map_.find(tf_gpu_id.value()); if (result == id_map_.end()) return false; - *cuda_gpu_id = result->second; + *platform_gpu_id = result->second; return true; } private: - TfToCudaGpuIdMap() = default; + TfToPlatformGpuIdMap() = default; void TestOnlyReset() LOCKS_EXCLUDED(mu_) { mutex_lock lock(mu_); @@ -78,17 +79,18 @@ class TfToCudaGpuIdMap { IdMapType id_map_ GUARDED_BY(mu_); friend class ::tensorflow::GpuIdManager; - TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap); + TF_DISALLOW_COPY_AND_ASSIGN(TfToPlatformGpuIdMap); }; } // namespace -Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, - CudaGpuId cuda_gpu_id) { - return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id); +Status GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id, + PlatformGpuId platform_gpu_id) { + return TfToPlatformGpuIdMap::singleton()->Insert(tf_gpu_id, platform_gpu_id); } -Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) { - if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) { +Status GpuIdManager::TfToPlatformGpuId(TfGpuId tf_gpu_id, + PlatformGpuId* platform_gpu_id) { + if (TfToPlatformGpuIdMap::singleton()->Find(tf_gpu_id, platform_gpu_id)) { return Status::OK(); } return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(), @@ -96,7 +98,7 @@ Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) { } void GpuIdManager::TestOnlyReset() { - TfToCudaGpuIdMap::singleton()->TestOnlyReset(); + TfToPlatformGpuIdMap::singleton()->TestOnlyReset(); } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h index 491d92ccdd..62df4310c4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h @@ -21,15 +21,17 @@ limitations under the License. namespace tensorflow { -// Class that maintains a map from TfGpuId to CudaGpuId, and manages the +// Class that maintains a map from TfGpuId to PlatformGpuId, and manages the // translation between them. class GpuIdManager { public: - // Adds a mapping from tf_gpu_id to cuda_gpu_id. - static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id); + // Adds a mapping from tf_gpu_id to platform_gpu_id. + static Status InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id, + PlatformGpuId platform_gpu_id); - // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found. - static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id); + // Gets the platform_gpu_id associated with tf_gpu_id. Returns OK if found. + static Status TfToPlatformGpuId(TfGpuId tf_gpu_id, + PlatformGpuId* platform_gpu_id); // Clears the map. Used in unit tests only. static void TestOnlyReset(); diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc index a663ec7051..8bf3c6a308 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc @@ -22,38 +22,38 @@ limitations under the License. namespace tensorflow { namespace { -CudaGpuId TfToCudaGpuId(TfGpuId tf) { - CudaGpuId cuda; - TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda)); - return cuda; +PlatformGpuId TfToPlatformGpuId(TfGpuId tf) { + PlatformGpuId platform_gpu_id; + TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf, &platform_gpu_id)); + return platform_gpu_id; } TEST(GpuIdManagerTest, Basics) { TfGpuId key_0(0); - CudaGpuId value_0(0); - TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0)); - EXPECT_EQ(value_0, TfToCudaGpuId(key_0)); + PlatformGpuId value_0(0); + TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0)); + EXPECT_EQ(value_0, TfToPlatformGpuId(key_0)); // Multiple calls to map the same value is ok. - TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0)); - EXPECT_EQ(value_0, TfToCudaGpuId(key_0)); + TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0)); + EXPECT_EQ(value_0, TfToPlatformGpuId(key_0)); // Map a different TfGpuId to a different value. TfGpuId key_1(3); - CudaGpuId value_1(2); - TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1)); - EXPECT_EQ(value_1, TfToCudaGpuId(key_1)); + PlatformGpuId value_1(2); + TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_1, value_1)); + EXPECT_EQ(value_1, TfToPlatformGpuId(key_1)); // Mapping a different TfGpuId to the same value is ok. TfGpuId key_2(10); - TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1)); - EXPECT_EQ(value_1, TfToCudaGpuId(key_2)); + TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_1)); + EXPECT_EQ(value_1, TfToPlatformGpuId(key_2)); // Mapping the same TfGpuId to a different value. - ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok()); + ASSERT_FALSE(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_0).ok()); // Getting a nonexistent mapping. - ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok()); + ASSERT_FALSE(GpuIdManager::TfToPlatformGpuId(TfGpuId(100), &value_0).ok()); } } // namespace diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h index b9c66b3328..b1f10fb1dc 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h +++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h @@ -24,34 +24,37 @@ limitations under the License. namespace tensorflow { -// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids. +// Utility methods for translation between Tensorflow GPU ids and platform GPU +// ids. class GpuIdUtil { public: // Convenient methods for getting the associated executor given a TfGpuId or - // CudaGpuId. - static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId( - se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) { - return gpu_manager->ExecutorForDevice(cuda_gpu_id.value()); + // PlatformGpuId. + static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId( + se::Platform* gpu_manager, PlatformGpuId platform_gpu_id) { + return gpu_manager->ExecutorForDevice(platform_gpu_id.value()); } - static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId( - CudaGpuId cuda_gpu_id) { - return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id); + static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId( + PlatformGpuId platform_gpu_id) { + return ExecutorForPlatformGpuId(GPUMachineManager(), platform_gpu_id); } static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId( TfGpuId tf_gpu_id) { - CudaGpuId cuda_gpu_id; - TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id)); - return ExecutorForCudaGpuId(cuda_gpu_id); + PlatformGpuId platform_gpu_id; + TF_RETURN_IF_ERROR( + GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); + return ExecutorForPlatformGpuId(platform_gpu_id); } - // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate. + // Verify that the platform_gpu_id associated with a TfGpuId is legitimate. static void CheckValidTfGpuId(TfGpuId tf_gpu_id) { - CudaGpuId cuda_gpu_id; - TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id)); + PlatformGpuId platform_gpu_id; + TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); const int visible_device_count = GPUMachineManager()->VisibleDeviceCount(); - CHECK_LT(cuda_gpu_id.value(), visible_device_count) - << "cuda_gpu_id is outside discovered device range." - << " TF GPU id: " << tf_gpu_id << " CUDA GPU id: " << cuda_gpu_id + CHECK_LT(platform_gpu_id.value(), visible_device_count) + << "platform_gpu_id is outside discovered device range." + << " TF GPU id: " << tf_gpu_id + << " platform GPU id: " << platform_gpu_id << " visible device count: " << visible_device_count; } }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc index b18688174d..3e95374fda 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc @@ -76,12 +76,16 @@ GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) { // This function is defined for debugging problems with the allocators. GPUProcessState::~GPUProcessState() { CHECK_EQ(this, instance_); - for (auto p : gpu_allocators_) { - delete p; - } instance_ = nullptr; } +int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) { + // Return the NUMA node associated with the GPU's StreamExecutor. + se::StreamExecutor* se = + GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie(); + return se->GetDeviceDescription().numa_node(); +} + Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, TfGpuId tf_gpu_id, size_t total_bytes) { @@ -93,64 +97,63 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) { gpu_allocators_.resize(tf_gpu_id.value() + 1); - if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) - gpu_al_.resize(tf_gpu_id.value() + 1); } - if (gpu_allocators_[tf_gpu_id.value()] == nullptr) { - VisitableAllocator* gpu_allocator; - + AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()]; + if (allocator_parts.allocator.get() == nullptr) { // Validate allocator types. if (!allocator_type.empty() && allocator_type != "BFC") { LOG(ERROR) << "Invalid allocator type: " << allocator_type; return nullptr; } - CudaGpuId cuda_gpu_id; - TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id)); - gpu_allocator = - new GPUBFCAllocator(cuda_gpu_id, total_bytes, options, + PlatformGpuId platform_gpu_id; + TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); + int bus_id = BusIdForGPU(tf_gpu_id); + while (bus_id >= gpu_visitors_.size()) { + gpu_visitors_.push_back({}); + } + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, + (options.per_process_gpu_memory_fraction() > 1.0 || + options.experimental().use_unified_memory()), + gpu_visitors_[bus_id], {}); + Allocator* gpu_allocator = + new GPUBFCAllocator(sub_allocator, total_bytes, options, strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc")); // If true, checks for memory overwrites by writing // distinctive patterns on both ends of allocated memory. if (useCudaMemoryGuardAllocator()) { - gpu_allocator = new GPUDebugAllocator(gpu_allocator, cuda_gpu_id); - gpu_allocator = new GPUNanResetAllocator(gpu_allocator, cuda_gpu_id); + gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id); + gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id); } else if (useCudaMallocAllocator()) { // If true, passes all allocation requests through to cudaMalloc // useful for doing memory debugging with tools like cuda-memcheck // **WARNING** probably will not work in a multi-gpu scenario - gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, cuda_gpu_id); - } - gpu_allocators_[tf_gpu_id.value()] = gpu_allocator; - - // If there are any pending AllocVisitors for this bus, add - // them now. - se::StreamExecutor* se = - GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie(); - int bus_id = se->GetDeviceDescription().numa_node(); - if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) { - for (const auto& v : gpu_visitors_[bus_id]) { - gpu_allocator->AddAllocVisitor(v); - } + gpu_allocator = + new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id); } + + Allocator* recording_allocator = nullptr; if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { ProcessState::MemDesc md; md.loc = ProcessState::MemDesc::GPU; - md.dev_index = cuda_gpu_id.value(); + md.dev_index = platform_gpu_id.value(); md.gpu_registered = false; md.nic_registered = true; - if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) { - gpu_al_.resize(tf_gpu_id.value() + 1); - } - gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator( + recording_allocator = new internal::RecordingAllocator( &process_state_->mem_desc_map_, gpu_allocator, md, &mu_); } + allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator, + std::unique_ptr<Allocator>(recording_allocator)}; + } + if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { + return allocator_parts.recording_allocator.get(); + } else { + return allocator_parts.allocator.get(); } - if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) - return gpu_al_[tf_gpu_id.value()]; - return gpu_allocators_[tf_gpu_id.value()]; #else LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda."; return nullptr; @@ -172,11 +175,12 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { tf_shared_lock lock(mu_); if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types && - static_cast<int>(cuda_al_.size()) > 0) { - return cuda_al_[0]; + !cuda_host_allocators_.empty() && + cuda_host_allocators_[0].recording_allocator != nullptr) { + return cuda_host_allocators_[0].recording_allocator.get(); } if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) { - return cuda_host_allocators_[0]; + return cuda_host_allocators_[0].allocator.get(); } } @@ -190,7 +194,7 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { // it knows is valid. se::StreamExecutor* se = nullptr; for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) { - if (gpu_allocators_[i] != nullptr) { + if (gpu_allocators_[i].allocator != nullptr) { se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie(); break; } @@ -199,6 +203,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { CHECK_NE(nullptr, se); while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) { + while (cuda_host_alloc_visitors_.size() <= numa_node) { + cuda_host_alloc_visitors_.push_back({}); + } + while (cuda_host_free_visitors_.size() <= numa_node) { + cuda_host_free_visitors_.push_back({}); + } + SubAllocator* sub_allocator = new CUDAHostAllocator( + se, numa_node, cuda_host_alloc_visitors_[numa_node], + cuda_host_free_visitors_[numa_node]); // TODO(zheng-xq): evaluate whether 64GB by default is the best choice. int64 cuda_host_mem_limit_in_mb = -1; Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB", @@ -208,62 +221,92 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message(); } int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20); - VisitableAllocator* allocator = - new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit, + Allocator* allocator = + new BFCAllocator(sub_allocator, cuda_host_mem_limit, true /*allow_growth*/, "cuda_host_bfc" /*name*/); - if (LogMemory::IsEnabled()) { + if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) { // Wrap the allocator to track allocation ids for better logging // at the cost of performance. - allocator = new TrackingVisitableAllocator(allocator, true); + allocator = new TrackingAllocator(allocator, true); } - cuda_host_allocators_.push_back(allocator); + cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator), + sub_allocator, + std::unique_ptr<Allocator>(nullptr)}); + AllocatorParts& allocator_parts = cuda_host_allocators_.back(); if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { ProcessState::MemDesc md; md.loc = ProcessState::MemDesc::CPU; md.dev_index = 0; md.gpu_registered = true; md.nic_registered = false; - cuda_al_.push_back(new internal::RecordingAllocator( - &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md, - &mu_)); + allocator_parts.recording_allocator.reset( + new internal::RecordingAllocator(&process_state_->mem_desc_map_, + allocator_parts.allocator.get(), md, + &mu_)); } } - if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) - return cuda_al_[0]; - return cuda_host_allocators_[0]; + if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { + return cuda_host_allocators_[0].recording_allocator.get(); + } else { + return cuda_host_allocators_[0].allocator.get(); + } } void GPUProcessState::AddGPUAllocVisitor(int bus_id, - const AllocVisitor& visitor) { - CHECK(process_state_); + const SubAllocator::Visitor& visitor) { #if GOOGLE_CUDA mutex_lock lock(mu_); - for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) { - se::StreamExecutor* se = - GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie(); - if (gpu_allocators_[i] && - (se->GetDeviceDescription().numa_node() + 1) == bus_id) { - gpu_allocators_[i]->AddAllocVisitor(visitor); - } - } + CHECK(gpu_allocators_.empty()) // Crash OK + << "AddGPUAllocVisitor must be called before " + "first call to GetGPUAllocator."; while (bus_id >= static_cast<int64>(gpu_visitors_.size())) { - gpu_visitors_.push_back(std::vector<AllocVisitor>()); + gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>()); } gpu_visitors_[bus_id].push_back(visitor); #endif // GOOGLE_CUDA } +void GPUProcessState::AddCUDAHostAllocVisitor( + int numa_node, const SubAllocator::Visitor& visitor) { +#if GOOGLE_CUDA + mutex_lock lock(mu_); + CHECK(cuda_host_allocators_.empty()) // Crash OK + << "AddCUDAHostAllocVisitor must be called before " + "first call to GetCUDAHostAllocator."; + while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) { + cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>()); + } + cuda_host_alloc_visitors_[numa_node].push_back(visitor); +#endif // GOOGLE_CUDA +} + +void GPUProcessState::AddCUDAHostFreeVisitor( + int numa_node, const SubAllocator::Visitor& visitor) { +#if GOOGLE_CUDA + mutex_lock lock(mu_); + CHECK(cuda_host_allocators_.empty()) // Crash OK + << "AddCUDAHostFreeVisitor must be called before " + "first call to GetCUDAHostAllocator."; + while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) { + cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>()); + } + cuda_host_free_visitors_[numa_node].push_back(visitor); +#endif // GOOGLE_CUDA +} + void GPUProcessState::TestOnlyReset() { - process_state_->ProcessState::TestOnlyReset(); + if (process_state_) { + process_state_->ProcessState::TestOnlyReset(); + } { mutex_lock lock(mu_); gpu_device_enabled_ = false; + gpu_allocators_.clear(); gpu_visitors_.clear(); - gtl::STLDeleteElements(&gpu_allocators_); - gtl::STLDeleteElements(&cuda_host_allocators_); - gtl::STLDeleteElements(&gpu_al_); - gtl::STLDeleteElements(&cuda_al_); + cuda_host_allocators_.clear(); + cuda_host_alloc_visitors_.clear(); + cuda_host_free_visitors_.clear(); } } diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h index cb41c3c6bd..43e9a31660 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h @@ -32,7 +32,6 @@ limitations under the License. namespace tensorflow { class Allocator; -class VisitableAllocator; class PoolAllocator; // Singleton that manages per-process state when GPUs are present. @@ -72,18 +71,30 @@ class GPUProcessState { virtual Allocator* GetCUDAHostAllocator(int numa_node); - // Registers a function to be called once on every new Region - // allocated by every GPURegionAllocator proximate to the specified - // bus. The AllocVisitor is provided with a memory pointer and the - // size of the area it identifies. The pointer is not guaranteed to - // be valid after the call terminates. The intention is for this - // interface to be used for network device memory registration. - // "bus_id" is platform-specific. On many platforms it - // should be 0. On machines with multiple PCIe buses, it should be - // the index of one of the PCIe buses. If the bus_id is invalid, - // results are undefined. - typedef std::function<void(void*, size_t)> AllocVisitor; - virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor); + // Registers a Visitor to be invoked on new chunks of memory allocated by the + // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor + // is provided with a memory pointer, a GPU id, and the size of the area it + // identifies. The pointer is not guaranteed to be valid after the call + // terminates. The intention is for this interface to be used for network + // device memory registration. "bus_id" is platform-specific. On many + // platforms it should be 0. On machines with multiple PCIe buses, it should + // be the index of one of the PCIe buses (maybe the NUMA node at which the + // PCIe is rooted). If the bus_id is invalid, results are undefined. + virtual void AddGPUAllocVisitor(int bus_id, + const SubAllocator::Visitor& visitor); + + // Registers a Visitor to be invoked on new chunks of memory allocated by + // the SubAllocator of the CUDAHostAllocator for the given numa_node. + virtual void AddCUDAHostAllocVisitor(int numa_node, + const SubAllocator::Visitor& visitor); + + // Registers a Visitor to be invoked on each chunk handed back for freeing to + // the SubAllocator of the CUDAHostAllocator for the given numa_node. + virtual void AddCUDAHostFreeVisitor(int numa_node, + const SubAllocator::Visitor& visitor); + + // Returns bus_id for the given GPU id. + virtual int BusIdForGPU(TfGpuId tf_gpu_id); protected: GPUProcessState(); @@ -103,16 +114,21 @@ class GPUProcessState { mutex mu_; - std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_); - std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_); - std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_); + struct AllocatorParts { + std::unique_ptr<Allocator> allocator; + SubAllocator* sub_allocator; // owned by allocator + std::unique_ptr<Allocator> recording_allocator; + }; + std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_); + std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_); - virtual ~GPUProcessState(); + std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_); + std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_ + GUARDED_BY(mu_); + std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_ + GUARDED_BY(mu_); - // Optional RecordingAllocators that wrap the corresponding - // Allocators for runtime attribute use analysis. - std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_); - std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_); + virtual ~GPUProcessState(); friend class GPUDeviceTest; }; diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc index 583bff2c07..6b2f6547b0 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc @@ -31,7 +31,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) { 2 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/)); @@ -49,7 +50,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) { 0 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); EXPECT_EQ(0, pool.get_from_pool_count()); @@ -82,7 +84,8 @@ TEST(PoolAllocatorTest, Alignment) { 0 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); for (int i = 0; i < 16; ++i) { size_t alignment = 1 << i; @@ -97,8 +100,8 @@ TEST(PoolAllocatorTest, Alignment) { TEST(PoolAllocatorTest, AutoResize) { PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/, - new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder, - "pool"); + new BasicCPUAllocator(0 /*numa_node*/, {}, {}), + new NoopRounder, "pool"); // Alloc/dealloc 10 sizes just a few times, confirming pool size // stays at 2. @@ -123,14 +126,32 @@ TEST(PoolAllocatorTest, AutoResize) { } TEST(PoolAllocatorTest, CudaHostAllocator) { + int alloc_count = 0; + int64 alloc_size = 0; + SubAllocator::Visitor alloc_visitor = + [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) { + ++alloc_count; + alloc_size += size; + }; + int free_count = 0; + int64 free_size = 0; + SubAllocator::Visitor free_visitor = + [&free_count, &free_size](void* ptr, int numa_node, int64 size) { + ++free_count; + free_size += size; + }; se::Platform* platform = se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); - PoolAllocator pool( - 2 /*pool_size_limit*/, false /*auto_resize*/, - new CUDAHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), - new NoopRounder, "pool"); + CUDAHostAllocator* sub_allocator = new CUDAHostAllocator( + platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie(), + 0 /*numa_node*/, {alloc_visitor}, {free_visitor}); + PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/, + sub_allocator, new NoopRounder, "pool"); + EXPECT_EQ(0, alloc_count); + EXPECT_EQ(0, alloc_size); + EXPECT_EQ(0, free_count); + EXPECT_EQ(0, free_size); // Repeatedly Get a 16-byte value, confirming that there's only // one real allocation. @@ -138,6 +159,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { EXPECT_EQ(0, pool.get_from_pool_count()); EXPECT_EQ(1, pool.allocated_count()); EXPECT_NE(nullptr, p1_16); + EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes + // Each suballocation includes a 16B ChunkPrefix. + static const int kChunkPrefixSize = 16; + EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size); pool.DeallocateRaw(p1_16); // Pool contents {16} EXPECT_EQ(1, pool.put_count()); @@ -148,6 +173,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { pool.DeallocateRaw(p2_16); // Put it back. // Pool contents {16} EXPECT_EQ(2, pool.put_count()); + EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes + EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(0, free_count); // Get two more values of different sizes. void* p3_4 = pool.AllocateRaw(4, 4); @@ -160,6 +188,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer. EXPECT_NE(nullptr, p4_2); EXPECT_EQ(0, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(0, free_count); // The pool is full: when we put back p4_2, the 16-byte buffer // should be evicted since it was least recently inserted. @@ -167,6 +198,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { // Pool contents {2, 4} EXPECT_EQ(4, pool.put_count()); EXPECT_EQ(1, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(1, free_count); + EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size); // Re-getting and putting size 2 or 4 should not alter pool size or // num-evicted. @@ -180,12 +215,20 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { EXPECT_EQ(6, pool.put_count()); EXPECT_EQ(3, pool.allocated_count()); EXPECT_EQ(1, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(1, free_count); + EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size); pool.Clear(); EXPECT_EQ(0, pool.get_from_pool_count()); EXPECT_EQ(0, pool.put_count()); EXPECT_EQ(0, pool.allocated_count()); EXPECT_EQ(0, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(3, free_count); + EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size); } TEST(PoolAllocatorTest, Pow2Rounder) { @@ -206,7 +249,8 @@ TEST(PoolAllocatorTest, Name) { 2 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); EXPECT_EQ("pool", pool.Name()); } diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 7f260b3139..4475fa979e 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -561,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph( grappler::GrapplerItem item; item.id = "tf_graph"; graph_->ToGraphDef(&item.graph); + // TODO(b/114748242): Add a unit test to test this bug fix. + if (flib_def_) { + *item.graph.mutable_library() = flib_def_->ToProto(); + } item.fetch.insert(item.fetch.end(), options.callable_options.fetch().begin(), diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 96ecfb41d4..37a979a8f1 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -38,7 +38,8 @@ void GraphOptimizer::Optimize( std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, - const std::function<bool(const Node*)>& cse_consider_fn) { + const std::function<bool(const Node*)>& cse_consider_fn, + const std::function<bool(const Node*)>& cf_consider_fn) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -62,6 +63,7 @@ void GraphOptimizer::Optimize( if (opts_.do_constant_folding()) { ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; + cf_opts.consider = cf_consider_fn; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 80246281cd..789cc56942 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -45,12 +45,15 @@ class GraphOptimizer { // // If cse_consider_fn is not null then only nodes for which cse_consider_fn // returns true will be considered for CSE. + // If cf_consider_fn is not null then only nodes for which cf_consider_fn + // returns true will be considered for CF. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, - const std::function<bool(const Node*)>& cse_consider_fn = nullptr); + const std::function<bool(const Node*)>& cse_consider_fn = nullptr, + const std::function<bool(const Node*)>& cf_consider_fn = nullptr); const OptimizerOptions& options() { return opts_; } diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0a1797fa19..f9aef3af70 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous { } mutex_lock l(mu_); - string edge_name = std::string(parsed.edge_name); + string edge_name(parsed.edge_name); if (table_.count(edge_name) > 0) { return errors::Internal("Send of an already sent tensor"); } @@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous { Tensor tensor; Status status = Status::OK(); { - string key = std::string(parsed.edge_name); + string key(parsed.edge_name); mutex_lock l(mu_); if (table_.count(key) <= 0) { status = errors::Internal("Did not find key ", key); diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc index db5022d56e..873182371e 100644 --- a/tensorflow/core/common_runtime/local_device.cc +++ b/tensorflow/core/common_runtime/local_device.cc @@ -62,7 +62,7 @@ struct LocalDevice::EigenThreadPoolInfo { LocalDevice::LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes) - : TracingDevice(options.env, attributes), owned_tp_info_(nullptr) { + : Device(options.env, attributes), owned_tp_info_(nullptr) { // Log info messages if TensorFlow is not compiled with instructions that // could speed up performance and are available on the current CPU. port::InfoAboutUnusedCPUFeatures(); diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h index 9a82fb7204..226f121bf3 100644 --- a/tensorflow/core/common_runtime/local_device.h +++ b/tensorflow/core/common_runtime/local_device.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_ #include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/common_runtime/tracing_device.h" #include "tensorflow/core/framework/device_attributes.pb.h" #include "tensorflow/core/platform/macros.h" @@ -32,7 +31,7 @@ struct SessionOptions; // initializes a shared Eigen compute device used by both. This // should eventually be removed once we refactor ThreadPoolDevice and // GPUDevice into more 'process-wide' abstractions. -class LocalDevice : public TracingDevice { +class LocalDevice : public Device { public: LocalDevice(const SessionOptions& options, const DeviceAttributes& attributes); diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 6b76e7e0e7..538a70668a 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -23,10 +23,11 @@ limitations under the License. #include <cstdlib> #include "tensorflow/core/common_runtime/bfc_allocator.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/numa.h" #ifndef INTEL_MKL_DNN_ONLY #include "i_malloc.h" @@ -38,19 +39,113 @@ typedef unsigned int uint; namespace tensorflow { -class MklSubAllocator : public SubAllocator { +class MklSubAllocator : public BasicCPUAllocator { public: + MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {} ~MklSubAllocator() override {} +}; + +// CPU allocator that handles small-size allocations by calling +// suballocator directly. Mostly, it is just a wrapper around a suballocator +// (that calls malloc and free directly) with support for bookkeeping. +class MklSmallSizeAllocator : public Allocator { + public: + MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, + const string& name) + : sub_allocator_(sub_allocator), name_(name) { + stats_.bytes_limit = total_memory; + } + ~MklSmallSizeAllocator() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator); + + inline string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* ptr = sub_allocator_->Alloc(alignment, num_bytes); + if (ptr != nullptr) { + std::pair<void*, size_t> map_val(ptr, num_bytes); + mutex_lock l(mutex_); + // Check that insertion in the hash map was successful. + CHECK(map_.insert(map_val).second); + // Increment statistics for small-size allocations. + IncrementStats(num_bytes); + } + return ptr; + } + + void DeallocateRaw(void* ptr) override { + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + + mutex_lock l(mutex_); + auto map_iter = map_.find(ptr); + if (map_iter != map_.end()) { + // Call free visitors. + size_t dealloc_bytes = map_iter->second; + sub_allocator_->Free(ptr, dealloc_bytes); + DecrementStats(dealloc_bytes); + map_.erase(map_iter); + } else { + LOG(ERROR) << "tried to deallocate invalid pointer"; + return; + } + } + + inline bool IsSmallSizeAllocation(const void* ptr) const { + mutex_lock l(mutex_); + return map_.find(ptr) != map_.end(); + } + + void GetStats(AllocatorStats* stats) override { + mutex_lock l(mutex_); + *stats = stats_; + } - void* Alloc(size_t alignment, size_t num_bytes) override { - return port::AlignedMalloc(num_bytes, alignment); + void ClearStats() override { + mutex_lock l(mutex_); + stats_.Clear(); } - void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); } + + private: + // Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + ++stats_.num_allocs; + stats_.bytes_in_use += alloc_size; + stats_.max_bytes_in_use = + std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + stats_.max_alloc_size = + std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size)); + } + + // Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + stats_.bytes_in_use -= dealloc_size; + } + + SubAllocator* sub_allocator_; // Not owned by this class. + + // Mutex for protecting updates to map of allocations. + mutable mutex mutex_; + + // Allocator name + string name_; + + // Hash map to keep track of "small" allocations + // We do not use BFC allocator for small allocations. + std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_); + + // Allocator stats for small allocs + AllocatorStats stats_ GUARDED_BY(mutex_); }; /// CPU allocator for MKL that wraps BFC allocator and intercepts /// and redirects memory allocation calls from MKL. -class MklCPUAllocator : public VisitableAllocator { +class MklCPUAllocator : public Allocator { public: // Constructor and other standard functions @@ -62,7 +157,10 @@ class MklCPUAllocator : public VisitableAllocator { MklCPUAllocator() { TF_CHECK_OK(Initialize()); } - ~MklCPUAllocator() override { delete allocator_; } + ~MklCPUAllocator() override { + delete small_size_allocator_; + delete large_size_allocator_; + } Status Initialize() { VLOG(2) << "MklCPUAllocator: In MklCPUAllocator"; @@ -96,8 +194,15 @@ class MklCPUAllocator : public VisitableAllocator { } VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; - allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes, - kAllowGrowth, kName); + + sub_allocator_ = new MklSubAllocator(); + + // SubAllocator is owned by BFCAllocator, so we do not need to deallocate + // it in MklSmallSizeAllocator. + small_size_allocator_ = + new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); + large_size_allocator_ = + new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); #ifndef INTEL_MKL_DNN_ONLY // For redirecting all allocations from MKL to this allocator // From: http://software.intel.com/en-us/node/528565 @@ -112,23 +217,45 @@ class MklCPUAllocator : public VisitableAllocator { inline string Name() override { return kName; } inline void* AllocateRaw(size_t alignment, size_t num_bytes) override { - return allocator_->AllocateRaw(alignment, num_bytes); + // If the allocation size is less than threshold, call small allocator, + // otherwise call large-size allocator (BFC). We found that BFC allocator + // does not deliver good performance for small allocations when + // inter_op_parallelism_threads is high. + return (num_bytes < kSmallAllocationsThreshold) + ? small_size_allocator_->AllocateRaw(alignment, num_bytes) + : large_size_allocator_->AllocateRaw(alignment, num_bytes); } inline void DeallocateRaw(void* ptr) override { - allocator_->DeallocateRaw(ptr); + // Check if ptr is for "small" allocation. If it is, then call Free + // directly. Otherwise, call BFC to handle free. + if (small_size_allocator_->IsSmallSizeAllocation(ptr)) { + small_size_allocator_->DeallocateRaw(ptr); + } else { + large_size_allocator_->DeallocateRaw(ptr); + } } - void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); } - - void ClearStats() override { allocator_->ClearStats(); } - - void AddAllocVisitor(Visitor visitor) override { - allocator_->AddAllocVisitor(visitor); + void GetStats(AllocatorStats* stats) override { + AllocatorStats l_stats, s_stats; + small_size_allocator_->GetStats(&s_stats); + large_size_allocator_->GetStats(&l_stats); + + // Combine statistics from small-size and large-size allocator. + stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs; + stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; + stats->max_bytes_in_use = + l_stats.max_bytes_in_use + s_stats.max_bytes_in_use; + + // Since small-size allocations go to MklSmallSizeAllocator, + // max_alloc_size from large_size_allocator would be the maximum + // size allocated by MklCPUAllocator. + stats->max_alloc_size = l_stats.max_alloc_size; } - void AddFreeVisitor(Visitor visitor) override { - allocator_->AddFreeVisitor(visitor); + void ClearStats() override { + small_size_allocator_->ClearStats(); + large_size_allocator_->ClearStats(); } private: @@ -148,26 +275,33 @@ class MklCPUAllocator : public VisitableAllocator { Status s = Status(error::Code::UNIMPLEMENTED, "Unimplemented case for hooking MKL function."); TF_CHECK_OK(s); // way to assert with an error message - return nullptr; // return a value and make static code analyzers happy + return nullptr; // return a value and make static code analyzers happy } static inline void* ReallocHook(void* ptr, size_t size) { Status s = Status(error::Code::UNIMPLEMENTED, "Unimplemented case for hooking MKL function."); TF_CHECK_OK(s); // way to assert with an error message - return nullptr; // return a value and make static code analyzers happy + return nullptr; // return a value and make static code analyzers happy } - /// Do we allow growth in BFC Allocator + // Do we allow growth in BFC Allocator static const bool kAllowGrowth = true; - /// Name + // Name static constexpr const char* kName = "mklcpu"; - /// The alignment that we need for the allocations + // The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; - VisitableAllocator* allocator_; // owned by this class + Allocator* large_size_allocator_; // owned by this class + MklSmallSizeAllocator* small_size_allocator_; // owned by this class. + + SubAllocator* sub_allocator_; // not owned by this class + + // Size in bytes that defines the upper-bound for "small" allocations. + // Any allocation below this threshold is "small" allocation. + static constexpr const size_t kSmallAllocationsThreshold = 4096; // Prevent copying and assignment TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator); diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc index a67411cd2e..e08ab57638 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/common_runtime/mkl_cpu_allocator.h" @@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) { } // namespace tensorflow -#endif // INTEL_MKL +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc index f9f36443a8..6af4ca4d96 100644 --- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc +++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc @@ -50,8 +50,8 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { } for (Node* n : matches) { AttrSlice n_attrs = n->attrs(); - auto base_make_node = [n, g, &n_attrs](const string& op, - const string& name) { + auto base_make_node = [n, &n_attrs](const string& op, + const string& name) { NodeBuilder node_builder(name, op); node_builder.Device(n->requested_device()); string colo; @@ -60,7 +60,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass { } return node_builder; }; - auto make_node = [n, g, &n_attrs, &base_make_node](string op) { + auto make_node = [n, g, &base_make_node](string op) { return base_make_node( op, g->NewName(strings::StrCat(n->name(), "/Internal"))); }; diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index 7f3c25d81d..3b59995433 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -254,9 +254,11 @@ class ColocationGraph { old_root_member.device_name, allow_soft_placement_); if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); + return errors::InvalidArgument( + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), ": ", + s.error_message()); } // Ensure that the common root has at least one supported device @@ -267,8 +269,10 @@ class ColocationGraph { old_root_member.supported_device_types); if (new_root_member.supported_device_types.empty()) { return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), + " because no device type supports both of those nodes and the " "other nodes colocated with them.", DebugInfo(x_root), DebugInfo(y_root)); } @@ -376,8 +380,9 @@ class ColocationGraph { // merged set device is different, so print both. return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->requested_device(), - "' because the node was colocated with a group of nodes that " + node->requested_device(), "' because the node ", + errors::FormatColocationNodeForError(node->name()), + " was colocated with a group of nodes that ", "required incompatible device '", DeviceNameUtils::ParsedNameToString( members_[node_root].device_name), @@ -809,10 +814,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } // Returns the first device in sorted devices list so we will always @@ -856,10 +861,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } int assigned_device = -1; @@ -925,21 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const { } } -bool Placer::ClientHandlesErrorFormatting() const { - return options_ != nullptr && - options_->config.experimental().client_handles_error_formatting(); -} - -// Returns the node name in single quotes. If the client handles formatted -// errors, appends a formatting tag which the client will reformat into, for -// example, " (defined at filename:123)". -// TODO(shikharagarwal): Remove this function once -// client_handles_error_formatting flag is removed. -string Placer::RichNodeName(const Node* node) const { - if (ClientHandlesErrorFormatting()) { - return errors::FormatNodeNameForError(node->name()); - } - return strings::StrCat("'", node->name(), "'"); -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index cefcdd25db..f97ffe7372 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -87,8 +87,6 @@ class Placer { // placement if the SessionOptions entry in 'options_' requests it. void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; - bool ClientHandlesErrorFormatting() const; - string RichNodeName(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 83d27e2730..9b8a95e3b6 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { } Status s = Place(&g); - EXPECT_TRUE( - str_util::StrContains(s.error_message(), - "Cannot colocate nodes 'foo' and 'in' because no " - "device type supports both of those nodes and the " - "other nodes colocated with them")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Cannot colocate nodes {{colocation_node foo}} and " + "{{colocation_node in}} because no device type supports both of those " + "nodes and the other nodes colocated with them")); } TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) { @@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { Status s = Place(&g); EXPECT_TRUE(str_util::StrContains( s.error_message(), - "Cannot colocate nodes 'var3' and 'assign3' because no " - "device type supports both of those nodes and the other " - "nodes colocated with them.")); + "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node " + "assign3}} because no device type supports both of those nodes and the " + "other nodes colocated with them.")); } TEST_F(PlacerTest, TestColocationAndReferenceConnections) { @@ -1154,35 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) { } SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - true); Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); LOG(WARNING) << s.error_message(); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation {{node in}}")); -} - -// Test that the "Cannot assign a device" error message does not contain a -// format tag when not it shouldn't -TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestDevice", - b.opts().WithName("in").WithDevice("/device:fakegpu:11")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - false); - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation 'in'")); - EXPECT_FALSE(str_util::StrContains( - s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot assign a device for operation in")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}")); } // Test that placement fails when a node requests an explicit device that is not @@ -1288,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot colocate nodes 'var' and 'assign'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot colocate nodes {{colocation_node " + "var}} and {{colocation_node assign}}")); } // Test that a generator node follows its consumers (where there are several diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc index 10a24ed14c..66dc8f3322 100644 --- a/tensorflow/core/common_runtime/pool_allocator.cc +++ b/tensorflow/core/common_runtime/pool_allocator.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -39,8 +40,7 @@ PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize, auto_resize_(auto_resize), pool_size_limit_(pool_size_limit), allocator_(allocator), - size_rounder_(size_rounder), - allocation_begun_(false) { + size_rounder_(size_rounder) { if (auto_resize) { CHECK_LT(size_t{0}, pool_size_limit) << "size limit must be > 0 if auto_resize is true."; @@ -92,7 +92,6 @@ ChunkPrefix* FindPrefix(void* user_ptr) { } // namespace void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { - if (!allocation_begun_) allocation_begun_ = true; if (num_bytes == 0) return nullptr; // If alignment is larger than kPoolAlignment, increase num_bytes so that we @@ -128,9 +127,6 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { return PrepareChunk(r, alignment, num_bytes); } else { void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes); - for (const auto& v : alloc_visitors_) { - v(ptr, num_bytes); - } return PrepareChunk(ptr, alignment, num_bytes); } } @@ -140,9 +136,6 @@ void PoolAllocator::DeallocateRaw(void* ptr) { ChunkPrefix* cp = FindPrefix(ptr); CHECK_LE((void*)cp, (void*)ptr); if (!has_size_limit_ && !auto_resize_) { - for (const auto& v : free_visitors_) { - v(cp, cp->num_bytes); - } allocator_->Free(cp, cp->num_bytes); } else { mutex_lock lock(mutex_); @@ -163,9 +156,6 @@ void PoolAllocator::Clear() { mutex_lock lock(mutex_); for (auto iter : pool_) { PtrRecord* pr = iter.second; - for (const auto& v : free_visitors_) { - v(pr->ptr, pr->num_bytes); - } allocator_->Free(pr->ptr, pr->num_bytes); delete pr; } @@ -220,9 +210,6 @@ void PoolAllocator::EvictOne() { DCHECK(iter != pool_.end()); } pool_.erase(iter); - for (const auto& v : free_visitors_) { - v(prec->ptr, prec->num_bytes); - } allocator_->Free(prec->ptr, prec->num_bytes); delete prec; ++evicted_count_; @@ -268,28 +255,19 @@ void PoolAllocator::EvictOne() { } } -void PoolAllocator::AddAllocVisitor(Visitor visitor) { - mutex_lock lock(mutex_); - CHECK(!allocation_begun_) - << "AddAllocVisitor may not be called after pool allocation " - << "has begun."; - alloc_visitors_.push_back(visitor); -} - -void PoolAllocator::AddFreeVisitor(Visitor visitor) { - mutex_lock lock(mutex_); - CHECK(!allocation_begun_) - << "AddFreeVisitor may not be called after pool allocation " - << "has begun."; - free_visitors_.push_back(visitor); -} - void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) { - return port::AlignedMalloc(num_bytes, static_cast<int>(alignment)); + void* ptr = nullptr; + if (num_bytes > 0) { + ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment)); + VisitAlloc(ptr, numa_node_, num_bytes); + } + return ptr; } void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) { - port::AlignedFree(ptr); + if (num_bytes > 0) { + VisitFree(ptr, numa_node_, num_bytes); + port::AlignedFree(ptr); + } } - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h index 607734445b..5b4623ba10 100644 --- a/tensorflow/core/common_runtime/pool_allocator.h +++ b/tensorflow/core/common_runtime/pool_allocator.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ -// Simple LRU pool allocators for various flavors of CPU RAM that -// implement the VisitableAllocator interface. +// Simple LRU pool allocators for various flavors of CPU RAM. #include <atomic> #include <map> #include <memory> #include <vector> -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -41,7 +40,7 @@ class RoundUpInterface { // Size-limited pool of memory buffers obtained from a SubAllocator // instance. Pool eviction policy is LRU. -class PoolAllocator : public VisitableAllocator { +class PoolAllocator : public Allocator { public: // "pool_size_limit" is the maximum number of returned, re-usable // memory buffers to keep in the pool. If pool_size_limit == 0, the @@ -64,14 +63,6 @@ class PoolAllocator : public VisitableAllocator { void DeallocateRaw(void* ptr) override; - // REQUIRES: The following functions may only be called prior - // to the first Allocate*() call. Once allocation has begun, it is - // illegal to register another visitor. - - void AddAllocVisitor(Visitor visitor) override; - - void AddFreeVisitor(Visitor visitor) override; - // Allocate an unused memory region of size "num_bytes". Fetch from // the pool if available, otherwise call allocator_. void* Get(size_t num_bytes); @@ -141,12 +132,6 @@ class PoolAllocator : public VisitableAllocator { int64 put_count_ GUARDED_BY(mutex_) = 0; int64 allocated_count_ GUARDED_BY(mutex_) = 0; int64 evicted_count_ GUARDED_BY(mutex_) = 0; - // Write access to these is guarded by mutex_, but not read - // access. They may only be modified prior to the first - // allocation. Later attempts to modify will fail. - std::vector<Visitor> alloc_visitors_; - std::vector<Visitor> free_visitors_; - std::atomic<bool> allocation_begun_; }; // Do-nothing rounder. Passes through sizes unchanged. @@ -166,7 +151,9 @@ class Pow2Rounder : public RoundUpInterface { class BasicCPUAllocator : public SubAllocator { public: // Argument numa_node is currently ignored. - explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {} + BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {} ~BasicCPUAllocator() override {} @@ -176,6 +163,8 @@ class BasicCPUAllocator : public SubAllocator { private: int numa_node_; + + TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc index 447338e7bd..bcaa37fc8a 100644 --- a/tensorflow/core/common_runtime/process_state.cc +++ b/tensorflow/core/common_runtime/process_state.cc @@ -71,20 +71,28 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { return MemDesc(); } -VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) { +Allocator* ProcessState::GetCPUAllocator(int numa_node) { CHECK_GE(numa_node, 0); if (!numa_enabled_) numa_node = 0; mutex_lock lock(mu_); while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) { + // If visitors have been defined we need an Allocator built from + // a SubAllocator. Prefer BFCAllocator, but fall back to PoolAllocator + // depending on env var setting. + const bool alloc_visitors_defined = + (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty()); bool use_bfc_allocator = false; - // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and - // efficient. - Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false, - &use_bfc_allocator); + Status status = ReadBoolFromEnvVar( + "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator); if (!status.ok()) { LOG(ERROR) << "GetCPUAllocator: " << status.error_message(); } - VisitableAllocator* allocator; + Allocator* allocator = nullptr; + SubAllocator* sub_allocator = + (alloc_visitors_defined || use_bfc_allocator) + ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1, + cpu_alloc_visitors_, cpu_free_visitors_) + : nullptr; if (use_bfc_allocator) { // TODO(reedwm): evaluate whether 64GB by default is the best choice. int64 cpu_mem_limit_in_mb = -1; @@ -95,34 +103,63 @@ VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) { LOG(ERROR) << "GetCPUAllocator: " << status.error_message(); } int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20); - allocator = new BFCAllocator( - new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit, - true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/); + DCHECK(sub_allocator); + allocator = + new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/, + "bfc_cpu_allocator_for_gpu" /*name*/); VLOG(2) << "Using BFCAllocator with memory limit of " << cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator"; - } else { - allocator = new PoolAllocator( - 100 /*pool_size_limit*/, true /*auto_resize*/, - new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), - new NoopRounder, "cpu_pool"); + } else if (alloc_visitors_defined) { + DCHECK(sub_allocator); + allocator = + new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/, + sub_allocator, new NoopRounder, "cpu_pool"); VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator " << "numa_enabled_=" << numa_enabled_ << " numa_node=" << numa_node; + } else { + DCHECK(!sub_allocator); + allocator = cpu_allocator(); } - if (LogMemory::IsEnabled()) { + if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) { // Wrap the allocator to track allocation ids for better logging // at the cost of performance. - allocator = new TrackingVisitableAllocator(allocator, true); + allocator = new TrackingAllocator(allocator, true); } cpu_allocators_.push_back(allocator); + if (!sub_allocator) { + DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty()); + } } return cpu_allocators_[numa_node]; } +void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) { + VLOG(1) << "AddCPUAllocVisitor"; + mutex_lock lock(mu_); + CHECK_EQ(0, cpu_allocators_.size()) // Crash OK + << "AddCPUAllocVisitor must be called prior to first call to " + "ProcessState::GetCPUAllocator"; + cpu_alloc_visitors_.push_back(std::move(visitor)); +} + +void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) { + mutex_lock lock(mu_); + CHECK_EQ(0, cpu_allocators_.size()) // Crash OK + << "AddCPUFreeVisitor must be called prior to first call to " + "ProcessState::GetCPUAllocator"; + cpu_free_visitors_.push_back(std::move(visitor)); +} + void ProcessState::TestOnlyReset() { mutex_lock lock(mu_); + // Don't delete this value because it's static. + Allocator* default_cpu_allocator = cpu_allocator(); mem_desc_map_.clear(); - gtl::STLDeleteElements(&cpu_allocators_); + for (Allocator* a : cpu_allocators_) { + if (a != default_cpu_allocator) delete a; + } + cpu_allocators_.clear(); gtl::STLDeleteElements(&cpu_al_); } diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h index 2892677333..cac312d849 100644 --- a/tensorflow/core/common_runtime/process_state.h +++ b/tensorflow/core/common_runtime/process_state.h @@ -30,7 +30,6 @@ limitations under the License. namespace tensorflow { class Allocator; -class VisitableAllocator; class PoolAllocator; // Singleton that manages per-process state, e.g. allocation of @@ -65,7 +64,15 @@ class ProcessState { // Returns the one CPUAllocator used for the given numa_node. // TEMPORARY: ignores numa_node. - VisitableAllocator* GetCPUAllocator(int numa_node); + Allocator* GetCPUAllocator(int numa_node); + + // Registers alloc visitor for the CPU allocator(s). + // REQUIRES: must be called before GetCPUAllocator. + void AddCPUAllocVisitor(SubAllocator::Visitor v); + + // Registers free visitor for the CPU allocator(s). + // REQUIRES: must be called before GetCPUAllocator. + void AddCPUFreeVisitor(SubAllocator::Visitor v); typedef std::unordered_map<const void*, MemDesc> MDMap; @@ -87,7 +94,9 @@ class ProcessState { mutex mu_; - std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_); + std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_); + std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_); + std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_); virtual ~ProcessState(); diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 103eee03b3..c00789a556 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -58,6 +58,15 @@ class RenamedDevice : public Device { return underlying_->GetAllocator(attr); } + Allocator* GetScopedAllocator(AllocatorAttributes attr, + int64 step_id) override { + return underlying_->GetScopedAllocator(attr, step_id); + } + + ScopedAllocatorMgr* GetScopedAllocatorMgr() const override { + return underlying_->GetScopedAllocatorMgr(); + } + const Eigen::ThreadPoolDevice* eigen_cpu_device() override { return underlying_->eigen_cpu_device(); } @@ -72,9 +81,10 @@ class RenamedDevice : public Device { return underlying_->MakeGpuDevice(); } - void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, - DeviceContext* dc, Allocator* allocator) override { - underlying_->ReinitializeGpuDevice(context, device, dc, allocator); + Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) override { + return underlying_->ReinitializeGpuDevice(context, device, dc, allocator); } Status MakeTensorFromProto(const TensorProto& tensor_proto, diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 1e3fed0d6f..43ca3f1e3e 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/rendezvous_util.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/reffed_status_callback.h" diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc deleted file mode 100644 index b931ef4229..0000000000 --- a/tensorflow/core/common_runtime/session_ref.cc +++ /dev/null @@ -1,170 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/core/common_runtime/session_ref.h" - -#include <utility> - -namespace tensorflow { - -namespace { - -// Scope helper to track active calls and manage session lifetime. -struct RunCounter { - std::shared_ptr<Session> session; - uint64* value; - mutex* m; - condition_variable* cv; - - explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m, - condition_variable* cv) - : session(std::move(s)), value(v), m(m), cv(cv) { - mutex_lock l(*m); - ++*value; - } - - ~RunCounter() { - mutex_lock l(*m); - if (--*value == 0) { - cv->notify_all(); - } - } -}; - -} // namespace - -Status SessionRef::CheckNotClosed() { - mutex_lock l(run_lock_); - if (session_ == nullptr) return errors::Cancelled("Session has been closed."); - return ::tensorflow::Status::OK(); -} - -Status SessionRef::Run(const RunOptions& run_options, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, - RunMetadata* run_metadata) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Run(run_options, inputs, output_tensor_names, - target_node_names, outputs, run_metadata); -} - -Status SessionRef::Create(const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Create(graph); -} - -Status SessionRef::Create(const RunOptions& run_options, - const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Create(run_options, graph); -} - -Status SessionRef::Extend(const RunOptions& run_options, - const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Extend(run_options, graph); -} - -Status SessionRef::Extend(const GraphDef& graph) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Extend(graph); -} - -Status SessionRef::Close(const RunOptions& run_options) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(run_lock_); - Status status = session_->Close(run_options); - session_.reset(); - while (run_count_ > 0) { - run_finished_.wait(l); - } - return status; -} - -Status SessionRef::Close() { - TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(run_lock_); - Status status = session_->Close(); - session_.reset(); - while (run_count_ > 0) { - run_finished_.wait(l); - } - return status; -} - -Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->Run(inputs, output_tensor_names, target_node_names, - outputs); -} - -Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->ListDevices(response); -} - -Status SessionRef::PRunSetup(const std::vector<string>& input_names, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - string* handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->PRunSetup(input_names, output_names, target_nodes, handle); -} - -Status SessionRef::PRun(const string& handle, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - std::vector<Tensor>* outputs) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->PRun(handle, inputs, output_names, outputs); -} - -Status SessionRef::MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->MakeCallable(callable_options, out_handle); -} - -Status SessionRef::RunCallable(CallableHandle handle, - const std::vector<Tensor>& feed_tensors, - std::vector<Tensor>* fetch_tensors, - RunMetadata* run_metadata) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->RunCallable(handle, feed_tensors, fetch_tensors, - run_metadata); -} - -Status SessionRef::ReleaseCallable(CallableHandle handle) { - TF_RETURN_IF_ERROR(CheckNotClosed()); - RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_); - return rc.session->ReleaseCallable(handle); -} - -} // namespace tensorflow diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/core/common_runtime/session_ref.h deleted file mode 100644 index 9459e7edbe..0000000000 --- a/tensorflow/core/common_runtime/session_ref.h +++ /dev/null @@ -1,86 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ - -#include <memory> - -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/public/session.h" - -namespace tensorflow { - -// A `SessionRef` manages the lifetime of a wrapped `Session` pointer. -// -// SessionRef blocks the return of Close() until all pending operations have -// been completed or cancelled and underlying session has been freed. Any -// subsequent operations on the SessionRef object will return errors::Cancelled. -class SessionRef : public Session { - public: - SessionRef(Session* session) : session_(session) {} - virtual ~SessionRef() {} - - Status Create(const GraphDef& graph) override; - Status Extend(const GraphDef& graph) override; - Status Create(const RunOptions& run_options, const GraphDef& graph) override; - Status Extend(const RunOptions& run_options, const GraphDef& graph) override; - Status Run(const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs) override; - - Status ListDevices(std::vector<DeviceAttributes>* response) override; - - Status Close() override; - Status Close(const RunOptions& run_options) override; - - Status Run(const RunOptions& run_options, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_tensor_names, - const std::vector<string>& target_node_names, - std::vector<Tensor>* outputs, RunMetadata* run_metadata) override; - - Status PRunSetup(const std::vector<string>& input_names, - const std::vector<string>& output_names, - const std::vector<string>& target_nodes, - string* handle) override; - - Status PRun(const string& handle, - const std::vector<std::pair<string, Tensor> >& inputs, - const std::vector<string>& output_names, - std::vector<Tensor>* outputs) override; - - Status MakeCallable(const CallableOptions& callable_options, - CallableHandle* out_handle) override; - - Status RunCallable(CallableHandle handle, - const std::vector<Tensor>& feed_tensors, - std::vector<Tensor>* fetch_tensors, - RunMetadata* run_metadata) override; - - Status ReleaseCallable(CallableHandle handle) override; - - private: - mutex run_lock_; - condition_variable run_finished_; - uint64 run_count_ GUARDED_BY(run_lock_) = {0}; - std::shared_ptr<Session> session_; - - Status CheckNotClosed(); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_ diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 65ff356e73..5b1915755d 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names, // Save only the tensors in output_names in the session. for (const string& name : output_names) { TensorId id(ParseTensorName(name)); - const string& op_name = std::string(id.first); + const string op_name(id.first); auto it = tensors_.find(op_name); if (it != tensors_.end()) { // Save the tensor to the session state. diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h index 04d5af9087..22650b0d83 100644 --- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h +++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 9c2510e6a9..a70ab93d4a 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -40,46 +41,24 @@ struct AllocStats { }; } // namespace -NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name) - : NodeExecStatsWrapper(new NodeExecStats) { - stats_->set_node_name(node_name); -} -NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats) - : stats_(stats) {} - -void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) { - DCHECK(v); - NodeOutput* no = stats_->add_output(); - no->set_slot(slot); - v->FillDescription(no->mutable_tensor_description()); -} - -void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { - for (const auto& allocator_pair : ctx->wrapped_allocators()) { - AddAllocation(allocator_pair.first, allocator_pair.second); - } - auto* ms = stats_->mutable_memory_stats(); - ms->set_temp_memory_size(ctx->temp_memory_allocated()); - for (const auto& alloc_id : ctx->persistent_alloc_ids()) { - ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); - } - ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); +NodeExecStatsWrapper::NodeExecStatsWrapper( + const Node* node, StepStatsCollector* step_stats_collector) + : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node, + step_stats_collector) { + stats_->set_node_name(node->name()); } -void NodeExecStatsWrapper::SetReferencedTensors( - const TensorReferenceVector& tensors) { - // be careful not to increment the reference count on any tensor - // while recording the information - for (size_t i = 0; i < tensors.size(); ++i) { - AllocationDescription* description = stats_->add_referenced_tensor(); - tensors.at(i).FillDescription(description); - } -} - -// TODO(tucker): merge with the DetailText function in session.cc -// in a common location. -bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) { - bool is_transfer_node = false; +NodeExecStatsWrapper::NodeExecStatsWrapper( + std::unique_ptr<NodeExecStats> stats, const Node* node, + StepStatsCollector* step_stats_collector) + : stats_(std::move(stats)), + node_(node), + step_stats_collector_(step_stats_collector) {} + +void NodeExecStatsWrapper::Done(const string& device) { + // TODO(tucker): merge with the DetailText function in session.cc in a common + // location. + DCHECK(node_); string memory; for (auto& all : stats_->memory()) { int64 tot = all.total_bytes(); @@ -96,31 +75,96 @@ bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) { } } } - const AttrSlice attrs = node->attrs(); + const AttrSlice attrs = node_->attrs(); string text; - if (IsSend(node)) { + if (IsSend(node_)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string recv_device; TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device)); - text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", tensor_name, " @", recv_device); - is_transfer_node = true; - } else if (IsRecv(node)) { + } else if (IsRecv(node_)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string send_device; TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device)); - text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", tensor_name, " @", send_device); - is_transfer_node = true; } else { text = - strings::StrCat(memory, node->name(), " = ", node->type_string(), "(", - str_util::Join(node->requested_inputs(), ", "), ")"); + strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", + str_util::Join(node_->requested_inputs(), ", "), ")"); } stats_->set_timeline_label(text); - return is_transfer_node; + step_stats_collector_->Save(device, this); +} + +void NodeExecStatsWrapper::RecordExecutorStarted() { + int64 now_nanos = Env::Default()->NowNanos(); + stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); + stats_->set_all_start_nanos(now_nanos); +} + +void NodeExecStatsWrapper::RecordComputeStarted() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::RecordComputeEnded() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::RecordExecutorEnded() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::SetScheduled(int64 nanos) { + stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); + stats_->set_scheduled_nanos(nanos); +} + +void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { + for (const auto& allocator_pair : ctx->wrapped_allocators()) { + AddAllocation(allocator_pair.first, allocator_pair.second); + } + auto* ms = stats_->mutable_memory_stats(); + ms->set_temp_memory_size(ctx->temp_memory_allocated()); + for (const auto& alloc_id : ctx->persistent_alloc_ids()) { + ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); + } + ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); +} + +void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) { + DCHECK(tensor); + NodeOutput* node_output = stats_->add_output(); + node_output->set_slot(slot); + tensor->FillDescription(node_output->mutable_tensor_description()); +} + +void NodeExecStatsWrapper::SetReferencedTensors( + const TensorReferenceVector& tensors) { + // be careful not to increment the reference count on any tensor + // while recording the information + for (size_t i = 0; i < tensors.size(); ++i) { + AllocationDescription* description = stats_->add_referenced_tensor(); + tensors.at(i).FillDescription(description); + } } void NodeExecStatsWrapper::AddAllocation( @@ -150,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() { allocations_.clear(); } -StepStatsCollector::StepStatsCollector(StepStats* ss) - : finalized_(false), step_stats_(ss) {} +StepStatsCollector::StepStatsCollector(StepStats* step_stats) + : finalized_(false), step_stats_(step_stats) {} static int ExtractGpuWithStreamAll(string device_name) { // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp, @@ -176,7 +220,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -205,7 +249,7 @@ static int ExtractGpuWithoutStream(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -252,7 +296,7 @@ void StepStatsCollector::BuildCostModel( for (auto& itr : per_device_stats) { const StringPiece device_name = itr.first; - const int gpu_id = ExtractGpuWithoutStream(std::string(device_name)); + const int gpu_id = ExtractGpuWithoutStream(string(device_name)); if (gpu_id >= 0) { // Reference the gpu hardware stats in addition to the regular stats // for this gpu device if they're available. @@ -338,28 +382,40 @@ void StepStatsCollector::BuildCostModel( } } -void StepStatsCollector::Save(const string& device, NodeExecStats* nt) { - Save(device, new NodeExecStatsWrapper(nt)); +void StepStatsCollector::Save(const string& device, + NodeExecStats* node_stats_pb) { + Save(device, + new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb), + nullptr, this)); } void StepStatsCollector::Save(const string& device, - NodeExecStatsWrapper* stats) { - if (!stats) return; - VLOG(1) << "Save dev " << device << " nt " << stats->stats(); + NodeExecStatsWrapper* node_stats) { + if (!node_stats) return; + VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats(); { mutex_lock l(mu_); if (finalized_) { LOG(WARNING) << "stats saved after finalize will not be collected."; } - if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) { + if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) { VLOG(1) << "step_stats_ nullptr or already collected too many nodes."; - delete stats; + delete node_stats; return; } - auto& dss = dev_stats_[device]; - dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats)); - collectedNodes++; + auto& device_stats = dev_stats_[device]; + device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats)); + collected_nodes_++; + } +} + +NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats( + const Node* node) { + // Only collect statistics for non-transfer nodes. + if (IsSend(node) || IsRecv(node)) { + return nullptr; } + return new NodeExecStatsWrapper(node, this); } string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) { @@ -446,12 +502,12 @@ void StepStatsCollector::Finalize() { FinalizeInternal(); } -void StepStatsCollector::FinalizeAndSwap(StepStats* ss) { +void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) { mutex_lock l(mu_); CHECK(step_stats_); FinalizeInternal(); - ss->Swap(step_stats_); - collectedNodes = 0; + step_stats->Swap(step_stats_); + collected_nodes_ = 0; } void StepStatsCollector::FinalizeInternal() { diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 7206fbf427..4365b11b19 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -36,81 +36,78 @@ class Node; class NodeExecStats; class OpKernelContext; class StepStats; +class StepStatsCollector; class Tensor; class TrackingAllocator; -// Wraps NodeExecStats and adds allocation to it. -class NodeExecStatsWrapper { +// Statistics collection interface for individual node execution. +// +// See `NodeExecStatsWrapper` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class NodeExecStatsInterface { public: - NodeExecStatsWrapper(const string& node_name); - // Owns 'stats'. - NodeExecStatsWrapper(NodeExecStats* stats); + virtual ~NodeExecStatsInterface() {} - // Destructor calls Finalize() to release the TrackingAllocators. - ~NodeExecStatsWrapper() { Finalize(); } - - // Records the absolute time in nanoseconds at which this node became - // runnable (i.e. was scheduled for execution). - void SetScheduled(int64 nanos) { - stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); - stats_->set_scheduled_nanos(nanos); - } + // Called when the statistics collection for the node has finished. Once this + // method is called, the caller should not make assumptions about the validity + // of this object. + virtual void Done(const string& device) = 0; // Called immediately after this node starts being processed by the executor. - void RecordExecutorStarted() { - int64 now_nanos = Env::Default()->NowNanos(); - stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); - stats_->set_all_start_nanos(now_nanos); - } + virtual void RecordExecutorStarted() = 0; // Called immediately before this node's `Compute()` or `ComputeAsync()` // method is called. - void RecordComputeStarted() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); - } + virtual void RecordComputeStarted() = 0; // Called immediately after this node's `Compute()` method returned (or, for // asynchronous operations, the callback passed to its `ComputeAsync()` method // was called). - void RecordComputeEnded() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); - } + virtual void RecordComputeEnded() = 0; // Called immediately after this executor finishes processing this node. - void RecordExecutorEnded() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); - } - - // Records information about the tensor produced by this node at the given - // output slot. - void SetOutput(int slot, const Tensor* v); + virtual void RecordExecutorEnded() = 0; // Records information about the memory allocated during the execution of this // node. - void SetMemory(OpKernelContext* ctx); + virtual void SetMemory(OpKernelContext* ctx) = 0; + + // Records information about the tensor produced by this node at the given + // output slot. + virtual void SetOutput(int slot, const Tensor* tensor) = 0; // Records information about the tensors that were accessed during the // execution of this node. - void SetReferencedTensors(const TensorReferenceVector& tensors); + virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0; - // Sets the timeline_label field of the wrapped NodeExecStats, using data - // from *node. Returns true iff the node is a transfer node. - bool SetTimelineLabel(const Node* node); + // Records the absolute time in nanoseconds at which this node became + // runnable (i.e. was scheduled for execution). + virtual void SetScheduled(int64 nanos) = 0; +}; + +// Wraps NodeExecStats and adds allocation to it. +class NodeExecStatsWrapper : public NodeExecStatsInterface { + public: + // Does not take ownership of `node` or `step_stats_collector`. + NodeExecStatsWrapper(const Node* node, + StepStatsCollector* step_stats_collector); + + // Takes ownership of 'stats' but not `node` or `step_stats_collector`. + NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node, + StepStatsCollector* step_stats_collector); + + // Destructor calls Finalize() to release the TrackingAllocators. + ~NodeExecStatsWrapper() { Finalize(); } + + void Done(const string& device) override; + void RecordExecutorStarted() override; + void RecordComputeStarted() override; + void RecordComputeEnded() override; + void RecordExecutorEnded() override; + void SetMemory(OpKernelContext* ctx) override; + void SetOutput(int slot, const Tensor* tensor) override; + void SetReferencedTensors(const TensorReferenceVector& tensors) override; + void SetScheduled(int64 nanos) override; private: friend class StepStatsCollector; @@ -128,9 +125,11 @@ class NodeExecStatsWrapper { gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2> allocations_; std::unique_ptr<NodeExecStats> stats_; + const Node* const node_; // Not owned. + StepStatsCollector* const step_stats_collector_; // Not owned. }; -// Statistics collection interface for individual node execution. +// Statistics collection interface for step execution. // // See `StepStatsCollector` for a concrete implementation of this interface // that interfaces with the `Session` layer. @@ -138,8 +137,9 @@ class StepStatsCollectorInterface { public: virtual ~StepStatsCollectorInterface() {} - // Saves `stats` to the collector. - virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0; + // Creates an instance of `NodeExecStatsInterface` that should be used for + // collecting statistics about individual node execution. + virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0; // Generates a string reporting the currently used memory based // on ResourceExhausted OOM `err` message. @@ -154,8 +154,8 @@ class StepStatsCollectorInterface { // Each DeviceStats object holds multiple NodeExecStats. class StepStatsCollector : public StepStatsCollectorInterface { public: - // Does not take ownership of `ss`. - explicit StepStatsCollector(StepStats* ss); + // Does not take ownership of `step_stats`. + explicit StepStatsCollector(StepStats* step_stats); // BuildCostModel builds or updates a CostModel managed by cost_model_manager, // using the currently collected DeviceStats associated with the devices in @@ -164,11 +164,12 @@ class StepStatsCollector : public StepStatsCollectorInterface { CostModelManager* cost_model_manager, const std::unordered_map<string, const Graph*>& device_map); - // Save saves nt to the DeviceStats object associated with device. + // Saves node statistics to the DeviceStats object associated with device. // Should be called before Finalize. - void Save(const string& device, NodeExecStats* nt); - void Save(const string& device, NodeExecStatsWrapper* stats) override; + void Save(const string& device, NodeExecStats* node_stats_pb); + void Save(const string& device, NodeExecStatsWrapper* node_stats); + NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override; string ReportAllocsOnResourceExhausted(const string& err) override; // The following 2 Finalize methods populate the StepStats passed @@ -176,20 +177,22 @@ class StepStatsCollector : public StepStatsCollectorInterface { // User shouldn't call Save() methods after Finalize. void Finalize(); // swaps the content of StepStats* from constructor with 'ss'. - void FinalizeAndSwap(StepStats* ss); + void FinalizeAndSwap(StepStats* step_stats); private: + // TODO(suharshs): Make this configurable if its not possible to find a value + // that works for all cases. + static const uint64 kMaxCollectedNodes = 1 << 20; + + typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector; + void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_); - typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec; - // TODO(suharshs): Make this configurable if its not possible to find a value - // that works for all cases. - const uint64 kMaxCollectedNodes = 1 << 20; mutex mu_; bool finalized_ GUARDED_BY(mu_); - std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_); + std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_); StepStats* step_stats_ GUARDED_BY(mu_); - uint64 collectedNodes GUARDED_BY(mu_) = 0; + uint64 collected_nodes_ GUARDED_BY(mu_) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 0fbc20b34b..8587d1783a 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -113,8 +113,11 @@ class MklCPUAllocatorFactory : public AllocatorFactory { } }; +#ifdef ENABLE_MKL REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory); +#endif // ENABLE_MKL + } // namespace -#endif +#endif // INTEL_MKL } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h deleted file mode 100644 index 39215efa35..0000000000 --- a/tensorflow/core/common_runtime/tracing_device.h +++ /dev/null @@ -1,57 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_ - -#include "tensorflow/core/common_runtime/device.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/tracing.h" - -namespace tensorflow { - -namespace test { -class Benchmark; -} -struct SessionOptions; - -// This class implements tracing functionality that is shared by its subclasses -// (including ThreadPoolDevice and XlaDevice). -class TracingDevice : public Device { - public: - TracingDevice(Env* env, const DeviceAttributes& attributes) - : Device(env, attributes) {} - - void Compute(OpKernel* op_kernel, OpKernelContext* context) override { - if (TF_PREDICT_FALSE( - tracing::GetTraceCollector() || - tracing::GetEventCollector(tracing::EventCategory::kCompute))) { - const string& op_name = op_kernel->name(); - tracing::ScopedActivity activity(op_name, op_kernel->type_string(), - op_kernel->IsExpensive()); - tracing::ScopedRegion region(tracing::EventCategory::kCompute, op_name); - op_kernel->Compute(context); - } else { - op_kernel->Compute(context); - } - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(TracingDevice); -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_TRACING_DEVICE_H_ diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h deleted file mode 100644 index ae0563a96a..0000000000 --- a/tensorflow/core/common_runtime/visitable_allocator.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ - -#include <functional> -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/tracking_allocator.h" - -namespace tensorflow { - -// Subclass VisitableAllocator instead of Allocator when a memory -// allocator needs to enable some kind of registration/deregistration -// of memory areas. -class VisitableAllocator : public Allocator { - public: - // Visitor gets called with a pointer to a memory area and its - // size in bytes. - typedef std::function<void(void*, size_t)> Visitor; - - // Register a visitor guaranteed to be called exactly once on each - // chunk of memory newly allocated from the underlying device. - // Typically, chunks will be reused and possibly sub-divided by a - // pool manager, so the calls will happen only once per process - // execution, not once per tensor (re)allocation. - virtual void AddAllocVisitor(Visitor visitor) = 0; - - // Register a visitor guaranteed to be called on each chunk of - // memory returned to the underlying device. - virtual void AddFreeVisitor(Visitor visitor) = 0; -}; - -// Needed for cases when a VisitableAllocator gets wrapped for tracking. -// Multiple-inheritance is considered acceptable in this case because -// VisitableAllocator is a pure virtual interface and only TrackingAllocator -// has default implementation. -class TrackingVisitableAllocator : public TrackingAllocator, - public VisitableAllocator { - public: - TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids) - : TrackingAllocator(allocator, track_ids), allocator_(allocator) {} - ~TrackingVisitableAllocator() override {} - - string Name() override { return TrackingAllocator::Name(); } - - void* AllocateRaw(size_t alignment, size_t num_bytes) override { - return TrackingAllocator::AllocateRaw(alignment, num_bytes); - } - - void DeallocateRaw(void* ptr) override { - TrackingAllocator::DeallocateRaw(ptr); - } - - void AddAllocVisitor(Visitor visitor) override { - allocator_->AddAllocVisitor(visitor); - } - - void AddFreeVisitor(Visitor visitor) override { - allocator_->AddFreeVisitor(visitor); - } - - protected: - VisitableAllocator* allocator_; -}; -} // namespace tensorflow -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 38863db1cc..6994dec3b5 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -693,6 +693,7 @@ uint64 DebugFileIO::diskBytesUsed = 0; mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED); bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { + mutex_lock l(bytes_mu); if (globalDiskBytesLimit == 0) { const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT"); if (env_tfdbg_disk_bytes_limit == nullptr || @@ -707,7 +708,6 @@ bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { if (bytes == 0) { return true; } - mutex_lock l(bytes_mu); if (diskBytesUsed + bytes < globalDiskBytesLimit) { diskBytesUsed += bytes; return true; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 6c146036ae..3361819e43 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, params.function_library = lib; params.create_kernel = [session, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { - // We do not share the kernel via the OpSegment if the node is - // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. - if (!lib->IsStateful(ndef.op()) || - lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { + if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { @@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { - // If the node is stateful, opseg owns it. Otherwise, delete it. - if (kernel && !lib->IsStateful(kernel->type_string())) { + if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { delete kernel; } }; @@ -479,10 +475,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id, delete step_container; }); Executor::Args args; - { - mutex_lock l(mu_); - args.step_id = ++next_id_; - } + args.step_id = step_id; args.rendezvous = rendezvous; args.collective_executor = ce_handle ? ce_handle->get() : nullptr; args.cancellation_manager = cancellation_manager; diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto index e7142a4ef9..e36e51d8d5 100644 --- a/tensorflow/core/example/example.proto +++ b/tensorflow/core/example/example.proto @@ -199,7 +199,13 @@ message Example { // to determine if all features within the FeatureList must // have the same size. The same holds for this FeatureList across multiple // examples. -// +// - For sequence modeling, e.g.: +// http://colah.github.io/posts/2015-08-Understanding-LSTMs/ +// https://github.com/tensorflow/nmt +// the feature lists represent a sequence of frames. +// In this scenario, all FeatureLists in a SequenceExample have the same +// number of Feature messages, so that the ith element in each FeatureList +// is part of the ith frame (or time step). // Examples of conformant and non-conformant examples' FeatureLists: // // Conformant FeatureLists: diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h index ec93b9aad9..016d1a92c1 100644 --- a/tensorflow/core/example/feature_util.h +++ b/tensorflow/core/example/feature_util.h @@ -103,6 +103,7 @@ limitations under the License. #include <iterator> #include <type_traits> +#include "absl/base/macros.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/example/feature.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -113,10 +114,10 @@ namespace tensorflow { namespace internal { -// DEPRECATED: Use GetFeature instead. // TODO(gorban): Update all clients in a followup CL. // Returns a reference to a feature corresponding to the name. // Note: it will create a new Feature if it is missing in the example. +ABSL_DEPRECATED("Use GetFeature instead.") Feature& ExampleFeature(const string& name, Example* example); // Specializations of RepeatedFieldTrait define a type of RepeatedField @@ -314,9 +315,9 @@ bool HasFeature(const string& key, const Example& example) { return HasFeature<FeatureType...>(key, GetFeatures(example)); } -// DEPRECATED: use HasFeature instead. // TODO(gorban): update all clients in a followup CL. template <typename... FeatureType> +ABSL_DEPRECATED("Use HasFeature instead.") bool ExampleHasFeature(const string& key, const Example& example) { return HasFeature<FeatureType...>(key, example); } diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 888ed0c57b..84cee5569c 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } +void Allocator::RunVariantCtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); +} + +void Allocator::RunVariantDtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); +} + // If true, cpu allocator collects more stats. static bool cpu_allocator_collect_stats = false; // If true, cpu allocator collects full stats. @@ -187,7 +196,7 @@ class CPUAllocatorFactory : public AllocatorFactory { class CPUSubAllocator : public SubAllocator { public: explicit CPUSubAllocator(CPUAllocator* cpu_allocator) - : cpu_allocator_(cpu_allocator) {} + : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {} void* Alloc(size_t alignment, size_t num_bytes) override { return cpu_allocator_->AllocateRaw(alignment, num_bytes); @@ -213,4 +222,22 @@ Allocator* cpu_allocator() { } return cpu_alloc; } + +SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {} + +void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) { + for (const auto& v : alloc_visitors_) { + v(ptr, index, num_bytes); + } +} + +void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) { + // Although we don't guarantee any order of visitor application, strive + // to apply free visitors in reverse order of alloc visitors. + for (int i = free_visitors_.size() - 1; i >= 0; --i) { + free_visitors_[i](ptr, index, num_bytes); + } +} } // namespace tensorflow diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 774b1fe137..8c23604625 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -23,12 +23,14 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { +class Variant; + // Attributes for a single allocation call. Different calls to the same // allocator could potentially have different allocation attributes. struct AllocationAttributes { @@ -228,13 +230,9 @@ class Allocator { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } - virtual void RunVariantCtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); - } + virtual void RunVariantCtor(Variant* p, size_t n); - virtual void RunVariantDtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); - } + virtual void RunVariantDtor(Variant* p, size_t n); // TODO(jeff): Maybe provide some interface to give info about // current allocation state (total number of bytes available for @@ -390,13 +388,36 @@ void EnableCPUAllocatorStats(bool enable); // full statistics. By default, it's disabled. void EnableCPUAllocatorFullStats(bool enable); -// Abstract interface of an object that does the underlying suballoc/free of -// memory for a higher-level allocator. +// An object that does the underlying suballoc/free of memory for a higher-level +// allocator. The expectation is that the higher-level allocator is doing some +// kind of cache or pool management so that it will call SubAllocator::Alloc and +// Free relatively infrequently, compared to the number of times its own +// AllocateRaw and Free methods are called. class SubAllocator { public: + // Visitor gets called with a pointer to a memory area and its + // size in bytes. The index value will be numa_node for a CPU + // allocator and GPU id for a GPU allocator. + typedef std::function<void(void*, int index, size_t)> Visitor; + + SubAllocator(const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors); + virtual ~SubAllocator() {} virtual void* Alloc(size_t alignment, size_t num_bytes) = 0; virtual void Free(void* ptr, size_t num_bytes) = 0; + + protected: + // Implementation of Alloc() method must call this on newly allocated + // value. + void VisitAlloc(void* ptr, int index, size_t num_bytes); + + // Implementation of Free() method must call this on value to be + // freed immediately before deallocation. + void VisitFree(void* ptr, int index, size_t num_bytes); + + const std::vector<Visitor> alloc_visitors_; + const std::vector<Visitor> free_visitors_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h index 24f282ce84..e907c52ba9 100644 --- a/tensorflow/core/framework/allocator_registry.h +++ b/tensorflow/core/framework/allocator_registry.h @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" namespace tensorflow { diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index 1a3994736c..4ffd732f8e 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <numeric> #include <vector> #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc index 1258e40c93..af59500aee 100644 --- a/tensorflow/core/framework/cancellation.cc +++ b/tensorflow/core/framework/cancellation.cc @@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) { } } +bool CancellationManager::TryDeregisterCallback(CancellationToken token) { + mutex_lock lock(mu_); + if (is_cancelled_ || is_cancelling_) { + return false; + } else { + callbacks_.erase(token); + return true; + } +} + CancellationManager::~CancellationManager() { if (!callbacks_.empty()) { StartCancel(); diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h index acdaaf6a90..7a5d942486 100644 --- a/tensorflow/core/framework/cancellation.h +++ b/tensorflow/core/framework/cancellation.h @@ -122,6 +122,15 @@ class CancellationManager { // cancellation manager. bool DeregisterCallback(CancellationToken token); + // Deregister the callback that, when registered, was associated + // with the given cancellation token. Returns true iff the callback + // was deregistered and will not be invoked; otherwise returns false + // immediately, with no guarantee that the callback has completed. + // + // This method is guaranteed to return true if StartCancel has not been + // called. + bool TryDeregisterCallback(CancellationToken token); + private: bool is_cancelling_; std::atomic_bool is_cancelled_; diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc index e3f18240b5..bf7593bc5f 100644 --- a/tensorflow/core/framework/cancellation_test.cc +++ b/tensorflow/core/framework/cancellation_test.cc @@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) { delete cm; } +TEST(Cancellation, TryDeregisterWithoutCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_TRUE(deregistered); + delete manager; + EXPECT_FALSE(is_cancelled); +} + +TEST(Cancellation, TryDeregisterAfterCancel) { + bool is_cancelled = false; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback( + token, [&is_cancelled]() { is_cancelled = true; }); + EXPECT_TRUE(registered); + manager->StartCancel(); + EXPECT_TRUE(is_cancelled); + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_FALSE(deregistered); + delete manager; +} + +TEST(Cancellation, TryDeregisterDuringCancel) { + Notification cancel_started, finish_callback, cancel_complete; + CancellationManager* manager = new CancellationManager(); + auto token = manager->get_cancellation_token(); + bool registered = manager->RegisterCallback(token, [&]() { + cancel_started.Notify(); + finish_callback.WaitForNotification(); + }); + EXPECT_TRUE(registered); + + thread::ThreadPool w(Env::Default(), "test", 1); + w.Schedule([&]() { + manager->StartCancel(); + cancel_complete.Notify(); + }); + cancel_started.WaitForNotification(); + + bool deregistered = manager->TryDeregisterCallback(token); + EXPECT_FALSE(deregistered); + + finish_callback.Notify(); + cancel_complete.WaitForNotification(); + delete manager; +} + } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index 9ffd8e1ee0..284dafb886 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" namespace tensorflow { - +namespace data { namespace { // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor. @@ -329,4 +329,5 @@ void BackgroundWorker::WorkerLoop() { } } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 04865a1d4f..697e0604bf 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -40,6 +41,15 @@ limitations under the License. namespace tensorflow { +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class Node; + +namespace data { +// A constant that can be used to enable auto-tuning. +constexpr int kAutoTune = -1; + class DatasetBase; class SerializationContext; @@ -66,11 +76,6 @@ class IteratorStateWriter { virtual ~IteratorStateWriter() {} }; -// Forward declarations to avoid introducing a dependency on headers in -// "tensorflow/core/graph/...". -class GraphDefBuilder; -class Node; - // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. class GraphDefBuilderWrapper { public: @@ -222,8 +227,7 @@ class GraphDefBuilderWrapper { return (str_util::EndsWith(op_def->name(), "Dataset") && op_def->output_arg_size() == 1 && op_def->output_arg(0).type() == DT_VARIANT) || - dataset::WhitelistedStatefulOpRegistry::Global()->Contains( - op_def->name()); + WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name()); } bool HasAttr(const string& op_type_name, const string& attr_name) const; @@ -290,6 +294,9 @@ class IteratorContext { // The Allocator to be used to allocate the output of an iterator. std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; + + // If non-null, identifies the object used for performance modeling. + std::shared_ptr<model::Model> model = nullptr; }; explicit IteratorContext(Params params) : params_(std::move(params)) {} @@ -341,6 +348,10 @@ class IteratorContext { return params_.stats_aggregator_getter; } + std::shared_ptr<model::Model> model() { return params_.model; } + + Params params() { return params_; } + private: Params params_; }; @@ -375,7 +386,11 @@ class SerializationContext { // defined below. class IteratorBase { public: - virtual ~IteratorBase() {} + virtual ~IteratorBase() { + for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { + (*rit)(); + } + } // Gets the next output from the range that this iterator is traversing. // @@ -409,6 +424,10 @@ class IteratorBase { // in the outputs of this iterator. virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; + // Returns a string that identifies the sequence of iterators leading up to + // this iterator. + virtual const string& prefix() const = 0; + // Performs initialization that needs to happen outside of a constructor to // properly propagate errors. virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } @@ -448,6 +467,18 @@ class IteratorBase { IteratorStateReader* reader) { return errors::Unimplemented("RestoreInternal"); } + + private: + friend class DatasetBase; // for access to `AddCleanupFunction` + + // Registers a cleanup function to be called upon object destruction. + // + // Registered functions are invoked in the reserve order of registration. + void AddCleanupFunction(std::function<void()>&& cleanup_fn) { + cleanup_fns_.push_back(std::move(cleanup_fn)); + } + + std::vector<std::function<void()>> cleanup_fns_; }; // Represents runtime information needed to construct a dataset. @@ -497,6 +528,13 @@ class DatasetBase : public core::RefCounted { Status MakeIterator(IteratorContext* ctx, const string& prefix, std::unique_ptr<IteratorBase>* iterator) const { *iterator = MakeIteratorInternal(prefix); + if (ctx->model()) { + ctx->model()->AddNode((*iterator)->prefix(), prefix); + std::shared_ptr<model::Model> model = ctx->model(); + const string& prefix = (*iterator)->prefix(); + (*iterator)->AddCleanupFunction( + [model, prefix]() { model->RemoveNode(prefix); }); + } return (*iterator)->Initialize(ctx); } @@ -523,6 +561,8 @@ class DatasetBase : public core::RefCounted { IteratorStateWriter* writer) const; protected: + friend class DatasetToGraphOp; // For access to graph related members. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} @@ -540,8 +580,6 @@ class DatasetBase : public core::RefCounted { virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const = 0; - friend class DatasetToGraphOp; // For access to graph related members. - private: const string name_; }; @@ -564,7 +602,7 @@ class DatasetBaseIterator : public IteratorBase { ~DatasetBaseIterator() override { params_.dataset->Unref(); } // The sequence of iterators leading up to this iterator. - const string& prefix() const { return params_.prefix; } + const string& prefix() const override { return params_.prefix; } const DataTypeVector& output_dtypes() const override { return params_.dataset->output_dtypes(); @@ -577,7 +615,10 @@ class DatasetBaseIterator : public IteratorBase { Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) final { tracing::ScopedActivity activity(params_.prefix); + RecordStart(ctx, true /* stop_output */); Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + if (s.ok() && !*end_of_sequence) RecordElement(ctx); + RecordStop(ctx, true /* start_output */); if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { s = errors::Internal( "Iterator \"", params_.prefix, @@ -604,6 +645,54 @@ class DatasetBaseIterator : public IteratorBase { return strings::StrCat(params_.prefix, ":", name); } + // When performance modeling is enabled, this method adds a constant parameter + // to the model node corresponding to this iterator. + void AddConstantParameter(IteratorContext* ctx, const string& name, + int64 value) { + if (ctx->model()) { + ctx->model()->AddConstantParameter(prefix(), name, value); + } + } + + // When performance modeling is enabled, this method adds a tunable parameter + // to the model node corresponding to this iterator. + // + // The performance modeling logic may use `value` to set the value of the + // tunable parameter at any point during the lifetime of this iterator. When + // it does, it notifies `cond_var`. + void AddTunableParameter(IteratorContext* ctx, const string& name, + std::atomic<int64>* value, int64 min, int64 max, + condition_variable* cond_var) { + if (ctx->model()) { + ctx->model()->AddTunableParameter(prefix(), name, value, min, max, + cond_var); + } + } + + // When performance modeling is enabled, this method records the fact that + // this iterator has produced an element. + void RecordElement(IteratorContext* ctx) { + if (ctx->model()) { + ctx->model()->RecordElement(prefix()); + } + } + + // When performance modeling is enabled, this method records the fact that + // a thread of this iterator has started work. + void RecordStart(IteratorContext* ctx, bool stop_output = false) { + if (ctx->model()) { + ctx->model()->RecordStart(prefix(), stop_output); + } + } + + // When performance modeling is enabled, this method records the fact that + // a thread of this iterator has stopped work. + void RecordStop(IteratorContext* ctx, bool start_output = false) { + if (ctx->model()) { + ctx->model()->RecordStop(prefix(), start_output); + } + } + private: BaseParams params_; }; @@ -751,6 +840,21 @@ class BackgroundWorker { std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); }; +} // namespace data + +// TODO(b/114112161): Remove these aliases when all users have moved over to the +// `tensorflow::data` namespace. +using data::DatasetBase; +using data::DatasetContext; +using data::DatasetIterator; +using data::DatasetOpKernel; +using data::IteratorBase; +using data::IteratorContext; +using data::IteratorStateReader; +using data::IteratorStateWriter; +using data::SerializationContext; +using data::UnaryDatasetOpKernel; + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h index 3b48999edb..74bd39cb61 100644 --- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -16,38 +16,38 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ #define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#include <unordered_set> #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -namespace dataset { +namespace data { // Registry for stateful ops that need to be used in dataset functions. // See below macro for usage details. class WhitelistedStatefulOpRegistry { public: - Status Add(StringPiece op_name) { - op_names_.insert(op_name); + Status Add(string op_name) { + op_names_.insert(std::move(op_name)); return Status::OK(); } - bool Contains(StringPiece op_name) { - return op_names_.find(op_name) != op_names_.end(); - } + bool Contains(const string& op_name) { return op_names_.count(op_name); } static WhitelistedStatefulOpRegistry* Global() { - static WhitelistedStatefulOpRegistry* reg = - new WhitelistedStatefulOpRegistry; + static auto* reg = new WhitelistedStatefulOpRegistry; return reg; } private: - WhitelistedStatefulOpRegistry() {} - WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); + WhitelistedStatefulOpRegistry() = default; + WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) = + delete; WhitelistedStatefulOpRegistry operator=( - WhitelistedStatefulOpRegistry const& copy); - std::set<StringPiece> op_names_; + WhitelistedStatefulOpRegistry const& copy) = delete; + + std::unordered_set<string> op_names_; }; -} // namespace dataset +} // namespace data // Use this macro to whitelist an op that is marked stateful but needs to be // used inside a map_fn in an input pipeline. This is only needed if you wish @@ -67,10 +67,9 @@ class WhitelistedStatefulOpRegistry { WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ - static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ - name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name) } // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 794250a2c1..446c31b17f 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -20,6 +20,7 @@ limitations under the License. #include <string> #include <vector> +#include "absl/base/macros.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/refcount.h" @@ -176,9 +177,9 @@ class DeviceBase { return nullptr; } - // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`. // This method is provided for backwards compatibility, and will be removed // in a future release. + ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.") Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) { return GetAllocator(attr); } @@ -214,10 +215,12 @@ class DeviceBase { // This is overridden by GPU devices to reinitialize the derived // type returned by MakeGpuDevice. - virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/, - PerOpGpuDevice* /*device*/, - DeviceContext* /*dc*/, - Allocator* /*allocator*/) {} + virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/, + PerOpGpuDevice* /*device*/, + DeviceContext* /*dc*/, + Allocator* /*allocator*/) { + return Status::OK(); + } // Unimplemented by default virtual const DeviceAttributes& attributes() const; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 26f32677af..a17959a448 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } +string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { + tf_shared_lock l(mu_); + int index = 0; + string name = strings::StrCat(prefix, index); + while (function_defs_.find(name) != function_defs_.end()) { + ++index; + name = strings::StrCat(prefix, index); + } + return name; +} + const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { @@ -1283,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create( for (const auto& r : ret_def) { fdef.mutable_ret()->insert({r.first, r.second}); } + + auto* op_def_registry = OpRegistry::Global(); + // Check if any op is stateful. + for (const auto& n : node_def) { + const OpDef* op_def = nullptr; + auto status = op_def_registry->LookUpOpDef(n.op, &op_def); + // Lookup can fail if e.g. we are calling a function that was not yet + // defined. If it happens, conservatively assume the op is stateful. + if (!status.ok() || op_def->is_stateful()) { + fdef.mutable_signature()->set_is_stateful(true); + } + } return fdef; } @@ -1344,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name, strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first); } } + if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true); } // Returns diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 03296a7761..e01eb7503d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override LOCKS_EXCLUDED(mu_); + // Generates new function name with the specified prefix that is unique + // across this library. + string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_); + // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 46b169dddc..0445c242e9 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -91,6 +91,31 @@ FunctionDef IsZero() { }); } +FunctionDef RandomUniform() { + const Tensor kZero = test::AsScalar<int64>(0); + + return FDH::Define( + // Name + "RandomUniform", + // Args + {"x: T"}, + // Return values + {"random_uniform: int64"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + {{{"random_uniform/shape"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform"}, + "RandomUniform", + {"random_uniform/shape"}, + {{"T", DT_INT32}, + {"Tout", DT_FLOAT}, + {"seed", 87654321}, + {"seed2", 42}}}}); +} + FunctionDef XTimesTwo() { const Tensor kTwo = test::AsScalar<int64>(2); return FDH::Define( @@ -110,6 +135,22 @@ FunctionDef XTimesTwo() { }); } +FunctionDef XAddX() { + return FDH::Define( + // Name + "XAddX", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}}, + }); +} + FunctionDef XTimesTwoInt32() { const Tensor kTwo = test::AsScalar<int64>(2); return FDH::Define( diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 6d6476b936..a01743423b 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, // x:T -> x * 2. FunctionDef XTimesTwo(); +// x:T -> x + x. +FunctionDef XAddX(); + // x:T -> x * 2, where x is int32. FunctionDef XTimesTwoInt32(); @@ -81,6 +84,9 @@ FunctionDef NonZero(); // x: T -> bool. FunctionDef IsZero(); +// x: T -> int64 +FunctionDef RandomUniform(); + // x:T, y:T -> y:T, x:T FunctionDef Swap(); diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc new file mode 100644 index 0000000000..b0330ec990 --- /dev/null +++ b/tensorflow/core/framework/model.cc @@ -0,0 +1,419 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/model.h" + +#include <memory> + +namespace tensorflow { +namespace data { +namespace model { + +// TODO(jsimsa): Use `Node` subclassing instead of types and node statements. +void Model::Node::CollectTunables( + std::vector<std::shared_ptr<Node::Tunable>>* tunables) { + tf_shared_lock l(mu_); + for (auto input : inputs_) { + input->CollectTunables(tunables); + } + switch (type_) { + case Type::MAP_AND_BATCH: + case Type::PARALLEL_INTERLEAVE_V2: + case Type::PARALLEL_MAP: { + if (auto* tunable_param = + gtl::FindOrNull(tunable_params_, "parallelism")) { + tunables->push_back(*tunable_param); + } + return; + } + default: + return; + } +} + +int64 Model::Node::GetParameterValue(const string& name) { + if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) { + return (*tunable_param)->value; + } + return constant_params_[name]; +} + +int64 Model::Node::ProcessingTimeLocked() { + switch (type_) { + case Type::BATCH: + case Type::MAP_AND_BATCH: + case Type::PADDED_BATCH: { + int64 batch_size = GetParameterValue("batch_size"); + return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs(); + } + case Type::FILTER: { + std::shared_ptr<Node> input = inputs_.front(); + double ratio = static_cast<double>(input->num_elements()) / + static_cast<double>(num_elements_); + return NanosPerElementLocked() + + static_cast<int64>(ratio * + static_cast<double>(ProcessingTimeForInputs())); + } + case Type::FLAT_MAP: + case Type::INTERLEAVE: + case Type::PARALLEL_INTERLEAVE: + case Type::PARALLEL_INTERLEAVE_V2: { + // TODO(jsimsa): model the first input + // TODO(jsimsa): use processing time history as a prior for future inputs + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 processing_time = + ProcessingTimeForInputs() - inputs_.front()->ProcessingTime(); + return NanosPerElementLocked() + + static_cast<double>(processing_time) / + static_cast<double>(inputs_.size() - 1); + } + case Type::CACHE: + case Type::CONCATENATE: + case Type::MAP: + case Type::PARALLEL_MAP: + case Type::PREFETCH: + // TODO(jsimsa): use processing time history as a prior for future inputs + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + return NanosPerElementLocked() + ProcessingTimeForInputs(); + } + default: + return NanosPerElementLocked(); + } +} + +int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) { + switch (type_) { + case Type::BATCH: + case Type::PADDED_BATCH: { + double batch_size = GetParameterValue("batch_size"); + int64 old_value = (*input_times)[input_times->size() - 1]; + (*input_times)[input_times->size() - 1] = static_cast<int64>( + static_cast<double>(old_value + NanosPerElementLocked()) / + batch_size); + auto cleanup = gtl::MakeCleanup([input_times, old_value]() { + (*input_times)[input_times->size() - 1] = old_value; + }); + return NanosPerElementLocked() + + batch_size * OutputTimeForInputs(input_times); + } + case Type::FILTER: { + std::shared_ptr<Node> input = inputs_.front(); + int64 old_value = (*input_times)[input_times->size() - 1]; + double ratio = static_cast<double>(input->num_elements()) / + static_cast<double>(num_elements_); + (*input_times)[input_times->size() - 1] = static_cast<int64>( + static_cast<double>(old_value + NanosPerElementLocked()) / ratio); + auto cleanup = gtl::MakeCleanup([input_times, old_value]() { + (*input_times)[input_times->size() - 1] = old_value; + }); + return NanosPerElementLocked() + + static_cast<int64>( + static_cast<double>(OutputTimeForInputs(input_times)) * ratio); + } + case Type::FLAT_MAP: + case Type::INTERLEAVE: { + // TODO(jsimsa): model the first input + // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1` + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1)); + (*input_times)[input_times->size() - 1] += delta; + auto cleanup = gtl::MakeCleanup([input_times, delta]() { + (*input_times)[input_times->size() - 1] -= delta; + }); + int64 output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + return NanosPerElementLocked() + + static_cast<double>(output_time) / + static_cast<double>(inputs_.size() - 1); + } + case Type::MAP_AND_BATCH: { + double batch_size = GetParameterValue("batch_size"); + double parallelism = GetParameterValue("parallelism"); + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) / + (batch_size * parallelism)); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = static_cast<int64>( + static_cast<double>(NanosPerElementLocked()) / parallelism + + batch_size * OutputTimeForInputs(input_times)); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_INTERLEAVE: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = GetParameterValue("parallelism"); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_INTERLEAVE_V2: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = + std::min(static_cast<int>(GetParameterValue("cycle_length")), + static_cast<int>(GetParameterValue("parallelism"))); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_MAP: { + double parallelism = + std::min(port::NumSchedulableCPUs(), + static_cast<int>(GetParameterValue("parallelism"))); + int64 delta = static_cast<int64>( + static_cast<double>(NanosPerElementLocked()) / parallelism); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = + static_cast<double>(NanosPerElementLocked()) / parallelism + + OutputTimeForInputs(input_times); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PREFETCH: { + int64 delta = NanosPerElementLocked(); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + return std::max(0LL, NanosPerElementLocked() + + OutputTimeForInputs(input_times) - + input_times->at(input_times->size() - 2)); + } + case Type::CACHE: + case Type::CONCATENATE: + case Type::MAP: + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + int64 delta = NanosPerElementLocked(); + (*input_times)[input_times->size() - 1] += delta; + auto cleanup = gtl::MakeCleanup([input_times, delta]() { + (*input_times)[input_times->size() - 1] -= delta; + }); + return NanosPerElementLocked() + OutputTimeForInputs(input_times); + } + default: + return NanosPerElementLocked(); + } +} + +void Model::AddConstantParameter(const string& node_name, + const string& parameter_name, int64 value) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, node_name); + if (node) { + (*node)->add_constant_param(parameter_name, value); + } +} + +void Model::AddNode(const string& name, const string& output_name) { + // The name captures the sequence of iterators joined by `::`. We use the full + // sequence as the key in the lookup table, but only the last element of the + // sequence as the name node. + std::vector<string> tokens = + str_util::Split(name, ':', str_util::SkipEmpty()); + // The output name might contain an index. We need to strip it to make it + // possible for the model to successfully identify the output node. + string sanitized_output_name = output_name; + if (str_util::EndsWith(output_name, "]")) { + sanitized_output_name = output_name.substr(0, output_name.rfind('[')); + } + std::shared_ptr<Node> output; + mutex_lock l(mu_); + auto it = lookup_table_.find(sanitized_output_name); + if (it != lookup_table_.end()) { + output = it->second; + } + std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output)); + if (!output_) { + output_ = node; + } + if (output) { + output->add_input(node); + } + lookup_table_.insert(std::make_pair(name, node)); +} + +void Model::AddProcessingTime(const string& name, int64 delta) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + (*node)->add_processing_time(delta); + } +} + +void Model::AddTunableParameter(const string& node_name, + const string& parameter_name, + std::atomic<int64>* value, int64 min, int64 max, + condition_variable* cond_var) { + tf_shared_lock l(mu_); + auto node = *gtl::FindOrNull(lookup_table_, node_name); + DCHECK(node); + node->add_tunable_param(parameter_name, value, min, max, cond_var); +} + +// The optimization algorithm starts by setting all tunable parallelism +// parameters to 1. It then repeatedly identifies the parameter whose increase +// in parallelism decreases the output time the most. This process is repeated +// until all parameters reach their maximum values or the projected output time +// is less than or equal to the processing time needed to produce an element +// divided by CPU budget. +void Model::Optimize(int64 cpu_budget) { + tf_shared_lock lock(mu_); + std::vector<std::shared_ptr<Model::Node::Tunable>> tunables; + const int64 processing_time = ProcessingTime(); + tunables = CollectTunables(); + for (auto tunable : tunables) { + tunable->value = 1; + } + while (true) { + const int64 output_time = OutputTime(); + bool all_tunables = true; + for (auto& tunable : tunables) { + if (tunable->value < tunable->max) { + all_tunables = false; + break; + } + } + if (output_time < processing_time / cpu_budget || all_tunables) { + break; + } + int64 best_delta = -1; + Model::Node::Tunable* best_tunable = nullptr; + for (auto& tunable : tunables) { + if (tunable->value == tunable->max) { + continue; + } + tunable->value++; + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_tunable = tunable.get(); + } + tunable->value--; + } + if (!best_tunable) { + // NOTE: This can happen because we are performing the optimization + // while the model data is changing. If this becomes an issue, we should + // look into performing the optimization using a model snapshot. + break; + } + best_tunable->value++; + } + VLOG(2) << "Number of knobs: " << tunables.size(); + for (auto& tunable : tunables) { + VLOG(2) << "Setting tunable parameter: " << tunable->value; + tunable->value_ptr->store(tunable->value); + if (tunable->cond_var) { + tunable->cond_var->notify_all(); + } + } +} + +void Model::RecordElement(const string& name) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + (*node)->record_element(); + } +} + +void Model::RecordStart(const string& name, bool stop_output) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + if (stop_output && (*node)->output()) { + (*node)->output()->record_stop(); + } + (*node)->record_start(); + } +} + +void Model::RecordStop(const string& name, bool start_output) { + tf_shared_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node) { + (*node)->record_stop(); + if (start_output && (*node)->output()) { + (*node)->output()->record_start(); + } + } +} + +void Model::RemoveNode(const string& name) { + mutex_lock l(mu_); + auto node = gtl::FindOrNull(lookup_table_, name); + if (node && (*node)->output()) { + (*node)->output()->remove_input(*node); + } + lookup_table_.erase(name); +} + +std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() { + std::vector<std::shared_ptr<Model::Node::Tunable>> tunables; + output_->CollectTunables(&tunables); + return tunables; +} + +int64 Model::OutputTime() { + std::vector<int64> input_times(1, 0); + return output_->OutputTime(&input_times); +} + +int64 Model::ProcessingTime() { return output_->ProcessingTime(); } + +} // namespace model +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h new file mode 100644 index 0000000000..26402f5cd3 --- /dev/null +++ b/tensorflow/core/framework/model.h @@ -0,0 +1,404 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ + +#include <list> +#include <memory> +#include <string> +#include <thread> // (b/114492873): move this include into core/platform +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace data { +namespace model { + +// Abstract representation of a TensorFlow input pipeline that can be used +// for collecting runtime information and optimizing performance. It collects +// runtime information about execution of the input pipeline that is used to +// create a performance model, which is in turn used to identify optimal values +// of tunable parameters. +// +// Developers of tf.data transformations are not expected to interact with this +// class directly. Boiler plate code for creating the abstract representation of +// the input pipeline and collecting runtime information has been added to the +// implementation of `DatasetBase` and `DatasetBaseIterator` respectively. +class Model { + public: + Model() = default; + + // Adds a constant parameter for the given node. + void AddConstantParameter(const string& node_name, + const string& parameter_name, int64 value) + LOCKS_EXCLUDED(mu_); + + // Adds a node with the given name and given output (identified by name). + void AddNode(const string& name, const string& output_name) + LOCKS_EXCLUDED(mu_); + + // Increments the processing time for the given node.. + void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_); + + // Adds a tunable parameter for the given node. + void AddTunableParameter(const string& node_name, + const string& parameter_name, + std::atomic<int64>* value, int64 min, int64 max, + condition_variable* cond_var) LOCKS_EXCLUDED(mu_); + + // Runs optimization. + void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_); + + // Records that a node has produced an element. + void RecordElement(const string& name) LOCKS_EXCLUDED(mu_); + + // Records that the given node has started work. If `stop_output` is set, it + // also records that the output of the given node has stopped work. + void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_); + + // Records that the given node has stopped work. If `stop_output` is set, it + // also records that the output of the given node has started work. + void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_); + + // Removes the given node. + void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_); + + private: + // Abstract representation of a TensorFlow input pipeline node. It collects + // information about inputs to this node, processing time spent executing the + // node logic, number of elements produced by the node, various other + // information (e.g. batch size or execution parallelism). + // + // Developers of tf.data transformations are not expected to interact with + // this class directly. Boiler plate code for creating the abstract + // representation of the input pipeline and collecting common information has + // been added to the implementation of `DatasetBase` and `DatasetBaseIterator` + // respectively. + // + // In addition, `DatasetBaseIterator` provides wrappers that can be used for + // transformation-specific information collection. The `SetMetadata` wrapper + // can be used to pass arbitrary metadata to the modeling framework, while the + // `StartWork` and `StopWork` wrappers should be used to correctly account for + // processing time of multi-threaded transformation that yield the CPU; such + // transformations should invoke `StartWork()` when a transformation thread + // starts executing (e.g. when created or woken up) and `StopWork()` when a + // transformation thread stops executing (e.g. when returning or waiting). + // + // TODO(jsimsa): Create an API to capture the abstract semantics of each + // tf.data transformation and replace switch-case blocks with inheritance. + class Node { + public: + // Represents a tunable parameter. + struct Tunable { + Tunable(std::atomic<int64>* value, int64 min, int64 max, + condition_variable* cond_var) + : value(*value), + min(min), + max(max), + value_ptr(value), + cond_var(cond_var) {} + + // Identifies the model value of the parameter. This can be different from + // the actual value (e.g. during optimization search). + int64 value; + + // Identifies the minimum value of the parameter. + int64 min; + + // Identifies the maximum value of the parameter. + int64 max; + + // Points to the actual value of the parameter. Not owned. + std::atomic<int64>* value_ptr; + + // If non-null, this condition variable is notified when the model updates + // the actual value of the parameter (via `value_ptr`). Not owned. + condition_variable* cond_var; + }; + + Node(int64 id, const string& name, std::shared_ptr<Node> output) + : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {} + + // Adds a constant parameter. + void add_constant_param(const string& name, int64 value) + LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + constant_params_[name] = value; + } + + // Adds an input. + void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.push_back(node); + } + + // Increments the aggregate processing time by the given delta. + void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + processing_time_ += delta; + } + + // Adds a tunable parameter. + void add_tunable_param(const string& name, std::atomic<int64>* value, + int64 min, int64 max, condition_variable* cond_var) + LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + tunable_params_[name] = + std::make_shared<Tunable>(value, min, max, cond_var); + } + + // Returns the unique node ID. + int64 id() LOCKS_EXCLUDED(mu_) { return id_; } + + // Returns the node inputs. + std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return inputs_; + } + + // Returns the node name. + const string& name() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return name_; + } + + // Returns the number of elements produced by the node. + int64 num_elements() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return num_elements_; + } + + // Returns the node output. + std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return output_; + } + + // Records that the node produced an element. + void record_element() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + num_elements_++; + } + + // Records that a node thread has started executing. + void record_start() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos(); + } + + // Records that a node thread has stopped executing. + void record_stop() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + std::thread::id tid = std::this_thread::get_id(); + auto start_time = gtl::FindOrNull(work_start_, tid); + DCHECK(start_time) + << "Encountered a stop event that was not preceded by a start event."; + if (start_time) { + processing_time_ += Env::Default()->NowNanos() - *start_time; + work_start_.erase(tid); + } + } + + // Removes an input. + void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.remove(input); + } + + // Set the node output. + void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + output_ = output; + } + + // Collects tunable parameters in the subtree rooted in this node. + void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables) + LOCKS_EXCLUDED(mu_); + + // Returns the per-element output time for this node. + int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return OutputTimeLocked(input_times); + } + + // Returns the per-element processing time spent in the subtree rooted in + // this node. + int64 ProcessingTime() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return ProcessingTimeLocked(); + } + + private: + enum class Type { + BATCH = 0, + CACHE, + CONCATENATE, + FILTER, + FLAT_MAP, + INTERLEAVE, + MAP, + MAP_AND_BATCH, + PADDED_BATCH, + PARALLEL_INTERLEAVE, + PARALLEL_INTERLEAVE_V2, + PARALLEL_MAP, + PREFETCH, + REPEAT, + SHUFFLE, + SKIP, + TAKE, + ZIP, + UNKNOWN, + }; + + // Gets a value of the given parameter (tunable or constant). + int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_); + + // Returns the per-element processing time spent in this node. + int64 NanosPerElement() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return NanosPerElementLocked(); + } + + int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) { + if (num_elements_ == 0) { + return 0; + } + return (int64)((double)processing_time_ / (double)num_elements_); + } + + int64 OutputTimeLocked(std::vector<int64>* input_times) + SHARED_LOCKS_REQUIRED(mu_); + + int64 OutputTimeForInputs(std::vector<int64>* input_times) + SHARED_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto input : inputs_) { + sum += input->OutputTime(input_times); + } + return sum; + } + + int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_); + + // Returns the per-element processing time spent in the inputs of this node. + int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto input : inputs_) { + sum += input->ProcessingTime(); + } + return sum; + } + + Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) { + if (name_ == "Batch") { + return Type::BATCH; + } + if (str_util::EndsWith(name_, "Cache")) { + return Type::CACHE; + } + if (name_ == "Concatenate") { + return Type::CONCATENATE; + } + if (name_ == "Filter") { + return Type::FILTER; + } + if (name_ == "FlatMap") { + return Type::FLAT_MAP; + } + if (name_ == "Interleave") { + return Type::INTERLEAVE; + } + if (name_ == "Map") { + return Type::MAP; + } + if (name_ == "MapAndBatch") { + return Type::MAP_AND_BATCH; + } + if (name_ == "PaddedBatch") { + return Type::PADDED_BATCH; + } + if (name_ == "ParallelInterleave") { + return Type::PARALLEL_INTERLEAVE; + } + if (name_ == "ParallelInterleaveV2") { + return Type::PARALLEL_INTERLEAVE_V2; + } + if (name_ == "ParallelMap") { + return Type::PARALLEL_MAP; + } + if (name_ == "Prefetch") { + return Type::PREFETCH; + } + if (str_util::EndsWith(name_, "Repeat")) { + return Type::REPEAT; + } + if (name_ == "Shuffle") { + return Type::SHUFFLE; + } + if (str_util::EndsWith(name_, "Skip")) { + return Type::SKIP; + } + if (str_util::EndsWith(name_, "Take")) { + return Type::TAKE; + } + if (name_ == "Zip") { + return Type::ZIP; + } + return Type::UNKNOWN; + } + + mutex mu_; + const int64 id_; + const string name_; + const Type type_; + int64 processing_time_ GUARDED_BY(mu_) = 0; + int64 num_elements_ GUARDED_BY(mu_) = 0; + std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_); + std::map<string, int64> constant_params_ GUARDED_BY(mu_); + // Tunables are shared with the model during optimization. + std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_); + std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_); + std::shared_ptr<Node> output_ GUARDED_BY(mu_); + }; + + std::vector<std::shared_ptr<Node::Tunable>> CollectTunables() + SHARED_LOCKS_REQUIRED(mu_); + + int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_); + + int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_); + + // Used for coordination between different input pipeline threads. Exclusive + // access is required only when adding or removing nodes. Concurrent access to + // existing nodes is protected by a node mutex. + mutex mu_; + int64 id_counter_ GUARDED_BY(mu_) = 1; + std::shared_ptr<Node> output_ GUARDED_BY(mu_); + std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_); +}; + +} // namespace model +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index bacc1d72c4..43ac1d0ada 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, node_def.name()); } +Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs) { + for (const auto& arg : op_def.input_arg()) { + TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); + } + return Status::OK(); +} + Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, int output_port, DataType* output_type) { DataTypeVector output_types; @@ -397,12 +405,18 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector* inputs, DataTypeVector* outputs) { - for (const auto& arg : op_def.input_arg()) { - TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs)); - } + TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs)); return OutputTypesForNode(node_def, op_def, outputs); } +Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs) { + DataTypeVector outputs; + TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs)); + *num_outputs = outputs.size(); + return Status::OK(); +} + Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { if (node_def.op() != op_def.name()) { return errors::InvalidArgument("NodeDef op '", node_def.op(), diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 499034cab2..187bfa2c88 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -249,6 +249,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name); // REQUIRES: ValidateOpDef(op_def).ok() Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, int input_port, DataType* input_type); +// Computes the input types for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, + DataTypeVector* inputs); // Computes the output type for a specific node output. // REQUIRES: ValidateOpDef(op_def).ok() Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, @@ -261,6 +265,10 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, // REQUIRES: ValidateOpDef(op_def).ok() Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector* inputs, DataTypeVector* outputs); +// Computes the number of outputs for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs); // Validates that the NodeDef: // * Defines all expected attrs from the OpDef. diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc index 74cc594863..d9d437024a 100644 --- a/tensorflow/core/framework/node_def_util_test.cc +++ b/tensorflow/core/framework/node_def_util_test.cc @@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) { "Illegal op input name 'a:00"); } +TEST(InputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_FLOAT); + EXPECT_EQ(types[1], DT_INT32); + + DataType type; + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_FLOAT); + EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_INT32); + EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok()); +} + +TEST(OutputTypesForNode, Simple) { + const OpDef op_def = ToOpDef(OpDefBuilder("Simple") + .Input("a: float") + .Input("b: int32") + .Output("c: string") + .Output("d: bool")); + const NodeDef node_def = ToNodeDef( + NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput())); + DataTypeVector types; + EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok()); + EXPECT_EQ(types[0], DT_STRING); + EXPECT_EQ(types[1], DT_BOOL); + + DataType type; + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok()); + EXPECT_EQ(type, DT_STRING); + EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok()); + EXPECT_EQ(type, DT_BOOL); + EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok()); +} + TEST(NameRangesForNodeTest, Simple) { const OpDef op_def = ToOpDef(OpDefBuilder("Simple") .Input("a: float") diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c694e10193..3e34bf0418 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ -// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : OpKernel(context, - std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {} + : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {} OpKernel::OpKernel(OpKernelConstruction* context, std::unique_ptr<const NodeDef> node_def) @@ -266,9 +265,12 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) params_->ensure_eigen_gpu_device(); if (params_->eigen_gpu_device != nullptr) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); - params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, - params_->op_device_context, - eigen_gpu_allocator); + Status s = params_->device->ReinitializeGpuDevice( + this, params_->eigen_gpu_device, params_->op_device_context, + eigen_gpu_allocator); + if (!s.ok()) { + SetStatus(s); + } } if (params_->record_tensor_accesses) { referenced_tensors_.Init(); @@ -525,10 +527,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input( return nullptr; } } - // TODO(rmlarsen): Use MakeUnique here. There is already a copy in - // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of - // general cleanup of ownership in this code. - std::unique_ptr<Tensor> output_tensor(new Tensor()); + + auto output_tensor = MakeUnique<Tensor>(); CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); return output_tensor; } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index e752599de1..4bbd6c3d7d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -372,18 +372,37 @@ class OpKernelConstruction { template <typename ListType, typename ElementType> class OpArgIterator { public: - typedef OpArgIterator<ListType, ElementType> ME; + using iterator_category = std::forward_iterator_tag; + using value_type = ElementType; + using pointer = ElementType*; + using reference = ElementType&; + using difference_type = ptrdiff_t; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} - bool operator==(const ME& rhs) { + + bool operator==(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ == rhs.i_; } - bool operator!=(const ME& rhs) { + + bool operator!=(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ != rhs.i_; } - void operator++() { ++i_; } - ElementType& operator*() { return (*list_)[i_]; } + + OpArgIterator operator++() { // prefix ++it + ++i_; + return *this; + } + + OpArgIterator operator++(int) { // postfix it++ + OpArgIterator old_value = *this; + ++i_; + return old_value; + } + + reference operator*() { return (*list_)[i_]; } + pointer operator->() { return &(*list_)[i_]; } private: const ListType* const list_; @@ -394,7 +413,7 @@ class OpArgIterator { // that are passed to the op as a single named argument. class OpInputList { public: - typedef OpArgIterator<OpInputList, const Tensor&> Iterator; + typedef OpArgIterator<OpInputList, const Tensor> Iterator; OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} OpInputList(OpKernelContext* ctx, int start, int stop) : ctx_(ctx), start_(start), stop_(stop) {} diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index dfc5aa7747..75ed4a4eaf 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) { delete item; } +bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op) { + // OpSegment should not own kernel if the node is stateless, or a function. + return lib->IsStateful(node_op) && + lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr; +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h index 4433a2554f..37d939ea2b 100644 --- a/tensorflow/core/framework/op_segment.h +++ b/tensorflow/core/framework/op_segment.h @@ -60,6 +60,10 @@ class OpSegment { Status FindOrCreate(const string& session_handle, const string& node_name, OpKernel** kernel, CreateKernelFn create_fn); + // Returns true if OpSegment should own the kernel. + static bool ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op); + private: // op name -> OpKernel typedef std::unordered_map<string, OpKernel*> KernelMap; diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 0a19861efd..508a8d3149 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -271,7 +271,7 @@ string ContainerInfo::DebugString() const { "]"); } -ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) { +const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat<ResourceHandle>()(0); } @@ -288,4 +288,13 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { return ctx->resource_manager()->Delete(p); } +Status ResourceHandlesShape(shape_inference::InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + for (int i = 0; i < n; ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index f8a587c9b5..abb6635984 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ +#include <memory> #include <string> #include <typeindex> #include <typeinfo> @@ -79,7 +80,7 @@ class ResourceBase : public core::RefCounted { virtual string DebugString() = 0; // Returns memory used by this resource. - virtual int64 MemoryUsed() const { return 0; }; + virtual int64 MemoryUsed() const { return 0; } }; // Container used for per-step resources. @@ -127,6 +128,14 @@ class ResourceMgr { Status Lookup(const string& container, const string& name, T** resource) const TF_MUST_USE_RESULT; + // Similar to Lookup, but looks up multiple resources at once, with only a + // single lock acquisition. + template <typename T> + Status LookupMany(absl::Span<std::pair<const string*, const string*> const> + containers_and_names, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* + resource) const TF_MUST_USE_RESULT; + // If "container" has a resource "name", returns it in // "*resource". Otherwise, invokes creator() to create the resource. // The caller takes the ownership of one ref on "*resource". @@ -234,7 +243,7 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, const string& name); // Returns a resource handle from a numbered op input. -ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); +const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle); @@ -246,6 +255,12 @@ Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); template <typename T> Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); +// Looks up multiple resources pointed by a sequence of resource handles. +template <typename T> +Status LookupResources( + OpKernelContext* ctx, absl::Span<ResourceHandle const> p, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values); + // Looks up or creates a resource. template <typename T> Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, @@ -348,6 +363,8 @@ class ResourceHandleOp : public OpKernel { void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + private: string container_; string name_; @@ -356,6 +373,26 @@ class ResourceHandleOp : public OpKernel { std::atomic<bool> initialized_{false}; }; +// Utility op kernel to produce a handle to a resource of type T. +template <typename T> +class ResourceHandlesOp : public OpKernel { + public: + explicit ResourceHandlesOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + std::vector<string> containers_; + std::vector<string> names_; + mutex mutex_; + std::vector<Tensor> resources_; + std::atomic<bool> initialized_{false}; +}; + +Status ResourceHandlesShape(shape_inference::InferenceContext* c); + // Registers a kernel for an op which produces a handle to a resource of the // specified type. #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \ @@ -388,6 +425,24 @@ Status ResourceMgr::Lookup(const string& container, const string& name, } template <typename T> +Status ResourceMgr::LookupMany( + absl::Span<std::pair<const string*, const string*> const> + containers_and_names, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const { + CheckDeriveFromResourceBase<T>(); + tf_shared_lock l(mu_); + resources->resize(containers_and_names.size()); + for (size_t i = 0; i < containers_and_names.size(); ++i) { + T* resource; + TF_RETURN_IF_ERROR(LookupInternal(*containers_and_names[i].first, + *containers_and_names[i].second, + &resource)); + (*resources)[i].reset(resource); + } + return Status::OK(); +} + +template <typename T> Status ResourceMgr::LookupInternal(const string& container, const string& name, T** resource) const { ResourceBase* found = nullptr; @@ -497,6 +552,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, } template <typename T> +Status LookupResources( + OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) { + std::vector<std::pair<const string*, const string*>> containers_and_names( + p.size()); + for (size_t i = 0; i < p.size(); ++i) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i])); + containers_and_names[i] = {&p[i]->container(), &p[i]->name()}; + } + return ctx->resource_manager()->LookupMany(containers_and_names, values); +} + +template <typename T> Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, T** value, std::function<Status(T**)> creator) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); @@ -553,6 +621,46 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) { ctx->set_output(0, resource_); } +template <typename T> +ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context) + : OpKernel(context) { + int n; + OP_REQUIRES_OK(context, context->GetAttr("N", &n)); + OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_)); + OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_)); + OP_REQUIRES( + context, containers_.size() == n, + errors::InvalidArgument("Number of containers (", containers_.size(), + ") must be equal to N (", n, ")")); + OP_REQUIRES(context, names_.size() == n, + errors::InvalidArgument("Number of names (", containers_.size(), + ") must be equal to N (", n, ")")); + resources_.resize(n); +} + +template <typename T> +void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) { + if (!initialized_.load()) { + mutex_lock ml(mutex_); + // Checking again to see if another thread has initialized the resource. + if (!initialized_.load()) { + AllocatorAttributes attr; + attr.set_on_host(true); + for (size_t i = 0; i < resources_.size(); ++i) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &resources_[i], attr)); + ResourceHandle h = + MakeResourceHandle<T>(ctx, containers_[i], names_[i]); + resources_[i].template scalar<ResourceHandle>()() = h; + } + initialized_.store(true); + } + } + for (size_t i = 0; i < resources_.size(); ++i) { + ctx->set_output(i, resources_[i]); + } +} + } // end namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h index 4a18efc940..af53ed0a3c 100644 --- a/tensorflow/core/framework/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -25,6 +25,8 @@ namespace tensorflow { class Summary; +namespace data { + // A `StatsAggregator` accumulates statistics incrementally. A // `StatsAggregator` can accumulate multiple different statistics, distinguished // by a string name. @@ -87,6 +89,7 @@ class StatsAggregatorResource : public ResourceBase { const std::shared_ptr<StatsAggregator> stats_aggregator_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 516afa517d..1dea6da911 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const { return ret; } +Tensor Tensor::SubSlice(int64 index) const { + CHECK_GE(dims(), 1); // Crash ok. + CHECK_LE(0, index); // Crash ok. + int64 dim0_size = shape_.dim_size(0); + CHECK_LE(index, dim0_size); // Crash ok. + Tensor ret; + ret.shape_ = shape_; + ret.shape_.RemoveDim(0); + ret.set_dtype(dtype()); + ret.buf_ = nullptr; + if (dim0_size > 0) { + const int64 elems_per_dim0 = NumElements() / dim0_size; + const int64 delta = index * elems_per_dim0; + const int64 num_elems = elems_per_dim0; + if (buf_) { + DataType dt = dtype(); + CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems)); + } + } + return ret; +} + bool Tensor::FromProto(const TensorProto& proto) { return FromProto(cpu_allocator(), proto); } @@ -948,9 +970,69 @@ void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape, } } +// Appends the spacing between elements for a given dim onto a result string +void PrintDimSpacing(int dim_index, int num_dims, string* result) { + if (dim_index == num_dims - 1) { + strings::StrAppend(result, " "); + return; + } + for (int j = 0; j < num_dims - dim_index - 1; j++) { + strings::StrAppend(result, "\n"); + } + for (int j = 0; j <= dim_index; j++) { + strings::StrAppend(result, " "); + } +} + +// Print from left dim to right dim recursively. +template <typename T> +void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape, + int64 num_elts_at_ends, int num_dims, const T* data, + int64 data_index, string* result) { + // We have recursed beyond all the dimensions into a single element + // of the tensor. + if (dim_index == num_dims) { + strings::StrAppend(result, PrintOneElement(data[data_index])); + return; + } + + strings::StrAppend(result, "["); + int64 element_count = shape[dim_index]; + int64 start_of_end = + std::max(num_elts_at_ends, element_count - num_elts_at_ends); + + // Loop every element of one dim. + int64 elements_per_iter = 1; + for (int i = dim_index + 1; i < num_dims; i++) { + elements_per_iter *= shape[i]; + } + for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) { + if (i > 0) { + PrintDimSpacing(dim_index, num_dims, result); + } + + // As for each element, print the sub-dim. + PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data, + data_index + elements_per_iter * i, result); + } + if (element_count > 2 * num_elts_at_ends) { + PrintDimSpacing(dim_index, num_dims, result); + strings::StrAppend(result, "..."); + } + for (int64 i = start_of_end; i < element_count; i++) { + // As for each element, print the sub-dim. + PrintDimSpacing(dim_index, num_dims, result); + PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data, + data_index + elements_per_iter * i, result); + } + + strings::StrAppend(result, "]"); +} + template <typename T> string SummarizeArray(int64 limit, int64 num_elts, - const TensorShape& tensor_shape, const char* data) { + const TensorShape& tensor_shape, const char* data, + const bool print_v2) { string ret; const T* array = reinterpret_cast<const T*>(data); @@ -963,17 +1045,26 @@ string SummarizeArray(int64 limit, int64 num_elts, if (num_elts > limit) strings::StrAppend(&ret, "..."); return ret; } - int64 data_index = 0; - const int shape_size = tensor_shape.dims(); - PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret); + if (print_v2) { + const int num_dims = tensor_shape.dims(); + PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret); + } else { + int64 data_index = 0; + const int shape_size = tensor_shape.dims(); + PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret); + + if (num_elts > limit) strings::StrAppend(&ret, "..."); + } - if (num_elts > limit) strings::StrAppend(&ret, "..."); return ret; } } // namespace -string Tensor::SummarizeValue(int64 max_entries) const { +string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const { const int64 num_elts = NumElements(); + if (max_entries < 0) { + max_entries = num_elts; + } size_t limit = std::min(max_entries, num_elts); if ((limit > 0) && (buf_ == nullptr)) { return strings::StrCat("uninitialized Tensor of ", num_elts, @@ -982,50 +1073,54 @@ string Tensor::SummarizeValue(int64 max_entries) const { const char* data = limit > 0 ? tensor_data().data() : nullptr; switch (dtype()) { case DT_HALF: - return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data); + return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data, + print_v2); break; case DT_FLOAT: - return SummarizeArray<float>(limit, num_elts, shape_, data); + return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2); break; case DT_DOUBLE: - return SummarizeArray<double>(limit, num_elts, shape_, data); + return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2); break; case DT_UINT32: - return SummarizeArray<uint32>(limit, num_elts, shape_, data); + return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2); break; case DT_INT32: - return SummarizeArray<int32>(limit, num_elts, shape_, data); + return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2); break; case DT_UINT8: case DT_QUINT8: - return SummarizeArray<uint8>(limit, num_elts, shape_, data); + return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2); break; case DT_UINT16: case DT_QUINT16: - return SummarizeArray<uint16>(limit, num_elts, shape_, data); + return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2); break; case DT_INT16: case DT_QINT16: - return SummarizeArray<int16>(limit, num_elts, shape_, data); + return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2); break; case DT_INT8: case DT_QINT8: - return SummarizeArray<int8>(limit, num_elts, shape_, data); + return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2); break; case DT_UINT64: - return SummarizeArray<uint64>(limit, num_elts, shape_, data); + return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2); break; case DT_INT64: - return SummarizeArray<int64>(limit, num_elts, shape_, data); + return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2); break; case DT_BOOL: // TODO(tucker): Is it better to emit "True False..."? This // will emit "1 0..." which is more compact. - return SummarizeArray<bool>(limit, num_elts, shape_, data); + return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2); break; default: { // All irregular cases string ret; + if (print_v2) { + strings::StrAppend(&ret, "["); + } // TODO(irving): Don't call flat every time around this // loop. for (size_t i = 0; i < limit; ++i) { @@ -1045,6 +1140,9 @@ string Tensor::SummarizeValue(int64 max_entries) const { } } if (max_entries < num_elts) strings::StrAppend(&ret, "..."); + if (print_v2) { + strings::StrAppend(&ret, "]"); + } return ret; } } diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 1b19ab5da3..d0f9eb56e2 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -37,11 +37,12 @@ namespace tensorflow { class AllocationDescription; class Allocator; class OpKernelContext; +class Tensor; class TensorBuffer; class TensorCApi; class TensorDescription; class TensorProto; -class VariantTensorData; + namespace batch_util { Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); @@ -153,7 +154,7 @@ class Tensor { /// Returns the estimated memory usage of this tensor. size_t TotalBytes() const; - // Returns the size of sallocated memory for this tensor. + // Returns the size of allocated memory for this tensor. size_t AllocatedBytes() const; /// Returns true iff this tensor is aligned. @@ -199,10 +200,29 @@ class Tensor { /// must check the returned tensor's alignment before calling certain /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). /// + /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor + /// also with N dimensions. If you want to select a sub tensor, see SubSlice. + /// /// REQUIRES: `dims()` >= 1 /// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)` Tensor Slice(int64 dim0_start, int64 dim0_limit) const; + /// \brief Select a subslice from this tensor along the 1st dimension. + /// + /// When fed with an N-dimensional tensor, this method returns a tensor with + /// N-1 dimensions, where the returned tensor is a subslice of the input + /// tensor along the first dimension. The N-1 dimensions of the returned + /// tensor are the last N-1 dimensions of the input tensor. + /// + /// NOTE: The returned tensor may not satisfy the same alignment + /// requirement as this tensor depending on the shape. The caller + /// must check the returned tensor's alignment before calling certain + /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). + /// + /// REQUIRES: `dims()` >= 1 + /// REQUIRES: `0 <= dim0_start < dim_size(0)` + Tensor SubSlice(int64 index) const; + /// \brief Parse `other` and construct the tensor. /// Returns `true` iff the parsing succeeds. If the parsing fails, @@ -429,7 +449,7 @@ class Tensor { int64 begin) const; /// Render the first `max_entries` values in `*this` into a string. - string SummarizeValue(int64 max_entries) const; + string SummarizeValue(int64 max_entries, bool print_v2 = false) const; /// A human-readable summary of the tensor suitable for debugging. string DebugString() const; diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 84a373c196..c596604143 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/math/math_util.h" @@ -1227,6 +1228,45 @@ TEST(Tensor, Slice_Basic) { } } +TEST(Tensor, SubSlice_Basic) { + { // General + Tensor x(DT_FLOAT, TensorShape({10, 4, 36})); + // Fills in known values. + for (int i = 0; i < 10; ++i) { + x.SubSlice(i).flat<float>().setConstant(i * 1.f); + } + // A simple sub-slice along dim0. + Tensor y = x.SubSlice(5); + EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 36}))); + auto tx = x.tensor<float, 3>(); + auto ty = y.tensor<float, 2>(); + for (int j = 0; j < 4; ++j) { + for (int k = 0; k < 36; ++k) { + EXPECT_EQ(ty(j, k), 5.0); + EXPECT_EQ(&tx(5, j, k), &ty(j, k)); + } + } + Tensor z = y.SubSlice(3).SubSlice(31); + auto tz = z.unaligned_flat<float>(); + EXPECT_EQ(*tz.data(), 5.0); + } + { + // Test unaligned access via a SubSlice. + Tensor x(DT_FLOAT, TensorShape({30, 5})); + x.flat<float>().setConstant(0.0); + + // Take an unaligned subslice. + Tensor y = x.SubSlice(1); +#if EIGEN_MAX_ALIGN_BYTES > 0 + EXPECT_FALSE(y.IsAligned()); +#endif + y.unaligned_flat<float>().setConstant(1.0); + for (int64 i = 0; i < y.NumElements(); ++i) { + EXPECT_EQ(1.0, y.unaligned_flat<float>()(i)); + } + } +} + template <typename T> Tensor MkTensor(DataType dt, const TensorShape& shape, std::vector<T> init_values) { @@ -1294,6 +1334,63 @@ TEST(SummarizeValue, STRING) { EXPECT_EQ("one two three four five one...", x.SummarizeValue(6)); } +TEST(SummarizeValue, INT32_PRINT_V2) { + Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true)); + EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true)); + EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true)); + EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true)); + x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true)); + x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]", + x.SummarizeValue(16, true)); + x = MkTensor<int>(DT_INT32, TensorShape({0}), {}); + EXPECT_EQ("[]", x.SummarizeValue(16, true)); +} + +TEST(SummarizeValue, INT32Dims_PRINT_V2) { + Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}), + {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}); + EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true)); + EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]", + x.SummarizeValue(10, true)); + EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]", + x.SummarizeValue(-1, true)); +} + +TEST(SummarizeValue, FLOAT_PRINT_V2) { + Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true)); + EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true)); + EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true)); + EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true)); + x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true)); + x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0}); + EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]", + x.SummarizeValue(16, true)); + x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {}); + EXPECT_EQ("[]", x.SummarizeValue(16, true)); +} + +TEST(SummarizeValue, BOOL_PRINT_V2) { + Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true}); + EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true)); + EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true)); + EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true)); +} + +TEST(SummarizeValue, STRING_PRINT_V2) { + Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}), + {"one", "two", "three", "four", "five"}); + EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true)); + EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true)); + x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}), + {"one", "two", "three", "four", "five"}); + EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true)); +} + void BM_CreateAndDestroy(int iters) { TensorShape shape({10, 20}); while (--iters) { diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 4bda8f9eb8..a7cf600bab 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include <vector> diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 15b1add2c1..2e96b05787 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -39,6 +38,8 @@ limitations under the License. namespace tensorflow { +class Variant; + // MemoryType is used to describe whether input or output Tensors of // an OpKernel should reside in "Host memory" (e.g., CPU memory) or // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index 5a507804b0..d43e3c72ec 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -23,11 +23,11 @@ limitations under the License. namespace tensorflow { -bool Variant::TryDecode(Variant* out) const { - const VariantTensorDataProto* p = get<VariantTensorDataProto>(); - if (p == nullptr) return false; - VariantTensorData data(*p); - return out->Decode(data); +bool Variant::Decode(VariantTensorData data) { + if (!is_empty()) { + return value_->Decode(std::move(data)); + } + return true; } template <> @@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) { template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) { - data->FromProto(value); + data->FromConstProto(value); } template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value) { - data.ToProto(value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) { + data->ToProto(value); return true; } @@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) { } template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { - return value->ParseFromString(buf); +bool DecodeVariant(string* buf, VariantTensorDataProto* value) { + return value->ParseFromString(*buf); } void EncodeVariantList(const Variant* variant_array, int64 n, @@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, if (variant_array[i].is_empty()) { variant_array[i] = VariantTensorDataProto(); } + // TODO(ebrevdo): Replace with StringPiece? Any way to make this a + // zero-copy operation that keeps a reference to the data in d? string str(d->Data(sizes[i]), sizes[i]); - if (!variant_array[i].Decode(str)) return false; + if (!variant_array[i].Decode(std::move(str))) return false; if (!DecodeUnaryVariant(&variant_array[i])) { LOG(ERROR) << "Could not decode variant with type_name: \"" << variant_array[i].TypeName() diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index 52732801a0..10eabbc85f 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -23,7 +23,6 @@ limitations under the License. #include <unordered_map> #include <utility> -#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/core/status.h" @@ -38,17 +37,19 @@ string TypeNameVariant(const T& value); template <typename T> string DebugStringVariant(const T& value); +// Allows for specializations of Variant Decoding. `data` may be modified in +// the process of decoding to `value`. template <typename T> -void EncodeVariant(const T& value, VariantTensorData* data); +bool DecodeVariant(VariantTensorData* data, T* value); template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value); +bool DecodeVariant(string* buf, T* value); template <typename T> -void EncodeVariant(const T& value, string* buf); +void EncodeVariant(const T& value, VariantTensorData* data); template <typename T> -bool DecodeVariant(const string& buf, T* value); +void EncodeVariant(const T& value, string* buf); // This is an implementation of a type-erased container that can store an // object of any type. The implementation is very similar to std::any, but has @@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value); // // string TypeName() const; // void Encode(VariantTensorData* data) const; -// void Decode(const VariantTensorData& data); +// void Decode(VariantTensorData data); // // Simple POD types can elide the Encode/Decode functions, they are provided by // helper methods. @@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value); // x.Encode(&serialized_f); // // Variant y = Foo(); // default constructed Foo. -// y.Decode(&serialized_f); +// y.Decode(std::move(serialized_f)); // EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>()); // // @@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value); // EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. // EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(), // y_type_unknown.TypeId()); -// // Decode and get y_type_unknown; compare to value in x. -// Foo f_decoded; -// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded)); -// EXPECT_EQ(f_decoded, f); // class Variant { public: @@ -241,12 +238,7 @@ class Variant { } // Deserialize `data` and update the stored object. - bool Decode(const VariantTensorData& data) { - if (!is_empty()) { - return value_->Decode(data); - } - return true; - } + bool Decode(VariantTensorData data); // Helper methods to directly serialize/deserialize from strings. void Encode(string* buf) const { @@ -254,31 +246,13 @@ class Variant { value_->Encode(buf); } } - bool Decode(const string& buf) { + bool Decode(string buf) { if (!is_empty()) { - return value_->Decode(buf); + return value_->Decode(std::move(buf)); } return true; } - template <typename T> - bool MaybeDecodeAndCopy(T* out) const { - const T* ret = get<T>(); - if (ret != nullptr) { - *out = std::move(*ret); - return true; - }; - Variant decoded = T(); - if (!TryDecode(&decoded)) return false; - T* decoded_ret = decoded.get<T>(); - CHECK_NOTNULL(decoded_ret); - *out = std::move(*decoded_ret); - return true; - } - - private: - bool TryDecode(Variant* out) const; - private: struct in_place_t {}; static constexpr in_place_t in_place{}; @@ -292,9 +266,9 @@ class Variant { virtual string TypeName() const = 0; virtual string DebugString() const = 0; virtual void Encode(VariantTensorData* data) const = 0; - virtual bool Decode(const VariantTensorData& data) = 0; + virtual bool Decode(VariantTensorData data) = 0; virtual void Encode(string* buf) const = 0; - virtual bool Decode(const string& data) = 0; + virtual bool Decode(string data) = 0; }; template <typename T> @@ -325,15 +299,13 @@ class Variant { EncodeVariant(value, data); } - bool Decode(const VariantTensorData& data) override { - return DecodeVariant(data, &value); + bool Decode(VariantTensorData data) override { + return DecodeVariant(&data, &value); } void Encode(string* buf) const override { EncodeVariant(value, buf); } - bool Decode(const string& buf) override { - return DecodeVariant(buf, &value); - } + bool Decode(string buf) override { return DecodeVariant(&buf, &value); } T value; }; diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h index f155aa4892..5e08e5a7a6 100644 --- a/tensorflow/core/framework/variant_encode_decode.h +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" @@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value, // Specialization for POD type template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, true /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { @@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for tensorflow::Tensor template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, true /* Tensor */, false /* protobuf */>, T* value) { @@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for protobuf template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, true /* protobuf */>, T* value) { @@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for other types template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { - return value->Decode(data); + return value->Decode(std::move(data)); } template <typename C, typename = void> @@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) { } template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value) { - return DecodeVariantImpl(data, TypeResolver<T>(), value); +bool DecodeVariant(VariantTensorData* data, T* value) { + return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value); } template <typename T> @@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) { } template <typename T> -bool DecodeVariant(const string& buf, T* value) { +bool DecodeVariant(string* buf, T* value) { VariantTensorData data; - if (!data.ParseFromString(buf)) return false; - if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false; + if (!data.ParseFromString(*buf)) return false; + if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) { + return false; + } return true; } // Specializations for VariantTensorDataProto template <> string TypeNameVariant(const VariantTensorDataProto& value); + template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data); + template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); + template <> void EncodeVariant(const VariantTensorDataProto& value, string* buf); + template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value); +bool DecodeVariant(string* buf, VariantTensorDataProto* value); // Encodes an array of Variant objects in to the given StringListEncoder. // `variant_array` is assumed to point to an array of `n` Variant objects. diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 60fa7bd559..daa744e877 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyCPUToGPU); + StoredTensorValue::CopyCPUToGPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST, - "StoredTensorValue", StoredTensorValue::CopyGPUToCPU); + StoredTensorValue::CopyGPUToCPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyGPUToGPU); + StoredTensorValue::CopyGPUToGPU); REGISTER_OP("CreateTestVariant") .Input("input: T") diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index ee07db1aee..ef5b240aea 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { } UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( - StringPiece type_name) { - auto found = shape_fns.find(type_name); + const TypeIndex& type_index) { + auto found = shape_fns.find(type_index); if (found == shape_fns.end()) return nullptr; return &found->second; } -void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, +void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index, const VariantShapeFn& shape_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape"; - VariantShapeFn* existing = GetShapeFn(type_name); + VariantShapeFn* existing = GetShapeFn(type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantShapeFn for type_name: " << type_name - << " already registered"; - shape_fns.insert(std::pair<StringPiece, VariantShapeFn>( - GetPersistentStringPiece(type_name), shape_fn)); + << "Unary VariantShapeFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered"; + shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn)); } Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { @@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { CHECK_EQ(variant_tensor.dims(), 0); const Variant& v = variant_tensor.scalar<Variant>()(); UnaryVariantOpRegistry::VariantShapeFn* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName()); + UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId()); if (shape_fn == nullptr) { return errors::Internal( - "No unary variant shape function found for Variant type_name: ", - v.TypeName()); + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(v.TypeId().name())); } return (*shape_fn)(v, shape); } @@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) { } // namespace #define REGISTER_VARIANT_SHAPE_TYPE(T) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>); + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>); // No encode/shape registered for std::complex<> and Eigen::half // objects yet. @@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double); UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* UnaryVariantOpRegistry::GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name) { - auto found = device_copy_fns.find(std::make_pair(direction, type_name)); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { + auto found = device_copy_fns.find(std::make_pair(direction, type_index)); if (found == device_copy_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterDeviceCopyFn( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy"; - AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name); + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); CHECK_EQ(existing, nullptr) << "UnaryVariantDeviceCopy for direction: " << direction - << " and type_name: " << type_name << " already registered"; + << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) + << " already registered"; device_copy_fns.insert( - std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn>( - std::make_pair(direction, GetPersistentStringPiece(type_name)), - device_copy_fn)); + std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index), + device_copy_fn)); } Status VariantDeviceCopy( @@ -170,35 +167,34 @@ Status VariantDeviceCopy( const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction, - from.TypeName()); + from.TypeId()); if (device_copy_fn == nullptr) { return errors::Internal( "No unary variant device copy function found for direction: ", - direction, " and Variant type_name: ", from.TypeName()); + direction, " and Variant type_index: ", + port::MaybeAbiDemangle(from.TypeId().name())); } return (*device_copy_fn)(from, to, copy_fn); } // Special casing UnaryOpFn per op and per device. UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( - VariantUnaryOp op, StringPiece device, StringPiece type_name) { - auto found = unary_op_fns.find({op, device, type_name}); + VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) { + auto found = unary_op_fns.find({op, device, type_index}); if (found == unary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterUnaryOpFn( - VariantUnaryOp op, const string& device, const string& type_name, + VariantUnaryOp op, const string& device, const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; - VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantUnaryOpFn for type_name: " << type_name + << "Unary VariantUnaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - unary_op_fn)); + {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); } namespace { @@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ - DEVICE_CPU, T, TF_STR(T), \ + DEVICE_CPU, T, \ ZerosLikeVariantPrimitiveType<T>); // No zeros_like registered for std::complex<> or Eigen::half objects yet. @@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); // Special casing BinaryOpFn per op and per device. UnaryVariantOpRegistry::VariantBinaryOpFn* UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name) { - auto found = binary_op_fns.find({op, device, type_name}); + const TypeIndex& type_index) { + auto found = binary_op_fns.find({op, device, type_index}); if (found == binary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterBinaryOpFn( - VariantBinaryOp op, const string& device, const string& type_name, + VariantBinaryOp op, const string& device, const TypeIndex& type_index, const VariantBinaryOpFn& add_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; - VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantBinaryOpFn for type_name: " << type_name + << "Unary VariantBinaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - add_fn)); + {op, GetPersistentStringPiece(device), type_index}, add_fn)); } namespace { @@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, #define REGISTER_VARIANT_ADD_TYPE(T) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ - T, TF_STR(T), \ - AddVariantPrimitiveType<T>); + T, AddVariantPrimitiveType<T>); // No add registered for std::complex<> or Eigen::half objects yet. REGISTER_VARIANT_ADD_TYPE(int); diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index e6a2665a56..7eb37e859f 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -22,10 +22,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/abi.h" namespace tensorflow { @@ -90,10 +94,11 @@ class UnaryVariantOpRegistry { AsyncVariantDeviceCopyFn; // Add a shape lookup function to the registry. - void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); + void RegisterShapeFn(const TypeIndex& type_index, + const VariantShapeFn& shape_fn); - // Returns nullptr if no shape function was found for the given TypeName. - VariantShapeFn* GetShapeFn(StringPiece type_name); + // Returns nullptr if no shape function was found for the given TypeIndex. + VariantShapeFn* GetShapeFn(const TypeIndex& type_index); // Add a decode function to the registry. void RegisterDecodeFn(const string& type_name, @@ -104,33 +109,33 @@ class UnaryVariantOpRegistry { // Add a copy-to-GPU function to the registry. void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, - const string& type_name, + const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn); // Returns nullptr if no copy function was found for the given // TypeName and direction. AsyncVariantDeviceCopyFn* GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index); // Add a unary op function to the registry. void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn); // Returns nullptr if no unary op function was found for the given // op, device, and TypeName. VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Add a binary op function to the registry. void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantBinaryOpFn& add_fn); // Returns nullptr if no binary op function was found for the given // op, device and TypeName. VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Get a pointer to a global UnaryVariantOpRegistry object static UnaryVariantOpRegistry* Global(); @@ -145,24 +150,26 @@ class UnaryVariantOpRegistry { static std::unordered_set<string>* PersistentStringStorage(); private: - std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns; - std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher> - decode_fns; + struct TypeIndexHash { + std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } + }; + + gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns; + gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns; // Map std::pair<Direction, type_name> to function. struct PairHash { template <typename Direction> - std::size_t operator()(const std::pair<Direction, StringPiece>& x) const { + std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); - ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + ret = Hash64Combine(ret, std::get<1>(x).hash_code()); return ret; } - StringPieceHasher sp_hasher_; }; - std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn, PairHash> + gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn, PairHash> device_copy_fns; // Map std::tuple<Op, device, type_name> to function. @@ -172,10 +179,11 @@ class UnaryVariantOpRegistry { // and references therein template <typename Op> struct FuncTuple { - FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname) - : op_type_(op), device_(dev), typename_(tname){}; + FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index) + : op_type_(op), device_(dev), type_index_(type_index) {} Op op_type_; - StringPiece device_, typename_; + StringPiece device_; + TypeIndex type_index_; }; // friend declaration for operator== // needed for clang @@ -184,11 +192,11 @@ class UnaryVariantOpRegistry { struct TupleHash { template <typename Op> std::size_t operator()( - const std::tuple<Op, StringPiece, StringPiece>& x) const { + const std::tuple<Op, StringPiece, TypeIndex>& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); - ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); + ret = Hash64Combine(ret, std::get<2>(x).hash_code()); return ret; } @@ -197,14 +205,14 @@ class UnaryVariantOpRegistry { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(x.op_type_); ret = Hash64Combine(ret, sp_hasher_(x.device_)); - ret = Hash64Combine(ret, sp_hasher_(x.typename_)); + ret = Hash64Combine(ret, x.type_index_.hash_code()); return ret; } StringPieceHasher sp_hasher_; }; - std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> + gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> unary_op_fns; - std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> + gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> binary_op_fns; // Find or insert a string into a persistent string storage @@ -225,7 +233,7 @@ template <typename Op> inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs, const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) { return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && - (lhs.typename_ == rhs.typename_); + (lhs.type_index_ == rhs.type_index_); } // Gets a TensorShape from a Tensor containing a scalar Variant. // Returns an Internal error if the Variant does not have a registered shape @@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, Variant* v_out) { const string& device = DeviceName<Device>::value; UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = - UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); if (unary_op_fn == nullptr) { return errors::Internal( "No unary variant unary_op function found for unary variant op enum: ", @@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, template <typename Device> Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, const Variant& a, const Variant& b, Variant* out) { - if (a.TypeName() != b.TypeName()) { + if (a.TypeId() != b.TypeId()) { return errors::Internal( "BianryOpVariants: Variants a and b have different " - "type names: '", + "type ids. Type names: '", a.TypeName(), "' vs. '", b.TypeName(), "'"); } const string& device = DeviceName<Device>::value; UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = - UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); if (binary_op_fn == nullptr) { return errors::Internal( "No unary variant binary_op function found for binary variant op " @@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration { public: typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn; - UnaryVariantShapeRegistration(const string& type_name, + UnaryVariantShapeRegistration(const TypeIndex& type_index, const LocalVariantShapeFn& shape_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterShapeFn( - type_name, - [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status { + type_index, + [type_index_name, shape_fn](const Variant& v, + TensorShape* s) -> Status { const T* t = v.get<T>(); if (t == nullptr) { return errors::Internal( - "VariantShapeFn: Could not access object, type_name: ", - type_name); + "VariantShapeFn: Could not access object, type_index: ", + type_index_name); } return shape_fn(*t, s); }); @@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration { return false; } Variant decoded = T(); - VariantTensorData data(*t); - if (!decoded.Decode(data)) { + VariantTensorData data(std::move(*t)); + if (!decoded.Decode(std::move(data))) { return false; } - *v = std::move(decoded); + std::swap(decoded, *v); return true; }); } @@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration { UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)> LocalVariantDeviceCopyFn; UnaryVariantDeviceCopyRegistration( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const LocalVariantDeviceCopyFn& device_copy_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( - direction, type_name, - [type_name, device_copy_fn]( + direction, type_index, + [type_index_name, device_copy_fn]( const Variant& from, Variant* to, UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn device_copy_tensor_fn) -> Status { @@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration { *to = T(); if (from.get<T>() == nullptr) { return errors::Internal( - "VariantCopyToGPUFn: Could not access object, type_name: ", - type_name); + "VariantCopyToGPUFn: Could not access object, type_index: ", + type_index_name); } const T& t = *from.get<T>(); T* t_out = to->get<T>(); @@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration { public: UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantUnaryOpFn& unary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( - op, device, type_name, - [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, - Variant* v_out) -> Status { + op, device, type_index, + [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, + Variant* v_out) -> Status { DCHECK_NE(v_out, nullptr); *v_out = T(); if (v.get<T>() == nullptr) { return errors::Internal( - "VariantUnaryOpFn: Could not access object, type_name: ", - type_name); + "VariantUnaryOpFn: Could not access object, type_index: ", + type_index_name); } const T& t = *v.get<T>(); T* t_out = v_out->get<T>(); @@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration { public: UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantBinaryOpFn& binary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( - op, device, type_name, - [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, - const Variant& b, Variant* out) -> Status { + op, device, type_index, + [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, + const Variant& b, + Variant* out) -> Status { DCHECK_NE(out, nullptr); *out = T(); if (a.get<T>() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'a', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'a', type_index: ", + type_index_name); } if (b.get<T>() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'b', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'b', type_index: ", + type_index_name); } const T& t_a = *a.get<T>(); const T& t_b = *b.get<T>(); @@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration { // Register a unary shape variant function with the signature: // Status ShapeFn(const T& t, TensorShape* s); -// to Variants having TypeName type_name. -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \ - shape_function) +// to Variants having TypeIndex type_index. +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, MakeTypeIndex<T>(), shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \ - shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function) +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \ + shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \ +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \ shape_function) \ static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \ - register_unary_variant_op_shape_registration_fn_##ctr(type_name, \ + register_unary_variant_op_shape_registration_fn_##ctr(type_index, \ shape_function) // Register a unary decode variant function for the given type. @@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration { // ****** NOTE ****** // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. // ****** NOTE ****** -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - T, direction, type_name, device_copy_fn) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, T, direction, type_name, device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \ + device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn) #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - ctr, T, direction, type_name, device_copy_fn) \ + ctr, T, direction, type_index, device_copy_fn) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) + ctr, T, direction, type_index, device_copy_fn) -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantDeviceCopyRegistration<T> \ - register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \ - device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_index, device_copy_fn) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantDeviceCopyRegistration<T> \ + register_unary_variant_op_device_copy_fn_##ctr( \ + direction, type_index, device_copy_fn) // Register a unary unary_op variant function with the signature: // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for UnaryVariantOp enum op. -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ - unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function) -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ - unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_index, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \ + type_index, unary_op_function) #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, unary_op_function) \ + ctr, op, device, T, type_index, unary_op_function) \ static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ unary_op_function) // Register a binary_op variant function with the signature: // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for BinaryVariantOp enum OP. -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ - binary_op_function) \ - REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, binary_op_function) +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function) #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, binary_op_function) \ + ctr, op, device, T, type_index, binary_op_function) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) + ctr, op, device, T, type_index, binary_op_function) -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantBinaryOpRegistration<T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration<T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ binary_op_function) } // end namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 7055e62c0e..b2443e8676 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -89,41 +89,37 @@ struct VariantValue { int value; }; -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", - VariantValue::ShapeFn); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn); REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "TEST VariantValue", VariantValue::CPUToGPUCopyFn); + VariantValue::CPUToGPUCopyFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, VariantValue, - "TEST VariantValue", VariantValue::CPUZerosLikeFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, VariantValue, - "TEST VariantValue", VariantValue::GPUZerosLikeFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - VariantValue, "TEST VariantValue", - VariantValue::CPUAddFn); + VariantValue, VariantValue::CPUAddFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - VariantValue, "TEST VariantValue", - VariantValue::GPUAddFn); + VariantValue, VariantValue::GPUAddFn); } // namespace TEST(VariantOpShapeRegistryTest, TestBasic) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), + class Blah {}; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()), nullptr); - auto* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); + auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn( + MakeTypeIndex<VariantValue>()); EXPECT_NE(shape_fn, nullptr); TensorShape shape; @@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { TEST(VariantOpShapeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantShapeFn f; - string kTypeName = "fjfjfj"; - registry.RegisterShapeFn(kTypeName, f); - EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), - "fjfjfj already registered"); + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); + registry.RegisterShapeFn(kTypeIndex, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpDecodeRegistryTest, TestBasic) { @@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) { TEST(VariantOpCopyToGPURegistryTest, TestBasic) { // No registered copy fn for GPU<->GPU. - EXPECT_EQ( - UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), - nullptr); + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, + MakeTypeIndex<VariantValue>()), + nullptr); auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, + MakeTypeIndex<VariantValue>()); EXPECT_NE(copy_to_gpu_fn, nullptr); VariantValue vv{true /* early_exit */}; @@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) { TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, - kTypeName, f); + kTypeIndex, f); EXPECT_DEATH(registry.RegisterDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), - "fjfjfj already registered"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantUnaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_CPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_CPU, kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_GPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_GPU, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpAddRegistryTest, TestBasicCPU) { - return; + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpAddRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) { TEST(VariantOpAddRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantBinaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); } } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc index 99712dc114..3e67e4a864 100644 --- a/tensorflow/core/framework/variant_tensor_data.cc +++ b/tensorflow/core/framework/variant_tensor_data.cc @@ -22,8 +22,8 @@ namespace tensorflow { VariantTensorData::VariantTensorData() {} -VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) { - FromProto(proto); +VariantTensorData::VariantTensorData(VariantTensorDataProto proto) { + FromProto(std::move(proto)); } VariantTensorData::~VariantTensorData() {} @@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { } } -bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { +bool VariantTensorData::FromProto(VariantTensorDataProto proto) { + // TODO(ebrevdo): Do this lazily. + set_type_name(proto.type_name()); + set_metadata(proto.metadata()); + for (const auto& tensor : proto.tensors()) { + Tensor tmp; + if (!tmp.FromProto(tensor)) return false; + tensors_.push_back(tmp); + } + return true; +} + +bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) { set_type_name(proto.type_name()); set_metadata(proto.metadata()); for (const auto& tensor : proto.tensors()) { @@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) { return proto.SerializeToString(buf); } -bool VariantTensorData::ParseFromString(const string& s) { +bool VariantTensorData::ParseFromString(string s) { VariantTensorDataProto proto; const bool status = proto.ParseFromString(s); - if (status) FromProto(proto); + if (status) FromProto(std::move(proto)); return status; } diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h index 7500e77d43..8a240ee1e3 100644 --- a/tensorflow/core/framework/variant_tensor_data.h +++ b/tensorflow/core/framework/variant_tensor_data.h @@ -19,13 +19,13 @@ limitations under the License. #include <algorithm> #include <vector> +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { class VariantTensorDataProto; -class Tensor; // The serialization format for Variant objects. Objects with references to // other Tensors can simply store those tensors in the `tensors` field, and @@ -38,7 +38,7 @@ class Tensor; class VariantTensorData { public: VariantTensorData(); - VariantTensorData(const VariantTensorDataProto& proto); + VariantTensorData(VariantTensorDataProto proto); ~VariantTensorData(); // Name of the type of objects being serialized. @@ -68,12 +68,14 @@ class VariantTensorData { // Conversion to and from VariantTensorDataProto void ToProto(VariantTensorDataProto* proto) const; - bool FromProto(const VariantTensorDataProto& proto); + // This allows optimizations via std::move. + bool FromProto(VariantTensorDataProto proto); + bool FromConstProto(const VariantTensorDataProto& proto); // Serialization via VariantTensorDataProto string SerializeAsString() const; bool SerializeToString(string* buf); - bool ParseFromString(const string& s); + bool ParseFromString(string s); string DebugString() const; diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc index eef5c47d15..08d09de7b8 100644 --- a/tensorflow/core/framework/variant_test.cc +++ b/tensorflow/core/framework/variant_test.cc @@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) { struct TensorList { void Encode(VariantTensorData* data) const { data->tensors_ = vec; } - bool Decode(const VariantTensorData& data) { - vec = data.tensors_; + bool Decode(VariantTensorData data) { + vec = std::move(data.tensors_); return true; } @@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) { x.Encode(&serialized); Variant y = TensorList(); - y.Decode(serialized); + y.Decode(std::move(serialized)); const TensorList& decoded_vec = *y.get<TensorList>(); for (int i = 0; i < 4; ++i) { @@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) { EXPECT_EQ(y_unknown.DebugString(), strings::StrCat( "Variant<type: TensorList value: ", data.DebugString(), ">")); - - TensorList unknown_decoded_vec; - EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec)); - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i); - } - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i); - } } TEST(VariantTest, VariantArray) { diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index ee10194142..eeb5c14eaa 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() { } if (processed < node_defs_.size()) { - LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed) + LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed) << " NODES IN A CYCLE"; for (int64 i = 0; i < node_defs_.size(); i++) { if (pending_count_[i] != 0) { LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i]) - << "WITH PENDING COUNT = " << pending_count_[i]; + << " WITH PENDING COUNT = " << pending_count_[i]; } } return errors::InvalidArgument(node_defs_.size() - processed, @@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { const NodeDef* node_def = node_defs_[pair->second.gdef_index]; const OpDef* op_def; TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); - if (key.second >= op_def->output_arg_size()) { + int num_outputs; + TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs)); + if (key.second >= num_outputs) { // key's index out of bounds missing_unused_input_map_keys_->push_back(key); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 73142ebde7..3eef6bd2bd 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput") .Output("y: T") .Attr("T: {float, int64}") .SetShapeFn(shape_inference::UnchangedShape); +REGISTER_OP("TestVariadicOutput") + .Output("outputs: N * int32") + .Attr("N: int >= 0") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("TestDefaultAttr") .Attr("default_int: int=31415") .SetShapeFn(shape_inference::NoOutputs); @@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) { opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0); // Unused but not missing opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0); + // Unused but not missing + opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0); ExpectOK( R"EOF( node { name: 'W2' op: 'TestParams' } node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] } node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } - node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } + node { name: 'variadic' op: 'TestVariadicOutput' + attr { key: "N" value { i: 5 } } } )EOF", opts, &refiner, &results); diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 50fd6bae12..06d3fefef1 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" @@ -976,7 +977,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; // nodes. Do not change the ordering of the Mkl passes. const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); +#endif // ENABLE_MKL ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node @@ -3155,7 +3158,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; // nodes. Do not change the ordering of the Mkl passes. const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); +#endif // ENABLE_MKL ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index 7f96a18023..77640e287c 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" @@ -3606,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index b67a321fc1..8c5ffd71a3 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -133,7 +133,9 @@ class MklToTfConversionPass : public GraphOptimizationPass { // complete picture of inputs and outputs of the nodes in the graphs. const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); +#endif // ENABLE_MKL Status MklToTfConversionPass::InsertConversionNodeOnEdge( std::unique_ptr<Graph>* g, Edge* e) { diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index ebcb6de551..319437a801 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" @@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000); } // namespace } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ea7788f654..0a38aa1c91 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) { return ret; } +Node* CheckNumerics(Graph* g, Node* in, const string& message) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics") + .Input(in) + .Attr("message", message) + .Finalize(g, &ret)); + return ret; +} + +Node* Arg(Graph* g, int64 index, DataType type) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg") + .Attr("T", type) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + +Node* Retval(Graph* g, int64 index, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval") + .Input(in) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 8585b35a19..b00196f587 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -32,7 +32,7 @@ namespace test { namespace graph { // Converts "g" into its corresponding GraphDef "def". -// DEPRECATED: call g->ToGraphDef(def) instead. +ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.") void ToGraphDef(Graph* g, GraphDef* def); // A few helpers to construct a graph. @@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type); // Add a DiagPart node in "g". Node* DiagPart(Graph* g, Node* in, DataType type); +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64 index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64 index, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc index 7171ae059b..3b1d7d8347 100644 --- a/tensorflow/core/grappler/clusters/cluster.cc +++ b/tensorflow/core/grappler/clusters/cluster.cc @@ -83,6 +83,7 @@ void Cluster::DisableOptimizer(bool disable) { rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT); rewriter_config->set_shape_optimization(RewriterConfig::OFF); rewriter_config->set_remapping(RewriterConfig::OFF); + rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF); rewriter_config->mutable_auto_parallel()->set_enable(false); rewriter_config->clear_optimizers(); } else { diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc index b97603c890..e4f6bf7c86 100644 --- a/tensorflow/core/grappler/clusters/single_machine.cc +++ b/tensorflow/core/grappler/clusters/single_machine.cc @@ -93,13 +93,13 @@ Status SingleMachine::Provision() { strings::StrCat("Not able to parse GPU device name: ", dev.name())); } TfGpuId tf_gpu_id(parsed.id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + PlatformGpuId platform_gpu_id; + Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id); if (!s.ok()) { return errors::Unavailable("Unknown TF GPU device with id ", tf_gpu_id.value(), ": ", s.ToString()); } - attr = GetLocalGPUInfo(cuda_gpu_id); + attr = GetLocalGPUInfo(platform_gpu_id); } else if (dev.device_type().find("XLA") == string::npos) { // Filter out the fake XLA devices to avoid double counting the actual // hardware resources that are available. diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc index a7519725a5..567e7c075e 100644 --- a/tensorflow/core/grappler/clusters/utils.cc +++ b/tensorflow/core/grappler/clusters/utils.cc @@ -70,13 +70,14 @@ DeviceProperties GetLocalCPUInfo() { return device; } -DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) { +DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) { DeviceProperties device; device.set_type("GPU"); #if GOOGLE_CUDA cudaDeviceProp properties; - cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value()); + cudaError_t error = + cudaGetDeviceProperties(&properties, platform_gpu_id.value()); if (error != cudaSuccess) { device.set_type("UNKNOWN"); LOG(ERROR) << "Failed to get device properties, error code: " << error; @@ -122,15 +123,15 @@ DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) { } else if (device.type == "GPU") { if (device.has_id) { TfGpuId tf_gpu_id(device.id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + PlatformGpuId platform_gpu_id; + Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id); if (!s.ok()) { LOG(ERROR) << s; return unknown; } - return GetLocalGPUInfo(cuda_gpu_id); + return GetLocalGPUInfo(platform_gpu_id); } else { - return GetLocalGPUInfo(CudaGpuId(0)); + return GetLocalGPUInfo(PlatformGpuId(0)); } } return unknown; diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h index ca15c48006..f0a342b728 100644 --- a/tensorflow/core/grappler/clusters/utils.h +++ b/tensorflow/core/grappler/clusters/utils.h @@ -28,7 +28,7 @@ DeviceProperties GetLocalCPUInfo(); // Returns the DeviceProperties for the specified GPU attached to the server on // which grappler is running. -DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id); +DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id); // Returns the DeviceProperties of the specified device DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device); diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc index 74218adbac..3863d62980 100644 --- a/tensorflow/core/grappler/clusters/utils_test.cc +++ b/tensorflow/core/grappler/clusters/utils_test.cc @@ -31,22 +31,22 @@ TEST(UtilsTest, GetLocalGPUInfo) { LOG(INFO) << "CUDA is enabled."; DeviceProperties properties; - // Invalid CUDA GPU ID. - properties = GetLocalGPUInfo(CudaGpuId(100)); + // Invalid platform GPU ID. + properties = GetLocalGPUInfo(PlatformGpuId(100)); EXPECT_EQ("UNKNOWN", properties.type()); - // Succeed when a valid CUDA GPU id was inserted. - properties = GetLocalGPUInfo(CudaGpuId(0)); + // Succeed when a valid platform GPU id was inserted. + properties = GetLocalGPUInfo(PlatformGpuId(0)); EXPECT_EQ("GPU", properties.type()); EXPECT_EQ("NVIDIA", properties.vendor()); #else LOG(INFO) << "CUDA is not enabled."; DeviceProperties properties; - properties = GetLocalGPUInfo(CudaGpuId(0)); + properties = GetLocalGPUInfo(PlatformGpuId(0)); EXPECT_EQ("GPU", properties.type()); - properties = GetLocalGPUInfo(CudaGpuId(100)); + properties = GetLocalGPUInfo(PlatformGpuId(100)); EXPECT_EQ("GPU", properties.type()); #endif } @@ -74,20 +74,20 @@ TEST(UtilsTest, GetDeviceInfo) { EXPECT_EQ("NVIDIA", properties.vendor()); #endif - // TF to CUDA GPU id mapping entry doesn't exist. + // TF to platform GPU id mapping entry doesn't exist. device.has_id = true; device.id = 0; properties = GetDeviceInfo(device); EXPECT_EQ("UNKNOWN", properties.type()); #if GOOGLE_CUDA - // Invalid CUDA GPU id. - GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100)); + // Invalid platform GPU id. + GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100)); properties = GetDeviceInfo(device); EXPECT_EQ("UNKNOWN", properties.type()); - // Valid CUDA GPU id. - GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0)); + // Valid platform GPU id. + GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(1), PlatformGpuId(0)); device.id = 1; properties = GetDeviceInfo(device); EXPECT_EQ("GPU", properties.type()); diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6710ff9df3..56c8339d57 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) { } bool IsEnqueue(const NodeDef& n) { - return (n.op().find("Enqueue") != std::string::npos && - n.op().find("EnqueueMany") == std::string::npos); + return (n.op().find("Enqueue") != string::npos && + n.op().find("EnqueueMany") == string::npos); } bool IsDequeue(const NodeDef& n) { - return (n.op().find("Dequeue") != std::string::npos && - n.op().find("DequeueMany") == std::string::npos); + return (n.op().find("Dequeue") != string::npos && + n.op().find("DequeueMany") == string::npos); } bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { @@ -345,6 +345,56 @@ void VerboseLogUnknownDimensionSources( } } +bool IsShapeFullyDefinedIntegerVectorOrScalar( + InferenceContext* ic, const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, const DataType& dtype) { + if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 || + !ic->FullyDefined(tensor_as_shape) || + (dtype != DT_INT32 && dtype != DT_INT64)) { + return false; + } + return true; +} + +// Returned tensor's shape is like `shape`, and its values and dtype are from +// `tensor_as_shape` and `dtype`. +TensorProto MakeTensorProtoFromShape(InferenceContext* ic, + const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, + const DataType& dtype) { + TensorProto tensor_proto; + tensor_proto.set_dtype(dtype); + auto* shape_proto = tensor_proto.mutable_tensor_shape(); + if (ic->Rank(shape) == 1) { + shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape)); + } + // For a scalar tensor, tensor_shape field will be left empty; no dim. + for (int i = 0; i < ic->Rank(tensor_as_shape); i++) { + int64 value = ic->Value(ic->Dim(tensor_as_shape, i)); + if (dtype == DT_INT32) { + tensor_proto.add_int_val(value); + } else { + tensor_proto.add_int64_val(value); + } + } + return tensor_proto; +} + +// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`, +// and dtype = `dtype`. +NodeDef MakeConstNodeDefFromShape(InferenceContext* ic, + const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, + const DataType& dtype) { + NodeDef const_node; + const_node.set_name("const_from_shape"); + const_node.set_op("Const"); + auto* attr = const_node.mutable_attr(); + (*attr)["dtype"].set_type(dtype); + auto* tensor = (*attr)["value"].mutable_tensor(); + *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype); + return const_node; +} } // namespace // Queue of nodes to process. Nodes can be enqueued in any order, but will be @@ -429,18 +479,22 @@ class SymbolicShapeRefiner { // perform shape inference on the function body. // // Propagate shape information of final function body node - // to function node `node`. + // to function node `function_node`. // - // In the event of an error, UpdateNode will simply set `node`'s + // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. - Status UpdateFunction(const NodeDef* node) { - auto it = fun_to_grappler_function_item_.find(node->op()); + Status UpdateFunction(const NodeDef* function_node) { + auto it = fun_to_grappler_function_item_.find(function_node->op()); if (it == fun_to_grappler_function_item_.end()) { return errors::InvalidArgument( - node->op(), " was not previously added to SymbolicShapeRefiner."); + function_node->op(), + " was not previously added to SymbolicShapeRefiner."); } - GrapplerFunctionItem& grappler_function_item = it->second; + // Copy (not reference) so that changes we make here (e.g., replacing + // Placeholder with Const) don't affect one in + // fun_to_grappler_function_item_. + GrapplerFunctionItem grappler_function_item = it->second; GraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. @@ -453,7 +507,7 @@ class SymbolicShapeRefiner { "supported."); } NodeDef* fun_node = gv.GetNode(fun_input.input_name); - const string& input = node->input(i); + const string& input = function_node->input(i); const string& node_name = NodeName(input); if (IsControlInput(input)) { @@ -478,17 +532,48 @@ class SymbolicShapeRefiner { TensorShapeProto proto; const auto& handle = input_inference_context->output(output_port_num); input_inference_context->ShapeHandleToProto(handle, &proto); + // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1. + for (int i = 0; i < proto.dim_size(); i++) { + if (proto.dim(i).size() < -1) { + proto.mutable_dim(i)->set_size(-1); + } + } *attr_output_shape.mutable_shape() = proto; (*fun_node->mutable_attr())["shape"] = attr_output_shape; } + // Replace input Placeholders with Consts, if values are known. Note that + // we don't check exceptions here as it's done in the above loop. + auto* ctx = GetNodeContext(function_node); + auto* ic = ctx->inference_context.get(); + for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { + const string& input = function_node->input(i); + const string& node_name = NodeName(input); + NodeDef* input_node = graph_.GetNode(node_name); + if (IsConstant(*input_node)) { + TF_CHECK_OK( + ReplaceInputWithConst(*input_node, i, &grappler_function_item)); + } else if (ic->input_tensors_as_shapes().size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i])) { + // We have fully defined input_tensors_as_shapes for this input; use it + // as a const input to the function node. + NodeDef const_input_node = MakeConstNodeDefFromShape( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i]); + TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i, + &grappler_function_item)); + } + } + // Perform inference on function body. GraphProperties gp(grappler_function_item); TF_RETURN_IF_ERROR(gp.InferStatically(true)); // Add return nodes for output shapes. - auto ic = GetContext(node); int output = 0; + ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size()); for (auto const& out_arg : grappler_function_item.outputs()) { if (out_arg.output_tensors.size() > 1) { // TODO(jmdecker): Handle case of multiple output tensors @@ -505,8 +590,9 @@ class SymbolicShapeRefiner { const NodeDef* retnode = gv.GetNode(node_name); if (retnode == nullptr) { - return errors::FailedPrecondition("Unable to find return node ", - node_name, " for ", node->name()); + return errors::FailedPrecondition( + "Unable to find return function_node ", node_name, " for ", + function_node->name()); } auto output_properties = gp.GetOutputProperties(retnode->name()); @@ -520,6 +606,14 @@ class SymbolicShapeRefiner { ShapeHandle out; TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out)); ic->set_output(output, out); + if (outprop.has_value()) { + // Forward tensor value to output_tensors_as_shape. + Tensor tensor; + if (tensor.FromProto(outprop.value())) { + MaybeSetTensorValueToShape(ic, tensor, + &ctx->output_tensors_as_shapes[output]); + } + } output++; } @@ -562,21 +656,9 @@ class SymbolicShapeRefiner { if (const_values[dst_input].FromProto( input->attr().at("value").tensor())) { input_tensors[dst_input] = &const_values[dst_input]; - // Integer tensors of rank one can also be interpreted as a shape - // provided all their values are >= -1. - if (const_values[dst_input].dims() == 1 && - (const_values[dst_input].dtype() == DT_INT32 || - const_values[dst_input].dtype() == DT_INT64)) { - ShapeHandle tensor_shape = inference_context->Vector( - const_values[dst_input].NumElements()); - ShapeHandle shp; - if (inference_context - ->MakeShapeFromTensor(input_tensors[dst_input], - tensor_shape, &shp) - .ok()) { - input_tensors_as_shapes[dst_input] = shp; - } - } + MaybeSetTensorValueToShape(inference_context, + const_values[dst_input], + &input_tensors_as_shapes[dst_input]); } } else if (IsRank(*input)) { if (c->inference_context->RankKnown(c->inference_context->input(0))) { @@ -671,11 +753,13 @@ class SymbolicShapeRefiner { // true, as the updates to the call node will have changed, even if it's // the same function being called twice with the same input shapes. // Example: simple_function.pbtxt - if (UpdateFunction(node).ok()) { + auto s = UpdateFunction(node); + if (s.ok()) { return Status::OK(); } else { VLOG(1) << "UpdateFunction failed for " << node->op() - << ". Defaulting to ShapeUnknown."; + << ". Defaulting to ShapeUnknown.\n" + << s.ToString(); } } @@ -942,13 +1026,25 @@ class SymbolicShapeRefiner { : t->scalar<int64>()(); dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size)); } else { - dims.push_back(ic->UnknownDim()); + // Don't have tensor value, but use input_tensors_as_shapes, if + // possible. + const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i]; + if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 && + ic->ValueKnown(ic->Dim(shape_handle, 0))) { + dims.push_back(ic->Dim(shape_handle, 0)); + } else { + dims.push_back(ic->UnknownDim()); + } } } if (valid) { c->output_tensors_as_shapes.resize(1); c->output_tensors_as_shapes[0] = ic->MakeShape(dims); } + } else if (IsIdentity(node)) { + // Pass input_tensors_as_shapes to output_tensors_as_shapes. + c->output_tensors_as_shapes.resize(1); + c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0]; } else if (IsSlice(node)) { ShapeHandle input = ic->input_tensors_as_shapes()[0]; bool valid = ic->RankKnown(input); @@ -1053,6 +1149,46 @@ class SymbolicShapeRefiner { } private: + bool IsIntegerVector(const Tensor& tensor) { + if (tensor.dims() == 1 && + (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) { + return true; + } + return false; + } + + bool IsIntegerScalar(const Tensor& tensor) { + if (tensor.dims() == 0 && + (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) && + tensor.NumElements() == 1) { + return true; + } + return false; + } + + void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor, + ShapeHandle* tensors_as_shapes) { + // Integer tensors of rank one can also be interpreted as a shape + // provided all their values are >= -1. + if (IsIntegerVector(tensor)) { + ShapeHandle tensor_shape = ic->Vector(tensor.NumElements()); + ShapeHandle shp; + // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1). + if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) { + *tensors_as_shapes = shp; + } + } else if (IsIntegerScalar(tensor)) { + // Scalar constant. + int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0) + : tensor.flat<int64>()(0); + // Ideally, values can be < -1, but MakeDim() fails with a value < -1. + // It's a limitation as we use ShapeHandle as a means to pass values. + if (value >= -1) { + *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)}); + } + } + } + const GraphView& graph_; int graph_def_version_; std::unordered_map<const NodeDef*, NodeContext> node_to_context_; @@ -1528,6 +1664,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { continue; } + auto* ic = ctx->inference_context.get(); + // Fill input properties. { auto& input_properties = input_properties_[node.name()]; @@ -1535,19 +1673,26 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { // Should always be empty, node names in graph are supposed to be unique. CHECK_EQ(input_properties.size(), 0); - input_properties.resize(ctx->inference_context->num_inputs()); + input_properties.resize(ic->num_inputs()); GraphView::InputPort input(&node, -1); - for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) { - shape_manager.AsTensorProperties(ctx->inference_context->input(i), - ctx->input_types[i], + for (int i = 0; i < ic->num_inputs(); ++i) { + shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i], &input_properties[i]); input.port_id = i; GraphView::OutputPort fanin = graph_view.GetRegularFanin(input); - if (!IsConstant(*fanin.node)) { - continue; + // Export tensor value (either const tensor or input_tensors_as_shapes) + // to input_properties.value. + if (IsConstant(*fanin.node)) { + const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); + *input_properties[i].mutable_value() = raw_val; + } else if (ic->input_tensors_as_shapes().size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i])) { + *input_properties[i].mutable_value() = MakeTensorProtoFromShape( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i]); } - const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); - *input_properties[i].mutable_value() = raw_val; } } @@ -1558,11 +1703,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { // Should always be empty, node names in graph are supposed to be unique. CHECK_EQ(output_properties.size(), 0); - output_properties.resize(ctx->inference_context->num_outputs()); - for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) { - shape_manager.AsTensorProperties(ctx->inference_context->output(i), - ctx->output_types[i], + output_properties.resize(ic->num_outputs()); + for (int i = 0; i < ic->num_outputs(); ++i) { + shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i], &output_properties[i]); + // Export tensor value (either const tensor or input_tensors_as_shapes) + // to output_properties.value. + if (IsConstant(node)) { + const TensorProto& raw_val = node.attr().at("value").tensor(); + *output_properties[i].mutable_value() = raw_val; + } else if (ctx->output_tensors_as_shapes.size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ctx->output_types[i])) { + *output_properties[i].mutable_value() = MakeTensorProtoFromShape( + ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ctx->output_types[i]); + } } } } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 8938b7c32e..362092a6cf 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -44,6 +44,30 @@ class GraphPropertiesTest : public ::testing::Test { // Provision a single machine with 3 cpu cores cluster_.reset(new SingleMachine(5 * 60, 3, 0)); TF_CHECK_OK(cluster_->Provision()); + + // This function is simply + // out = Fill(shape, value), but + // Fill requires values in the shape input, not just shape of it, to infer + // output shape. + auto f = FunctionDefHelper::Create( + // Name + "MyFillFunc", + // Inputs + {"shape: int32", "value: float"}, + // Outputs + {"out: float"}, + // Attrs + {}, + // Nodes + { + {{"a"}, + "Fill", + {"shape", "value"}, + {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}}, + }, + // Returns + {{"out", "a:output:0"}}); + function_lib_.add_function()->Swap(&f); } void TearDown() override { @@ -69,7 +93,29 @@ class GraphPropertiesTest : public ::testing::Test { return s; } + // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected + // ones. + void ExpectTensorValues(const std::vector<int64>& expected, + const TensorProto& tensor_proto_to_compare) { + Tensor tensor; + EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare)); + EXPECT_EQ(expected.size(), tensor.NumElements()); + // We're interested in only integer tensors as only shapes are exported as + // graph properties values. + CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64); + if (tensor.dtype() == DT_INT32) { + for (int i = 0; i < tensor.NumElements(); i++) { + EXPECT_EQ(expected[i], tensor.flat<int32>()(i)); + } + } else { + for (int i = 0; i < tensor.NumElements(); i++) { + EXPECT_EQ(expected[i], tensor.flat<int64>()(i)); + } + } + } + std::unique_ptr<SingleMachine> cluster_; + FunctionDefLibrary function_lib_; }; TEST_F(GraphPropertiesTest, StaticProperties) { @@ -785,7 +831,220 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { EXPECT_EQ("float: [128,256]", PropToString(prop)); } -TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { +TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2}); + Output a1 = ops::Identity(s.WithOpName("a1"), a); + Output b = ops::Const(s.WithOpName("b"), 99, {}); + Output b1 = ops::Identity(s.WithOpName("b1"), b); + Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4}); + Output c1 = ops::Identity(s.WithOpName("c1"), c); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + + // Check output shapes. + EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0])); + EXPECT_EQ("int32: [2]", + PropToString(properties.GetOutputProperties("a1")[0])); + EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0])); + EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0])); + EXPECT_EQ("int32: [4,4,4]", + PropToString(properties.GetOutputProperties("c")[0])); + EXPECT_EQ("int32: [4,4,4]", + PropToString(properties.GetOutputProperties("c1")[0])); + + // Check has_value. + EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value()); + EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value()); + EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value()); + EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value()); + // Note that we propagate tensro value of only 1D vector and scalar. + EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value()); + + // Check values. + ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value()); + ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value()); + ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value()); + ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value()); + ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value()); + ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value()); + std::vector<int64> c_values; + for (int i = 0; i < 4 * 4 * 4; i++) { + c_values.push_back(1); + } + ExpectTensorValues({c_values}, + properties.GetOutputProperties("c")[0].value()); + ExpectTensorValues({c_values}, + properties.GetInputProperties("c1")[0].value()); + // No output value for c1, as it's neither 1D vector nor scalar. +} + +TEST_F(GraphPropertiesTest, IdentityPassingShape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 5, {2}); + Output b = ops::Identity(s.WithOpName("b"), a); + Output c = ops::Const(s.WithOpName("const"), 0.1f, {}); + // Fill needs not only e's shape but also the value of e to figure out output + // shape; hence, Identity op (b) should pass a's value as + // output_tensors_as_shape. + Output d = ops::Fill(s.WithOpName("fill"), b, c); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("fill"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [5,5]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, PackWithConstInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {}); + Output b = ops::Const(s.WithOpName("b"), 2, {}); + Output c = ops::Const(s.WithOpName("c"), 3, {}); + Output d = ops::Const(s.WithOpName("d"), 4, {}); + // Note ops::Stack instantiates Pack op. + Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d}); + // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4} + Output f = ops::Const(s.WithOpName("const"), 0.1f, {}); + // Fill needs not only e's shape but also its value to figure out output + // shape. + Output g = ops::Fill(s.WithOpName("fill"), e, f); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("fill"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, PackWithIdentityInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops + // from Const. + // If output_tensors_as_shape is not not set for those Shape ops or Pack op + // doesn't take input_tensors_as_shape, Fill op's input doesn't have value; + // hence, its output shape becomes unknown. + Output a0 = ops::Const(s.WithOpName("a0"), 1, {}); + Output b0 = ops::Const(s.WithOpName("b0"), 2, {}); + Output c0 = ops::Const(s.WithOpName("c0"), 3, {}); + Output d0 = ops::Const(s.WithOpName("d0"), 4, {}); + Output a = ops::Identity(s.WithOpName("a"), a0); + Output b = ops::Identity(s.WithOpName("b"), b0); + Output c = ops::Identity(s.WithOpName("c"), c0); + Output d = ops::Identity(s.WithOpName("d"), d0); + // Note ops::Stack instantiates Pack op. + Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d}); + // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4} + Output f = ops::Const(s.WithOpName("const"), 0.1f, {}); + // Fill needs not only e's shape but also its value to figure out output + // shape. + Output g = ops::Fill(s.WithOpName("fill"), e, f); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("fill"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, FunctionWithConstInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_)); + Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4}); + Output value = ops::Const(s.WithOpName("value"), 0.1f, {}); + auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc", + s.graph()->op_registry()); + tensorflow::Node* func_op; + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto _value = tensorflow::ops::AsNodeOut(s, value); + TF_CHECK_OK( + builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFillFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) { + // Same to FunctionWithConstInput, but function inputs are Identity of Const, + // so tensor shapes, not tensor value, should be used as Const input to + // function. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_)); + Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4}); + Output shape = ops::Identity(s.WithOpName("shape"), shape_); + Output value = ops::Const(s.WithOpName("value"), 0.1f, {}); + auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc", + s.graph()->op_registry()); + tensorflow::Node* func_op; + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto _value = tensorflow::ops::AsNodeOut(s, value); + TF_CHECK_OK( + builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFillFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) { + FunctionDefLibrary library; + *library.add_function() = FunctionDefHelper::Create( + "MyFunc", // Name + {"x: int32"}, // Inputs + {"out: int32"}, // Outputs + {}, // Attrs + {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes + {{"out", "a:output:0"}}); // Returns + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(library)); + + // MyFunc takes Const (shape) and passes it with Identity. Expect function + // output has the same shape as well as value (output_tensors_as_shape) as + // input Const tensor. + Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2}); + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto builder = + tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry()); + tensorflow::Node* func_op; + TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op)); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(true)); + const auto out_props = properties.GetOutputProperties("MyFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("int32: [2]", PropToString(out_prop0)); + EXPECT_TRUE(out_prop0.has_value()); + ExpectTensorValues({5, 7}, out_prop0.value()); + ExpectTensorValues({5, 7}, + properties.GetInputProperties("MyFunc")[0].value()); +} + +TEST_F(GraphPropertiesTest, FunctionWithScalarInput) { // Create graph with a function that takes a scalar value so that we use // Placeholder with scalar as for input to the function shape inference. // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of @@ -818,7 +1077,7 @@ TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { // MyFunc output shouldn't be unknown rank. GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically(false)); + TF_CHECK_OK(properties.InferStatically(true)); const auto out_props = properties.GetOutputProperties("MyFunc"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); @@ -856,18 +1115,10 @@ TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) { EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) { @@ -882,51 +1133,25 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) { EXPECT_EQ(2, out_props.size()); const OpInfo::TensorProperties& out_prop0 = out_props[0]; - EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); - EXPECT_EQ(4, out_prop0.shape().dim_size()); - EXPECT_EQ(128, out_prop0.shape().dim(0).size()); - EXPECT_EQ(112, out_prop0.shape().dim(1).size()); - EXPECT_EQ(112, out_prop0.shape().dim(2).size()); - EXPECT_EQ(64, out_prop0.shape().dim(3).size()); + EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0)); const OpInfo::TensorProperties& out_prop1 = out_props[1]; - EXPECT_EQ(DT_FLOAT, out_prop1.dtype()); - EXPECT_EQ(128, out_prop1.shape().dim(0).size()); - EXPECT_EQ(112, out_prop1.shape().dim(1).size()); - EXPECT_EQ(112, out_prop1.shape().dim(2).size()); - EXPECT_EQ(24, out_prop1.shape().dim(3).size()); + EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1)); const auto in_props = properties.GetInputProperties("y0"); EXPECT_EQ(4, in_props.size()); const OpInfo::TensorProperties& in_prop0 = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop0.dtype()); - EXPECT_EQ(1, in_prop0.shape().dim_size()); - EXPECT_EQ(64, in_prop0.shape().dim(0).size()); + EXPECT_EQ("float: [64]", PropToString(in_prop0)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_EQ(4, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(1, in_prop1.shape().dim(1).size()); - EXPECT_EQ(24, in_prop1.shape().dim(2).size()); - EXPECT_EQ(64, in_prop1.shape().dim(3).size()); + EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1)); const OpInfo::TensorProperties& in_prop2 = in_props[2]; - EXPECT_EQ(DT_FLOAT, in_prop2.dtype()); - EXPECT_EQ(4, in_prop2.shape().dim_size()); - EXPECT_EQ(128, in_prop2.shape().dim(0).size()); - EXPECT_EQ(224, in_prop2.shape().dim(1).size()); - EXPECT_EQ(224, in_prop2.shape().dim(2).size()); - EXPECT_EQ(3, in_prop2.shape().dim(3).size()); + EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2)); const OpInfo::TensorProperties& in_prop3 = in_props[3]; - EXPECT_EQ(DT_FLOAT, in_prop3.dtype()); - EXPECT_EQ(4, in_prop3.shape().dim_size()); - EXPECT_EQ(7, in_prop3.shape().dim(0).size()); - EXPECT_EQ(7, in_prop3.shape().dim(1).size()); - EXPECT_EQ(3, in_prop3.shape().dim(2).size()); - EXPECT_EQ(8, in_prop3.shape().dim(3).size()); + EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3)); } TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) { @@ -986,18 +1211,10 @@ TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) { EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) { @@ -1022,27 +1239,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) { const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o"); const OpInfo::TensorProperties& out_prop = out_props[0]; EXPECT_EQ(DT_FLOAT, out_prop.dtype()); - EXPECT_FALSE(out_prop.shape().unknown_rank()); - EXPECT_EQ(2, out_prop.shape().dim_size()); - EXPECT_EQ(1, out_prop.shape().dim(0).size()); - EXPECT_EQ(2, out_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(out_prop)); const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o"); EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) { @@ -1066,28 +1272,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) { TF_CHECK_OK(properties.InferStatically(false)); const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o"); const OpInfo::TensorProperties& out_prop = out_props[0]; - EXPECT_EQ(DT_FLOAT, out_prop.dtype()); - EXPECT_FALSE(out_prop.shape().unknown_rank()); - EXPECT_EQ(2, out_prop.shape().dim_size()); - EXPECT_EQ(1, out_prop.shape().dim(0).size()); - EXPECT_EQ(2, out_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(out_prop)); const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o"); EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) { @@ -1115,28 +1309,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) { TF_CHECK_OK(properties.InferStatically(false)); const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I"); const OpInfo::TensorProperties& out_prop = out_props[0]; - EXPECT_EQ(DT_FLOAT, out_prop.dtype()); - EXPECT_FALSE(out_prop.shape().unknown_rank()); - EXPECT_EQ(2, out_prop.shape().dim_size()); - EXPECT_EQ(1, out_prop.shape().dim(0).size()); - EXPECT_EQ(2, out_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(out_prop)); const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I"); EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(3, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,3]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, SymbolicShapes) { diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index aad00ce039..5415324b48 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -127,7 +127,7 @@ static void ExtractExtraProperties( // For filename input, the file size can also be useful. if (op_def && i < op_def->input_arg_size() && - op_def->input_arg(i).name().find("filename") != std::string::npos) { + op_def->input_arg(i).name().find("filename") != string::npos) { Tensor tensor; if (!tensor.FromProto(t)) { continue; @@ -153,7 +153,7 @@ static void ExtractExtraProperties( // When the input is a handle (e.g. look up table handle), the information // in the op itself is not sufficient to predict the op memory. if (op_def && i < op_def->input_arg_size() && - op_def->input_arg(i).name().find("handle") != std::string::npos) { + op_def->input_arg(i).name().find("handle") != string::npos) { string new_key = strings::StrCat("parent_", i, "_op"); AttrValue attr; attr.set_s(input_node->op()); @@ -209,13 +209,13 @@ DeviceProperties GetDeviceInfo(const string& device_str) { if (DeviceNameUtils::ParseFullName(device_str, &parsed)) { if (parsed.type == "GPU") { TfGpuId tf_gpu_id(parsed.id); - CudaGpuId cuda_gpu_id; - Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id); + PlatformGpuId platform_gpu_id; + Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id); if (!s.ok()) { // We are probably running simulation without linking cuda libraries. - cuda_gpu_id = CudaGpuId(parsed.id); + platform_gpu_id = PlatformGpuId(parsed.id); } - return GetLocalGPUInfo(cuda_gpu_id); + return GetLocalGPUInfo(platform_gpu_id); } else if (parsed.type == "CPU") { return GetLocalCPUInfo(); } @@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) { buckets_.begin(), std::plus<uint64>()); } -std::string TensorSizeHistogram::ToString() const { - std::string r; +string TensorSizeHistogram::ToString() const { + string r; char buf[200]; snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_); r.append(buf); diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index d2c7c67666..5fd6717712 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -80,7 +80,7 @@ class TensorSizeHistogram { uint64 Max() const { return max_; } uint64 NumElem() const { return num_elem_; } uint64 SumElem() const { return sum_elem_; } - std::string ToString() const; + string ToString() const; protected: const int Index(const uint64 value) const; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 02a379fca8..80889afc86 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { // Helper lambda to extract port num from _Send and _Recv op name. auto get_port_num = [](const string& name) -> int { - if (name.find("bn_0") != std::string::npos) { + if (name.find("bn_0") != string::npos) { return 0; - } else if (name.find("bn_1") != std::string::npos) { + } else if (name.find("bn_1") != string::npos) { return 1; - } else if (name.find("bn_2") != std::string::npos) { + } else if (name.find("bn_2") != string::npos) { return 2; - } else if (name.find("bn_minus1") != std::string::npos) { + } else if (name.find("bn_minus1") != string::npos) { return -1; } return -999; diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h index 26d38a4931..97626346c7 100644 --- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h @@ -138,7 +138,7 @@ class GraphAnalyzer { // The entries are owned by collation_map_, so must be removed from // ordered_collation_ before removing them from collation_map_. struct ReverseLessByCount { - bool operator()(CollationEntry* left, CollationEntry* right) { + bool operator()(CollationEntry* left, CollationEntry* right) const { return left->count > right->count; // Reverse order. } }; diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index a6b6b6f8b2..0b8cb5e919 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -14,11 +14,44 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/grappler/utils.h" namespace tensorflow { namespace grappler { +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + for (int output_arg_id = 0; output_arg_id < op.output_arg_size(); + ++output_arg_id) { + if (port_id < 0) { + return -1; + } else if (port_id == 0) { + return output_arg_id; + } + + // Default is 1 port per output arg. + int n = 1; + + const auto& output_arg = op.output_arg(output_arg_id); + if (!output_arg.number_attr().empty()) { + n = node.attr().at(output_arg.number_attr()).i(); + } else if (!output_arg.type_list_attr().empty()) { + n = node.attr().at(output_arg.type_list_attr()).list().type_size(); + } + + if (n < 0) { + // This should never happen. + DCHECK_GE(n, 0); + return -1; + } else if (port_id < n) { + return output_arg_id; + } + port_id -= n; + } + + return -1; +} + GraphView::GraphView(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { auto node = graph_->mutable_node(i); @@ -39,7 +72,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) { void GraphView::AddFanouts(NodeDef* node) { for (int i = 0; i < node->input_size(); ++i) { OutputPort fanin; - string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); + const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); fanin.node = nodes_[fanin_name]; InputPort input; diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index ac260f85a0..ec946ca3b5 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -20,11 +20,21 @@ limitations under the License. #include <unordered_set> #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { namespace grappler { +// Map a node/op's output port_id to arg_id. +// +// The port_id refers to the n-th tensor of the node, while the arg_id refers to +// the n-th arg of the op. These two can be different if an op's arg is a list +// of tensors. +// +// We return -1 for any invalid port_id (i.e., no corresponding arg_id). +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); + // A utility class to simplify the traversal of a GraphDef. class GraphView { public: diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 958eb921fb..3d7d2faf7c 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/cc/ops/parsing_ops.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" @@ -25,6 +26,88 @@ namespace { class GraphViewTest : public ::testing::Test {}; +TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); + ops::ShapeN b(s.WithOpName("b"), {a, a, a}); + + GraphDef graph_def; + TF_CHECK_OK(s.ToGraphDef(&graph_def)); + GraphView graph_view(&graph_def); + + const NodeDef& a_node_def = *graph_view.GetNode("a"); + const NodeDef& b_node_def = *graph_view.GetNode("b"); + + const OpDef* a_op_def = nullptr; + const OpDef* b_op_def = nullptr; + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok()); + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1)); + + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4)); +} + +TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { + for (int num_splits : {1, 2}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10}); + ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits); + + GraphDef graph_def; + TF_CHECK_OK(s.ToGraphDef(&graph_def)); + GraphView graph_view(&graph_def); + + const NodeDef& b_node_def = *graph_view.GetNode("b"); + const OpDef* b_op_def = nullptr; + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + + for (int port_id = 0; port_id <= num_splits * 3; ++port_id) { + int arg_id = -1; + if (port_id < num_splits * 3) { + arg_id = port_id / num_splits; + } + EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id)); + } + } +} + +TEST_F(GraphViewTest, ParseSingleExample) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const<string>(s.WithOpName("a"), "", {}); + Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1}); + ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"}, + {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}}); + + GraphDef graph_def; + TF_CHECK_OK(s.ToGraphDef(&graph_def)); + GraphView graph_view(&graph_def); + + const NodeDef& c_node_def = *graph_view.GetNode("c"); + + const OpDef* c_op_def = nullptr; + EXPECT_TRUE( + OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok()); + + EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1)); + EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2)); + EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3)); + EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4)); + EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5)); + EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6)); + EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8)); +} + TEST_F(GraphViewTest, BasicGraph) { TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"}); GrapplerItem item; diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc index 5029dff877..def9198a69 100644 --- a/tensorflow/core/grappler/inputs/utils.cc +++ b/tensorflow/core/grappler/inputs/utils.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/inputs/utils.h" -#include "tensorflow/core/platform/env.h" #include <vector> +#include "tensorflow/core/platform/env.h" + namespace tensorflow { namespace grappler { @@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) { return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr); } -bool FileExists(const std::string& file, Status* status) { +bool FileExists(const string& file, Status* status) { *status = Env::Default()->FileExists(file); return status->ok(); } -Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path, +Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path, GraphDef* result) { Status status; if (FileExists(graph_def_pbtxt_path, &status)) { diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h index 627dd5359f..4b9cb0a9ad 100644 --- a/tensorflow/core/grappler/inputs/utils.h +++ b/tensorflow/core/grappler/inputs/utils.h @@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files, std::vector<Status>* status = nullptr); bool FilesExist(const std::set<string>& files); -bool FileExists(const std::string& file, Status* status); +bool FileExists(const string& file, Status* status); -Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path, +Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path, GraphDef* result); } // end namespace grappler diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 653b088b1d..3521669b63 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) { bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } -bool IsElementWiseMonotonic(const NodeDef& node) { - static const std::unordered_set<string>* element_wise_monotonic_ops = +// Returns true if node represents a unary elementwise function that is +// monotonic. If *is_non_decreasing is true, the function is non-decreasing, +// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, +// e.g. inv. +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { + static const std::unordered_set<string>* monotonic_non_decreasing_ops = CHECK_NOTNULL((new std::unordered_set<string>{ - "Relu", - "Relu6", - "Sigmoid", - "Sqrt", - "Tanh", + "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", + "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", + "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", + })); + static const std::unordered_set<string>* monotonic_non_increasing_ops = + CHECK_NOTNULL((new std::unordered_set<string>{ + "Inv", + "Reciprocal", + "Erfc", + "Rsqrt", + "Neg", })); - return element_wise_monotonic_ops->count(node.op()) > 0; + if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = true; + } + return true; + } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = false; + } + return true; + } + return false; } bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } @@ -470,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) { } } // Queue ops modify the queue which is a side effect. - if (node.op().find("Queue") != std::string::npos) { + if (node.op().find("Queue") != string::npos) { return false; } return !ModifiesInputsInPlace(node); diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 94439265c9..25ab6b65ac 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); -bool IsElementWiseMonotonic(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index a24004dc16..960d1addb3 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -8,10 +8,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") # Platform specific build config load( - "//tensorflow/core:platform/default/build_config.bzl", - "tf_protos_grappler", -) -load( "//tensorflow/core:platform/default/build_config_root.bzl", "if_static", ) @@ -97,7 +93,6 @@ cc_library( deps = [ ":evaluation_utils", ":graph_optimizer", - ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -107,6 +102,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:symbolic_shapes", ], ) @@ -261,7 +257,6 @@ cc_library( ":constant_folding", ":graph_optimizer", ":graph_optimizer_stage", - ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -270,6 +265,7 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:symbolic_shapes", "//tensorflow/core/grappler/utils:topological_sort", ], ) @@ -515,12 +511,14 @@ cc_library( ":custom_graph_optimizer_registry", ":debug_stripper", ":dependency_optimizer", + ":experimental_implementation_selector", ":function_optimizer", ":graph_optimizer", ":layout_optimizer", ":loop_optimizer", ":memory_optimizer", ":model_pruner", + ":pin_to_host_optimizer", ":remapper", ":scoped_allocator_optimizer", ":shape_optimizer", @@ -647,7 +645,6 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_optimizer", - ":symbolic_shapes", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -657,6 +654,7 @@ cc_library( "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/costs:graph_properties", "//tensorflow/core/grappler/utils:frame", + "//tensorflow/core/grappler/utils:symbolic_shapes", ], ) @@ -714,31 +712,6 @@ tf_cuda_cc_test( ) cc_library( - name = "symbolic_shapes", - srcs = ["symbolic_shapes.cc"], - hdrs = ["symbolic_shapes.h"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - ] + tf_protos_grappler(), -) - -tf_cc_test( - name = "symbolic_shapes_test", - srcs = ["symbolic_shapes_test.cc"], - deps = [ - ":symbolic_shapes", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - ], -) - -cc_library( name = "debug_stripper", srcs = ["debug_stripper.cc"], hdrs = [ @@ -846,3 +819,106 @@ tf_cc_test( "//third_party/eigen3", ], ) + +cc_library( + name = "function_api_info", + srcs = ["function_api_info.cc"], + hdrs = ["function_api_info.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "function_api_info_test", + size = "small", + srcs = ["function_api_info_test.cc"], + deps = [ + ":function_api_info", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "experimental_implementation_selector", + srcs = ["experimental_implementation_selector.cc"], + hdrs = ["experimental_implementation_selector.h"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":function_api_info", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + ], +) + +tf_cc_test( + name = "experimental_implementation_selector_test", + size = "small", + srcs = ["experimental_implementation_selector_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":experimental_implementation_selector", + ":function_api_info", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) + +cc_library( + name = "pin_to_host_optimizer", + srcs = ["pin_to_host_optimizer.cc"], + hdrs = [ + "pin_to_host_optimizer.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_optimizer", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/utils:frame", + "//tensorflow/core/grappler/utils:symbolic_shapes", + "//tensorflow/core/grappler/utils:topological_sort", + ], +) + +tf_cuda_cc_test( + name = "pin_to_host_optimizer_test", + srcs = ["pin_to_host_optimizer_test.cc"], + deps = [ + ":pin_to_host_optimizer", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4fed88d536..3388ee8035 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -35,8 +35,8 @@ limitations under the License. #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/constant_folding.h" #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h" -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { for (int i = 0; i < output->input_size(); ++i) { auto input = output->input(i); - string name = ParseNodeName(input, &position); + StringPiece name = ParseNodeNameAsStringPiece(input, &position); if (name == node.name() && /*control input*/ position < 0) { return true; } @@ -1121,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); NodeDef* tail = node; - // TODO(rmlarsen): Enable after debugging breakage in Bayesflow. - if (ctx().opt_level == RewriterConfig::AGGRESSIVE) { - tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, - *ctx().nodes_to_preserve); - } + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); NodeDef* first_transpose; TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); @@ -1328,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const string node_name = node->name(); NodeDef* x; NodeDef* y; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); bool updated = false; - if (IsAdd(*node)) { - if (IsNeg(*x)) { - // (-a) + b = b - a - node->set_op("Sub"); - node->mutable_input()->SwapElements(0, 1); - node->set_input(1, x->input(0)); - node->add_input(AsControlDependency(x->name())); - ctx().node_map->AddOutput(NodeName(x->input(0)), node_name); - updated = true; - } else if (IsNeg(*y)) { - // a + (-b) = a - b - node->set_op("Sub"); - node->set_input(1, y->input(0)); - node->add_input(AsControlDependency(y->name())); - ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); - updated = true; - } - } else if (IsSub(*node)) { - if (IsNeg(*y)) { - // a - (-b) = a + b - node->set_op("Add"); - node->set_input(1, y->input(0)); - node->add_input(AsControlDependency(y->name())); - ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); - updated = true; - } + if (IsNeg(*y)) { + // a - (-b) = a + b or a + (-b) = a - b + ForwardControlDependencies(node, {y}); + ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0)); + node->set_op(IsAdd(*node) ? "Sub" : "Add"); + node->set_input(1, y->input(0)); + updated = true; + } else if (IsAdd(*node) && IsNeg(*x)) { + // (-a) + b = b - a + ForwardControlDependencies(node, {x}); + ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0)); + node->set_op("Sub"); + node->mutable_input()->SwapElements(0, 1); + node->set_input(1, x->input(0)); + updated = true; } if (updated) { AddToOptimizationQueue(node); @@ -1583,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { for (NodeDef* output : outputs) { if (IsControlInput(output->input(0))) continue; int port; - const string node_name = ParseNodeName(output->input(0), &port); + const StringPiece node_name = + ParseNodeNameAsStringPiece(output->input(0), &port); if (node_name == node.name()) { tails->insert(ChainLink(output, port)); } else { @@ -1633,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } else { for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) { int port; - const string node_name = ParseNodeName(new_tail->input(0), &port); + const StringPiece node_name = + ParseNodeNameAsStringPiece(new_tail->input(0), &port); if (node_name != tail->name()) { return Status::OK(); } @@ -2382,26 +2369,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1]; - for (int i = 0; i < p.shape().dim_size(); ++i) { - if (p.shape().dim(i).size() < 0) { + const auto& pow_props = + ctx().graph_properties->GetInputProperties(node->name())[1]; + for (int i = 0; i < pow_props.shape().dim_size(); ++i) { + if (pow_props.shape().dim(i).size() < 0) { // skip if p is is not fully defined. return Status::OK(); } } - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor pow(p.dtype(), p.shape()); - if (!pow.FromProto(p.value())) { + if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) { + Tensor pow(pow_props.dtype(), pow_props.shape()); + if (!pow.FromProto(pow_props.value())) { return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); + pow_props.value().DebugString()); } complex128 prev, curr; for (int i = 0; i < pow.NumElements(); ++i) { - if (!GetElementUnexhaustive(pow, i, - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128}, - &curr)) { + if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) { // input data type is not supported by Pow. Skip. return Status::OK(); } @@ -2414,12 +2399,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage { NodeDef *x, *y; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + const auto& value_props = + ctx().graph_properties->GetInputProperties(node->name())[0]; + const TensorShapeProto& output_shape = + ctx().graph_properties->GetOutputProperties(node->name())[0].shape(); if (curr == complex128(2, 0)) { node->set_op("Square"); node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); AddToOptimizationQueue(y); - } else if (curr == complex128(1, 0)) { + } else if (curr == complex128(1, 0) && + ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { + // Pow could be used to broadcast, so make sure the shapes of the two + // arguments are identical before replacing Pow with Identity. node->set_op("Identity"); node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); @@ -2429,20 +2421,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage { node->set_input(1, AsControlDependency(y->name())); AddToOptimizationQueue(node); AddToOptimizationQueue(y); - } else if (curr == complex128(0, 0)) { - const auto& b = - ctx().graph_properties->GetInputProperties(node->name())[0]; - for (int i = 0; i < b.shape().dim_size(); ++i) { - if (b.shape().dim(i).size() < 0) { + } else if (curr == complex128(0, 0) && + ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { + for (int i = 0; i < value_props.shape().dim_size(); ++i) { + if (value_props.shape().dim(i).size() < 0) { // skip if b is is not fully defined. return Status::OK(); } } - if (TensorShape::IsValid(b.shape()) && b.has_value()) { - Tensor base(b.dtype(), b.shape()); - if (!base.FromProto(b.value())) { + if (TensorShape::IsValid(value_props.shape()) && + value_props.has_value()) { + Tensor base(value_props.dtype(), value_props.shape()); + if (!base.FromProto(value_props.value())) { return errors::InvalidArgument("Cannot parse tensor from proto: ", - b.value().DebugString()); + value_props.value().DebugString()); } node->set_op("Const"); Tensor c(base.dtype(), base.shape()); @@ -2600,12 +2592,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { ~ConvertExpm1Stage() override = default; bool IsSupported(const NodeDef* node) const override { - if (!IsSub(*node)) - return false; + if (!IsSub(*node)) return false; NodeDef* input; - if (!GetInputNode(node->input(0), &input).ok()) - return false; + if (!GetInputNode(node->input(0), &input).ok()) return false; return IsExp(*input); } @@ -2625,10 +2615,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { return Status::OK(); } - const auto& t = - ctx().graph_properties->GetInputProperties(exp->name())[0]; - const auto& c = - ctx().graph_properties->GetInputProperties(node->name())[1]; + const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0]; + const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1]; for (int k = 0; k < c.shape().dim_size(); ++k) { // Skip if c shape is not fully determined. if (c.shape().dim(k).size() < 0) { @@ -2706,8 +2694,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { // 0. inner_function is not in the preserve set, // 1. inner_function's Op is element-wise monotonic // 2. inner_function's output is not being consumed elsewhere. + bool is_non_decreasing = false; if (!IsInPreserveSet(*inner_function) && - IsElementWiseMonotonic(*inner_function) && + IsElementWiseMonotonic(*inner_function, &is_non_decreasing) && ctx().node_map->GetOutputs(inner_function->name()).size() == 1) { // Swap the first inputs of the inner function Op & the reduction Op. NodeDef* inner_input; @@ -2719,7 +2708,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { UpdateConsumers(reduction_node, inner_function->name()); ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(), reduction_node->name()); - + if (!is_non_decreasing) { + // Flip Min<->Max if the function is non-increasing, e.g. + // Max(Neg(x)) = Neg(Min(x)). + const string opposite = IsMax(*reduction_node) ? "Min" : "Max"; + reduction_node->set_op(opposite); + } AddToOptimizationQueue(reduction_node); AddToOptimizationQueue(inner_function); AddToOptimizationQueue(inner_input); @@ -2937,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const { for (const auto& input : node.input()) { int pos; - string node_name = ParseNodeName(input, &pos); - h = Hash64CombineUnordered(Hash64(node_name), h); + const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos); + h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h); h = Hash64CombineUnordered(std::hash<int>()(pos), h); } for (const auto& attr : node.attr()) { @@ -3050,6 +3044,13 @@ void ArithmeticOptimizer::DedupComputations() { return; } std::set<int> duplicates; + // Populate feed_inplace_op; + std::unordered_set<NodeDef*> feeds_inplace_op; + for (int i = 0; i < optimized_graph_->node_size(); ++i) { + if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) { + feeds_inplace_op.insert(optimized_graph_->mutable_node(i)); + } + } do { stop = true; UniqueNodes nodes; @@ -3058,19 +3059,19 @@ void ArithmeticOptimizer::DedupComputations() { continue; } NodeDef* node = optimized_graph_->mutable_node(i); - if (!CanDedup(*node)) { + if (!CanDedup(*node) || + feeds_inplace_op.find(node) != feeds_inplace_op.end()) { continue; } NodeDef* rep = nodes.FindOrAddRepresentative(node); if (rep == node) { continue; } - // If either node feeds an inplace op, deduping them may cause data races. - // For example: If we dedup nodes initializing two independent inplace - // accumulations, they will write to the same buffer, clobbering each - // other's results. - if (FeedsInPlaceOp(graph_view, *rep) || - FeedsInPlaceOp(graph_view, *node)) { + // If either node or rep feeds an inplace op, deduping them may cause data + // races. For example: If we dedup nodes initializing two independent + // inplace accumulations, they will write to the same buffer, clobbering + // each other's results. + if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) { continue; } VLOG(3) << "Remove duplicated node: node=" << node->name() @@ -3078,20 +3079,20 @@ void ArithmeticOptimizer::DedupComputations() { const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name()); for (NodeDef* fanout : fanouts) { for (int i = 0; i < fanout->input_size(); ++i) { - string* name = fanout->mutable_input(i); - int position; - const string nodename = ParseNodeName(*name, &position); - if (nodename == node->name()) { - // Update name in-place. - if (position > 0) { - *name = StrCat(rep->name(), ":", position); - } else if (position == 0) { - *name = rep->name(); - } else { - *name = StrCat("^", rep->name()); - } - node_map_->AddOutput(rep->name(), fanout->name()); + string* fanout_input = fanout->mutable_input(i); + const int position = + NodePositionIfSameNode(*fanout_input, node->name()); + // Update name in-place. + if (position < -1) { + continue; + } else if (position > 0) { + *fanout_input = StrCat(rep->name(), ":", position); + } else if (position == 0) { + *fanout_input = rep->name(); + } else { + *fanout_input = StrCat("^", rep->name()); } + node_map_->AddOutput(rep->name(), fanout->name()); } } duplicates.insert(i); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bfccc0affd..77f3c64c65 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); - EXPECT_EQ(std::string("\0\0\0@", 4), + EXPECT_EQ(string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); @@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); - EXPECT_EQ(std::string("\0\0\0@", 4), + EXPECT_EQ(string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); @@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y); Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y); Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y); - auto add_all = ops::AddN(s.WithOpName("add_all"), - {add_x_y, add_negx_y, add_x_negy, add_negx_negy, - sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy}); + Output neg_x_with_dep = ops::Neg( + s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x); + Output add_negx_with_dep_y = + ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y); + auto add_all = + ops::AddN(s.WithOpName("add_all"), + {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y, + sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y}); GrapplerItem item; item.fetch = {"add_all"}; @@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveNegation(&optimizer); - OptimizeAndPrune(&optimizer, &item, &output); + OptimizeTwice(&optimizer, &item, &output); EXPECT_EQ(item.graph.node_size(), output.node_size()); int found = 0; @@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { if (node.name() == "Add_negx_y") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_x", node.input(2)); } else if (node.name() == "Add_x_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); } else if (node.name() == "Add_negx_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ("Neg_y", node.input(0)); - EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_x", node.input(2)); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("Neg_x", node.input(0)); + EXPECT_EQ("y", node.input(1)); } else if (node.name() == "Sub_x_negy") { ++found; EXPECT_EQ("Add", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); } else if (node.name() == "Sub_negx_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(4, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); - EXPECT_EQ("^Neg_x", node.input(3)); + } else if (node.name() == "Add_negx_with_dep_y") { + ++found; + EXPECT_EQ("Sub", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("x", node.input(1)); + EXPECT_EQ("^Add_x_y", node.input(2)); } } - EXPECT_EQ(5, found); + EXPECT_EQ(6, found); auto tensors = EvaluateNodes(output, item.fetch, feed); EXPECT_EQ(1, tensors.size()); @@ -2468,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2}); auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2}); auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2}); + auto z = ops::Const(s.WithOpName("z"), {42.0f}, {}); + auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3}); + auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3}); Output out2 = ops::Pow(s.WithOpName("out2"), x, y2); Output out1 = ops::Pow(s.WithOpName("out1"), x, y1); Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5); @@ -2475,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5); Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1); Output out = ops::Pow(s.WithOpName("out"), x, y); + Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones); + Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros); GrapplerItem item; - item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"}; + item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", + "out_1", "out", "out_bcast1", "out_bcast2"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); - EXPECT_EQ(7, tensors_expected.size()); + EXPECT_EQ(9, tensors_expected.size()); GraphDef got; ArithmeticOptimizer optimizer; EnableOnlyConvertPow(&optimizer); OptimizeAndPrune(&optimizer, &item, &got); auto tensors = EvaluateNodes(got, item.fetch); - EXPECT_EQ(7, tensors.size()); + EXPECT_EQ(9, tensors.size()); - for (int i = 0; i < 7; ++i) { + for (int i = 0; i < tensors.size(); ++i) { EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements()); test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6); } @@ -2503,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("y_.5", "Const", {}, {}, &want); AddNode("y_1", "Const", {}, {}, &want); AddNode("y", "Const", {}, {}, &want); + AddNode("z", "Const", {}, {}, &want); + AddNode("ones", "Const", {}, {}, &want); + AddNode("zeros", "Const", {}, {}, &want); AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want); AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want); AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want); @@ -2511,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) { AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want); AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want); AddNode("out", "Pow", {"x", "y"}, {}, &want); + AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want); + AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want); CompareGraphs(want, got); } @@ -3248,6 +3265,48 @@ TEST_F(ArithmeticOptimizerTest, VerifyGraphsMatch(item.graph, output, __LINE__); } +TEST_F(ArithmeticOptimizerTest, + OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output neg = ops::Neg(s.WithOpName("neg"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + // Check if the inputs are switched + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "neg") { + EXPECT_EQ("Neg", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("reduce_max", node.input(0)); + ++required_node_count; + } else if (node.name() == "reduce_max") { + EXPECT_EQ("Min", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); +} + TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 99737a71eb..ca5d3a6dfd 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -32,8 +32,8 @@ limitations under the License. #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/evaluation_utils.h" -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -437,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { } namespace { -bool ShapesEqual(const TensorShapeProto& shape1, - const TensorShapeProto& shape2) { - if (shape1.unknown_rank() || shape2.unknown_rank()) { - return false; - } - if (shape1.dim_size() != shape2.dim_size()) { - return false; - } - for (int i = 0; i < shape1.dim_size(); ++i) { - if (shape1.dim(i).size() != shape2.dim(i).size()) { - return false; - } - if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) { - return false; - } - } - return true; -} - bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties, BCast::Vec* shape, int64* min_id) { if (shape_node.op() == "Shape") { @@ -2125,7 +2106,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { Tensor axis_t(DT_INT32, TensorShape({})); NodeDef* axis_node = optimized_graph->add_node(); axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); - const int axis = node->attr().at("axis").i(); + const int axis = + node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i(); if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) .ok()) { @@ -2348,7 +2330,8 @@ Status ConstantFolding::SimplifyArithmeticOperations( properties.GetInputProperties(node->name())[1].shape(); const bool x_is_zero = IsZeros(*x); const bool x_is_one = x_is_zero ? false : IsOnes(*x); - const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); + const bool y_matches_output_shape = + ShapesSymbolicallyEqual(output_shape, y_shape); if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { // 1 * y = y or 0 + y = y. @@ -2378,7 +2361,8 @@ Status ConstantFolding::SimplifyArithmeticOperations( properties.GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); const bool y_is_one = y_is_zero ? false : IsOnes(*y); - const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); + const bool x_matches_output_shape = + ShapesSymbolicallyEqual(output_shape, x_shape); if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 2a19b3f95a..b09360a2c2 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -3015,37 +3015,48 @@ TEST_F(ConstantFoldingTest, TrivialPack) { auto stack = ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x}, ops::Stack::Axis(1)); + auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x}); GrapplerItem item; TF_CHECK_OK(scope.ToGraphDef(&item.graph)); - item.fetch.push_back("stack"); + item.fetch = {"stack", "stack_no_axis"}; ConstantFolding optimizer(nullptr /* cpu_device */); GraphDef output; Status status = optimizer.Optimize(nullptr, item, &output); TF_EXPECT_OK(status); - EXPECT_EQ(5, output.node_size()); + EXPECT_EQ(7, output.node_size()); + int found = 0; for (const auto& node : output.node()) { if (node.name() == "stack") { - EXPECT_EQ("stack", node.name()); EXPECT_EQ("ExpandDims", node.op()); EXPECT_EQ(3, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1)); EXPECT_EQ("^y", node.input(2)); + ++found; + } else if (node.name() == "stack_no_axis") { + EXPECT_EQ("ExpandDims", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1)); + ++found; } else if (node.name() == "ConstantFolding/stack_const_axis") { EXPECT_EQ("Const", node.op()); EXPECT_EQ(1, node.input_size()); EXPECT_EQ("^x", node.input(0)); + ++found; } } + EXPECT_EQ(found, 3); - std::vector<string> fetch = {"stack"}; + std::vector<string> fetch = {"stack", "stack_no_axis"}; auto tensors_expected = EvaluateNodes(item.graph, fetch); auto tensors = EvaluateNodes(output, fetch); - EXPECT_EQ(1, tensors_expected.size()); - EXPECT_EQ(1, tensors.size()); + EXPECT_EQ(2, tensors_expected.size()); + EXPECT_EQ(2, tensors.size()); EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape()); + EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape()); } // The test does not evalute the optimized and original graphs to check if their diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 530c957068..d198a2a591 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -19,10 +19,10 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -32,6 +32,7 @@ tf_cc_test( visibility = ["//visibility:public"], deps = [ ":filter_fusion", + ":graph_test_utils", ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:test", @@ -50,14 +51,15 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":graph_utils", + ":function_utils", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core:lib_internal", ] + tf_protos_all(), @@ -68,6 +70,7 @@ tf_cc_test( srcs = ["fusion_utils_test.cc"], visibility = ["//visibility:public"], deps = [ + ":function_utils", ":fusion_utils", ":graph_utils", "//tensorflow/core:framework", @@ -79,6 +82,40 @@ tf_cc_test( ) cc_library( + name = "function_utils", + srcs = ["function_utils.cc"], + hdrs = [ + "function_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:utils", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "function_utils_test", + srcs = ["function_utils_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/tools/graph_transforms:transform_utils", + ], +) + +cc_library( name = "graph_utils", srcs = ["graph_utils.cc"], hdrs = [ @@ -107,11 +144,66 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/kernels:cast_op", ], ) cc_library( + name = "graph_test_utils", + testonly = 1, + srcs = ["graph_test_utils.cc"], + hdrs = [ + "graph_test_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:testlib", + ] + tf_protos_all(), +) + +cc_library( + name = "hoist_random_uniform", + srcs = ["hoist_random_uniform.cc"], + hdrs = [ + "hoist_random_uniform.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":graph_utils", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "hoist_random_uniform_test", + srcs = ["hoist_random_uniform_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_test_utils", + ":graph_utils", + ":hoist_random_uniform", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ] + tf_protos_all(), +) + +cc_library( name = "latency_all_edges", srcs = ["latency_all_edges.cc"], hdrs = [ @@ -139,7 +231,9 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":function_utils", ":graph_utils", + ":vectorization_utils", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", "//tensorflow/core/grappler:grappler_item", @@ -164,7 +258,6 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work. ], ) @@ -221,7 +314,7 @@ cc_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core:ptr_util", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -230,6 +323,7 @@ tf_cc_test( srcs = ["map_and_filter_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_and_filter_fusion", "//tensorflow/core:framework", @@ -256,10 +350,10 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -268,6 +362,7 @@ tf_cc_test( srcs = ["map_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_fusion", "//tensorflow/core:framework", @@ -275,6 +370,44 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/kernels:control_flow_ops", + ], +) + +cc_library( + name = "map_parallelization", + srcs = ["map_parallelization.cc"], + hdrs = [ + "map_parallelization.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "map_parallelization_test", + srcs = ["map_parallelization_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_test_utils", + ":graph_utils", + ":map_parallelization", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", ], ) @@ -351,10 +484,12 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":filter_fusion", + ":hoist_random_uniform", ":latency_all_edges", ":map_and_batch_fusion", ":map_and_filter_fusion", ":map_fusion", + ":map_parallelization", ":map_vectorization", ":noop_elimination", ":shuffle_and_repeat_fusion", @@ -375,3 +510,43 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", ], ) + +cc_library( + name = "vectorization_utils", + srcs = ["vectorization_utils.cc"], + hdrs = [ + "vectorization_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":graph_utils", + "@com_google_absl//absl/strings", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/optimizers/data/vectorization", + "//tensorflow/core/grappler/utils:functions", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "vectorization_utils_test", + srcs = ["vectorization_utils_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":vectorization_utils", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core/kernels:cast_op", + "//tensorflow/tools/graph_transforms:transform_utils", + ] + tf_protos_all(), +) diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc index c71aa6e804..1ad495bbad 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node, fused_node.set_op("FilterDataset"); fused_node.add_input(first_filter_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = first_filter_node.attr().at("predicate"); *attr.mutable_func()->mutable_name() = fused_function.signature().name(); (*fused_node.mutable_attr())["predicate"] = std::move(attr); - copy_attribute("Targuments", first_filter_node, &fused_node); + graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, second_filter_node, &fused_node); + graph_utils::CopyAttribute(key, second_filter_node, &fused_node); return fused_node; } @@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // functions, or make sure that optimization passes run after filter // fusion. TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate)); - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. nodes_to_delete.insert(first_filter_node->name()); nodes_to_delete.insert(second_filter_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc index 12b1924efd..c8becc5cc0 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,14 +28,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "FilterDataset", {string(input_node_name)}, - {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeFilterNode; TEST(FilterFusionTest, FuseTwoFilterIntoOne) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc new file mode 100644 index 0000000000..e3f6d8e1ea --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -0,0 +1,175 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { + +FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name, + const string& output, int position) + : node_name(node_name), node_output(output), position(position) { + full_str = strings::StrCat(node_name, ":", node_output, ":", position); +} + +FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) { + // Parses node_name:node_output:position string into its components. + full_str = input; + StringPiece capture; + StringPiece remaining; + + // Parse "node_name" + if (strings::Scanner(input) + .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) + .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_name = string(capture.data(), capture.size()); + } + + // Parse "node_output" if it exists + if (strings::Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .One(strings::Scanner::LETTER) + .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE) + .GetResult(&remaining, &capture)) { + node_output = string(capture.data(), capture.size()); + } + + // Parse "position" if it exists + if (strings::Scanner(remaining) + .OneLiteral(":") + .RestartCapture() + .Many(strings::Scanner::DIGIT) + .GetResult(nullptr, &capture)) { + CHECK(strings::safe_strto32(capture, &position)); + } +} + +// TODO(rachelim): Create a utility class similar to MutableGraphView for +// FunctionDefs, and use that to manipulate functions. It'll be more +// performant if we kept mappings of nodes->inputs/outputs, so that we don't +// have to search over all nodes each time. +// Note that we're not using GrapplerFunctionItem because it doesn't cover +// some of our desired uses (eg changing the outputs of a function), and the +// FunctionDef -> GraphDef conversion isn't really necessary in this case. +void ReplaceReferences(const string& from, const string& to, + FunctionDef* func) { + for (NodeDef& n : *func->mutable_node_def()) { + std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from, + to); + } + + for (auto& p : *func->mutable_ret()) { + if (p.second == from) { + p.second = to; + } + } +} + +void AddFunctionOutputWithUniqueName(StringPiece prefix, + StringPiece output_tensor_name, + FunctionDef* function, DataType dt) { + string name = string(prefix); + int id = function->signature().output_arg_size(); + while (ContainsFunctionOutputWithName(name, *function)) { + name = strings::StrCat(prefix, "/_", id); + ++id; + } + auto* output = function->mutable_signature()->mutable_output_arg()->Add(); + output->set_name(name); + output->set_type(dt); + + (*function->mutable_ret())[name] = string(output_tensor_name); +} + +NodeDef* AddNode(StringPiece name, StringPiece op, + const std::vector<string>& inputs, + const std::vector<std::pair<string, AttrValue>>& attributes, + FunctionDef* fd) { + NodeDef* node = fd->add_node_def(); + if (!name.empty()) { + node->set_name(string(name)); + } else { + SetUniqueFunctionNodeName(op, fd, node); + } + node->set_op(string(op)); + for (const string& input : inputs) { + node->add_input(input); + } + for (auto attr : attributes) { + (*node->mutable_attr())[attr.first] = attr.second; + } + return node; +} + +bool ContainsFunctionNodeWithName(StringPiece name, + const FunctionDef& function) { + return FindFunctionNodeWithName(name, function) != -1; +} + +bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { + return FindFunctionNodeWithOp(op, function) != -1; +} + +bool ContainsFunctionOutputWithName(StringPiece name, + const FunctionDef& function) { + return FindFunctionOutputWithName(name, function) != -1; +} + +int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) { + return graph_utils::GetFirstElementIndexWithPredicate( + [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, + function.signature().input_arg()); +} + +int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) { + return graph_utils::GetFirstElementIndexWithPredicate( + [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, + function.signature().output_arg()); +} + +int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { + return graph_utils::GetFirstElementIndexWithPredicate( + [&name](const NodeDef& node) { return node.name() == name; }, + function.node_def()); +} + +int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { + return graph_utils::GetFirstElementIndexWithPredicate( + [&op](const NodeDef& node) { return node.op() == op; }, + function.node_def()); +} + +void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, + NodeDef* node) { + string name = string(prefix); + int id = function->node_def_size(); + while (ContainsFunctionNodeWithName(name, *function)) { + name = strings::StrCat(prefix, "/_", id); + ++id; + } + node->set_name(std::move(name)); +} + +} // end namespace function_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h new file mode 100644 index 0000000000..d4ce824652 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/function_utils.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { +// This namespace contains utility functions for querying and modifying +// FunctionDefs. + +// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings +// have the format node_name:node_output:position (if they derive from nodes), +// or input_name (if they derive from an argument). +struct FunctionDefTensorDesc { + FunctionDefTensorDesc() = default; + + FunctionDefTensorDesc(const string& node_name, const string& output, + int position); + + // Parses node_name:node_output:position string into its components. + explicit FunctionDefTensorDesc(const string& input); + + // TODO(rachelim): Add provisions to deal with special formats, like how + // GrapplerFunctionItem expands node output range if position is not defined + string full_str; + string node_name; + string node_output; + int position = -1; +}; + +// Replaces all references to `from` tensor in func's nodes' inputs and retvals +// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`. +void ReplaceReferences(const string& from, const string& to, FunctionDef* func); + +// Adds a function output to the function def, ensuring that the output key +// is unique, and maps to output_tensor_name in the ret dict. +void AddFunctionOutputWithUniqueName(StringPiece prefix, + StringPiece output_tensor_name, + FunctionDef* function, DataType dt); + +// Adds a node to a FunctionDef. +NodeDef* AddNode(StringPiece name, StringPiece op, + const std::vector<string>& inputs, + const std::vector<std::pair<string, AttrValue>>& attributes, + FunctionDef* fd); + +// Checks whether the function contains a node with the given name. +bool ContainsFunctionNodeWithName(StringPiece name, + const FunctionDef& function); + +// Checks whether the function contains a node with the given op. +bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); + +// Checks whether the function contains an output with the given name. +bool ContainsFunctionOutputWithName(StringPiece name, + const FunctionDef& function); + +// Returns the index of the function input with the given name or -1 if the +// function node does not exist. +int FindFunctionInputWithName(StringPiece name, const FunctionDef& function); + +// Returns the index of the function output with the given name or -1 if the +// function node does not exist. +int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function); + +// Returns the index of the function node with the given name or -1 if the +// function node does not exist. +int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); + +// Returns the index of the function node with the given op or -1 if the +// function node does not exist. +int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); + +// Sets the function node name using the `prefix` as a prefix while guaranteeing +// the name is unique across the functions nodes. +void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, + NodeDef* node); + +} // end namespace function_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc new file mode 100644 index 0000000000..3739e20eb1 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc @@ -0,0 +1,164 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace grappler { +namespace function_utils { +namespace { + +TEST(FunctionDefTensorDesc, Parsing) { + FunctionDefTensorDesc f("Cast:y:0"); + EXPECT_EQ(f.full_str, "Cast:y:0"); + EXPECT_EQ(f.node_name, "Cast"); + EXPECT_EQ(f.node_output, "y"); + EXPECT_EQ(f.position, 0); + + FunctionDefTensorDesc f2("Arg0"); + EXPECT_EQ(f2.full_str, "Arg0"); + EXPECT_EQ(f2.node_name, "Arg0"); + EXPECT_EQ(f2.node_output, ""); + EXPECT_EQ(f2.position, -1); +} + +TEST(ReplaceReferencesTest, ReplaceReferencesTest) { + FunctionDef outer = FunctionDefHelper::Create( + "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {}, + {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}}); + NodeDef* derive_node = + AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer); + // Check that both the input to "X" and retval of "outer" are replaced. + ReplaceReferences("MapDefun:output:0", "arg0", &outer); + EXPECT_EQ(outer.ret().at("out"), "arg0"); + EXPECT_EQ(derive_node->input(0), "arg0"); +} + +TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) { + FunctionDef function = test::function::XTimesTwo(); + AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64); + EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function)); + EXPECT_EQ(function.ret().at("y/_1"), "two"); +} + +TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_FALSE(ContainsFunctionNodeWithName( + "weird_name_that_should_not_be_there", function)); + EXPECT_TRUE(ContainsFunctionNodeWithName("two", function)); +} + +TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there", + function)); + EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function)); +} + +TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_TRUE(ContainsFunctionOutputWithName("y", function)); + EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function)); +} + +TEST(FunctionUtilsTest, FindFunctionNodeWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ( + FindFunctionNodeWithName("weird_name_that_should_not_be_there", function), + -1); + EXPECT_NE(FindFunctionNodeWithName("two", function), -1); +} + +TEST(FunctionUtilsTest, FindFunctionNodeWithOp) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ( + FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function), + -1); + EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1); +} + +TEST(FunctionUtilsTest, FindFunctionInputWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ(FindFunctionInputWithName("x", function), 0); + EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1); +} + +TEST(FunctionUtilsTest, FindFunctionOutputWithName) { + FunctionDef function = test::function::XTimesTwo(); + EXPECT_EQ(FindFunctionOutputWithName("y", function), 0); + EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1); +} + +TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) { + FunctionDef function = test::function::XTimesTwo(); + NodeDef node; + SetUniqueFunctionNodeName("abc", &function, &node); + for (const NodeDef& function_node : function.node_def()) { + EXPECT_NE(node.name(), function_node.name()); + } + auto* new_node = function.add_node_def(); + *new_node = node; + + NodeDef other; + SetUniqueFunctionNodeName("abc", &function, &other); + EXPECT_NE(other.name(), new_node->name()); +} + +TEST(FunctionUtilsTest, AddNodeToFunctionDef) { + FunctionDef func; + const char* op_name = "xxx"; + AddNode(op_name, op_name, {}, {}, &func); + + const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func)); + EXPECT_EQ(node1.op(), op_name); + EXPECT_EQ(node1.input_size(), 0); + EXPECT_EQ(node1.attr_size(), 0); + + const std::vector<string> inputs({"input1", "input2"}); + AddNode("", op_name, inputs, {}, &func); + const NodeDef& node2 = + func.node_def(FindFunctionNodeWithName("xxx/_2", func)); + EXPECT_EQ(node2.op(), op_name); + EXPECT_EQ(node2.attr_size(), 0); + EXPECT_EQ(node2.input_size(), inputs.size()); + for (size_t i = 0; i < inputs.size(); ++i) { + EXPECT_EQ(node2.input(i), inputs[i]); + } + + AttrValue a1, a2; + a1.set_type(DT_INT32); + a2.set_type(DT_INT64); + const std::vector<std::pair<string, AttrValue>> attrs( + {{"attr1", a1}, {"attr2", a2}}); + AddNode("", op_name, {}, attrs, &func); + const NodeDef& node3 = + func.node_def(FindFunctionNodeWithName("xxx/_3", func)); + EXPECT_EQ(node3.op(), op_name); + EXPECT_EQ(node3.input_size(), 0); + EXPECT_EQ(node3.attr_size(), attrs.size()); + for (size_t i = 0; i < attrs.size(); ++i) { + EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type()); + } +} + +} // namespace +} // namespace function_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc index 01a78c04b0..b3bfee138f 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/gtl/flatmap.h" @@ -407,7 +408,7 @@ void LazyConjunctionNodes(const FunctionDef& first_function, auto* if_node = fused_function->add_node_def(); // This is guaranteed to succeed. TF_CHECK_OK(if_builder.Finalize(if_node)); - graph_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node); + function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node); GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0"; } diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc index d5c6466080..e667affeea 100644 --- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -110,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) { CheckUniqueNames(*fused_function); ASSERT_TRUE( - graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function)); + function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function)); const auto &equal_node = fused_function->node_def( - graph_utils::FindFunctionNodeWithOp("Equal", *fused_function)); + function_utils::FindFunctionNodeWithOp("Equal", *fused_function)); EXPECT_EQ(xtimes_two->signature().output_arg(0).name(), fused_function->signature().output_arg(0).name()); diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc new file mode 100644 index 0000000000..b2eec7220e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapDataset", {string(input_node_name)}, + {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<DataType>{}}}); +} + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "FilterDataset", {string(input_node_name)}, + {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<TensorShape>{}}}); +} + +} // end namespace graph_tests_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h new file mode 100644 index 0000000000..ca0fde997d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name = "XTimesTwo"); + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name = "IsZero"); + +} // end namespace graph_tests_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 5a7fe19265..3eaaf8fbef 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -88,6 +88,16 @@ NodeDef* AddScalarConstNodeHelper( } // namespace +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) { + NodeDef node; + node.set_op("Placeholder"); + SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node); + (*node.mutable_attr())["dtype"].set_type(dtype); + TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape(); + shape->set_unknown_rank(false); + return graph->AddNode(std::move(node)); +} + NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, const std::vector<std::pair<string, AttrValue>>& attributes, @@ -108,26 +118,6 @@ NodeDef* AddNode(StringPiece name, StringPiece op, return graph->AddNode(std::move(node)); } -NodeDef* AddNode(StringPiece name, StringPiece op, - const std::vector<string>& inputs, - const std::vector<std::pair<string, AttrValue>>& attributes, - FunctionDef* fd) { - NodeDef* node = fd->add_node_def(); - if (!name.empty()) { - node->set_name(string(name)); - } else { - SetUniqueFunctionNodeName(op, fd, node); - } - node->set_op(string(op)); - for (const string& input : inputs) { - node->add_input(input); - } - for (auto attr : attributes) { - (*node->mutable_attr())[attr.first] = attr.second; - } - return node; -} - template <> NodeDef* AddScalarConstNode(bool v, MutableGraphView* graph) { return AddScalarConstNodeHelper( @@ -196,6 +186,11 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) { return true; } +bool ContainsGraphFunctionWithName(StringPiece name, + const FunctionDefLibrary& library) { + return FindGraphFunctionWithName(name, library) != -1; +} + bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) { return FindGraphNodeWithName(name, graph) != -1; } @@ -204,31 +199,24 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) { return FindGraphNodeWithOp(op, graph) != -1; } -bool ContainsGraphFunctionWithName(StringPiece name, - const FunctionDefLibrary& library) { - return FindGraphFunctionWithName(name, library) != -1; -} - -bool ContainsFunctionNodeWithName(StringPiece name, - const FunctionDef& function) { - return FindFunctionNodeWithName(name, function) != -1; -} - -bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - return FindFunctionNodeWithOp(op, function) != -1; +int FindGraphFunctionWithName(StringPiece name, + const FunctionDefLibrary& library) { + return GetFirstElementIndexWithPredicate( + [&name](const FunctionDef& function) { + return function.signature().name() == name; + }, + library.function()); } int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } std::vector<int> FindAllGraphNodesWithOp(const string& op, @@ -237,31 +225,6 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); } -int FindGraphFunctionWithName(StringPiece name, - const FunctionDefLibrary& library) { - std::vector<int> indices = GetElementIndicesWithPredicate( - [&name](const FunctionDef& function) { - return function.signature().name() == name; - }, - library.function()); - return indices.empty() ? -1 : indices.front(); -} - -int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( - [&name](const NodeDef& node) { return node.name() == name; }, - function.node_def()); - return indices.empty() ? -1 : indices.front(); -} - -int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( - [&op](const NodeDef& node) { return node.op() == op; }, - function.node_def()); - - return indices.empty() ? -1 : indices.front(); -} - NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) { if (node.input_size() == 0) return nullptr; GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0); @@ -273,7 +236,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, string name = string(prefix); int id = graph->node_size(); while (ContainsGraphNodeWithName(name, *graph)) { - if (name.rfind("_generated") != std::string::npos && + if (name.rfind("_generated") != string::npos && (name.rfind("_generated") == (name.size() - strlen("_generated")))) { name.insert(name.rfind("_generated"), strings::StrCat("/_", id)); } else { @@ -284,17 +247,6 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, node->set_name(std::move(name)); } -void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, - NodeDef* node) { - string name = string(prefix); - int id = function->node_def_size(); - while (ContainsFunctionNodeWithName(name, *function)) { - name = strings::StrCat(prefix, "/_", id); - ++id; - } - node->set_name(std::move(name)); -} - void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function) { string name = string(prefix); @@ -306,6 +258,20 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, function->mutable_signature()->set_name(std::move(name)); } +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node) { + (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name); +} + +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node) { + CopyAttribute(attribute_name, first, to_node); + (*to_node->mutable_attr()) + .at(attribute_name) + .mutable_list() + ->MergeFrom(second.attr().at(attribute_name).list()); +} + } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index 6f431c232d..5dd7819100 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -31,17 +31,29 @@ namespace tensorflow { namespace grappler { namespace graph_utils { +// Returns the index of the first element in collection that fulfills predicate. +// If no such element exists, returns -1. +template <typename Predicate, typename Collection> +int GetFirstElementIndexWithPredicate(const Predicate& predicate, + const Collection& collection) { + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + return idx; + } + idx++; + } + return -1; +} + // Adds a node to the graph. NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, const std::vector<std::pair<string, AttrValue>>& attributes, MutableGraphView* graph); -// Adds a node to a FunctionDef. -NodeDef* AddNode(StringPiece name, StringPiece op, - const std::vector<string>& inputs, - const std::vector<std::pair<string, AttrValue>>& attributes, - FunctionDef* fd); +// Adds Placeholder node for given type. +NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph); // Adds a Const node with the given value to the graph. template <typename T> @@ -76,13 +88,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph); bool ContainsGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library); -// Checks whether the function contains a node with the given name. -bool ContainsFunctionNodeWithName(StringPiece name, - const FunctionDef& function); - -// Checks whether the function contains a node with the given op. -bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function); - // Checks whether the graph contains a node with the given op. bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph); @@ -95,14 +100,6 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph); int FindGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library); -// Returns the index of the function node with the given name or -1 if the -// function node does not exist. -int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function); - -// Returns the index of the function node with the given op or -1 if the -// function node does not exist. -int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function); - // Returns the index of the first node with the given op or -1 if no such node // exists. int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph); @@ -119,16 +116,21 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, // is unique across the graph. void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node); -// Sets the function node name using the `prefix` as a prefix while guaranteeing -// the name is unique across the functions nodes. -void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, - NodeDef* node); - // Sets the node name using the `prefix` name as a prefix while guaranteeing the // name is unique across the graph. void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function); +// Copies attribute having name `attribute_name` from node `from` to node +// `to_node`. +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node); + +// Concatenates list attribute having name `attribute_name` from `first` and +// `second` node, setting it to `to_node`. +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node); + } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index c19ac7b880..db986542b2 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -24,6 +24,18 @@ namespace grappler { namespace graph_utils { namespace { +TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) { + std::vector<int> vec({1, 2, 3, 4, 5, 6}); + auto result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 3 == 0; }, vec); + + EXPECT_EQ(result, 2); + + result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 7 == 0; }, vec); + EXPECT_EQ(result, -1); +} + TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph_def; MutableGraphView graph(&graph_def); @@ -112,20 +124,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) { ContainsGraphFunctionWithName(new_function->signature().name(), library)); } -TEST(GraphUtilsTest, ContainsFunctionNodeWithName) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_FALSE(ContainsFunctionNodeWithName( - "weird_name_that_should_not_be_there", function)); - EXPECT_TRUE(ContainsFunctionNodeWithName("two", function)); -} - -TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there", - function)); - EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function)); -} - TEST(GraphUtilsTest, ContainsNodeWithOp) { GraphDef graph_def; MutableGraphView graph(&graph_def); @@ -150,22 +148,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) { EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1); } -TEST(GraphUtilsTest, FindFunctionNodeWithName) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_EQ( - FindFunctionNodeWithName("weird_name_that_should_not_be_there", function), - -1); - EXPECT_NE(FindFunctionNodeWithName("two", function), -1); -} - -TEST(GraphUtilsTest, FindFunctionNodeWithOp) { - FunctionDef function = test::function::XTimesTwo(); - EXPECT_EQ( - FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function), - -1); - EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1); -} - TEST(GraphUtilsTest, FindGraphFunctionWithName) { FunctionDefLibrary library; EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1); @@ -225,21 +207,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) { EXPECT_NE(node2->name(), node3->name()); } -TEST(GraphUtilsTest, SetUniqueFunctionNodeName) { - FunctionDef function = test::function::XTimesTwo(); - NodeDef node; - SetUniqueFunctionNodeName("abc", &function, &node); - for (const NodeDef& function_node : function.node_def()) { - EXPECT_NE(node.name(), function_node.name()); - } - auto* new_node = function.add_node_def(); - *new_node = node; - - NodeDef other; - SetUniqueFunctionNodeName("abc", &function, &other); - EXPECT_NE(other.name(), new_node->name()); -} - TEST(GraphUtilsTest, SetUniqueGraphFunctionName) { FunctionDefLibrary library; FunctionDef* new_function = library.add_function(); @@ -251,43 +218,6 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) { other_function->signature().name()); } -TEST(GraphUtilsTest, AddNodeToFunctionDef) { - FunctionDef func; - const char* op_name = "xxx"; - AddNode(op_name, op_name, {}, {}, &func); - - const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func)); - EXPECT_EQ(node1.op(), op_name); - EXPECT_EQ(node1.input_size(), 0); - EXPECT_EQ(node1.attr_size(), 0); - - const std::vector<string> inputs({"input1", "input2"}); - AddNode("", op_name, inputs, {}, &func); - const NodeDef& node2 = - func.node_def(FindFunctionNodeWithName("xxx/_2", func)); - EXPECT_EQ(node2.op(), op_name); - EXPECT_EQ(node2.attr_size(), 0); - EXPECT_EQ(node2.input_size(), inputs.size()); - for (size_t i = 0; i < inputs.size(); ++i) { - EXPECT_EQ(node2.input(i), inputs[i]); - } - - AttrValue a1, a2; - a1.set_type(DT_INT32); - a2.set_type(DT_INT64); - const std::vector<std::pair<string, AttrValue>> attrs( - {{"attr1", a1}, {"attr2", a2}}); - AddNode("", op_name, {}, attrs, &func); - const NodeDef& node3 = - func.node_def(FindFunctionNodeWithName("xxx/_3", func)); - EXPECT_EQ(node3.op(), op_name); - EXPECT_EQ(node3.input_size(), 0); - EXPECT_EQ(node3.attr_size(), attrs.size()); - for (size_t i = 0; i < attrs.size(); ++i) { - EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type()); - } -} - TEST(GraphUtilsTest, GetInputNode) { GraphDef graph_def; MutableGraphView graph(&graph_def); diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc new file mode 100644 index 0000000000..ce0b2db039 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node, + const FunctionDef& stateless_function, + MutableGraphView* graph) { + NodeDef stateless_map; + graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(), + &stateless_map); + + stateless_map.set_op("MapDataset"); + stateless_map.add_input(zip_node.name()); + // Add placeholders. + for (int i = 1; i < map_node.input_size(); i++) + stateless_map.add_input(map_node.input(i)); + + auto attr = map_node.attr().at("f"); + *attr.mutable_func()->mutable_name() = stateless_function.signature().name(); + *attr.mutable_func()->mutable_attr() = stateless_function.attr(); + (*stateless_map.mutable_attr())["f"] = std::move(attr); + + graph_utils::CopyAttribute("Targuments", map_node, &stateless_map); + for (auto key : {"output_shapes", "output_types"}) + graph_utils::CopyAttribute(key, map_node, &stateless_map); + + if (const auto* attr = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) + (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr; + + return stateless_map; +} + +NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, + MutableGraphView* graph) { + NodeDef random_dataset; + random_dataset.set_op("RandomDataset"); + graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(), + &random_dataset); + + const auto* seed = graph_utils::AddScalarConstNode<int64>( + random_uniform_node.attr().at("seed").i(), graph); + const auto* seed2 = graph_utils::AddScalarConstNode<int64>( + random_uniform_node.attr().at("seed2").i(), graph); + + random_dataset.add_input(seed->name()); + random_dataset.add_input(seed2->name()); + + (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape(); + (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type( + DT_INT64); + + return random_dataset; +} + +NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { + NodeDef batch_dataset; + batch_dataset.set_op("BatchDatasetV2"); + graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(), + &batch_dataset); + const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph); + const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph); + batch_dataset.add_input(random_dataset.name()); + batch_dataset.add_input(batch_size->name()); + batch_dataset.add_input(drop_reminder->name()); + + (*batch_dataset.mutable_attr())["output_shapes"] + .mutable_list() + ->add_shape() + ->mutable_dim() + ->Add() + ->set_size(-1); + (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type( + DT_INT64); + + return batch_dataset; +} + +NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, + MutableGraphView* graph) { + NodeDef zip_node; + graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(), + &zip_node); + + zip_node.set_op("ZipDataset"); + zip_node.add_input(first_node.name()); + zip_node.add_input(second_node.name()); + + for (auto key : {"output_shapes", "output_types"}) + graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node); + + (*zip_node.mutable_attr())["N"].set_i(2); + + return zip_node; +} + +// We need to insert our argument before the placeholders, which are the last +// arguments. +OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) { + int new_argument_idx = signature->input_arg_size() - num_placeholders; + signature->add_input_arg(); + for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) { + signature->mutable_input_arg()->SwapElements(i - 1, i); + } + auto* seed_arg = signature->mutable_input_arg(new_argument_idx); + seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx)); + seed_arg->set_type(DT_INT64); + + return seed_arg; +} + +// Make function that uses `StatelessRandomUniform` instead of `RandomUniform` +// to make it less statefull. The function can still be stateful, but in when +// other stateful ops are e.g. `Assert`, then it will be parallelizable. +const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function, + bool is_stateful, + int num_placeholders, + FunctionDefLibrary* library) { + FunctionDef* stateless_function = library->add_function(); + *stateless_function = map_function; + if (is_stateful) + stateless_function->mutable_signature()->set_is_stateful(is_stateful); + graph_utils::SetUniqueGraphFunctionName("stateless_function", library, + stateless_function); + + auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(), + num_placeholders); + + auto* const random_uniform = stateless_function->mutable_node_def( + function_utils::FindFunctionNodeWithOp("RandomUniform", + *stateless_function)); + + // Replace RandomUniform node with StatelessRandomUniform. + random_uniform->set_op("StatelessRandomUniform"); + random_uniform->add_input(seed_arg->name()); + (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64); + random_uniform->mutable_attr()->erase("seed"); + random_uniform->mutable_attr()->erase("seed2"); + + return stateless_function; +} +// This function returns true if function is stateful and has single +// RandomUniform op and no other stateful ops except Assert. +// `is_stateful_after_hoisting` is set to true if RandomUniform is the only +// stateful op and hoisting can be performed. +bool CanHoistRandomUniform(const FunctionDef& map_function, + const FunctionLibraryDefinition& library, + bool* is_stateful_after_hoisting, + const NodeDef** random_uniform_op) { + if (!map_function.signature().is_stateful()) return false; + *is_stateful_after_hoisting = true; + + bool have_other_stateful_ops = false; + + for (const auto& node : map_function.node_def()) { + const OpDef* op_def; + TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def)); + // Skip stateless nodes and assert, as it does not actually have a state. + if (!op_def->is_stateful()) continue; + + if (op_def->name() == "Assert") { + have_other_stateful_ops = true; + continue; + } + + // TODO(prazek): For now we only handle RandomUniform, we should handle + // RandomUniformInt as well. + if (op_def->name() != "RandomUniform") return false; + + // TODO(prazek): For now we can only hoist single RandomUniform. + if (*random_uniform_op != nullptr) return false; + + *random_uniform_op = &node; + } + + if (!have_other_stateful_ops) *is_stateful_after_hoisting = false; + + // Have we found single RandomUniform? + return *random_uniform_op != nullptr; +} + +int NumberOfPlaceholders(const NodeDef& map_node) { + // First input of MapDataset is the argument to the function. Rest of the + // inputs are placeholders. + return map_node.input_size() - 1; +} + +} // namespace + +Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + + auto get_map_node = [](const NodeDef& node) -> const NodeDef* { + // TODO(prazek): we could also handle ParallelMapDataset and + // MapAndBatchDataset. + if (node.op() == "MapDataset") return &node; + return nullptr; + }; + + for (const NodeDef& node : item.graph.node()) { + const NodeDef* map_node = get_map_node(node); + if (!map_node) continue; + + const auto& fun = map_node->attr().at("f"); + const FunctionDef* func = function_library.Find(fun.func().name()); + + const NodeDef* random_uniform_op = nullptr; + bool is_stateful_after_hoisting = true; + if (!CanHoistRandomUniform(*func, function_library, + &is_stateful_after_hoisting, &random_uniform_op)) + continue; + const auto* random_seed_dataset = + graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph)); + + const auto* batch_dataset = + graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph)); + + const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph); + + const auto* zip_node = + graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph)); + + const auto* stateless_func = MakeLessStatefulFunction( + *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node), + output->mutable_library()); + + const auto* stateless_map = graph.AddNode( + MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph)); + + graph.ReplaceInput(*map_node, *stateless_map); + + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. + nodes_to_delete.insert(map_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h new file mode 100644 index 0000000000..d1bcf6782d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimization hoists instances of `random_uniform` out of a function +// with the aim of making it stateless. It creates a new function that takes a +// random seed as an extra argument and uses `stateless_random_uniform` instead +// of `random_uniform` to make it stateless. +// It also creates RandomDataset(seed).batch(2), which is zipped with old input +// to the map. The batching in RandomDataset is because we need 2 seeds for +// `stateless_random_uniform`. +// TODO(prazek): for now only `RandomUniform` is handled, but we could handle +// `RandomUniformInt` similarly. +class HoistRandomUniform : public CustomGraphOptimizer { + public: + HoistRandomUniform() = default; + ~HoistRandomUniform() override = default; + + string name() const override { return "hoist_random_uniform"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc new file mode 100644 index 0000000000..455459e3f6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(HoistRandomUniform, SimpleHoisting) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, + {{"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<DataType>{}}}), + graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"), + NDef("cache", "CacheDataset", {"map1", "filename"}, {})}, + // FunctionLib + { + test::function::RandomUniform(), + }); + + HoistRandomUniform optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output)); + const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output); + const int zip_dataset_id = + graph_utils::FindGraphNodeWithOp("ZipDataset", output); + const int random_dataset_id = + graph_utils::FindGraphNodeWithOp("RandomDataset", output); + const int batch_random_id = + graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output); + ASSERT_NE(random_dataset_id, -1); + ASSERT_NE(zip_dataset_id, -1); + ASSERT_NE(new_map_id, -1); + ASSERT_NE(batch_random_id, -1); + + const auto& new_map = output.node(new_map_id); + const auto& zip = output.node(zip_dataset_id); + const auto& random = output.node(random_dataset_id); + const auto& batch = output.node(batch_random_id); + + ASSERT_EQ(new_map.input_size(), 1); + EXPECT_EQ(new_map.input(0), zip.name()); + + ASSERT_EQ(zip.input_size(), 2); + EXPECT_EQ(zip.input(0), "range"); + EXPECT_EQ(zip.input(1), batch.name()); + + ASSERT_EQ(batch.input_size(), 3); + EXPECT_EQ(batch.input(0), random.name()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 63945b8b9e..e66766eb23 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, // Set `f` and `Targuments` attributes. for (auto key : {"f", "Targuments"}) { - (*new_node.mutable_attr())[key] = map_node.attr().at(key); + graph_utils::CopyAttribute(key, map_node, &new_node); } + // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node.mutable_attr())[key] = batch_node.attr().at(key); + graph_utils::CopyAttribute(key, batch_node, &new_node); } return new_node; } diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index f1844a141c..c4868eacbb 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node, fused_node.set_op("MapDataset"); fused_node.add_input(map_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = map_node.attr().at("f"); attr.mutable_func()->set_name(fused_function.signature().name()); (*fused_node.mutable_attr())["f"] = std::move(attr); - copy_attribute("Targuments", map_node, &fused_node); + graph_utils::CopyAttribute("Targuments", map_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, map_node, &fused_node); + graph_utils::CopyAttribute(key, map_node, &fused_node); + + if (const auto* attr = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) + (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr; // Add the predicate output attributes. (*fused_node.mutable_attr())["output_types"] diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc index f029a093fa..6e6da37d7c 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -27,24 +28,8 @@ limitations under the License. namespace tensorflow { namespace grappler { namespace { - -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} - -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "FilterDataset", {string(input_node_name)}, - {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeFilterNode; +using graph_tests_utils::MakeMapNode; TEST(MapAndFilterFusionTest, FuseMapAndFilter) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index a78ecb09f7..bd943342e8 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node, NodeDef fused_node; graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), &fused_node); - fused_node.set_op("MapDataset"); fused_node.add_input(parent_map_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = parent_map_node.attr().at("f"); *attr.mutable_func()->mutable_name() = fused_function.signature().name(); (*fused_node.mutable_attr())["f"] = std::move(attr); - copy_attribute("Targuments", parent_map_node, &fused_node); - + graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, map_node, &fused_node); + graph_utils::CopyAttribute(key, map_node, &fused_node); + auto value_or_false = [](const AttrValue* attr) { + if (!attr) return false; + return attr->b(); + }; + + const auto* first_parallelism = + gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism"); + const auto* second_parallelism = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"); + // Some graphs cannot execute with use_inter_op_parallelism=False, so we need + // to set it to true if one of the ops have it set to true. + if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) { + (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true); + } return fused_node; } @@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // fusion. TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function)); - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. nodes_to_delete.insert(parent_map_node->name()); nodes_to_delete.insert(map_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc index b25dfbd0b8..8889f9dddd 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -28,14 +29,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeMapNode; TEST(MapFusionTest, FuseTwoMapNodesIntoOne) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc new file mode 100644 index 0000000000..782c9f48b7 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -0,0 +1,103 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { +namespace { + +bool CanParallelize(const FunctionDef& function, + const FunctionLibraryDefinition& library) { + if (!function.signature().is_stateful()) return true; + + for (const auto& node : function.node_def()) { + const OpDef* op_def; + TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def)); + // Assert is marked as stateful, but it does not have any state (except + // changing io). Similarly to CUDA, we do not give guarantee that the + // assert operation that would fail would be the first one, so that we can + // parallelize it. + if (op_def->is_stateful() && op_def->name() != "Assert") return false; + } + + return true; +} + +NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) { + NodeDef parallel_map = map_node; + graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(), + ¶llel_map); + parallel_map.set_op("ParallelMapDataset"); + // TODO(b/114475558): We want to set `num_parallel_calls` to a special value, + // so that dynamic tunning will pick the optimal value at runtime. Because + // this feature is not yet implemented, we set it to 2, which is the smallest + // value that introduces parallelism. + auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph); + parallel_map.add_input(num_parallel_calls->name()); + + return parallel_map; +} + +} // namespace + +Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + auto get_map_node = [](const NodeDef& node) -> const NodeDef* { + if (node.op() == "MapDataset") return &node; + return nullptr; + }; + + for (const NodeDef& node : item.graph.node()) { + const NodeDef* map_node = get_map_node(node); + if (!map_node) continue; + + auto* function = + function_library.Find(map_node->attr().at("f").func().name()); + if (!CanParallelize(*function, function_library)) continue; + + auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph)); + graph.ReplaceInput(*map_node, *parallel_map); + nodes_to_delete.insert(map_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h new file mode 100644 index 0000000000..ac9cf7e12a --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimization parallelizes MapDataset when function is stateless. +class MapParallelization : public CustomGraphOptimizer { + public: + MapParallelization() = default; + ~MapParallelization() override = default; + + string name() const override { return "map_parallelization"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc new file mode 100644 index 0000000000..9fdfe8af30 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc @@ -0,0 +1,85 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +using graph_tests_utils::MakeMapNode; +const char stateless_fun_name[] = "XTimesTwo"; +const char stateful_fun_name[] = "RandomUniform"; + +TEST(MapParallelizationTest, ParallelizeSimpleMap) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeMapNode("map1", "range", stateless_fun_name)}, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + MapParallelization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + +TEST(MapParallelization, ParallelizeAssert) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeMapNode("map1", "range", stateful_fun_name), + MakeMapNode("map2", "map1", stateless_fun_name), + NDef("cache", "CacheDataset", {"map2", "filename"}, {})}, + // FunctionLib + { + test::function::XTimesTwo(), + test::function::RandomUniform(), + }); + + MapParallelization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index a019b77eb7..32ab912619 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/map_vectorization.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -33,15 +35,11 @@ namespace tensorflow { namespace grappler { namespace { -void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) { - (*to->mutable_attr())[attr_name] = from.attr().at(attr_name); -} - -FunctionDef* AddVectorizedFunction(const NodeDef& map_node, +// Returns a FunctionDef containing a MapDefun op that wraps the original +// function. +FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, const FunctionDef& orig_func, FunctionDefLibrary* library) { - // If we decide to use a different method of vectorization, we can just - // swap out this part. FunctionDef* vectorized_func = library->add_function(); // Function inputs and outputs are the same as original, just // with different shapes. @@ -52,14 +50,14 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, // Add MapDefun node NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add(); map_defun_node->set_op("MapDefun"); - graph_utils::SetUniqueFunctionNodeName(map_defun_node->op(), vectorized_func, - map_defun_node); + function_utils::SetUniqueFunctionNodeName(map_defun_node->op(), + vectorized_func, map_defun_node); // Set attrs and inputs for (const string& k : {"f", "output_types", "output_shapes"}) { // Function, output types and (unbatched) shapes are the same as the // original map node. - CopyAttribute(k, map_node, map_defun_node); + graph_utils::CopyAttribute(k, map_node, map_defun_node); } // Get types of input arguments from original map function @@ -81,6 +79,30 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, return vectorized_func; } +FunctionDef* AddVectorizedFunction(const NodeDef& map_node, + const FunctionDef& orig_func, + FunctionDefLibrary* library) { + // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing + // efficient vectorization with VectorizeMapDefun. + FunctionDef* vectorized_func = + CreateMapDefunWrapper(map_node, orig_func, library); + NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0); + DCHECK_EQ(map_defun_node->op(), "MapDefun"); + + // Create a copy of the original function so that we can mutate it, and + // attach that to the map defun node. + FunctionDef* map_defun_fn = library->add_function(); + *map_defun_fn = orig_func; + graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library, + map_defun_fn); + (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name( + map_defun_fn->signature().name()); + + vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn, + map_defun_node); + return vectorized_func; +} + bool IsOutputShapesFullyDefined(const NodeDef& node) { auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes"); if (shapes_attr == nullptr) return false; @@ -169,13 +191,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node, } // Set attrs - CopyAttribute("Targuments", old_map_node, &map_node); + graph_utils::CopyAttribute("Targuments", old_map_node, &map_node); auto& func_attr = (*map_node.mutable_attr())["f"]; func_attr.mutable_func()->set_name(vectorized_func.signature().name()); for (auto key : {"output_shapes", "output_types"}) { - CopyAttribute(key, old_batch_node, &map_node); + graph_utils::CopyAttribute(key, old_batch_node, &map_node); } + + (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true); + return map_node; } diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc index a26f1000a3..cf5a19bab1 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc @@ -33,25 +33,27 @@ namespace { bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) { if (take_node.op() != "TakeDataset") return false; - const NodeDef& count_node = *graph.GetNode(take_node.input(1)); + const auto& count_node = *graph.GetNode(take_node.input(1)); + if (count_node.op() != "Const") return false; // We are looking only for 'take' with negative count. return count_node.attr().at("value").tensor().int64_val(0) < 0; } +bool IsConstNodeWithValue(const NodeDef& node, int value) { + if (node.op() != "Const") return false; + return node.attr().at("value").tensor().int64_val(0) == value; +} + bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) { if (skip_node.op() != "SkipDataset") return false; - - const NodeDef& count_node = *graph.GetNode(skip_node.input(1)); // We are looking only for skip(0) nodes. - return count_node.attr().at("value").tensor().int64_val(0) == 0; + return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0); } bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) { if (repeat_node.op() != "RepeatDataset") return false; - - const NodeDef& count_node = *graph.GetNode(repeat_node.input(1)); // We are looking only for repeat(1) nodes. - return count_node.attr().at("value").tensor().int64_val(0) == 1; + return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1); } bool IsNoOp(const NodeDef& node, const GraphView& graph) { diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc index f445e75aa7..be1a66df75 100644 --- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc +++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc @@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node, GetCommonAttributes(), graph); } +NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node, + MutableGraphView *graph) { + NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph); + return graph_utils::AddNode("", node_type, + {std::move(input_node), node_count->name()}, + GetCommonAttributes(), graph); +} + NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) { NodeDef *node_filename = graph_utils::AddScalarConstNode<StringPiece>("", graph); @@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P( ::testing::Values(*kTakeNode, *kSkipNode, *kRepeatNode))); +struct NoOpPlaceholdersTest + : ::testing::TestWithParam<std::tuple<string, string>> {}; + +TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) { + GrapplerItem item; + MutableGraphView graph(&item.graph); + + static_assert(std::tuple_size<NodesTypes>::value == 2, + "Make sure to include everything in the test"); + const std::vector<string> noop_nodes = {std::get<0>(GetParam()), + std::get<1>(GetParam())}; + NodeDef *range_node = MakeRangeNode(&graph); + std::vector<string> nodes_to_keep; + nodes_to_keep.reserve(noop_nodes.size()); + NodeDef *previous = range_node; + + for (const auto &noop_node : noop_nodes) { + NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph); + nodes_to_keep.push_back(node->name()); + previous = node; + } + + NoOpElimination optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + for (const auto &noop_node_name : nodes_to_keep) + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output)); +} + +INSTANTIATE_TEST_CASE_P( + DoNotRemovePlaceholders, NoOpPlaceholdersTest, + ::testing::Combine( + ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"), + ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"))); + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index cb0ff670e8..99c4afa634 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster, // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node.mutable_attr())[key] = repeat_node.attr().at(key); + graph_utils::CopyAttribute(key, repeat_node, &new_node); } return new_node; }; diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD new file mode 100644 index 0000000000..1462cb234d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -0,0 +1,69 @@ +package( + default_visibility = ["//visibility:private"], +) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") + +VECTORIZER_DEPS = [ + ":vectorizer_registry", + "//tensorflow/core/grappler/optimizers/data:function_utils", +] + tf_protos_all() + +cc_library( + name = "vectorizer", + hdrs = ["vectorizer.h"], + deps = [ + "//tensorflow/core:lib", + ] + tf_protos_all(), +) + +cc_library( + name = "vectorizer_registry", + srcs = ["vectorizer_registry.cc"], + hdrs = ["vectorizer_registry.h"], + deps = [ + ":vectorizer", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +cc_library( + name = "cast_vectorizer", + srcs = ["cast_vectorizer.cc"], + deps = VECTORIZER_DEPS, + alwayslink = 1, +) + +cc_library( + name = "unpack_vectorizer", + srcs = ["unpack_vectorizer.cc"], + deps = VECTORIZER_DEPS, + alwayslink = 1, +) + +cc_library( + name = "vectorization", + hdrs = ["vectorizer_registry.h"], + visibility = ["//visibility:public"], + deps = [ + ":cast_vectorizer", + ":unpack_vectorizer", + ":vectorizer", + ":vectorizer_registry", + ], +) + +tf_cc_test( + name = "vectorizer_registry_test", + srcs = ["vectorizer_registry_test.cc"], + deps = [ + ":vectorizer_registry", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ] + tf_protos_all(), +) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc new file mode 100644 index 0000000000..c1739737a0 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -0,0 +1,54 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +class CastVectorizer : public Vectorizer { + public: + Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) override { + if (inputs.size() != 1) { + return errors::Internal("Cast op should only have one input."); + } + + // Add new Cast node + NodeDef* new_cast_node = outer_scope->add_node_def(); + *new_cast_node = node; + new_cast_node->clear_name(); + function_utils::SetUniqueFunctionNodeName( + strings::StrCat("vectorized/", node.name()), outer_scope, + new_cast_node); + new_cast_node->set_input(0, inputs[0]); + + // Add the output mapping to conversion map + (*conversion_map)[strings::StrCat(node.name(), ":y:0")] = + strings::StrCat(new_cast_node->name(), ":y:0"); + + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("Cast", CastVectorizer); + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc new file mode 100644 index 0000000000..776d3179c5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -0,0 +1,61 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +class UnpackVectorizer : public Vectorizer { + public: + Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) override { + if (inputs.size() != 1) { + return errors::Internal("Unpack op should only have one input."); + } + + // Add new Unpack node + NodeDef* new_unpack_node = outer_scope->add_node_def(); + *new_unpack_node = node; + new_unpack_node->clear_name(); + function_utils::SetUniqueFunctionNodeName( + strings::StrCat("vectorized/", node.name()), outer_scope, + new_unpack_node); + + // Increment "axis" attr by 1: + (*new_unpack_node->mutable_attr())["axis"].set_i( + node.attr().at("axis").i() + 1); + new_unpack_node->set_input(0, inputs[0]); + + // Add the output mappings to conversion map + int num = new_unpack_node->attr().at("num").i(); + for (int i = 0; i < num; ++i) { + (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] = + strings::StrCat(new_unpack_node->name(), ":output:", i); + } + + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("Unpack", UnpackVectorizer); + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h new file mode 100644 index 0000000000..d341dbba7d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +// Interface for vectorization of TensorFlow operations. See `CastVectorizer` +// for an example. +class Vectorizer { + public: + virtual ~Vectorizer() {} + + // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope` + // that produce the same vector output(s) as executing `node`'s op + // on elements of the vector inputs, and adding mappings to `conversion_map` + // from old output tensor names to new (vectorized) output tensor names. + // The new node(s) collectively have the same number of inputs and outputs as + // the node being converted, and use the tensor names in `inputs` as their + // inputs. + virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) = 0; +}; + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc new file mode 100644 index 0000000000..a6551e36ac --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +VectorizerRegistry* VectorizerRegistry::Global() { + static VectorizerRegistry* registry = new VectorizerRegistry; + return registry; +} + +Vectorizer* VectorizerRegistry::Get(const string& op_type) { + auto found = vectorizers_.find(op_type); + if (found == vectorizers_.end()) { + return nullptr; + } + return found->second.get(); +} + +void VectorizerRegistry::Register(const string& op_type, + std::unique_ptr<Vectorizer> vectorizer) { + auto existing = Get(op_type); + CHECK_EQ(existing, nullptr) + << "Vectorizer for op type: " << op_type << " already registered"; + vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>( + op_type, std::move(vectorizer))); +} +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h new file mode 100644 index 0000000000..16159d47ca --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h @@ -0,0 +1,75 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_ + +#include <functional> +#include <map> + +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +// A global VectorizerRegistry is used to hold all the vectorizers. +class VectorizerRegistry { + public: + // Returns a pointer to a global VectorizerRegistry object. + static VectorizerRegistry* Global(); + + // Returns a pointer to a vectorizer that can vectorize an op for the op type. + Vectorizer* Get(const string& op_type); + + // Registers a vectorizer that can vectorize an op for the given op type. + void Register(const string& op_type, std::unique_ptr<Vectorizer> vectorizer); + + private: + std::map<string, std::unique_ptr<Vectorizer>> vectorizers_; +}; + +namespace vectorizer_registration { + +class VectorizerRegistration { + public: + VectorizerRegistration(const string& op_type, + std::unique_ptr<Vectorizer> vectorizer) { + VectorizerRegistry::Global()->Register(op_type, std::move(vectorizer)); + } +}; + +} // namespace vectorizer_registration + +#define REGISTER_VECTORIZER(op_type, vectorizer) \ + REGISTER_VECTORIZER_UNIQ_HELPER(__COUNTER__, op_type, vectorizer) + +#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \ + REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) + +#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ + static ::tensorflow::grappler::vectorization_utils:: \ + vectorizer_registration::VectorizerRegistration \ + vectorizer_registration_##ctr( \ + op_type, \ + ::std::unique_ptr< \ + ::tensorflow::grappler::vectorization_utils::Vectorizer>( \ + new vectorizer())) + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc new file mode 100644 index 0000000000..86e303564b --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -0,0 +1,50 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +class TestVectorizer : public Vectorizer { + public: + Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, + FunctionDef* outer_scope, + std::map<string, string>* conversion_map) override { + return Status::OK(); + } +}; + +REGISTER_VECTORIZER("test_op", TestVectorizer); + +TEST(TestVectorizer, TestTestVectorizer) { + EXPECT_EQ(VectorizerRegistry::Global()->Get("nonexistent"), nullptr); + + auto vectorizer = VectorizerRegistry::Global()->Get("test_op"); + EXPECT_NE(vectorizer, nullptr); + + FunctionDef function; + NodeDef node; + std::map<string, string> conversion_map; + EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok()); +} + +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc new file mode 100644 index 0000000000..cb56b65985 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -0,0 +1,292 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" + +#include "absl/strings/str_join.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/functions.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/scanner.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +using function_utils::FunctionDefTensorDesc; + +namespace { + +void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node, + const string& output_retval, const DataType t) { + // Set to unknown shape + TensorShapeProto tensor_shape_proto; + PartialTensorShape().AsProto(&tensor_shape_proto); + + function_utils::AddFunctionOutputWithUniqueName( + "vectorized_out", output_retval, map_defun_fn, t); + + *(*map_defun_node->mutable_attr())["output_shapes"] + .mutable_list() + ->add_shape() = tensor_shape_proto; + (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t); +} + +void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node, int output_position) { + DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size()) + << "Trying to remove output that doesn't exist. Output number: " + << output_position; + + int num_later_outputs = + map_defun_fn->signature().output_arg_size() - output_position - 1; + + // Remove from map_defun_fn's ret dict and output args + map_defun_fn->mutable_ret()->erase( + map_defun_fn->signature().output_arg(output_position).name()); + map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange( + output_position, 1); + + // Renumber outputs that come after + for (int i = 0; i < num_later_outputs; ++i) { + function_utils::ReplaceReferences( + strings::StrCat(map_defun_node->name(), + ":output:", output_position + i + 1), + strings::StrCat(map_defun_node->name(), + ":output:", output_position + i), + outer_scope); + } + map_defun_node->mutable_attr() + ->at("output_shapes") + .mutable_list() + ->mutable_shape() + ->DeleteSubrange(output_position, 1); + map_defun_node->mutable_attr() + ->at("output_types") + .mutable_list() + ->mutable_type() + ->ExtractSubrange(output_position, 1, nullptr); +} + +int FindOutputToConvert(const FunctionDef& function, + const std::set<string>& unconvertible, + FunctionDefTensorDesc* f) { + for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) { + const string& ret_key = function.signature().output_arg(i).name(); + *f = FunctionDefTensorDesc(function.ret().at(ret_key)); + + if (unconvertible.find(f->node_name) == unconvertible.end()) { + return i; + } + } + return -1; +} + +// Helper class that vectorizes the body of a MapDefun node, adding new +// operations to the graph that collectively compute the same value as what +// running the MapDefun function on slices of the input would produce. +// Each instance of the class encapsulates all the data necessary to vectorize a +// MapDefun op in place. +class Vectorization { + public: + Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node) + : outer_scope_(outer_scope), + map_defun_fn_(map_defun_fn), + map_defun_node_(map_defun_node) {} + + // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in + // the outer_scope_, until there are no convertible outputs remaining. + // This method is idempotent. + void Vectorize(); + + private: + // Vectorizes the map defun function's output at output_position + Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc); + // Given a descriptor of the original output tensor, gets a string + // corresponding to the converted output tensor. + Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc, + string* converted); + Status AddConversionMappingFromInput( + const FunctionDefTensorDesc& output_desc); + + // Adds mappings from node's outputs tensors to converted output tensors, + // creating the necessary new node(s). Generally, the steps to convert an op + // are: + // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_, + // and modify map_defun_node_ attrs accordingly + // 2) Create new node(s) in outer_scope_ that act on batched input tensors. + // These operations collectively compute the same value as what running + // the original operation on slices of the input tensors would produce. + // For example, a Cast op in MapDefun translates to a Cast op in + // outer_scope_, since the vectorized version of Cast is itself. + // 3) Set inputs of new node(s) to the corresponding converted inputs (that + // are now outputs of map_defun_node_) + // 4) For each output of the old node, add the mapping of output strings to + // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0") + Status AddConversionMappingFromOp(const NodeDef& node, + const FunctionDefTensorDesc& output_desc); + + // Maps a tensor name to the name of the corresponding vectorized tensor. For + // example, "Cast:y:0" -> "Vectorize/Cast:y:0" + std::map<string, string> conversion_map_; + // Unconvertible node names + std::set<string> unconvertible_; + + FunctionDef* outer_scope_; + FunctionDef* map_defun_fn_; + NodeDef* map_defun_node_; +}; + +Status Vectorization::AddConversionMappingFromOp( + const NodeDef& node, const FunctionDefTensorDesc& output_desc) { + for (const string& input_name : node.input()) { + if (IsControlInput(input_name)) { + return errors::InvalidArgument( + "Vectorizing outputs with control inputs is currently not " + "supported."); + } + } + + // TODO(rachelim): Have some mechanism for registering converters and some + // uniform, simpler way to represent them. + + DataTypeVector types; + const OpDef* op_def = nullptr; + TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); + TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types)); + + std::vector<string> promoted_inputs; + promoted_inputs.reserve(node.input_size()); + for (int i = 0; i < node.input_size(); ++i) { + promoted_inputs.push_back(strings::StrCat( + map_defun_node_->name(), + ":output:", map_defun_fn_->signature().output_arg_size() + i)); + } + + auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); + if (vectorizer == nullptr) { + return errors::Unimplemented("No vectorizer registered for op: ", + node.op()); + } + + TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, + &conversion_map_)); + + // If we get here, the conversion was successful, so we promote the inputs + // of the ops to MapDefun outputs. + for (int i = 0; i < types.size(); ++i) { + AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]); + } + + return Status::OK(); +} + +Status Vectorization::AddConversionMappingFromInput( + const FunctionDefTensorDesc& output_desc) { + int input_index = function_utils::FindFunctionInputWithName( + output_desc.node_name, *map_defun_fn_); + if (input_index == -1) { + return errors::Internal("Cannot convert non-existent input."); + } + + conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index); + return Status::OK(); +} + +Status Vectorization::ConvertOutputHelper( + const FunctionDefTensorDesc& output_desc, string* converted) { + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) { + *converted = *found; + return Status::OK(); + } + + int index = function_utils::FindFunctionNodeWithName(output_desc.node_name, + *map_defun_fn_); + if (index == -1) { // The output comes from an input + TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc)); + } else { + TF_RETURN_IF_ERROR(AddConversionMappingFromOp( + map_defun_fn_->node_def(index), output_desc)); + } + *converted = conversion_map_.at(output_desc.full_str); + return Status::OK(); +} + +Status Vectorization::ConvertOutput(int output_position, + const FunctionDefTensorDesc& output_desc) { + string converted_output_name; + TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name)); + + // Remove the old output and make everything that referenced it point + // to the new string + function_utils::ReplaceReferences( + strings::StrCat(map_defun_node_->name(), ":output:", output_position), + converted_output_name, outer_scope_); + RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_, + output_position); + + return Status::OK(); +} + +void Vectorization::Vectorize() { + while (true) { + FunctionDefTensorDesc desc; + int output_position = + FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc); + if (output_position == -1) break; + + if (!ConvertOutput(output_position, desc).ok()) { + unconvertible_.insert(desc.node_name); + } + } + + // If we've converted all the outputs of the MapDefun function, we no longer + // need the MapDefun node and can delete it. + if (map_defun_fn_->signature().output_arg_size() == 0) { + outer_scope_->mutable_node_def()->DeleteSubrange( + function_utils::FindFunctionNodeWithName(map_defun_node_->name(), + *outer_scope_), + 1); + } + + if (!unconvertible_.empty()) { + VLOG(2) << "The following nodes could not be converted: [" + << absl::StrJoin(unconvertible_, ", ") << "]."; + } +} +} // namespace + +void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node) { + Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); +} + +} // end namespace vectorization_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h new file mode 100644 index 0000000000..bb405faa77 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h @@ -0,0 +1,90 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_ + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { + +// Given a function, `map_defun_fn`, that is mapped across some input vector +// elements via a MapDefun operation, `VectorizeMapDefun` attempts to +// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the +// `outer_scope`; that is, replacing `map_defun_fn` operations with new +// `outer_scope` operations that produce the same vector output(s) as executing +// the `map_defun_fn` operations on elements of vector input(s) would. If all +// `map_defun_fn` operations are successfully lifted, `map_defun_node` is +// eliminated from `outer_scope` altogether. However, if some operations cannot +// be lifted, and this vectorization only succeeds partially, `map_defun_node` +// remains to be used for operations that were not lifted. +// +// Example: +// If the input to the `VectorizeMapDefun` function is a MapDefun +// whose `map_defun_fn` performs the Cast operation, the vectorization will +// eliminate the MapDefun. This is because the Cast operation supports +// any tensor shape and can thus be lifted to the `outer_scope`. +// +// Before: +// +// +// outer_scope +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | map_defun_fn +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// outer_scope +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, + NodeDef* map_defun_node); + +} // end namespace vectorization_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc new file mode 100644 index 0000000000..e129fa9237 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -0,0 +1,600 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h" + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/tools/graph_transforms/transform_utils.h" + +namespace tensorflow { +namespace grappler { +namespace vectorization_utils { +namespace { + +NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs, + DataType src, DataType dst, bool truncate, + FunctionDef* fn) { + NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn); + graph_transforms::SetNodeAttr("SrcT", src, node); + graph_transforms::SetNodeAttr("DstT", dst, node); + graph_transforms::SetNodeAttr("Truncate", truncate, node); + return node; +} + +NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs, + DataType t, int axis, int num, FunctionDef* fn) { + NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn); + graph_transforms::SetNodeAttr("T", t, node); + graph_transforms::SetNodeAttr("axis", axis, node); + graph_transforms::SetNodeAttr("num", num, node); + return node; +} + +NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, + const std::vector<DataType>& t_arguments, + const std::vector<DataType>& output_types, + const std::vector<TensorShape>& output_shapes, + const string& function_name, FunctionDef* fn) { + NameAttrList func; + func.set_name(function_name); + NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); + graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("output_types", output_types, node); + graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); + graph_transforms::SetNodeAttr("f", func, node); + return node; +} + +// TODO(rachelim): Use FunctionDefHelper::Create instead +FunctionDef CreateFunction( + StringPiece name, const std::vector<std::pair<string, DataType>>& inputs, + const std::vector<std::pair<string, DataType>>& outputs, + const std::map<string, string>& rets) { + FunctionDef func; + auto* signature = func.mutable_signature(); + signature->set_name(string(name)); + for (const auto& x : inputs) { + auto* arg_def = signature->add_input_arg(); + arg_def->set_name(x.first); + arg_def->set_type(x.second); + } + for (const auto& x : outputs) { + auto* arg_def = signature->add_output_arg(); + arg_def->set_name(x.first); + arg_def->set_type(x.second); + } + for (const auto& x : rets) { + (*func.mutable_ret())[x.first] = x.second; + } + + return func; +} + +TEST(FunctionDefInputDescTest, ConstructedCorrectly) {} + +// Before: +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// | +-----------+ Arg0 +---+ Arg1 +----+ | +// | | +---+--+ +---+--+ | | +// | | | | | | +// | | MapDefun +---v--+ +---v--+ | | +// | +-----------+ Ret0 +---+ Ret1 +----+ | +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +// +// After: +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | | | | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, + {{"ret0", "arg0"}, {"ret1", "arg1"}}); + FunctionDef outer = CreateFunction( + "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, + {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}, + {{}, {}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + EXPECT_EQ(outer.ret().at("mapdefun"), "ret0"); + EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1"); +} + +// Before: +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// | +-----------+ Arg0 +---+ Arg1 +----+ | +// | | +---+--+ +---+--+ | | +// | | | | | | +// | | +------+ | +---v--+ | | +// | | |Const | | | Op0 | | | +// | | +---v--+ | +---+--+ | | +// | | | | | | | +// | | | +---v--+ +---v--+ | | +// | | +---| XOp1 | | XOp2 | | | +// | | +---+--+ +---+--+ | | +// | | | | | | +// | | MapDefun +---v--+ +---v--+ | | +// | +-----------+ Ret0 +---+ Ret1 +----+ | +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +// where XOp1 and XOp2 are not convertible. +// +// After: +// +// No change because the ops are not convertible. +// +TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, + {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}}); + NodeDef* x_op1 = + function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner); + CHECK_NOTNULL(x_op1); + + NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner); + CHECK_NOTNULL(x_op2); + + FunctionDef outer = CreateFunction( + "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, + {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32}, + {{}, {}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + FunctionDef outer_copy(outer); + FunctionDef inner_copy(inner); + VectorizeMapDefun(&outer, &inner, map_defun); + // They should be unchanged + EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); + EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); + NodeDef* cast_op = + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + CHECK_NOTNULL(cast_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT64}}, + {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& cast_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(outer.node_def_size(), 1); +} + +// Before: +// +// +------+ +// +---------------+ Arg0 +-------------------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +---------------+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | +----------+ | | +// | | | | | | +// | | MapDefun +---v--+ +---v--+ | | +// | +-----------+ Ret0 +---+ Ret1 +----+ | +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +-------------------+ +// | +---+--+ | +// | | | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +----------+ | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { + // Tests that behavior is correct when an output is used more than once. + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT64}, {"ret1", DT_INT64}}, + {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}}); + NodeDef* cast_op = + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + CHECK_NOTNULL(cast_op); + + FunctionDef outer = CreateFunction( + "outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}}, + {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64}, + {{}, {}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& cast_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + EXPECT_EQ(cast_node.input(0), "x"); + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(outer.ret().at("mapdefun_0"), + strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(outer.node_def_size(), 1); +} + +// Before: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +------------------+ Arg0 +------------------+ | +// | | +---+--+ | | +// | | | | | +// | | | | | +// | | +---v---+ num=3 | | +// | | |Unstack| axis=0 | | +// | | ++--+--++ | | +// | | | | | | | +// | | +----+ | +-------+ | | +// | | | | | | | +// | | MapDefun +---v--+ +-v----+ +--v---+ | | +// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ | +// | +---+--+ +--+---+ +--+---+ | +// | | | | | +// | +---v--+ +--v---+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +// +// After: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v---+ num=3 | +// | |Unstack| axis=1 | +// | ++--+--++ | +// | | | | | +// | +----+ | +-------+ | +// | | | | | +// | | | | | +// | +---v--+ +-v----+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { + FunctionDef inner = CreateFunction( + "inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}}, + {{"ret0", "MyUnstack:output:0"}, + {"ret1", "MyUnstack:output:1"}, + {"ret2", "MyUnstack:output:2"}}); + NodeDef* unstack_op = + AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner); + CHECK_NOTNULL(unstack_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT32}, + {"mapdefun_0", DT_INT32}, + {"mapdefun_1", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, + {"mapdefun_0", "MapDefun:output:1"}, + {"mapdefun_1", "MapDefun:output:2"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32}, + {{1}, {1}, {1}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& unpack_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + EXPECT_EQ(unpack_node.input(0), "x"); + EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); + EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); + EXPECT_EQ(unpack_node.attr().at("num").i(), 3); + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(unpack_node.name(), ":output:0")); + EXPECT_EQ(outer.ret().at("mapdefun_0"), + strings::StrCat(unpack_node.name(), ":output:1")); + EXPECT_EQ(outer.ret().at("mapdefun_1"), + strings::StrCat(unpack_node.name(), ":output:2")); + EXPECT_EQ(outer.node_def_size(), 1); +} + +// Before: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +------------------+ Arg0 +------------------+ | +// | | +---+--+ | | +// | | | | | +// | | +---+--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | +---v---+ num=3 | | +// | | |Unstack| axis=0 | | +// | | ++--+--++ | | +// | | | | | | | +// | | +----+ | +-------+ | | +// | | | | | | | +// | | MapDefun +---v--+ +-v----+ +--v---+ | | +// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ | +// | +---+--+ +--+---+ +--+---+ | +// | | | | | +// | +---v--+ +--v---+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +// +// After: +// +// +------+ +// +----------------------+ Arg0 +----------------------+ +// | +---+--+ | +// | | | +// | +---+--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v---+ num=3 | +// | |Unstack| axis=1 | +// | ++--+--++ | +// | | | | | +// | +----+ | +-------+ | +// | | | | | +// | | | | | +// | +---v--+ +-v----+ +--v---+ | +// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+ +// +------+ +------+ +------+ +// +TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { + FunctionDef inner = CreateFunction( + "inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}}, + {{"ret0", "MyUnstack:output:0"}, + {"ret1", "MyUnstack:output:1"}, + {"ret2", "MyUnstack:output:2"}}); + NodeDef* cast_op = + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + CHECK_NOTNULL(cast_op); + NodeDef* unstack_op = + AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); + CHECK_NOTNULL(unstack_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT32}, + {"mapdefun_0", DT_INT32}, + {"mapdefun_1", DT_INT32}}, + {{"mapdefun", "MapDefun:output:0"}, + {"mapdefun_0", "MapDefun:output:1"}, + {"mapdefun_1", "MapDefun:output:2"}}); + + NodeDef* map_defun = AddMapDefunNode( + "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32}, + {{1}, {1}, {1}}, inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + VectorizeMapDefun(&outer, &inner, map_defun); + EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); + const NodeDef& cast_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + EXPECT_EQ(cast_node.input(0), "x"); + const NodeDef& unpack_node = + outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); + EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); + EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); + EXPECT_EQ(unpack_node.attr().at("num").i(), 3); + + EXPECT_EQ(outer.ret().at("mapdefun"), + strings::StrCat(unpack_node.name(), ":output:0")); + EXPECT_EQ(outer.ret().at("mapdefun_0"), + strings::StrCat(unpack_node.name(), ":output:1")); + EXPECT_EQ(outer.ret().at("mapdefun_1"), + strings::StrCat(unpack_node.name(), ":output:2")); + EXPECT_EQ(outer.node_def_size(), 2); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +---+--+ | | +// | | +---------+ | | +// | | +---v--+ | | | +// | | |Print | | | | +// | | +---+--+ | | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// No change because we don't deal with control inputs for now. +// +TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { + FunctionDef inner = + CreateFunction("inner_function", {{"arg0", DT_INT32}}, + {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); + // The attrs aren't relevant + NodeDef* print_op = + function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner); + CHECK_NOTNULL(print_op); + NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, + false, &inner); + CHECK_NOTNULL(cast_op); + + FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}}, + {{"mapdefun", DT_INT64}}, + {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + CHECK_NOTNULL(map_defun); + + FunctionDef outer_copy(outer); + FunctionDef inner_copy(inner); + VectorizeMapDefun(&outer, &inner, map_defun); + // They should be unchanged + EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); +} + +// TODO(rachelim): More test cases when we get around to implementing them: +// [] A badly defined converter, e.g. doesn't produce nodes that have the +// same number of outputs/inputs as the nodes to be converted +// [] Converter where the 'converted' form has multiple nodes. +// [] Case with dependent nodes, e.g. ops with const inputs that are +// broadcasted. +// [] Python-side tests to actually run the functions to make sure +// they work. + +} // namespace +} // namespace vectorization_utils +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc index 9701a038d0..800160e649 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc @@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, // be optimized away by dependency optimizer. for (string& inp : *node.mutable_input()) { if (!IsControlInput(inp)) { - inp = AsControlDependency(inp); + inp = AsControlDependency(NodeName(inp)); } } } else if (IsCheckNumerics(node) || IsPrint(node)) { @@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item, // input. for (size_t i = 1; i < node.input_size(); ++i) { if (!IsControlInput(node.input(i))) { - *node.mutable_input(i) = AsControlDependency(node.input(i)); + *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i))); } } } diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc index 96ceee791f..affd2d51c2 100644 --- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc +++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc @@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) { CompareGraphs(item.graph, output); } +TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape({6})); + auto split = + ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2); + Output x = split[0]; + Output y = split[1]; + Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y); + auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y}); + Output add = ops::Add( + s.WithOpName("add").WithControlDependencies({assert.operation}), x, y); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + DebugStripper optimizer; + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + for (const NodeDef& node : output.node()) { + for (const string& input : node.input()) { + if (IsControlInput(input)) { + EXPECT_EQ(input.find(':'), -1); + } + } + } +} + TEST_F(DebugStripperTest, StripAssertFromGraph) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc new file mode 100644 index 0000000000..2c36c9b7b3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include <string> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +Status ExperimentalImplementationSelector::LoadFunctions( + const GraphDef& graph) { + lib_info_.reset(new FunctionLibraryApiInfo); + TF_RETURN_IF_ERROR(lib_info_->Init(graph.library())); + return Status::OK(); +} + +Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall( + NodeDef* node_def) const { + // There are two ways of calling functions: + // 1. By specifying an op name as a function name, or + // 2. Via the @defun functional interface, where the real function name + // appear as the attribute with type func. + std::vector<string> function_attribute_names; + for (const auto& attr : node_def->attr()) { + if (attr.second.has_func() && + lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) { + function_attribute_names.emplace_back(attr.first); + } + } + + if (function_attribute_names.empty() && + lib_info_->GetApiInfo(node_def->op()) == nullptr) { + // A regular op, or a function which has no interface. + return Status::OK(); + } + + string task, device; + if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) { + return errors::Internal("Could not split device name:", node_def->device()); + } + VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device() + << " = (" << task << ", " << device << ")"; + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseLocalName(device, &parsed_name); + + for (const auto& attr_name : function_attribute_names) { + string function_name = node_def->attr().at(attr_name).func().name(); + string best_function_name; + lib_info_->GetBestImplementation(function_name, parsed_name.type, + &best_function_name); + if (function_name != best_function_name) { + node_def->mutable_attr() + ->find(attr_name) + ->second.mutable_func() + ->set_name(best_function_name); + } + } + if (lib_info_->GetApiInfo(node_def->op()) != nullptr) { + string best_function_name; + lib_info_->GetBestImplementation(node_def->op(), parsed_name.type, + &best_function_name); + if (node_def->op() != best_function_name) { + node_def->set_op(best_function_name); + } + } + return Status::OK(); +} + +Status ExperimentalImplementationSelector::SelectImplementation( + GraphDef* graph) const { + for (int k = 0; k < graph->node_size(); ++k) + TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k))); + + return Status::OK(); +} + +Status ExperimentalImplementationSelector::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph)); + return SelectImplementation(optimized_graph); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h new file mode 100644 index 0000000000..82f7473a14 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ + +#include <string> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +// -- EXPERIMENTAL -- +// This transformation replaces function calls by the appropriate function +// definition based on properties of the runtime system. For instance, +// we may choose one implementation over another if we have a GPU with +// enough memory available. +// +// It is a way for the programmer to specify alternative implementations +// of the same functionality in the graph, and let TensorFlow pick the +// most appropriate one at runtime. +// +// For instance, the python code might specify: +// @Defun(tf.float32, +// experimental_api_implements='plus_one', +// experimental_api_preferred_device='GPU') +// def plus_one_gpu(x): return x + 1.0 +// +// @Defun(tf.float32, +// experimental_api_implements='plus_one') +// def plus_one_reference_implementation(x): return x + 1.0 +// input = tf.constant(2.0, dtype=tf.float32) +// +// z = plus_one_reference_implementation(input) +// z = plus_one_gpu(input) +// print(sess.run(z)) +// +// At runtime, we will trim either `plus_one_gpu` or +// `plus_one_reference_implementation` based on the availability of the GPU. +// +// Available annotations: +// - experimental_api_implements(string): all functions mapping to the same +// string can be interchanged. For now, all functions must have the same +// signature and overloads are not allowed. Defuns within defuns are +// allowed. +// - experimental_api_preferred_device(string): sets which device is preferred. +class ExperimentalImplementationSelector : public CustomGraphOptimizer { + public: + ExperimentalImplementationSelector() = default; + ~ExperimentalImplementationSelector() override = default; + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + string name() const override { + return "experimental_implementation_selector"; + } + + // This call is not thread-safe. + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + // Does not take any feedback. + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + Status LoadFunctions(const GraphDef& graph); + Status MaybeOptimizeFunctionCall(NodeDef* node_def) const; + + // Finds all call sites for functions, then replace with the appropriate + // implementation. + // There are two ways of calling functions: + // 1. By specifying an op name as a function name, and + // 2. Via the functional interface, where the function name appears as an + // Attr. + // + // There may be multiple call sites for a given function. The function body + // may call into another function, so a function might have to be duplicated. + // For simplicity, we do not change function bodies. Also, we do not change + // gradients. + Status SelectImplementation(GraphDef* graph) const; + + std::unique_ptr<FunctionLibraryApiInfo> lib_info_; + + TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector); +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc new file mode 100644 index 0000000000..3f1ebefac6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char CpuDevice[] = "/device:CPU:0"; +constexpr char GpuDevice[] = "/device:GPU:0"; + +class ExperimentalImplementationSelectorTest : public GrapplerTest {}; + +TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + std::unique_ptr<CustomGraphOptimizer> optimizer( + new ExperimentalImplementationSelector); + ASSERT_NE(nullptr, optimizer); + TF_ASSERT_OK(optimizer->Init()); + + GraphDef output; + const Status status = optimizer->Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // This is a trivial graph so there is nothing to update. + EXPECT_EQ(item.graph.node_size(), output.node_size()); +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("times_two"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XAddX(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("times_two"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice), + NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(output.node_size(), 5); + for (const NodeDef& node : output.node()) { + if (node.name() == "y1") { + // Make sure the implementation has been swapped to use the GPU version. + EXPECT_EQ("XAddX", node.op()); + } else if (node.name() == "y2") { + // Make sure the implementation is not changed. + EXPECT_EQ("XTimesTwo", node.op()); + } + } +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("random_boost"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XTimesFour(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("random_boost"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice), + NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + const Tensor input = test::AsScalar<float>(1.0f); + item.fetch = {"z"}; + item.feed.emplace_back("x", input); + + const auto four_times_boosted_tensor = EvaluateFetchNodes(item); + test::ExpectTensorEqual<float>(four_times_boosted_tensor[0], + test::AsScalar<float>(4.0f)); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + GrapplerItem optimized(item, std::move(output)); + const auto twice_boosted_tensor = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual<float>(twice_boosted_tensor[0], + test::AsScalar<float>(2.0f)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc new file mode 100644 index 0000000000..798e0f6fd5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_api_info.h" + +#include <string> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +FunctionApiInfo::FunctionApiInfo() {} +FunctionApiInfo::~FunctionApiInfo() {} + +Status FunctionApiInfo::Init(const FunctionDef& function_def) { + for (const auto& attr : function_def.attr()) { + if (attr.first == "experimental_api_preferred_device") { + preferred_device_ = attr.second.s(); + } + if (attr.first == "experimental_api_implements") { + interface_name_ = attr.second.s(); + } + } + if (interface_name_.empty() && !preferred_device_.empty()) { + return errors::InvalidArgument( + "Function '", function_def.signature().name(), + "' has a preferred device, but does not implement an interface"); + } + return Status::OK(); +} + +const string& FunctionApiInfo::preferred_device() const { + return preferred_device_; +} + +const string& FunctionApiInfo::interface_name() const { + return interface_name_; +} + +FunctionLibraryApiInfo::FunctionLibraryApiInfo() {} +FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {} + +namespace { +bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) { + if (f1.ret().size() != f2.ret().size()) return false; + const auto& sig1 = f1.signature(); + const auto& sig2 = f2.signature(); + // Functions have positional semantics, so we don't check for names. + if (sig1.input_arg_size() != sig2.input_arg_size()) return false; + for (int k = 0; k < sig1.input_arg_size(); ++k) { + const OpDef::ArgDef& arg1 = sig1.input_arg(k); + const OpDef::ArgDef& arg2 = sig2.input_arg(k); + if (arg1.type() != arg2.type()) return false; + if (arg1.type_attr() != arg2.type_attr()) return false; + if (arg1.number_attr() != arg2.number_attr()) return false; + if (arg1.type_list_attr() != arg2.type_list_attr()) return false; + if (arg1.is_ref() != arg2.is_ref()) return false; + } + return true; +} + +Status ValidateSignature(const string& interface_name, + const std::vector<const FunctionDef*>& equiv_funcs) { + if (equiv_funcs.size() < 2) return Status::OK(); + for (size_t k = 1; k < equiv_funcs.size(); ++k) { + if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k])) + return errors::InvalidArgument( + "Functions '", equiv_funcs[0]->signature().name(), "' and '", + equiv_funcs[k]->signature().name(), "' both implement '", + interface_name, "' but their signatures do not match."); + } + return Status::OK(); +} + +Status ValidateSignatures( + const std::unordered_map<string, std::vector<const FunctionDef*>>& + intf_to_func) { + for (const auto& item : intf_to_func) + TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second)); + return Status::OK(); +} +} // namespace + +Status FunctionLibraryApiInfo::Init( + const FunctionDefLibrary& function_library) { + std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func; + for (const auto& function : function_library.function()) { + std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo); + TF_RETURN_IF_ERROR(func_info->Init(function)); + // Ignore the function if it does not implement any interface. + if (func_info->interface_name().empty()) continue; + + const string& function_name = function.signature().name(); + const string& interface_name = func_info->interface_name(); + func_to_intf_[function_name] = interface_name; + intf_to_funcs_[interface_name].emplace_back(function_name); + intf_to_func[interface_name].emplace_back(&function); + func_info_[function_name] = std::move(func_info); + } + TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func)); + return Status::OK(); +} + +void FunctionLibraryApiInfo::GetEquivalentImplementations( + const string& function_name, std::vector<string>* other_names) const { + const auto intf_it = func_to_intf_.find(function_name); + // The function does not implement any interface. + if (intf_it == func_to_intf_.end()) return; + CHECK(!intf_it->second.empty()) << "Function " << function_name + << "should at least implement 1 interface."; + const auto it = intf_to_funcs_.find(intf_it->second); + CHECK(it != intf_to_funcs_.end()) + << "Function " << function_name << " maps to " << intf_it->second + << " but no reverse mapping was found"; + CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty"; + other_names->reserve(it->second.size() - 1); + for (const auto& other_name : it->second) { + if (other_name == function_name) continue; + other_names->emplace_back(other_name); + } +} + +void FunctionLibraryApiInfo::GetBestImplementation( + const string& function_name, const string& device, + string* best_func_name) const { + CHECK(best_func_name != nullptr); + const auto func_it = func_to_intf_.find(function_name); + if (func_it == func_to_intf_.end()) return; + + const auto it = intf_to_funcs_.find(func_it->second); + // No function found for the given interface. + if (it == intf_to_funcs_.end()) return; + for (const auto& func_name : it->second) { + const auto func_api_info = func_info_.find(func_name)->second.get(); + if (func_api_info->preferred_device() == device) { + best_func_name->assign(func_name); + return; + } + } + // Didn't find a function with the match device name, choose the first one + // among all the available functions. + best_func_name->assign(it->second.front()); +} + +const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo( + const string& function_name) const { + const auto it = func_info_.find(function_name); + if (it == func_info_.end()) return nullptr; + return it->second.get(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h new file mode 100644 index 0000000000..412687c58c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +class FunctionApiInfo { + public: + FunctionApiInfo(); + virtual ~FunctionApiInfo(); + + Status Init(const FunctionDef& function_def); + + const string& interface_name() const; + const string& preferred_device() const; + + private: + string interface_name_; + string preferred_device_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo); +}; + +// A collection of information for function and the interface it implements. +// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple +// functions could implement the same interface with different behavior based on +// different hardware condition and limits, +// eg F1 = math_ops.add(math_ops.add(x, x), y), or +// F2 = math_ops.add(math_ops.matmul(x, 2), y). +class FunctionLibraryApiInfo { + public: + FunctionLibraryApiInfo(); + virtual ~FunctionLibraryApiInfo(); + // Populate the internal field for the functions within the function_library. + Status Init(const FunctionDefLibrary& function_library); + + void GetEquivalentImplementations(const string& function_name, + std::vector<string>* other_names) const; + + void GetBestImplementation(const string& function_name, const string& device, + string* best_func_name) const; + + const FunctionApiInfo* GetApiInfo(const string& function_name) const; + + private: + // Map between function name to function details. + std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_; + // Map between function name to interface name. + std::unordered_map<string, string> func_to_intf_; + // Map between interface name to function names. + std::unordered_map<string, std::vector<string>> intf_to_funcs_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc new file mode 100644 index 0000000000..582890d3e3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_api_info.h" + +#include <string> +#include <unordered_set> +#include <vector> + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { +void SetArg(const string& name, const string& type_name, + OpDef::ArgDef* arg_def) { + arg_def->set_name(name); + arg_def->set_type_attr(type_name); +} + +typedef std::pair<string, string> ArgSpec; // name, type. + +void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) { + for (const auto& arg_spec : args_spec) + SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg()); + SetArg("output", "float32", sig->add_output_arg()); +} + +void PopulateFunction(const string& name, const string& api_interface_name, + const string& preferred_device, + const std::vector<ArgSpec>& input_args, + FunctionDef* func_def) { + OpDef* sig = func_def->mutable_signature(); + sig->set_name(name); + + SetArgs(input_args, sig); + + if (!api_interface_name.empty() || !preferred_device.empty()) { + auto* func_attr = func_def->mutable_attr(); + if (!api_interface_name.empty()) + (*func_attr)["experimental_api_implements"].set_s(api_interface_name); + if (!preferred_device.empty()) + (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device); + } +} + +void PopulateSampleLibrary(const bool mismatch_args, + FunctionDefLibrary* func_lib) { + const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}}; + const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"}, + {"in2", "int32"}}; + PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args, + func_lib->add_function()); + PopulateFunction("DoStuffGpu", "DoStuff", "GPU", + mismatch_args ? func_wrong_args : func_args, + func_lib->add_function()); + PopulateFunction("DoThings", "DoThings", "", func_args, + func_lib->add_function()); + PopulateFunction("OneOff", "", "", func_args, func_lib->add_function()); + PopulateFunction("AnotherOneOff", "", "", func_args, + func_lib->add_function()); +} + +bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name, + const std::vector<string>& expected_other) { + std::vector<string> other_impl; + lib_api_info.GetEquivalentImplementations(func_name, &other_impl); + const std::unordered_set<string> actual(other_impl.begin(), other_impl.end()); + const std::unordered_set<string> expected(expected_other.begin(), + expected_other.end()); + return actual == expected; +} + +bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info, + const string& function_name, const string& device, + const string& expected_function_name) { + string best_function_name; + lib_api_info.GetBestImplementation(function_name, device, + &best_function_name); + + return best_function_name == expected_function_name; +} + +string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name) { + auto* info = lib_api_info.GetApiInfo(func_name); + CHECK_NOTNULL(info); + return info->interface_name(); +} + +string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name) { + auto* info = lib_api_info.GetApiInfo(func_name); + CHECK_NOTNULL(info); + return info->preferred_device(); +} + +TEST(FunctionApiInfoTest, ParseTags) { + FunctionDefLibrary func_lib; + PopulateSampleLibrary(/* mismatch_args */ false, &func_lib); + FunctionLibraryApiInfo lib_api_info; + TF_ASSERT_OK(lib_api_info.Init(func_lib)); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {})); + + EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu")); + EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu")); + EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings")); + + EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu")); + EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu")); + EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings")); + + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu")); + + EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings")); + // TPU impl is not available, choose the first one available which is the CPU. + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu")); +} + +TEST(FunctionApiInfoTest, MismatchedArguments) { + FunctionDefLibrary func_lib; + PopulateSampleLibrary(/* mismatch_args */ true, &func_lib); + FunctionLibraryApiInfo lib_api_info; + const Status ret = lib_api_info.Init(func_lib); + EXPECT_FALSE(ret.ok()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 5fd34efeb1..c59645e5f2 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -23,11 +23,13 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" #include "tensorflow/core/grappler/optimizers/memory_optimizer.h" #include "tensorflow/core/grappler/optimizers/model_pruner.h" +#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" #include "tensorflow/core/grappler/optimizers/remapper.h" #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h" #include "tensorflow/core/grappler/optimizers/shape_optimizer.h" @@ -72,6 +74,16 @@ bool IsRunOnceOptimizer(const string& name) { name == "loop_optimizer"; } +// Check if the graphdef contains nodes that indicate TPU execution. +bool IsTPUGraphDef(const GraphDef& def) { + for (auto node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} + } // namespace #define MK_OPT(NAME, VALUE) \ @@ -94,6 +106,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( MK_OPT("scoped_allocator", new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); + MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); return std::unique_ptr<GraphOptimizer>(); } @@ -122,6 +135,9 @@ Status MetaOptimizer::InitializeOptimizers( if (cfg_.remapping() != RewriterConfig::OFF) { optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping())); } + if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) { + optimizers->push_back(MakeUnique<PinToHostOptimizer>()); + } if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) { optimizers->push_back( MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization())); @@ -156,7 +172,7 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } - return Status::OK(); + return InitializeCustomGraphOptimizers(optimizers); } Status MetaOptimizer::InitializeOptimizersByName( @@ -180,15 +196,40 @@ Status MetaOptimizer::InitializeOptimizersByName( VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } + return InitializeCustomGraphOptimizers(optimizers); +} + +Status MetaOptimizer::InitializeCustomGraphOptimizers( + std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { for (const auto& optimizer_config : cfg_.custom_optimizers()) { - auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull( - optimizer_config.name()); + // Initialize the ExperimentalImplementationSelector here instead of + // CustomizeOptimizer registry, due the static link issue in TensorRT for + // double registry. + // TODO(laigd): Remove this hack and change it back to use the registry once + // the duplicate static import issue is fixed. + std::unique_ptr<CustomGraphOptimizer> custom_optimizer; + if (optimizer_config.name() == "ExperimentalImplementationSelector") { + custom_optimizer.reset(new ExperimentalImplementationSelector()); + } else { + custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull( + optimizer_config.name()); + } if (custom_optimizer) { VLOG(2) << "Registered custom configurable graph optimizer: " << optimizer_config.name(); TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config)); optimizers->push_back(std::move(custom_optimizer)); } else { + // If there are no custom optimizers with given name, try to initalize a + // default optimizer. This way, custom configurable optimizers can be + // mixed with default optimizers in any order. + auto optimizer = MakeNewOptimizer(optimizer_config.name()); + if (optimizer) { + VLOG(2) << "Registered default graph optimizer: " + << optimizer_config.name(); + optimizers->push_back(std::move(optimizer)); + continue; + } VLOG(2) << "Can't register an optimizer by name: " << optimizer_config.name(); } @@ -208,7 +249,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, } std::vector<std::unique_ptr<GraphOptimizer>> optimizers; - if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) { + if (cfg_.optimizers().empty()) { TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers)); } else { TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers)); @@ -326,10 +367,25 @@ Status MetaOptimizer::RunOptimizer( Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + VLOG(1) << "Starting optimization for grappler item: " << item.id; optimization_results_.clear(); // 1. Optimize main graph TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); + VLOG(1) << "Optimized main graph."; + + // Skip optimizing functions if this is a TPU graph. Currently, Grappler + // passes do not handle TPU functions correctly in a variety of ways (Note + // that due to the pre-placement TPU graph rewriting passes, the TPU-related + // ops are encapsulated away into functions). For example, TPU graphs contain + // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler + // passes could prune that away. Grappler passes could also cause issues + // around shape inference. Since the desired and existing behavior is to not + // optimize TPU functions with Grappler, this check preserves that. + if (IsTPUGraphDef(*optimized_graph)) { + VLOG(2) << "Skipping optimizing funcs for TPU graphs"; + return Status::OK(); + } // 2. Optimize function library FunctionLibraryDefinition flib(OpRegistry::Global(), @@ -393,7 +449,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } - VLOG(3) << "Optimized " << optimized_funcs.size() + VLOG(1) << "Optimized " << optimized_funcs.size() << " functions: " << str_util::Join(optimized_funcs, ", "); return Status::OK(); @@ -427,6 +483,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) { cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT || cfg.debug_stripper() == RewriterConfig::ON || cfg.scoped_allocator_optimization() == RewriterConfig::ON || + cfg.pin_to_host_optimization() != RewriterConfig::OFF || !cfg.optimizers().empty() || !cfg.custom_optimizers().empty(); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 151a54cbdf..831c5e37c0 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -52,6 +52,9 @@ class MetaOptimizer : public GraphOptimizer { // Initialize active optimizers from RewriterConfig optimizer names. Status InitializeOptimizersByName( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; + // Initialize active optimizers from RewriterConfig.custom_optimizers. + Status InitializeCustomGraphOptimizers( + std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 9a03c7dfef..e74e0f7501 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -64,6 +64,13 @@ bool TestOptimizer::optimized_; REGISTER_GRAPH_OPTIMIZER(TestOptimizer); +class TestGraphOptimizer : public TestOptimizer { + public: + string name() const override { return "test_graph_optimizer"; } +}; + +REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -83,6 +90,27 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + TestGraphOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizer"); + auto customGraphOptimizer = rewriter_config.add_custom_optimizers(); + customGraphOptimizer->set_name("TestGraphOptimizer"); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); + EXPECT_TRUE(TestGraphOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; @@ -98,6 +126,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TF_EXPECT_OK(status); } +TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + RewriterConfig rewriter_config; + auto customGraphOptimizer = rewriter_config.add_custom_optimizers(); + customGraphOptimizer->set_name("TestGraphOptimizer"); + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestGraphOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc new file mode 100644 index 0000000000..2190d38937 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -0,0 +1,264 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/tensor_shape.pb.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/grappler/graph_view.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" +#include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/str_util.h" + +namespace tensorflow { +namespace grappler { +namespace internal { + +// TODO(williamchan): Change this constant to be something smarter, maybe +// dynamically determined. +constexpr int64 kTensorMaxSize = 64; + +// Find KernelDef for `node`. +Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { + // Try find KernelDef for node.device, else GPU or CPU. + for (const DeviceType& device : + {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) { + Status s = FindKernelDef(device, node, kdef, nullptr); + if (s.ok()) { + return Status::OK(); + } + } + + return errors::NotFound("Could not find KernelDef for op: ", node.op()); +} + +// Check if all node's inputs are pinned to CPU memory. +bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { + // Loop through all the inputs excluding the controlling nodes. + for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) { + // Check if (the fanin) op's device is on CPU. + if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) { + continue; + } + + // Check if (the fanin) op's output port is pinned to HostMemory. + const OpDef* fanin_odef = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef); + if (!s.ok()) { + LOG(INFO) << "Could not find OpDef for : " << fanin.node->op(); + return false; + } + + const int output_arg_id = + OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id); + if (output_arg_id < 0) { + LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" + << node.DebugString() << "\n" + << fanin.node->DebugString() << "\n" + << fanin_odef->DebugString(); + return false; + } + + const KernelDef* fanin_kdef = nullptr; + s = TryFindKernelDef(*fanin.node, &fanin_kdef); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op(); + return false; + } + + bool fanin_pinned = false; + for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) { + if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) { + fanin_pinned = true; + break; + } + } + + if (!fanin_pinned) { + return false; + } + } + + return true; +} + +bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { + // Check if Tensor is integer and small size. + + // Check type to be int32 or int64. + if (prop.dtype() != DataType::DT_INT32 && + prop.dtype() != DataType::DT_INT64) { + return false; + } + + // Check size known and small. + const int64 size = NumCoefficients(prop.shape()); + if (size < 0 || size > kTensorMaxSize) { + return false; + } + + return true; +} + +bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties, + const NodeDef& node) { + for (const auto& prop : properties.GetInputProperties(node.name())) { + if (!IsTensorIntegerAndSmall(prop)) { + return false; + } + } + + for (const auto& prop : properties.GetOutputProperties(node.name())) { + if (!IsTensorIntegerAndSmall(prop)) { + return false; + } + } + return true; +} + +string TryFindHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, const string& device) { + // Force this node onto the CPU. + if (device.empty() && has_device_cpu) { + return "/device:CPU:0"; + } else if (str_util::StrContains(device, DEVICE_GPU)) { + // Sometimes the cluster can have: + // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} + // and we need to handle them properly. + for (const auto& device_match : + {std::pair<string, string>("GPU", "CPU:0"), + std::pair<string, string>("/device", "/device:CPU:0")}) { + const string device_host = + strings::StrCat(device.substr(0, device.rfind(device_match.first)), + device_match.second); + if (devices.find(device_host) != devices.end()) { + return device_host; + } + } + } + + // We couldn't find an appropriate Host device, return original device. + return device; +} + +bool IsTPUGraphDef(const GraphDef& def) { + for (const auto& node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || + node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} + +// All the nodes that should be blacklisted and not swapped. +bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); } +} // end namespace internal + +Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + + // Skip all TPU graphs. + if (internal::IsTPUGraphDef(*optimized_graph)) { + return Status::OK(); + } + + GraphProperties properties(item); + bool has_properties = false; + GraphView graph(optimized_graph); + + gtl::FlatSet<string> devices; + if (cluster) { + const std::vector<string> device_names = cluster->GetDeviceNames(); + devices.insert(device_names.begin(), device_names.end()); + } else { + devices = {"/device:CPU:0"}; + } + + const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end(); + + // Topologically sort the graph, so that we traverse the nodes in order. This + // will help us discover producer->consumer chains of Host ops. + TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph)); + + // All the Const nodes, and their original devices in topological order. + std::vector<std::pair<NodeDef*, string>> const_nodes; + + for (auto& node : *optimized_graph->mutable_node()) { + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + continue; + } + + // Skip these node types. + if (internal::IsBlacklisted(node)) { + continue; + } + + // Check the node can be run on CPU. + Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr); + if (!s.ok()) { + continue; + } + + // Check all input's are pinned to CPU. + if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) { + continue; + } + + if (!has_properties) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + has_properties = true; + } + + // Check all inputs and outputs are integers and small. + if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) { + continue; + } + + if (IsConstant(node)) { + const_nodes.emplace_back(&node, node.device()); + } + // Try and swap the device to Host. + node.set_device( + internal::TryFindHostDevice(devices, has_device_cpu, node.device())); + } + + // Traverse all `const_nodes`, and map them back to GPU greedily. + for (auto& it : const_nodes) { + NodeDef* node = it.first; + const string& device = it.second; + + // Check all the consumers of this node, if any of them are on the original + // device, swap this node back onto the original device. + for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { + if (fanout.node->device() == device) { + node->set_device(device); + break; + } + } + } + return Status::OK(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h new file mode 100644 index 0000000000..d557a03463 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ + +#include <unordered_set> +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/optimizers/graph_optimizer.h" +#include "tensorflow/core/lib/gtl/flatset.h" +#include "tensorflow/core/protobuf/rewriter_config.pb.h" + +namespace tensorflow { +namespace grappler { +namespace internal { +// Try and find an appropriate Host device in `devices` given `device`. +string TryFindHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, const string& device); +} // end namespace internal + +// Optimize TensorFlow ops that should be swapped into the CPU to avoid +// excessive cpu<->gpu memcpy/sync. +// +// TODO(williamchan): The current heuristic will swap any small integer Const to +// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of +// gpu->gpu->gpu may have been better/faster. We should probably fix this. +class PinToHostOptimizer : public GraphOptimizer { + public: + PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {} + explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level) + : opt_level_(opt_level) {} + + ~PinToHostOptimizer() override {} + + string name() const override { return "pin_to_host_optimizer"; }; + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + RewriterConfig::Toggle opt_level_; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc new file mode 100644 index 0000000000..173cb3fe3c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -0,0 +1,194 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +class PinToHostOptimizerTest : public GrapplerTest {}; + +TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { + gtl::FlatSet<string> devices = {}; + EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); + + devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), + "/device:CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), + "/device:CPU:0"); + + devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), + "/device:XLA_CPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), + "/device:XLA_CPU:0"); + + devices = {"/device:XLA_GPU:0"}; + EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), + "/device:XLA_GPU:0"); + EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), + "/device:XLA_GPU:*"); +} + +TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024}); + Output c = ops::Shape(s.WithOpName("c"), a); + Output d = ops::Const(s.WithOpName("d"), 0, {1}); + Output e = ops::ReduceProd(s.WithOpName("e"), c, d); + + GrapplerItem item; + item.fetch = {"a", "c", "d", "e"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_TRUE(node.device().empty()); + } else if (node.name() == "d" || node.name() == "e") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } + ++found; + } + EXPECT_EQ(found, 4); +} + +TEST_F(PinToHostOptimizerTest, TopologicalSort) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024}); + Output c = ops::Shape(s.WithOpName("c"), a); + Output d = ops::Const(s.WithOpName("d"), 0, {1}); + Output e = ops::ReduceProd(s.WithOpName("e"), c, d); + + GrapplerItem item; + item.fetch = {"a", "c", "d", "e"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + // Reverse the graph, and hence rely on the optimizer to sort it. + std::reverse(item.graph.mutable_node()->begin(), + item.graph.mutable_node()->end()); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_TRUE(node.device().empty()); + } else if (node.name() == "d" || node.name() == "e") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } + ++found; + } + EXPECT_EQ(found, 4); +} + +TEST_F(PinToHostOptimizerTest, NoSwap) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // `b` should be too big to swap, consequently `c` should not be swapped. + // PinToHostOptimizer should then detect that `a` should not be swapped. + Output a = ops::Const(s.WithOpName("a"), 1, {1, 1}); + Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024}); + Output c = ops::MatMul(s.WithOpName("c"), a, b); + + GrapplerItem item; + item.fetch = {"a", "b", "c"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + EXPECT_TRUE(node.device().empty()); + ++found; + } + EXPECT_EQ(found, 3); +} + +TEST_F(PinToHostOptimizerTest, PortIdToArgId) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); + ops::ShapeN b(s.WithOpName("b"), {a, a, a}); + + GrapplerItem item; + item.fetch = {"a", "b"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + auto tensors = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(tensors_expected.size(), tensors.size()); + for (int i = 0; i < tensors.size(); ++i) { + test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]); + } + + int found = 0; + for (const NodeDef& node : output.node()) { + EXPECT_EQ(node.device(), "/device:CPU:0"); + ++found; + } + EXPECT_EQ(found, 2); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 03e36a7b9c..9ada8b7ff9 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) { Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, GraphDef* optimized_graph) { GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + bool inferred_properties = false; GraphView graph(const_cast<GraphDef*>(&item.graph)); // During inference, most of the inputs to FusedBatchNorm are constant, and we // can therefore replace the op with a much cheaper set of primitives. + optimized_graph->mutable_node()->Reserve(item.graph.node_size()); for (const NodeDef& node : item.graph.node()) { if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") { bool optimizable = (node.attr().count("T") == 0 || @@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, !node.attr().at("is_training").b()); if (optimizable) { int const_inputs = 0; + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& props = properties.GetInputProperties(node.name()); for (const auto& prop : props) { if (prop.has_value()) { @@ -218,7 +224,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/, const GraphDef& /*optimized_graph*/, double /*result*/) { - // Nothing to do for ArithmeticOptimizer. + // Nothing to do for RemapperOptimizer. } } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index caa0b7b0cb..6ccb1cd783 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -20,10 +20,9 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/grappler/graph_view.h" #include "tensorflow/core/grappler/grappler_item.h" -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" - #include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -34,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph = item.graph; GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + bool inferred_properties = false; GraphView graph(optimized_graph); // The product of all the dimensions in a tensor shape can be expressed more @@ -56,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } const GraphView::OutputPort reduce_indices = graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1)); + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& prop = properties.GetOutputProperties(reduce_indices.node->name()); if (prop.size() < reduce_indices.port_id) { @@ -93,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (!IsSize(*input1.node) || !IsSize(*input2.node)) { continue; } + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& prop1 = properties.GetInputProperties(input1.node->name()); const auto& prop2 = properties.GetInputProperties(input2.node->name()); if (prop1.size() != 1 || prop2.size() != 1) { diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index 153785d3b4..db6e4e6852 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" +#include <iterator> #include <memory> #include <queue> #include <vector> @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -165,6 +167,34 @@ int NodePosition(const string& name) { return position; } +int NodePositionIfSameNode(const string& input_name, const string& node_name) { + const bool is_ctrl = input_name[0] == '^'; + auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); + auto node_it = node_name.begin(); + if (node_name.empty() || + std::distance(input_it, input_name.end()) < node_name.size()) { + return -2; + } + while (node_it != node_name.end()) { + if (*input_it++ != *node_it++) { + return -2; + } + } + if (input_it == input_name.end()) { + return is_ctrl ? -1 : 0; + } else if (*input_it++ == ':') { + StringPiece remaining(&(*input_it), + std::distance(input_it, input_name.end())); + int position; + if (!strings::safe_strto32(remaining, &position)) { + return -2; + } + return is_ctrl ? -1 : position; + } else { + return -2; + } +} + string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 20dbeea2cf..296ee1678e 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -107,6 +107,7 @@ bool IsSameInput(const string& name1, const string& name2); string NodeName(const string& name); // Get the trailing position number ":{digits}" (if any) of a node name. +// Returns -1 for control inputs. int NodePosition(const string& name); inline StringPiece ParseNodeNameAsStringPiece(const string& name, @@ -142,6 +143,11 @@ inline string ParseNodeName(const string& name, int* position) { return string(ParseNodeNameAsStringPiece(name, position)); } +// Returns NodePosition(input_name) if NodeName(input_name) == node_name. +// Otherwise returns -2; +// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0. +int NodePositionIfSameNode(const string& input_name, const string& node_name); + // Add a prefix to a node name with a custom delimiter. string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter); diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD index e540cc0476..bdbb8836e1 100644 --- a/tensorflow/core/grappler/utils/BUILD +++ b/tensorflow/core/grappler/utils/BUILD @@ -1,6 +1,10 @@ licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load( + "//tensorflow/core:platform/default/build_config.bzl", + "tf_protos_grappler", +) cc_library( name = "scc", @@ -210,3 +214,28 @@ tf_cc_test( "//tensorflow/core:testlib", ], ) + +cc_library( + name = "symbolic_shapes", + srcs = ["symbolic_shapes.cc"], + hdrs = ["symbolic_shapes.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ] + tf_protos_grappler(), +) + +tf_cc_test( + name = "symbolic_shapes_test", + srcs = ["symbolic_shapes_test.cc"], + deps = [ + ":symbolic_shapes", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index a2c363ea6e..a428aea7f5 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -304,21 +304,21 @@ Status GrapplerFunctionItemInstantiation::GetArgType( } GrapplerFunctionItem::GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector<InputArgExpansion>& input_arg_expansions, - const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, const int graph_def_version, - bool is_stateful, GraphDef&& function_body) - : description_(description), - func_attr_(func_attr), - input_arg_expansions_(input_arg_expansions), - output_arg_expansions_(output_arg_expansions), + string func_name, string description, AttrValueMap func_attr, + std::vector<InputArgExpansion> input_arg_expansions, + std::vector<OutputArgExpansion> output_arg_expansions, + std::vector<string> keep_nodes, const int graph_def_version, + const bool is_stateful, GraphDef&& function_body) + : description_(std::move(description)), + func_attr_(std::move(func_attr)), + input_arg_expansions_(std::move(input_arg_expansions)), + output_arg_expansions_(std::move(output_arg_expansions)), is_stateful_(is_stateful) { - id = func_name; - keep_ops = keep_nodes; - // Swap the graph body. - graph.Swap(&function_body); + // Move assign GrapplerItem members. + keep_ops = std::move(keep_nodes); + id = std::move(func_name); + graph = std::move(function_body); + graph.mutable_versions()->set_producer(graph_def_version); // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { @@ -598,8 +598,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, *item = GrapplerFunctionItem( /*func_name=*/signature.name(), /*description=*/signature.description(), /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()), - inputs, outputs, keep_nodes, graph_def_version, is_stateful, - std::move(function_body)); + std::move(inputs), std::move(outputs), std::move(keep_nodes), + graph_def_version, is_stateful, std::move(function_body)); return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 61588ceb83..733caf325f 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation { class GrapplerFunctionItem : public GrapplerItem { public: GrapplerFunctionItem() = default; - GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector<InputArgExpansion>& input_arg_expansions, - const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, const int versions, - bool is_stateful, GraphDef&& function_body); + GrapplerFunctionItem(string func_name, string description, + AttrValueMap func_attr, + std::vector<InputArgExpansion> input_arg_expansions, + std::vector<OutputArgExpansion> output_arg_expansions, + std::vector<string> keep_nodes, int graph_def_version, + bool is_stateful, GraphDef&& function_body); const string& description() const; diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc index 910b0acaef..6266733f3e 100644 --- a/tensorflow/core/grappler/utils/grappler_test.cc +++ b/tensorflow/core/grappler/utils/grappler_test.cc @@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() { // optimizations interfering in the comparison. RewriterConfig* cfg = options_.config.mutable_graph_options()->mutable_rewrite_options(); - cfg->set_constant_folding(RewriterConfig::OFF); + // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned + // off. cfg->set_arithmetic_optimization(RewriterConfig::OFF); + cfg->set_constant_folding(RewriterConfig::OFF); + cfg->set_debug_stripper(RewriterConfig::OFF); cfg->set_dependency_optimization(RewriterConfig::OFF); - cfg->set_loop_optimization(RewriterConfig::OFF); cfg->set_function_optimization(RewriterConfig::OFF); cfg->set_layout_optimizer(RewriterConfig::OFF); - cfg->set_debug_stripper(RewriterConfig::OFF); + cfg->set_loop_optimization(RewriterConfig::OFF); + cfg->set_pin_to_host_optimization(RewriterConfig::OFF); } std::vector<Tensor> GrapplerTest::EvaluateNodes( diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h index 4fb7aab647..ceb9f5dbf2 100644 --- a/tensorflow/core/grappler/utils/scc.h +++ b/tensorflow/core/grappler/utils/scc.h @@ -24,15 +24,16 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Compute modified strongly connected components: +// Computes modified strongly connected components: // All nodes that are not part of a loop are assigned the special -1 id // All nodes that are part of at least one loop are assigned a positive // component id: if 2 nodes v and w are reachable from one another (i.e. if they // belong to the same scc), they'll be assigned the same id, otherwise they'll -// be assigned distinct ids. Returns the number of distinct ids. +// be assigned distinct ids. *num_components is set to the number of distinct +// ids. void StronglyConnectedComponents( const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components, - int* num_ids); + int* num_components); // Returns the number of individual loops present in the graph, and populate the // 'loops' argument with the collection of loops (denoted by their loop ids) a diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc index 155843a744..1666de4b80 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc +++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h index ace7bd1fe7..0a7d8ac82b 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h +++ b/tensorflow/core/grappler/utils/symbolic_shapes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_ -#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_ +#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ +#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/grappler/costs/op_performance_data.pb.h" @@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator, } // namespace grappler } // end namespace tensorflow -#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_ +#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_ diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc index 7ce995d1c5..6ac644cdb1 100644 --- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc +++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h" +#include "tensorflow/core/grappler/utils/symbolic_shapes.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index c6e035834c..6b787a6910 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/notification.h" #include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" namespace tensorflow { namespace grappler { @@ -147,6 +148,21 @@ TEST_F(UtilsTest, NodePosition) { EXPECT_EQ(0, NodePosition("")); } +TEST_F(UtilsTest, NodePositionIfSameNode) { + EXPECT_EQ(-2, NodePositionIfSameNode(":123", "")); + EXPECT_EQ(-2, NodePositionIfSameNode(":", "")); + EXPECT_EQ(-2, NodePositionIfSameNode("", "")); + EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc")); + EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc")); + EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc")); + EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz")); + EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz")); +} + TEST_F(UtilsTest, AddNodeNamePrefix) { EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED")); EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED")); @@ -209,7 +225,6 @@ TEST_F(UtilsTest, GetTailOfChain) { auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop")); GraphDef graph; TF_CHECK_OK(s.ToGraphDef(&graph)); - LOG(INFO) << graph.DebugString(); ASSERT_EQ("c0", graph.node(0).name()); ASSERT_EQ("c1", graph.node(1).name()); @@ -336,9 +351,26 @@ TEST_F(UtilsTest, NumNonControlOutputs) { } TEST_F(UtilsTest, DeleteNodes) { - // TODO(rmlarsen): write forgtten test. + // TODO(rmlarsen): write forgotten test. } +#define BM_NodePositionIfSameNode(I, N, NAME) \ + static void BM_NodePositionIfSameNode_##NAME(int iters) { \ + string input = I; \ + string node = N; \ + for (int i = 0; i < iters; ++i) { \ + const int pos = NodePositionIfSameNode(input, node); \ + CHECK_GT(pos, -3); \ + } \ + } \ + BENCHMARK(BM_NodePositionIfSameNode_##NAME) + +BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7); +BM_NodePositionIfSameNode("foo/bar/baz", "foo/bar/baz", Match_0); +BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl); +BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0); +BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end); + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index f5682b6e13..0534b1829d 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -30,6 +30,7 @@ load( "//tensorflow:tensorflow.bzl", "if_android", "tf_cc_test", + "tf_cc_test_mkl", "tf_cc_tests", "tf_cc_binary", "tf_copts", @@ -50,6 +51,10 @@ load( "tf_kernel_tests_linkstatic", ) load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load( "//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml", @@ -212,6 +217,19 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "extract_volume_patches_op", + prefix = "extract_volume_patches_op", + deps = [ + ":bounds_check", + ":eigen_helpers", + ":ops_util", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + cc_library( name = "conv_3d", hdrs = ["conv_3d.h"], @@ -617,6 +635,7 @@ cc_library( ":diag_op", ":edit_distance_op", ":extract_image_patches_op", + ":extract_volume_patches_op", ":gather_nd_op", ":gather_op", ":guarantee_const_op", @@ -636,6 +655,7 @@ cc_library( ":reshape_op", ":reverse_op", ":reverse_sequence_op", + ":searchsorted_op", ":shape_ops", ":slice_op", ":snapshot_op", @@ -643,14 +663,7 @@ cc_library( ":split_v_op", ":strided_slice_op", ":tile_ops", - ] + if_mkl( - [ - ":mkl_transpose_op", - ], - [ - ":transpose_op", - ], - ) + [ + ":transpose_op", ":unique_op", ":unpack_op", ":unravel_index_op", @@ -876,6 +889,12 @@ tf_kernel_library( ) tf_kernel_library( + name = "searchsorted_op", + prefix = "searchsorted_op", + deps = ARRAY_DEPS, +) + +tf_kernel_library( name = "inplace_ops", prefix = "inplace_ops", deps = ARRAY_DEPS, @@ -893,24 +912,13 @@ tf_kernel_library( deps = ARRAY_DEPS, ) -if_mkl( - [tf_mkl_kernel_library( - name = "mkl_transpose_op", - srcs = [ - "mkl_transpose_op.cc", - "transpose_op.cc", - ], - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS + mkl_deps(), - )], - [tf_kernel_library( - name = "transpose_op", - srcs = [ - "transpose_op.cc", - ], - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS, - )], +tf_kernel_library( + name = "transpose_op", + srcs = [ + "transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]), ) tf_kernel_library( @@ -1123,7 +1131,7 @@ tf_cuda_cc_test( name = "depthwise_conv_ops_test", size = "small", srcs = ["depthwise_conv_ops_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":conv_ops", ":image", @@ -2020,8 +2028,8 @@ tf_kernel_library( ":variable_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:resource_variable_ops_op_lib", - "//third_party/eigen3", ], ) @@ -2720,6 +2728,7 @@ cc_library( ) LOGGING_DEPS = [ + "@com_google_absl//absl/strings", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -2777,6 +2786,7 @@ tf_cc_tests( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/strings", ], ) @@ -4414,6 +4424,7 @@ cc_library( ":reduce_join_op", ":regex_full_match_op", ":regex_replace_op", + ":string_format_op", ":string_join_op", ":string_length_op", ":string_split_op", @@ -4423,8 +4434,16 @@ cc_library( ], ) +cc_library( + name = "string_util", + srcs = ["string_util.cc"], + hdrs = ["string_util.h"], + deps = ["//tensorflow/core:lib"], +) + STRING_DEPS = [ ":bounds_check", + ":string_util", "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -4445,6 +4464,30 @@ tf_kernel_library( ) tf_kernel_library( + name = "string_format_op", + prefix = "string_format_op", + deps = STRING_DEPS + ["@com_google_absl//absl/strings"], +) + +tf_cc_test( + name = "string_format_op_test", + size = "small", + srcs = ["string_format_op_test.cc"], + deps = [ + ":string_format_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_kernel_library( name = "string_join_op", prefix = "string_join_op", deps = STRING_DEPS, @@ -4522,6 +4565,25 @@ tf_kernel_library( deps = STRING_DEPS, ) +tf_cc_test( + name = "substr_op_test", + size = "small", + srcs = ["substr_op_test.cc"], + deps = [ + ":substr_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + ], +) + tf_kernel_library( name = "as_string_op", prefix = "as_string_op", @@ -5112,6 +5174,7 @@ filegroup( "spacetobatch_functor.h", "spacetodepth_op.h", "spectrogram.h", + "string_util.h", "tensor_array.h", "tile_functor.h", "tile_ops_cpu_impl.h", @@ -5191,6 +5254,8 @@ filegroup( "cwise_op_squared_difference.cc", "cwise_op_sub.cc", "cwise_op_tanh.cc", + "cwise_op_xlogy.cc", + "cwise_op_xdivy.cc", "data_format_ops.cc", "decode_wav_op.cc", "deep_conv2d.cc", @@ -5202,6 +5267,7 @@ filegroup( "fifo_queue.cc", "fifo_queue_op.cc", "fused_batch_norm_op.cc", + "listdiff_op.cc", "population_count_op.cc", "population_count_op.h", "winograd_transform.h", @@ -5279,6 +5345,7 @@ filegroup( "spectrogram_op.cc", "stack_ops.cc", "string_join_op.cc", + "string_util.cc", "summary_op.cc", "tensor_array.cc", "tensor_array_ops.cc", @@ -6226,6 +6293,26 @@ tf_mkl_kernel_library( ] + mkl_deps(), ) +tf_cc_test_mkl( + name = "mkl_conv_ops_test", + size = "small", + srcs = ["mkl_conv_ops_test.cc"], + deps = [ + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_mkl_kernel_library( name = "mkl_tfconv_op", prefix = "mkl_tfconv", @@ -6357,6 +6444,15 @@ tf_mkl_kernel_library( deps = NN_DEPS + mkl_deps() + [":cwise_op"], ) +tf_mkl_kernel_library( + name = "mkl_transpose_op", + srcs = [ + "mkl_transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + mkl_deps(), +) + # NOTE(lespeholt): This rule is deprecated, please use: # tensorflow/core/util/batch_util.h cc_library( diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc index 54c45bfe63..f48bd0c318 100644 --- a/tensorflow/core/kernels/batch_matmul_op_complex.cc +++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc @@ -17,14 +17,18 @@ limitations under the License. namespace tensorflow { -#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) +// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc +// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL). +// Anything else (the complement) should register the TF ones. +// (MKL-DNN doesn't implement these kernels either.) +#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL) TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU); -#endif +#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL #if GOOGLE_CUDA TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU); -#endif +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index 584b507c70..25ae795d8e 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -21,10 +21,15 @@ limitations under the License. namespace tensorflow { -#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) +// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc +// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL). +// Anything else (the complement) should register the TF ones. +// (MKL-DNN doesn't implement these kernels either.) +#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL) TF_CALL_float(REGISTER_BATCH_MATMUL_CPU); TF_CALL_double(REGISTER_BATCH_MATMUL_CPU); -#endif +#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL + TF_CALL_half(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 792eb74e31..039b0db144 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -1,7 +1,7 @@ # Description: Utilities. package( - default_visibility = ["//tensorflow:internal"], + default_visibility = ["//visibility:public"], ) licenses(["notice"]) # Apache 2.0 @@ -12,7 +12,11 @@ cc_library( name = "periodic_function_dynamic", srcs = ["periodic_function.cc"], hdrs = ["periodic_function.h"], - visibility = ["//visibility:public"], + visibility = [ + "//learning/serving:__subpackages__", + "//tensorflow:internal", + "//tensorflow_serving:__subpackages__", + ], deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:protos_all_cc", @@ -21,7 +25,11 @@ cc_library( cc_library( name = "periodic_function", - visibility = ["//visibility:public"], + visibility = [ + "//learning/serving:__subpackages__", + "//tensorflow:internal", + "//tensorflow_serving:__subpackages__", + ], deps = [ ":periodic_function_dynamic", "//tensorflow/core:lib", @@ -190,7 +198,11 @@ cc_library( testonly = 1, srcs = ["fake_clock_env.cc"], hdrs = ["fake_clock_env.h"], - visibility = ["//visibility:public"], + visibility = [ + "//learning/serving:__subpackages__", + "//tensorflow:internal", + "//tensorflow_serving:__subpackages__", + ], deps = [ "//tensorflow/core:lib", "//tensorflow/core:tensorflow", diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 7b28c8e91f..e15ea82e7d 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -134,8 +134,8 @@ class BiasOp : public BinaryOp<T> { if (data_format_ == FORMAT_NCHW) { int32 batch, height, width, channel; GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel); - Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1); - Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width); + Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1); + Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width); const Device& d = context->eigen_device<Device>(); output->tensor<T, 4>().device(d) = input.tensor<T, 4>() + @@ -247,14 +247,14 @@ class BiasGradOp : public OpKernel { OP_REQUIRES(context, output_backprop.dims() == 4, errors::InvalidArgument( "NCHW format supports only 4D input/output tensor.")); - Eigen::DSizes<int, 4> four_dims(batch, channel, height, width); + Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width); #ifdef EIGEN_HAS_INDEX_LIST using idx0 = Eigen::type2index<0>; using idx2 = Eigen::type2index<2>; using idx3 = Eigen::type2index<3>; Eigen::IndexList<idx0, idx2, idx3> reduction_axes; #else - Eigen::array<int, 3> reduction_axes = {0, 2, 3}; + Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3}; #endif output->template flat<T>().device(context->eigen_device<Device>()) = output_backprop.flat<T>() @@ -263,11 +263,12 @@ class BiasGradOp : public OpKernel { .sum(reduction_axes) .template cast<T>(); // End of code by intel_tf. } else { - Eigen::DSizes<int, 2> two_dims(batch * height * width, channel); + Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width, + channel); #ifdef EIGEN_HAS_INDEX_LIST Eigen::IndexList<Eigen::type2index<0> > reduction_axis; #else - Eigen::array<int, 1> reduction_axis = {0}; + Eigen::array<Eigen::Index, 1> reduction_axis = {0}; #endif output->template flat<T>().device(context->eigen_device<Device>()) = output_backprop.flat<T>() diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc index 6074b3e1f6..7d09e9b820 100644 --- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc @@ -17,7 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU -#include "external/cub_archive/cub/device/device_histogram.cuh" +#include "third_party/cub/device/device_histogram.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD index 4910021c63..4e8bfa02fc 100644 --- a/tensorflow/core/kernels/boosted_trees/BUILD +++ b/tensorflow/core/kernels/boosted_trees/BUILD @@ -15,7 +15,9 @@ load( tf_proto_library( name = "boosted_trees_proto", - srcs = ["boosted_trees.proto"], + srcs = [ + "boosted_trees.proto", + ], cc_api_version = 2, visibility = ["//visibility:public"], ) @@ -87,9 +89,21 @@ tf_kernel_library( ) tf_kernel_library( + name = "quantile_ops", + srcs = ["quantile_ops.cc"], + deps = [ + "//tensorflow/core:boosted_trees_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles", + ], +) + +tf_kernel_library( name = "boosted_trees_ops", deps = [ ":prediction_ops", + ":quantile_ops", ":resource_ops", ":stats_ops", ":training_ops", diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto index c9664f0c1c..1ab72af059 100644 --- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto +++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto @@ -11,6 +11,7 @@ message Node { oneof node { Leaf leaf = 1; BucketizedSplit bucketized_split = 2; + CategoricalSplit categorical_split = 3; } NodeMetadata metadata = 777; } @@ -57,6 +58,18 @@ message BucketizedSplit { int32 right_id = 4; } +message CategoricalSplit { + // Categorical feature column and split describing the rule feature value == + // value. + int32 feature_id = 1; + int32 value = 2; + + // Node children indexing into a contiguous + // vector of nodes starting from the root. + int32 left_id = 3; + int32 right_id = 4; +} + // Tree describes a list of connected nodes. // Node 0 must be the root and can carry any payload including a leaf // in the case of representing the bias. diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc index b2efa06941..4ae26fb95b 100644 --- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc +++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc @@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel { // Proto to store debug outputs, per example. boosted_trees::DebugOutput example_debug_info; // Initial bias prediction. E.g., prediction based off training mean. - example_debug_info.add_logits_path(resource->GetTreeWeight(0) * - resource->node_value(0, 0)); + float tree_logit = + resource->GetTreeWeight(0) * resource->node_value(0, 0); + example_debug_info.add_logits_path(tree_logit); int32 node_id = 0; int32 tree_id = 0; int32 feature_id; - float tree_logit; float past_trees_logit = 0; // Sum of leaf logits from prior trees. - // Populate proto. + // Go through each tree and populate proto. while (tree_id <= last_tree) { - // Feature id used to split. - feature_id = resource->feature_id(tree_id, node_id); - example_debug_info.add_feature_ids(feature_id); - // Get logit after split. - node_id = resource->next_node(tree_id, node_id, i, - batch_bucketized_features); - tree_logit = resource->GetTreeWeight(tree_id) * - resource->node_value(tree_id, node_id); - // Output logit incorporates sum of leaf logits from prior trees. - example_debug_info.add_logits_path(tree_logit + past_trees_logit); - if (resource->is_leaf(tree_id, node_id)) { - // Move onto other trees. - past_trees_logit += tree_logit; + if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees. + // Accumulate tree_logits only if the leaf is non-root, but do so + // for bias tree. + if (tree_id == 0 || node_id > 0) { + past_trees_logit += tree_logit; + } ++tree_id; node_id = 0; + } else { // Add to proto. + // Feature id used to split. + feature_id = resource->feature_id(tree_id, node_id); + example_debug_info.add_feature_ids(feature_id); + // Get logit after split. + node_id = resource->next_node(tree_id, node_id, i, + batch_bucketized_features); + tree_logit = resource->GetTreeWeight(tree_id) * + resource->node_value(tree_id, node_id); + // Output logit incorporates sum of leaf logits from prior trees. + example_debug_info.add_logits_path(tree_logit + past_trees_logit); } } // Set output as serialized proto containing debug info. diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc new file mode 100644 index 0000000000..d1840941c1 --- /dev/null +++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc @@ -0,0 +1,453 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include <algorithm> +#include <iterator> +#include <string> +#include <vector> + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +const char* const kExampleWeightsName = "example_weights"; +const char* const kMaxElementsName = "max_elements"; +const char* const kGenerateQuantiles = "generate_quantiles"; +const char* const kNumBucketsName = "num_buckets"; +const char* const kEpsilonName = "epsilon"; +const char* const kBucketBoundariesName = "bucket_boundaries"; +const char* const kBucketsName = "buckets"; +const char* const kSummariesName = "summaries"; +const char* const kNumStreamsName = "num_streams"; +const char* const kNumFeaturesName = "num_features"; +const char* const kFloatFeaturesName = "float_values"; +const char* const kResourceHandleName = "quantile_stream_resource_handle"; + +using QuantileStreamResource = BoostedTreesQuantileStreamResource; +using QuantileStream = + boosted_trees::quantiles::WeightedQuantilesStream<float, float>; +using QuantileSummary = + boosted_trees::quantiles::WeightedQuantilesSummary<float, float>; +using QuantileSummaryEntry = + boosted_trees::quantiles::WeightedQuantilesSummary<float, + float>::SummaryEntry; + +// Generates quantiles on a finalized QuantileStream. +std::vector<float> GenerateBoundaries(const QuantileStream& stream, + const int64 num_boundaries) { + std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries); + + // Uniquify elements as we may get dupes. + auto end_it = std::unique(boundaries.begin(), boundaries.end()); + boundaries.resize(std::distance(boundaries.begin(), end_it)); + return boundaries; +} + +// Generates quantiles on a finalized QuantileStream. +std::vector<float> GenerateQuantiles(const QuantileStream& stream, + const int64 num_quantiles) { + // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values + // will be returned. + std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1); + CHECK_EQ(boundaries.size(), num_quantiles); + return boundaries; +} + +std::vector<float> GetBuckets(const int32 feature, + const OpInputList& buckets_list) { + const auto& buckets = buckets_list[feature].flat<float>(); + std::vector<float> buckets_vector(buckets.data(), + buckets.data() + buckets.size()); + return buckets_vector; +} + +REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource); + +REGISTER_KERNEL_BUILDER( + Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU), + IsResourceInitialized<BoostedTreesQuantileStreamResource>); + +class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel { + public: + explicit BoostedTreesCreateQuantileStreamResourceOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_)); + } + + void Compute(OpKernelContext* context) override { + // Only create one, if one does not exist already. Report status for all + // other exceptions. If one already exists, it unrefs the new one. + // An epsilon value of zero could cause perfoamance issues and is therefore, + // disallowed. + const Tensor* epsilon_t; + OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t)); + float epsilon = epsilon_t->scalar<float>()(); + OP_REQUIRES( + context, epsilon > 0, + errors::InvalidArgument("An epsilon value of zero is not allowed.")); + + const Tensor* num_streams_t; + OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t)); + int64 num_streams = num_streams_t->scalar<int64>()(); + + auto result = + new QuantileStreamResource(epsilon, max_elements_, num_streams); + auto status = CreateResource(context, HandleFromInput(context, 0), result); + if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { + OP_REQUIRES(context, false, status); + } + } + + private: + // An upper bound on the number of entries that the summaries might have + // for a feature. + int64 max_elements_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU), + BoostedTreesCreateQuantileStreamResourceOp); + +class BoostedTreesMakeQuantileSummariesOp : public OpKernel { + public: + explicit BoostedTreesMakeQuantileSummariesOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + // Read float features list; + OpInputList float_features_list; + OP_REQUIRES_OK( + context, context->input_list(kFloatFeaturesName, &float_features_list)); + + // Parse example weights and get batch size. + const Tensor* example_weights_t; + OP_REQUIRES_OK(context, + context->input(kExampleWeightsName, &example_weights_t)); + auto example_weights = example_weights_t->flat<float>(); + const int64 batch_size = example_weights.size(); + const Tensor* epsilon_t; + OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t)); + float epsilon = epsilon_t->scalar<float>()(); + + OpOutputList summaries_output_list; + OP_REQUIRES_OK( + context, context->output_list(kSummariesName, &summaries_output_list)); + + auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) { + // Iterating features. + for (int64 index = begin; index < end; index++) { + const auto feature_values = float_features_list[index].flat<float>(); + QuantileStream stream(epsilon, batch_size + 1); + // Run quantile summary generation. + for (int64 j = 0; j < batch_size; j++) { + stream.PushEntry(feature_values(j), example_weights(j)); + } + stream.Finalize(); + const auto summary_entry_list = stream.GetFinalSummary().GetEntryList(); + Tensor* output_t; + OP_REQUIRES_OK( + context, + summaries_output_list.allocate( + index, + TensorShape({static_cast<int64>(summary_entry_list.size()), 4}), + &output_t)); + auto output = output_t->matrix<float>(); + for (auto row = 0; row < summary_entry_list.size(); row++) { + const auto& entry = summary_entry_list[row]; + output(row, 0) = entry.value; + output(row, 1) = entry.weight; + output(row, 2) = entry.min_rank; + output(row, 3) = entry.max_rank; + } + } + }; + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * batch_size; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_features_, + kCostPerUnit, do_quantile_summary_gen); + } + + private: + int64 num_features_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU), + BoostedTreesMakeQuantileSummariesOp); + +class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel { + public: + explicit BoostedTreesQuantileStreamResourceAddSummariesOp( + OpKernelConstruction* const context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + ResourceHandle handle; + OP_REQUIRES_OK(context, + HandleFromInput(context, kResourceHandleName, &handle)); + QuantileStreamResource* stream_resource; + // Create a reference to the underlying resource using the handle. + OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); + // Remove the reference at the end of this scope. + mutex_lock l(*stream_resource->mutex()); + core::ScopedUnref unref_me(stream_resource); + + OpInputList summaries_list; + OP_REQUIRES_OK(context, + context->input_list(kSummariesName, &summaries_list)); + int32 num_streams = stream_resource->num_streams(); + CHECK_EQ(static_cast<int>(num_streams), summaries_list.size()); + + auto do_quantile_add_summary = [&](const int64 begin, const int64 end) { + // Iterating all features. + for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) { + const Tensor& summaries = summaries_list[feature_idx]; + const auto summary_values = summaries.matrix<float>(); + const auto& tensor_shape = summaries.shape(); + const int64 entries_size = tensor_shape.dim_size(0); + CHECK_EQ(tensor_shape.dim_size(1), 4); + std::vector<QuantileSummaryEntry> summary_entries; + summary_entries.reserve(entries_size); + for (int64 i = 0; i < entries_size; i++) { + float value = summary_values(i, 0); + float weight = summary_values(i, 1); + float min_rank = summary_values(i, 2); + float max_rank = summary_values(i, 3); + QuantileSummaryEntry entry(value, weight, min_rank, max_rank); + summary_entries.push_back(entry); + } + stream_resource->stream(feature_idx)->PushSummary(summary_entries); + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_streams; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_streams, + kCostPerUnit, do_quantile_add_summary); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU), + BoostedTreesQuantileStreamResourceAddSummariesOp); + +class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel { + public: + explicit BoostedTreesQuantileStreamResourceFlushOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr(kGenerateQuantiles, &generate_quantiles_)); + } + + void Compute(OpKernelContext* context) override { + ResourceHandle handle; + OP_REQUIRES_OK(context, + HandleFromInput(context, kResourceHandleName, &handle)); + QuantileStreamResource* stream_resource; + // Create a reference to the underlying resource using the handle. + OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); + // Remove the reference at the end of this scope. + mutex_lock l(*stream_resource->mutex()); + core::ScopedUnref unref_me(stream_resource); + + const Tensor* num_buckets_t; + OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t)); + const int64 num_buckets = num_buckets_t->scalar<int64>()(); + const int64 num_streams = stream_resource->num_streams(); + + auto do_quantile_flush = [&](const int64 begin, const int64 end) { + // Iterating over all streams. + for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) { + QuantileStream* stream = stream_resource->stream(stream_idx); + stream->Finalize(); + stream_resource->set_boundaries( + generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets) + : GenerateBoundaries(*stream, num_buckets), + stream_idx); + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_streams; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_streams, + kCostPerUnit, do_quantile_flush); + + stream_resource->set_buckets_ready(true); + } + + private: + bool generate_quantiles_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU), + BoostedTreesQuantileStreamResourceFlushOp); + +class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp + : public OpKernel { + public: + explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + ResourceHandle handle; + OP_REQUIRES_OK(context, + HandleFromInput(context, kResourceHandleName, &handle)); + QuantileStreamResource* stream_resource; + // Create a reference to the underlying resource using the handle. + OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); + // Remove the reference at the end of this scope. + mutex_lock l(*stream_resource->mutex()); + core::ScopedUnref unref_me(stream_resource); + + const int64 num_streams = stream_resource->num_streams(); + CHECK_EQ(num_features_, num_streams); + OpOutputList bucket_boundaries_list; + OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName, + &bucket_boundaries_list)); + + auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) { + // Iterating over all streams. + for (int64 stream_idx = begin; stream_idx < end; stream_idx++) { + const auto& boundaries = stream_resource->boundaries(stream_idx); + Tensor* bucket_boundaries_t = nullptr; + OP_REQUIRES_OK(context, + bucket_boundaries_list.allocate( + stream_idx, {static_cast<int64>(boundaries.size())}, + &bucket_boundaries_t)); + auto* quantiles_flat = bucket_boundaries_t->flat<float>().data(); + memcpy(quantiles_flat, boundaries.data(), + sizeof(float) * boundaries.size()); + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_streams; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_streams, + kCostPerUnit, do_quantile_get_buckets); + } + + private: + int64 num_features_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries") + .Device(DEVICE_CPU), + BoostedTreesQuantileStreamResourceGetBucketBoundariesOp); + +// Given the calculated quantiles thresholds and input data, this operation +// converts the input features into the buckets (categorical values), depending +// on which quantile they fall into. +class BoostedTreesBucketizeOp : public OpKernel { + public: + explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + // Read float features list; + OpInputList float_features_list; + OP_REQUIRES_OK( + context, context->input_list(kFloatFeaturesName, &float_features_list)); + OpInputList bucket_boundaries_list; + OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName, + &bucket_boundaries_list)); + OP_REQUIRES(context, + tensorflow::TensorShapeUtils::IsVector( + bucket_boundaries_list[0].shape()), + errors::InvalidArgument( + strings::Printf("Buckets should be flat vectors."))); + OpOutputList buckets_list; + OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list)); + + auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) { + // Iterating over all resources + for (int64 feature_idx = begin; feature_idx < end; feature_idx++) { + const Tensor& values_tensor = float_features_list[feature_idx]; + const int64 num_values = values_tensor.dim_size(0); + + Tensor* output_t = nullptr; + OP_REQUIRES_OK( + context, buckets_list.allocate( + feature_idx, TensorShape({num_values, 1}), &output_t)); + auto output = output_t->matrix<int32>(); + + const std::vector<float>& bucket_boundaries_vector = + GetBuckets(feature_idx, bucket_boundaries_list); + CHECK(!bucket_boundaries_vector.empty()) + << "Got empty buckets for feature " << feature_idx; + auto flat_values = values_tensor.flat<float>(); + for (int64 instance = 0; instance < num_values; instance++) { + const float value = flat_values(instance); + auto bucket_iter = + std::lower_bound(bucket_boundaries_vector.begin(), + bucket_boundaries_vector.end(), value); + if (bucket_iter == bucket_boundaries_vector.end()) { + --bucket_iter; + } + const int32 bucket = static_cast<int32>( + bucket_iter - bucket_boundaries_vector.begin()); + // Bucket id. + output(instance, 0) = bucket; + } + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_features_; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_features_, + kCostPerUnit, do_quantile_get_quantiles); + } + + private: + int64 num_features_; +}; + +REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU), + BoostedTreesBucketizeOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD index 3163c63949..12d9473776 100644 --- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD +++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD @@ -1,5 +1,5 @@ # Description: -# This directory contains common utilities used in boosted_trees. +# This directory contains common quantile utilities used in boosted_trees. package( default_visibility = ["//tensorflow:internal"], ) @@ -16,6 +16,7 @@ cc_library( name = "weighted_quantiles", srcs = [], hdrs = [ + "quantile_stream_resource.h", "weighted_quantiles_buffer.h", "weighted_quantiles_stream.h", "weighted_quantiles_summary.h", @@ -23,6 +24,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h new file mode 100644 index 0000000000..1c31724272 --- /dev/null +++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h @@ -0,0 +1,96 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ +#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ + +#include <vector> +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +using QuantileStream = + boosted_trees::quantiles::WeightedQuantilesStream<float, float>; + +// Quantile Stream Resource for a list of streams sharing the same number of +// quantiles, maximum elements, and epsilon. +class BoostedTreesQuantileStreamResource : public ResourceBase { + public: + BoostedTreesQuantileStreamResource(const float epsilon, + const int64 max_elements, + const int64 num_streams) + : are_buckets_ready_(false), + epsilon_(epsilon), + num_streams_(num_streams), + max_elements_(max_elements) { + streams_.reserve(num_streams_); + boundaries_.reserve(num_streams_); + for (int64 idx = 0; idx < num_streams; ++idx) { + streams_.push_back(QuantileStream(epsilon, max_elements)); + boundaries_.push_back(std::vector<float>()); + } + } + + string DebugString() override { return "QuantileStreamResource"; } + + tensorflow::mutex* mutex() { return &mu_; } + + QuantileStream* stream(const int64 index) { return &streams_[index]; } + + const std::vector<float>& boundaries(const int64 index) { + return boundaries_[index]; + } + + void set_boundaries(const std::vector<float>& boundaries, const int64 index) { + boundaries_[index] = boundaries; + } + + float epsilon() const { return epsilon_; } + int64 num_streams() const { return num_streams_; } + + bool are_buckets_ready() const { return are_buckets_ready_; } + void set_buckets_ready(const bool are_buckets_ready) { + are_buckets_ready_ = are_buckets_ready; + } + + private: + ~BoostedTreesQuantileStreamResource() override {} + + // Mutex for the whole resource. + tensorflow::mutex mu_; + + // Quantile streams. + std::vector<QuantileStream> streams_; + + // Stores the boundaries. Same size as streams_. + std::vector<std::vector<float>> boundaries_; + + // Whether boundaries are created. Initially boundaries are empty until + // set_boundaries are called. + bool are_buckets_ready_; + + const float epsilon_; + const int64 num_streams_; + // An upper-bound for the number of elements. + int64 max_elements_; + + TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc index cc90bb2f45..2798722536 100644 --- a/tensorflow/core/kernels/boosted_trees/resources.cc +++ b/tensorflow/core/kernels/boosted_trees/resources.cc @@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node( DCHECK_LT(tree_id, tree_ensemble_->trees_size()); DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size()); const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id); - DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit); - const auto& split = node.bucketized_split(); - if (bucketized_features[split.feature_id()](index_in_batch) <= - split.threshold()) { - return split.left_id(); - } else { - return split.right_id(); + + switch (node.node_case()) { + case boosted_trees::Node::kBucketizedSplit: { + const auto& split = node.bucketized_split(); + return (bucketized_features[split.feature_id()](index_in_batch) <= + split.threshold()) + ? split.left_id() + : split.right_id(); + } + case boosted_trees::Node::kCategoricalSplit: { + const auto& split = node.categorical_split(); + return (bucketized_features[split.feature_id()](index_in_batch) == + split.value()) + ? split.left_id() + : split.right_id(); + } + default: + DCHECK(false) << "Node type " << node.node_case() << " not supported."; } + return -1; } float BoostedTreesEnsembleResource::node_value(const int32 tree_id, diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index a7836896c7..390db8fe5a 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -51,9 +51,11 @@ class ConditionalAccumulator // dtype: The datatype of the gradients to be accumulated. // shape: The shape of the accumulated gradients. // name: A name to use for the ConditionalAccumulator. + // reduction_type: The reduction type, i.e., MEAN or SUM ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {} + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name, + reduction_type) {} ~ConditionalAccumulator() override{}; protected: diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index 90593c56b8..292cf0cd64 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/conditional_accumulator_base.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { ConditionalAccumulatorBase::ConditionalAccumulatorBase( - const DataType& dtype, const PartialTensorShape& shape, const string& name) - : dtype_(dtype), shape_(shape), name_(name) { + const DataType& dtype, const PartialTensorShape& shape, const string& name, + const string& reduction_type) + : dtype_(dtype), + shape_(shape), + name_(name), + reduction_type_(reduction_type) { counter_ = 0; current_global_step_ = 0; } @@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, current_global_step_++; // Average the accumulated gradient - DivideAccumGradByCounter(ctx); + if (reduction_type_ == "MEAN") { + DivideAccumGradByCounter(ctx); + } // Set output for accumulated gradient tensor bool successful_set_output = SetOutput(ctx); diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index b7b7482a00..4a5ec6f0fb 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase { // name: A name to use for the ConditionalAccumulator. ConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name); + const string& name, const string& reduction_type); typedef AsyncOpKernel::DoneCallback DoneCallback; @@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase { const DataType dtype_; const PartialTensorShape shape_; const string name_; + const string reduction_type_; mutex mu_; int counter_ GUARDED_BY(mu_); int64 current_global_step_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 012a0dcc12..ca24d690f8 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel { &accumulator_handle_, nullptr)); OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, + context->GetAttr("reduction_type", &reduction_type_)); } void Compute(OpKernelContext* ctx) override { @@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { DataType dtype_; PartialTensorShape shape_; ContainerInfo cinfo_; + string reduction_type_; private: Status SetAccumulatorHandle(OpKernelContext* ctx) diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index e13bf8a4c6..52ac51a9b6 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { ConditionalAccumulator<Device, T>* accumulator = - new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name()); + new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(), + reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index de9b69828e..639c3062cc 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -137,17 +137,16 @@ struct MatMulConvFunctor { } }; -// Shuffles a filter tensor from: -// [<spatial_dims>, in, out] -// to: -// [out, in, <spatial_dims>] +// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format. +// +// Note: Currently OIHW is the only supported destination format. Support for +// OHWI format will be added in a follow-up change. template <typename Device, typename T, typename IndexType, int NDIMS> struct TransformFilter { - void operator()(const Device& d, + void operator()(const Device& d, FilterTensorFormat dst_filter_format, typename TTypes<T, NDIMS, IndexType>::ConstTensor in, typename TTypes<T, NDIMS, IndexType>::Tensor out) { - // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together - // to speed up the shuffle operation. + // Merge the spatial dimensions together to speed up the shuffle operation. Eigen::DSizes<IndexType, 3> merged_dims; merged_dims[0] = in.dimension(0); // spatial dimensions for (int i = 1; i < NDIMS - 2; ++i) { @@ -156,16 +155,30 @@ struct TransformFilter { merged_dims[1] = in.dimension(NDIMS - 2); // input filters merged_dims[2] = in.dimension(NDIMS - 1); // output filters + CHECK(dst_filter_format == FORMAT_OIHW) + << "Unsupported destination filter format: " + << ToString(dst_filter_format); + // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged + // in the beginning. + Eigen::DSizes<IndexType, 3> shuffling_perm = + Eigen::DSizes<IndexType, 3>(2, 1, 0); + Eigen::DSizes<IndexType, NDIMS> expanded_dims; - expanded_dims[0] = in.dimension(NDIMS - 1); // output filters - expanded_dims[1] = in.dimension(NDIMS - 2); // input filters - for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions - expanded_dims[i + 2] = in.dimension(i); + int out_index = 0; + for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) { + if (shuffling_perm[merged_dim] == 0) { + for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) { + expanded_dims[out_index++] = in.dimension(spatial_dim); + } + } else { + constexpr int kLastSpatialDim = NDIMS - 3; + expanded_dims[out_index++] = + in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]); + } } - out.device(d) = in.reshape(merged_dims) - .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0)) - .reshape(expanded_dims); + out.device(d) = + in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims); } }; @@ -282,7 +295,9 @@ struct SwapDimension0And2InTensor3 { const gtl::ArraySlice<int64>& input_dims, T* out); }; -// Reverses the effect of TransformFilter above. +// Transforms back filter from OIHW to HWOI format to reverse effect of +// TransformFilter above. +// TODO(hinsu): Support reverse transformation from filter format OHWI as well. template <typename Device, typename T, int NDIMS> struct ReverseTransformFilter { void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h index 02e3655ad1..b819c6f910 100644 --- a/tensorflow/core/kernels/conv_3d.h +++ b/tensorflow/core/kernels/conv_3d.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_CONV_3D_H_ #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" #include "tensorflow/core/kernels/eigen_cuboid_convolution.h" namespace tensorflow { @@ -28,6 +29,14 @@ namespace functor { template <typename Device, typename T> struct CuboidConvolution; +// Backward input pass for the cuboid convolution. +template <typename Device, typename T> +struct CuboidConvolutionBackwardInput; + +// Backward filter pass for the cuboid convolution. +template <typename Device, typename T> +struct CuboidConvolutionBackwardFilter; + typedef Eigen::ThreadPoolDevice CPUDevice; template <typename T> @@ -42,6 +51,40 @@ struct CuboidConvolution<CPUDevice, T> { } }; +template <typename T> +struct CuboidConvolutionBackwardInput<CPUDevice, T> { + void operator()(const CPUDevice& d, + typename TTypes<T, 5>::Tensor input_backward, + typename TTypes<T, 5>::ConstTensor filter, + typename TTypes<T, 5>::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, + input_backward.dimension(3), // input_planes + input_backward.dimension(2), // input_rows + input_backward.dimension(1), // input_cols + stride_cols, stride_rows, stride_planes); + } +}; + +template <typename T> +struct CuboidConvolutionBackwardFilter<CPUDevice, T> { + void operator()(const CPUDevice& d, + typename TTypes<T, 5>::Tensor filter_backward, + typename TTypes<T, 5>::ConstTensor input, + typename TTypes<T, 5>::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel( + input, output_backward, + filter_backward.dimension(2), // kernel_planes + filter_backward.dimension(1), // kernel_rows + filter_backward.dimension(0), // kernel_cols + stride_cols, stride_rows, stride_planes); + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 63b1bcda43..9e86a16b66 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -1018,7 +1018,8 @@ namespace functor { extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \ template <> \ void TransformFilter<GPUDevice, T, int, 4>::operator()( \ - const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 4, int>::ConstTensor in, \ typename TTypes<T, 4, int>::Tensor out); \ extern template struct TransformFilter<GPUDevice, T, int, 4>; \ template <> \ diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index d664a11e73..43bb5ea56c 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -901,7 +901,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( &transformed_filter)); functor::TransformFilter<GPUDevice, T, int, 4>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), + ctx->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 4>()), To32Bit(transformed_filter.tensor<T, 4>())); Tensor transformed_out_backprop; @@ -1090,7 +1091,8 @@ namespace functor { extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \ template <> \ void TransformFilter<GPUDevice, T, int, 4>::operator()( \ - const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 4, int>::ConstTensor in, \ typename TTypes<T, 4, int>::Tensor out); \ extern template struct TransformFilter<GPUDevice, T, int, 4>; \ template <> \ diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index fc0a2f123f..507720c998 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -41,6 +41,17 @@ limitations under the License. namespace tensorflow { +// Compute padding for the given spatial dimension. +int ConvBackpropDimensions::SpatialPadding(const Padding& padding, + int dim) const { + return (padding == VALID) + ? 0 + : std::max<int>( + 0, static_cast<int>((output_size(dim) - 1) * stride(dim) + + (filter_size(dim) - 1) * dilation(dim) + + 1 - input_size(dim))); +} + // The V2 version computes windowed output size with arbitrary dilation_rate, // while the original version only handles the cases where dilation_rates equal // to 1. diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h index 535586d53a..9551959463 100644 --- a/tensorflow/core/kernels/conv_grad_ops.h +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -234,6 +234,16 @@ struct ConvBackpropDimensions { // Input and output feature depth. int64 in_depth, out_depth; + + // Convenience access methods for spatial dimensions properties. + int64 input_size(int dim) const { return spatial_dims[dim].input_size; } + int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; } + int64 output_size(int dim) const { return spatial_dims[dim].output_size; } + int64 stride(int dim) const { return spatial_dims[dim].stride; } + int64 dilation(int dim) const { return spatial_dims[dim].dilation; } + + // Compute padding for the given spatial dimension. + int SpatialPadding(const Padding& padding, int dim) const; }; // Common code between implementations of Conv?DBackpropInput and diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 15f1bf9aba..bab91f5e86 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -32,111 +33,130 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; #endif +namespace { + +// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and +// conv_grad_input_ops_3d.cc. + +// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels. + +// "Depth" is already used for the channel dimension, so for the third spatial +// dimension in this file we use "plane", although in NDHWC layout it's +// indicated with a "D". + +// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage +// order (planes, height, width, depth), constructed from patches in 'col_data', +// which is required to be in storage order (out_planes * out_height * +// out_width, filter_planes, filter_height, filter_width, in_depth). +// +// Based on 2-dimensional implementation written by Yangqing Jia (jiayq). +template <typename T> +void Col2im(const T* col_data, const int depth, const int planes, + const int height, const int width, const int filter_p, + const int filter_h, const int filter_w, const int pad_pt, + const int pad_t, const int pad_l, const int pad_pb, const int pad_b, + const int pad_r, const int stride_p, const int stride_h, + const int stride_w, T* im_data) { + const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; + const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + int p_pad = -pad_pt; + for (int p = 0; p < planes_col; ++p) { + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + T* im_patch_data = + im_data + (p_pad * height * width + h_pad * width + w_pad) * depth; + for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && + iw < width) { + for (int i = 0; i < depth; ++i) { + im_patch_data[i] += col_data[i]; + } + } + im_patch_data += depth; + col_data += depth; + } + // Jump over remaining number of depth. + im_patch_data += depth * (width - filter_w); + } + // Jump over remaining number of (depth * width). + im_patch_data += (depth * width) * (height - filter_h); + } + w_pad += stride_w; + } + h_pad += stride_h; + } + p_pad += stride_p; + } +} + +// Returns in 'col_data', image patches in storage order (planes, height, width, +// depth) extracted from image at 'input_data', which is required to be in +// storage order (batch, planes, height, width, depth). +// +// Based on 2-dimensional implementation written by Yangqing Jia (jiayq). +template <typename T> +void Im2col(const T* input_data, const int depth, const int planes, + const int height, const int width, const int filter_p, + const int filter_h, const int filter_w, const int pad_pt, + const int pad_t, const int pad_l, const int pad_pb, const int pad_b, + const int pad_r, const int stride_p, const int stride_h, + const int stride_w, T* col_data) { + const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; + const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + + int p_pad = -pad_pt; + for (int p = 0; p < planes_col; ++p) { + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && + iw < width) { + memcpy(col_data, + input_data + + (ip * height * width + ih * width + iw) * depth, + sizeof(T) * depth); + } else { + // This should be simply padded with zero. + memset(col_data, 0, sizeof(T) * depth); + } + col_data += depth; + } + } + } + w_pad += stride_w; + } + h_pad += stride_h; + } + p_pad += stride_p; + } +} + +} // namespace + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -// TODO(mjanusz): Get rid of the macro and return shapes directly. -#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ - const Tensor& out_backprop = context->input(2); \ - OP_REQUIRES( \ - context, input_shape.dims() == 5, \ - errors::InvalidArgument(label, ": input must be 5-dimensional")); \ - OP_REQUIRES( \ - context, filter_shape.dims() == 5, \ - errors::InvalidArgument(label, ": filter must be 5-dimensional")); \ - OP_REQUIRES( \ - context, out_backprop.dims() == 5, \ - errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \ - const int64 batch = input_shape.dim_size(0); \ - OP_REQUIRES( \ - context, batch == out_backprop.dim_size(0), \ - errors::InvalidArgument( \ - label, ": input and out_backprop must have the same batch size")); \ - const std::array<int64, 3> input_size = { \ - {GetTensorDim(input_shape, data_format_, '0'), \ - GetTensorDim(input_shape, data_format_, '1'), \ - GetTensorDim(input_shape, data_format_, '2')}}; \ - const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \ - const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \ - filter_shape.dim_size(1), \ - filter_shape.dim_size(2)}}; \ - const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \ - const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \ - const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \ - OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \ - errors::InvalidArgument( \ - label, ": input and filter must have the same depth")); \ - const int64 out_depth = filter_shape.dim_size(4); \ - OP_REQUIRES( \ - context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \ - errors::InvalidArgument( \ - label, ": filter and out_backprop must have the same out_depth")); \ - const std::array<int64, 3> dilations = { \ - {GetTensorDim(dilation_, data_format_, '0'), \ - GetTensorDim(dilation_, data_format_, '1'), \ - GetTensorDim(dilation_, data_format_, '2')}}; \ - const std::array<int64, 3> strides = { \ - {GetTensorDim(stride_, data_format_, '0'), \ - GetTensorDim(stride_, data_format_, '1'), \ - GetTensorDim(stride_, data_format_, '2')}}; \ - std::array<int64, 3> out, padding; \ - OP_REQUIRES_OK( \ - context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \ - padding_, &out, &padding)); \ - OP_REQUIRES(context, output_planes == out[0], \ - errors::InvalidArgument( \ - label, \ - ": Number of planes of out_backprop doesn't match " \ - "computed: actual = ", \ - output_planes, ", computed = ", out[0])); \ - OP_REQUIRES( \ - context, output_rows == out[1], \ - errors::InvalidArgument( \ - label, ": Number of rows of out_backprop doesn't match computed: ", \ - "actual = ", output_rows, ", computed = ", out[1])); \ - OP_REQUIRES( \ - context, output_cols == out[2], \ - errors::InvalidArgument( \ - label, ": Number of cols of out_backprop doesn't match computed: ", \ - "actual = ", output_cols, ", computed = ", out[2])); \ - const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \ - const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \ - const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \ - const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \ - const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \ - const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \ - const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \ - const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \ - const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \ - const auto bottom_pad_planes = \ - padded_out_planes - expanded_out_planes - top_pad_planes; \ - const auto bottom_pad_rows = \ - padded_out_rows - expanded_out_rows - top_pad_rows; \ - const auto right_pad_cols = \ - padded_out_cols - expanded_out_cols - left_pad_cols; \ - VLOG(2) << "Conv3d: " << label \ - << ": expanded_out_planes = " << expanded_out_planes \ - << ": expanded_out_rows = " << expanded_out_rows \ - << ", expanded_out_cols = " << expanded_out_cols \ - << ", padded_out_planes = " << padded_out_planes \ - << ", padded_out_rows = " << padded_out_rows \ - << ", padded_out_cols = " << padded_out_cols \ - << ", top_pad_planes = " << top_pad_planes \ - << ", top_pad_rows = " << top_pad_rows \ - << ", left_pad_cols = " << left_pad_cols \ - << ", bottom_pad_planes = " << bottom_pad_planes \ - << ", bottom_pad_rows = " << bottom_pad_rows \ - << ", right_pad_cols = " << right_pad_cols - -// Backprop for input. +// Backprop for input that offloads computation to +// Eigen::CuboidConvolutionBackwardInput. template <typename Device, class T> class Conv3DBackpropInputOp : public OpKernel { public: @@ -192,6 +212,116 @@ class Conv3DBackpropInputOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& filter = context->input(1); const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape input_shape; + if (takes_shape_) { + const Tensor& input_sizes = context->input(0); + // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. + OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); + } else { + input_shape = context->input(0).shape(); + } + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, + stride_, padding_, data_format_, &dims)); + + Tensor* in_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + + functor::CuboidConvolutionBackwardInput<Device, T>()( + context->eigen_device<Device>(), + in_backprop->tensor<T, 5>(), // input_backward + filter.tensor<T, 5>(), // filter + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + } + + private: + std::vector<int32> dilation_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp); +}; + +// Custom backprop for input that explicitly does the work sharding and calls +// Eigen only to multiply matrices. +template <typename Device, class T> +class Conv3DCustomBackpropInputOp : public OpKernel { + // Limit the maximum size of allocated temporary buffer to + // kMaxTempAllocationOverhead times the size of the input tensors (input, + // filter, out_backprop). If the size of the temporary buffer exceeds this + // limit, fallback on Eigen implementation. + static constexpr int kMaxTempAllocationOverhead = 25; + + public: + explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.")); + } + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + + // TODO(yangzihao): Add CPU version of dilated conv 3D. + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, '0') == 1 && + GetTensorDim(dilation_, data_format_, '1') == 1 && + GetTensorDim(dilation_, data_format_, '2') == 1), + errors::InvalidArgument( + "Current CPU implementation does not yet support " + "dilation rates larger than 1.")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filter = context->input(1); + const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); @@ -200,51 +330,239 @@ class Conv3DBackpropInputOp : public OpKernel { } else { input_shape = context->input(0).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, + stride_, padding_, data_format_, &dims)); + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); - // Fill out a padded out_backprop. - TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows, - padded_out_cols, out_depth}); - Tensor padded_output; + int64 top_pad_planes, bottom_pad_planes; + int64 top_pad_rows, bottom_pad_rows; + int64 left_pad_cols, right_pad_cols; + + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, + &top_pad_planes, &bottom_pad_planes)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, + dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, + &top_pad_rows, &bottom_pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, + dims.spatial_dims[2].filter_size, + dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, + &left_pad_cols, &right_pad_cols)); + + // TODO(ezhulenev): Extract work size and shard estimation to shared + // functions in conv_grad_ops, and update 2d convolution backprop. + + // The total dimension size of each kernel. + const int64 filter_total_size = + dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * + dims.spatial_dims[2].filter_size * dims.in_depth; + + // The output image size is the spatial size of the output. + const int64 output_image_size = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size; + + const auto cache_sizes = Eigen::internal::CacheSizes(); + const ptrdiff_t l3_cache_size = cache_sizes.m_l3; + + // Use L3 cache size as target working set size. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + // Calculate size of matrices involved in MatMul: C = A x B. + const int64 size_A = output_image_size * dims.out_depth; + + const int64 size_B = filter_total_size * dims.out_depth; + + const int64 size_C = output_image_size * filter_total_size; + + const int64 work_unit_size = size_A + size_B + size_C; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Use parallel tensor contractions if there is no batching. + // + // Compared to Conv2D code, this version is missing work size estimation. In + // benchmarks I didn't find a case when it's beneficial to run parallel + // contraction compared to sharding and matmuls. + const bool use_parallel_contraction = dims.batch_size == 1; + + const size_t shard_size = + use_parallel_contraction + ? 1 + : (target_working_set_size + work_unit_size - 1) / work_unit_size; + + // Total number of elements in all the tensors used by this kernel. + int64 total_tensor_elements = input_shape.num_elements() + + filter_shape.num_elements() + + out_backprop_shape.num_elements(); + + // Shape of the temporary workspace buffer. + TensorShape col_buffer_shape = {static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}; + int64 col_buffer_elements = col_buffer_shape.num_elements(); + + // If the temporary allocation overhead is too large, fallback on Eigen + // implementation which requires much less memory. + int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; + if (col_buffer_overhead > kMaxTempAllocationOverhead) { + VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: " + "col_buffer_overhead=" + << col_buffer_overhead; + + functor::CuboidConvolutionBackwardInput<Device, T>()( + context->eigen_device<Device>(), + in_backprop->tensor<T, 5>(), // input_backward + filter.tensor<T, 5>(), // filter + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + + return; + } + + Tensor col_buffer; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4}; - Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), - eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>()); - const Tensor& padded_output_cref = padded_output; - - // Fill a new "reverted" filter. We need to transpose the in_depth and - // out_depth for the filter and reverse the planes, rows and cols. - TensorShape r_filter_shape( - {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth}); - Tensor r_filter; - OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(), - r_filter_shape, &r_filter)); - Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3}; - Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order, - filter_rev_dims, r_filter.tensor<T, 5>()); - const Tensor& r_filter_cref = r_filter; - - // Now we can call conv_3d directly. - functor::CuboidConvolution<Device, T>()( - context->eigen_device<Device>(), in_backprop->tensor<T, 5>(), - padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); + context->allocate_temp(DataTypeToEnum<T>::value, + col_buffer_shape, &col_buffer)); + + // The input offset corresponding to a single input image. + const int64 input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * + dims.spatial_dims[2].input_size * dims.in_depth; + + // The output offset corresponding to a single output image. + const int64 output_offset = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size * dims.out_depth; + + const T* filter_data = filter.template flat<T>().data(); + T* col_buffer_data = col_buffer.template flat<T>().data(); + const T* out_backprop_data = out_backprop.template flat<T>().data(); + + auto in_backprop_flat = in_backprop->template flat<T>(); + T* input_backprop_data = in_backprop_flat.data(); + in_backprop_flat.device(context->eigen_device<Device>()) = + in_backprop_flat.constant(T(0)); + + if (use_parallel_contraction) { + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + ConstTensorMap; + + // Initialize contraction dims (we need to transpose 'B' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 1; + contract_dims[0].second = 1; + + for (int image_id = 0; image_id < dims.batch_size; ++image_id) { + // Compute gradient into col_buffer. + TensorMap C(col_buffer_data, output_image_size, filter_total_size); + + ConstTensorMap A(out_backprop_data + output_offset * image_id, + output_image_size, dims.out_depth); + ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); + + C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); + + Col2im<T>(col_buffer_data, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + input_backprop_data); + + input_backprop_data += input_offset; + } + } else { + typedef Eigen::Map< + Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> + MatrixMap; + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, + Eigen::RowMajor>> + ConstMatrixMap; + + for (int image_id = 0; image_id < dims.batch_size; + image_id += shard_size) { + const int shard_limit = + std::min(static_cast<int>(shard_size), + static_cast<int>(dims.batch_size) - image_id); + + auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols, + &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols, + &output_image_size, &filter_total_size, + &input_backprop_data, &col_buffer_data, + &out_backprop_data, &filter_data, &input_offset, + &output_offset, &size_C](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + T* im2col_buf = col_buffer_data + shard_id * size_C; + T* input_data = input_backprop_data + shard_id * input_offset; + const T* out_data = out_backprop_data + shard_id * output_offset; + + // Compute gradient into 'im2col_buf'. + MatrixMap C(im2col_buf, output_image_size, filter_total_size); + + ConstMatrixMap A(out_data, output_image_size, dims.out_depth); + ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); + + C.noalias() = A * B.transpose(); + + Col2im<T>(im2col_buf, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + input_data); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + work_unit_size, shard); + + input_backprop_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } + } } private: @@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel { Padding padding_; TensorFormat data_format_; bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp); }; +// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than +// default Eigen implementation (at the cost of ~2x-8x peak memory usage). + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropInputOp<CPUDevice, T>); \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropInputOp<CPUDevice, T>); + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<CPUDevice, T>); + TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL -// Backprop for filter. +// Backprop for filter that offloads computation to +// Eigen::CuboidConvolutionBackwardFilter. template <typename Device, class T> class Conv3DBackpropFilterOp : public OpKernel { public: @@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); - TensorShape filter_shape; + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( @@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel { filter_shape = context->input(1).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensions( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); + Tensor* filter_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); @@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel { return; } - // For the backprop of the filter, we need to also transpose the - // out_backprop. - // The shape of backprop is - // [batch, out_z, out_y, out_x, out_depth] - // And we need to change it to - // [out_depth, out_x, out_y, out_z, batch] - Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0}; - TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows, - padded_out_cols, batch}); - Tensor padded_output; + functor::CuboidConvolutionBackwardFilter<Device, T>()( + context->eigen_device<Device>(), + filter_backprop->tensor<T, 5>(), // filter_backward + input.tensor<T, 5>(), // input + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + } + + private: + std::vector<int32> dilation_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp); +}; + +// Custom backprop for filter that explicitly does the work sharding and calls +// Eigen only to multiply matrices. +template <typename Device, class T> +class Conv3DCustomBackpropFilterOp : public OpKernel { + // Limit the maximum size of allocated temporary buffer to + // kMaxTempAllocationOverhead times the size of the input tensors (input, + // filter, out_backprop). If the size of the temporary buffer exceeds this + // limit, fallback on Eigen implementation. + static constexpr int kMaxTempAllocationOverhead = 25; + + public: + explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU.")); + } + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + + // TODO(yangzihao): Add CPU version of dilated conv 3D. + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, '0') == 1 && + GetTensorDim(dilation_, data_format_, '1') == 1 && + GetTensorDim(dilation_, data_format_, '2') == 1), + errors::InvalidArgument( + "Current CPU implementation does not yet support " + "dilation rates larger than 1.")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const TensorShape& input_shape = input.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape filter_shape; + if (takes_shape_) { + const Tensor& filter_sizes = context->input(1); + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + filter_sizes.vec<int32>(), &filter_shape)); + } else { + filter_shape = context->input(1).shape(); + } + + ConvBackpropDimensions dims; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), - eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>()); - const Tensor& padded_output_cref = padded_output; - - // For the backprop of the filter, we need to transpose the input. - // The shape of input is - // [batch, in_z, in_y, in_x, in_depth] - // And we need to change it to - // [in_z, in_y, in_x, batch, in_depth] - Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4}; - TensorShape in_shuffle_shape( - {input_size[0], input_size[1], input_size[2], batch, in_depth}); - Tensor in_shuffle; + ConvBackpropComputeDimensions( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); + + Tensor* filter_backprop; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - in_shuffle_shape, &in_shuffle)); - // No need for reversing this time. - Eigen::array<bool, 5> no_reverse{false, false, false, false, false}; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), input.tensor<T, 5>(), in_order, - no_reverse, in_shuffle.tensor<T, 5>()); - const Tensor& in_shuffle_cref = in_shuffle; - - // The output of the conv_3d would be - // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth] - // and we need to shuffle it back to - // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth]; - // And we need to reverse the filter backprops. - // So we need to allocate (sigh) yet another piece of memory to hold the - // output. - TensorShape filter_shuffle_shape( - {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth}); - Tensor filter_shuffle; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum<T>::v(), - filter_shuffle_shape, &filter_shuffle)); - functor::CuboidConvolution<Device, T>()( - context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(), - padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); - - // Now copy the filter_backprop back to the destination. - Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0}; - Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; - const Tensor& filter_shuffle_cref = filter_shuffle; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(), - filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>()); + context->allocate_output(0, filter_shape, &filter_backprop)); + + if (input_shape.num_elements() == 0) { + filter_backprop->template flat<T>().setZero(); + return; + } + + int64 top_pad_planes, bottom_pad_planes; + int64 top_pad_rows, bottom_pad_rows; + int64 left_pad_cols, right_pad_cols; + + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, + &top_pad_planes, &bottom_pad_planes)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, + dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, + &top_pad_rows, &bottom_pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, + dims.spatial_dims[2].filter_size, + dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, + &left_pad_cols, &right_pad_cols)); + + // TODO(ezhulenev): Extract work size and shard estimation to shared + // functions in conv_grad_ops, and update 2d convolution backprop. + + // The total dimension size of each kernel. + const int64 filter_total_size = + dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * + dims.spatial_dims[2].filter_size * dims.in_depth; + // The output image size is the spatial size of the output. + const int64 output_image_size = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size; + + // Shard 'batch' images (volumes) into 'shard_size' groups of images + // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by + // dividing the L3 cache size ('target_working_set_size') by the matmul size + // of an individual image ('work_unit_size'). + + const auto cache_sizes = Eigen::internal::CacheSizes(); + const ptrdiff_t l3_cache_size = cache_sizes.m_l3; + + // TODO(andydavis) + // *) Consider reducing 'target_working_set_size' if L3 is shared by + // other concurrently running tensorflow ops. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + const int64 size_A = output_image_size * filter_total_size; + + const int64 size_B = output_image_size * dims.out_depth; + + const int64 size_C = filter_total_size * dims.out_depth; + + const int64 work_unit_size = size_A + size_B + size_C; + + const size_t shard_size = + (target_working_set_size + work_unit_size - 1) / work_unit_size; + + // Total number of elements in all the tensors used by this kernel. + int64 total_tensor_elements = input_shape.num_elements() + + filter_shape.num_elements() + + out_backprop_shape.num_elements(); + + // Shape of the temporary workspace buffer. + TensorShape col_buffer_shape = {static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}; + int64 col_buffer_elements = col_buffer_shape.num_elements(); + + // If the temporary allocation overhead is too large, fallback on Eigen + // implementation which requires much less memory. + int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; + if (col_buffer_overhead > kMaxTempAllocationOverhead) { + VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: " + "col_buffer_overhead=" + << col_buffer_overhead; + + functor::CuboidConvolutionBackwardFilter<Device, T>()( + context->eigen_device<Device>(), + filter_backprop->tensor<T, 5>(), // filter_backward + input.tensor<T, 5>(), // input + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + + return; + } + + Tensor col_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::value, + col_buffer_shape, &col_buffer)); + + // The input offset corresponding to a single input image. + const int64 input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * + dims.spatial_dims[2].input_size * dims.in_depth; + // The output offset corresponding to a single output image. + const int64 output_offset = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size * dims.out_depth; + + const T* input_data = input.template flat<T>().data(); + T* col_buffer_data = col_buffer.template flat<T>().data(); + const T* out_backprop_data = out_backprop.template flat<T>().data(); + T* filter_backprop_data = filter_backprop->template flat<T>().data(); + + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + ConstTensorMap; + + TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); + C.setZero(); + + // Initialize contraction dims (we need to transpose 'A' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 0; + contract_dims[0].second = 0; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { + const int shard_limit = + std::min(static_cast<int>(shard_size), + static_cast<int>(dims.batch_size) - image_id); + + auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes, + &top_pad_rows, &left_pad_cols, &bottom_pad_planes, + &bottom_pad_rows, &right_pad_cols, &input_offset, + &size_A](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + const T* input_data_shard = input_data + shard_id * input_offset; + T* col_data_shard = col_buffer_data + shard_id * size_A; + + // When we compute the gradient with respect to the filters, we need + // to do im2col to allow gemm-type computation. + Im2col<T>(input_data_shard, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + col_data_shard); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + size_A, shard); + + ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, + filter_total_size); + ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, + dims.out_depth); + + // Gradient with respect to filter. + C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); + + input_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } } private: @@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel { Padding padding_; TensorFormat data_format_; bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp); }; +// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than +// default Eigen implementation (at the cost of ~2x-8x peak memory usage). + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropFilterOp<CPUDevice, T>); \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); \ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ .TypeConstraint<T>("T"), \ Conv3DBackpropFilterOp<CPUDevice, T>); -TF_CALL_half(REGISTER_CPU_KERNEL); + TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL +// WARNING: Eigen::half is not trivially copyable and can't be used in +// custom backprop filter kernel because of memcpy and memset in Im2col. +#define REGISTER_CPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); + +TF_CALL_half(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL + // GPU definitions of both ops. #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. @@ -445,7 +1054,8 @@ namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void TransformFilter<GPUDevice, T, int, 5>::operator()( \ - const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 5, int>::ConstTensor in, \ typename TTypes<T, 5, int>::Tensor out); \ template <> \ void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ @@ -523,6 +1133,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& filter = context->input(1); const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); @@ -531,7 +1145,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { } else { input_shape = context->input(0).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, dilation_, + stride_, padding_, data_format_, &dims)); + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); @@ -539,13 +1160,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 && - dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 && - stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 && + if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 && + dims.filter_size(2) == 1 && dims.dilation(0) == 1 && + dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 && + dims.stride(1) == 1 && dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = batch * input_size[0] * input_size[1] * input_size[2]; - const uint64 k = out_depth; - const uint64 n = in_depth; + const uint64 m = dims.batch_size * dims.input_size(0) * + dims.input_size(1) * dims.input_size(2); + const uint64 k = dims.out_depth; + const uint64 n = dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), out_backprop.template flat<T>().size()); @@ -567,13 +1190,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (filter_size[0] == input_size[0] && - filter_size[1] == input_size[1] && - filter_size[2] == input_size[2] && padding_ == Padding::VALID && - data_format_ == FORMAT_NHWC) { - const uint64 m = batch; - const uint64 k = out_depth; - const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth; + } else if (dims.filter_size(0) == dims.input_size(0) && + dims.filter_size(1) == dims.input_size(1) && + dims.filter_size(2) == dims.input_size(2) && + padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { + const uint64 m = dims.batch_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.input_size(0) * dims.input_size(1) * + dims.input_size(2) * dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), out_backprop.template flat<T>().size()); @@ -597,65 +1221,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { return; } - int padding_rows = 0, padding_cols = 0, padding_planes = 0; - - if (padding_ == Padding::SAME) { - padding_planes = std::max<int>( - 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); - padding_cols = std::max<int>( - 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); - padding_rows = std::max<int>( - 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); - } + int padding_planes = dims.SpatialPadding(padding_, 0); + int padding_rows = dims.SpatialPadding(padding_, 1); + int padding_cols = dims.SpatialPadding(padding_, 2); + const bool planes_odd = (padding_planes % 2 != 0); const bool rows_odd = (padding_rows % 2 != 0); const bool cols_odd = (padding_cols % 2 != 0); - const bool planes_odd = (padding_planes % 2 != 0); TensorShape compatible_input_shape; if (rows_odd || cols_odd || planes_odd) { // cuDNN only supports the same amount of padding on both sides. compatible_input_shape = { - batch, - in_depth, - input_size[0] + planes_odd, - input_size[1] + rows_odd, - input_size[2] + cols_odd, + dims.batch_size, + dims.in_depth, + dims.input_size(0) + planes_odd, + dims.input_size(1) + rows_odd, + dims.input_size(2) + cols_odd, }; } else { - compatible_input_shape = {batch, in_depth, input_size[0], input_size[1], - input_size[2]}; + compatible_input_shape = {dims.batch_size, dims.in_depth, + dims.input_size(0), dims.input_size(1), + dims.input_size(2)}; } CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(batch) - .set_spatial_dim(DimIndex::X, output_cols) - .set_spatial_dim(DimIndex::Y, output_rows) - .set_spatial_dim(DimIndex::Z, output_planes) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) - .set_spatial_dim(DimIndex::Y, filter_size[1]) - .set_spatial_dim(DimIndex::Z, filter_size[0]) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) - .set_dilation_rate(DimIndex::Y, dilations[1]) - .set_dilation_rate(DimIndex::Z, dilations[0]) - .set_filter_stride(DimIndex::X, strides[2]) - .set_filter_stride(DimIndex::Y, strides[1]) - .set_filter_stride(DimIndex::Z, strides[0]) + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2); @@ -664,20 +1282,23 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { Tensor transformed_filter; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({out_depth, in_depth, filter_size[0], - filter_size[1], filter_size[2]}), - &transformed_filter)); + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &transformed_filter)); functor::TransformFilter<GPUDevice, T, int, 5>()( - context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), + context->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 5>()), To32Bit(transformed_filter.tensor<T, 5>())); // Shape: batch, filters, z, y, x. Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, - output_cols}; - if (out_depth > 1) { + TensorShape nchw_shape = {dims.batch_size, dims.out_depth, + dims.output_size(0), dims.output_size(1), + dims.output_size(2)}; + if (dims.out_depth > 1) { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum<T>::value, nchw_shape, &transformed_out_backprop)); @@ -713,14 +1334,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { const int device_id = stream->parent()->device_ordinal(); DataType dtype = context->input(0).dtype(); const ConvParameters conv_parameters = { - batch, - in_depth, - {{input_size[0], input_size[1], input_size[2]}}, + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, FORMAT_NCHW, - out_depth, - {{filter_size[0], filter_size[1], filter_size[2]}}, - {{dilations[0], dilations[1], dilations[2]}}, - {{strides[0], strides[1], strides[2]}}, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, {{padding_planes, padding_rows, padding_cols}}, dtype, device_id, @@ -799,10 +1420,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { if (rows_odd || cols_odd || planes_odd) { Tensor in_backprop_remove_padding; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, - {batch, in_depth, input_size[0], - input_size[1], input_size[2]}, - &in_backprop_remove_padding)); + context->allocate_temp( + DataTypeToEnum<T>::value, + {dims.batch_size, dims.in_depth, dims.input_size(0), + dims.input_size(1), dims.input_size(2)}, + &in_backprop_remove_padding)); // Remove the padding for odd spatial dimensions. functor::PadInput<GPUDevice, T, int, 5>()( @@ -896,6 +1518,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); @@ -905,7 +1531,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { filter_shape = context->input(1).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, dilation_, + stride_, padding_, data_format_, &dims)); Tensor* filter_backprop; OP_REQUIRES_OK(context, @@ -914,13 +1545,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 && - dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 && - strides[2] == 1 && strides[1] == 1 && strides[0] == 1 && + if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && + dims.filter_size(0) == 1 && dims.dilation(2) == 1 && + dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 && + dims.stride(1) == 1 && dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = in_depth; - const uint64 k = batch * input_size[1] * input_size[2] * input_size[0]; - const uint64 n = out_depth; + const uint64 m = dims.in_depth; + const uint64 k = dims.batch_size * dims.input_size(1) * + dims.input_size(2) * dims.input_size(0); + const uint64 n = dims.out_depth; // The shape of output backprop is // [batch, out_z, out_y, out_x, out_depth] @@ -951,13 +1584,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (filter_size[0] == input_size[0] && - filter_size[1] == input_size[1] && - filter_size[2] == input_size[2] && padding_ == Padding::VALID && - data_format_ == FORMAT_NHWC) { - const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth; - const uint64 k = batch; - const uint64 n = out_depth; + } else if (dims.filter_size(0) == dims.input_size(0) && + dims.filter_size(1) == dims.input_size(1) && + dims.filter_size(2) == dims.input_size(2) && + padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { + const uint64 m = dims.input_size(0) * dims.input_size(1) * + dims.input_size(2) * dims.in_depth; + const uint64 k = dims.batch_size; + const uint64 n = dims.out_depth; auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), input.template flat<T>().size()); @@ -979,30 +1613,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { return; } - int padding_rows = 0, padding_cols = 0, padding_planes = 0; - - if (padding_ == Padding::SAME) { - padding_planes = std::max<int>( - 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); - padding_cols = std::max<int>( - 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); - padding_rows = std::max<int>( - 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); - } - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - bool planes_odd = (padding_planes % 2 != 0); + int padding_planes = dims.SpatialPadding(padding_, 0); + int padding_rows = dims.SpatialPadding(padding_, 1); + int padding_cols = dims.SpatialPadding(padding_, 2); + const bool planes_odd = (padding_planes % 2 != 0); + const bool rows_odd = (padding_rows % 2 != 0); + const bool cols_odd = (padding_cols % 2 != 0); Tensor compatible_input; if (rows_odd || cols_odd || planes_odd) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat(data_format_, batch, - {{input_size[0] + planes_odd, - input_size[1] + rows_odd, - input_size[2] + cols_odd}}, - in_depth), - &compatible_input)); + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum<T>::value, + ShapeFromFormat(data_format_, dims.batch_size, + {{dims.input_size(0) + planes_odd, + dims.input_size(1) + rows_odd, + dims.input_size(2) + cols_odd}}, + dims.in_depth), + &compatible_input)); functor::PadInput<GPUDevice, T, int, 5>()( context->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 5>()), {{0, 0, 0}}, @@ -1016,35 +1644,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, GetTensorDim(compatible_input, data_format_, '2')) .set_spatial_dim(DimIndex::Y, GetTensorDim(compatible_input, data_format_, '1')) .set_spatial_dim(DimIndex::Z, GetTensorDim(compatible_input, data_format_, '0')) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(batch) - .set_spatial_dim(DimIndex::X, output_cols) - .set_spatial_dim(DimIndex::Y, output_rows) - .set_spatial_dim(DimIndex::Z, output_planes) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) - .set_spatial_dim(DimIndex::Y, filter_size[1]) - .set_spatial_dim(DimIndex::Z, filter_size[0]) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) - .set_dilation_rate(DimIndex::Y, dilations[1]) - .set_dilation_rate(DimIndex::Z, dilations[0]) - .set_filter_stride(DimIndex::X, strides[2]) - .set_filter_stride(DimIndex::Y, strides[1]) - .set_filter_stride(DimIndex::Z, strides[0]) + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2); @@ -1052,19 +1680,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { Tensor pre_transformed_filter_backprop; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({out_depth, in_depth, filter_size[0], - filter_size[1], filter_size[2]}), - &pre_transformed_filter_backprop)); + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &pre_transformed_filter_backprop)); Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, - output_cols}; + TensorShape nchw_shape = {dims.batch_size, dims.out_depth, + dims.output_size(0), dims.output_size(1), + dims.output_size(2)}; OP_REQUIRES_OK( context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, &transformed_out_backprop)); - if (out_depth > 1) { + if (dims.out_depth > 1) { functor::NHWCToNCHW<GPUDevice, T, 5>()( context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), transformed_out_backprop.tensor<T, 5>()); @@ -1076,10 +1706,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { } Tensor transformed_input; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1), - compatible_input.dim_size(2), - compatible_input.dim_size(3)}; - if (in_depth > 1) { + TensorShape nchw_shape = { + dims.batch_size, dims.in_depth, compatible_input.dim_size(1), + compatible_input.dim_size(2), compatible_input.dim_size(3)}; + if (dims.in_depth > 1) { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, &transformed_input)); @@ -1110,14 +1740,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { const int device_id = stream->parent()->device_ordinal(); DataType dtype = input.dtype(); const ConvParameters conv_parameters = { - batch, - in_depth, - {{input_size[0], input_size[1], input_size[2]}}, + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, FORMAT_NCHW, - out_depth, - {{filter_size[0], filter_size[1], filter_size[2]}}, - {{dilations[0], dilations[1], dilations[2]}}, - {{strides[0], strides[1], strides[2]}}, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, {{padding_planes, padding_rows, padding_cols}}, dtype, device_id, diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index ef692418d6..78856c4a99 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> { }; #endif +#define TF_REQUIRES(EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \ + } while (false) + +Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params) { + TF_RETURN_IF_ERROR(context->GetAttr("dilations", ¶ms->dilations)); + TF_RETURN_IF_ERROR(context->GetAttr("strides", ¶ms->strides)); + TF_RETURN_IF_ERROR(context->GetAttr("padding", ¶ms->padding)); + string data_format_string; + TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string)); + TF_REQUIRES(FormatFromString(data_format_string, ¶ms->data_format), + errors::InvalidArgument("Invalid data format")); + + const auto& strides = params->strides; + const auto& dilations = params->dilations; + const auto& data_format = params->data_format; + + TF_REQUIRES(dilations.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + TF_REQUIRES(strides.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int64 stride_n = GetTensorDim(strides, data_format, 'N'); + const int64 stride_c = GetTensorDim(strides, data_format, 'C'); + const int64 stride_h = GetTensorDim(strides, data_format, 'H'); + const int64 stride_w = GetTensorDim(strides, data_format, 'W'); + TF_REQUIRES( + stride_n == 1 && stride_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + TF_REQUIRES(stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + + const int64 dilation_n = GetTensorDim(dilations, data_format, 'N'); + const int64 dilation_c = GetTensorDim(dilations, data_format, 'C'); + const int64 dilation_h = GetTensorDim(dilations, data_format, 'H'); + const int64 dilation_w = GetTensorDim(dilations, data_format, 'W'); + TF_REQUIRES( + dilation_n == 1 && dilation_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + TF_REQUIRES( + dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + + return Status::OK(); +} + +Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions) { + // Check that 2D convolution input and filter have exactly 4 dimensions. + TF_REQUIRES(input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + TF_REQUIRES(filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + for (int i = 0; i < 3; i++) { + TF_REQUIRES( + FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), + errors::InvalidArgument("filter too large")); + } + + // The last dimension for input is in_depth. Check that it is the same as the + // filter's in_depth or it is evenly divisible by filter's in_depth. + const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C'); + const int64 patch_depth_raw = filter.dim_size(2); + TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input depth too large")); + TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Patch depth too large")); + const int in_depth = static_cast<int>(in_depth_raw); + const int patch_depth = static_cast<int>(patch_depth_raw); + TF_REQUIRES(in_depth % patch_depth == 0, + errors::InvalidArgument( + "input depth must be evenly divisible by filter depth: ", + in_depth, " vs ", patch_depth)); + + // The last dimension for filter is out_depth. + const int out_depth = static_cast<int>(filter.dim_size(3)); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H'); + TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input rows too large")); + const int input_rows = static_cast<int>(input_rows_raw); + const int filter_rows = static_cast<int>(filter.dim_size(0)); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W'); + TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input cols too large")); + const int input_cols = static_cast<int>(input_cols_raw); + const int filter_cols = static_cast<int>(filter.dim_size(1)); + + // The first dimension for input is batch. + const int64 batch_raw = GetTensorDim(input, params.data_format, 'N'); + TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("batch is too large")); + const int batch = static_cast<int>(batch_raw); + + // Take the stride and dilation from the second and third dimensions only (we + // do not support striding or dilation on the batch or depth dimension). + const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H'); + const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W'); + const int dilation_rows = + GetTensorDim(params.dilations, params.data_format, 'H'); + const int dilation_cols = + GetTensorDim(params.dilations, params.data_format, 'W'); + + // Compute windowed output sizes for rows and columns. + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input_rows, filter_rows, dilation_rows, stride_rows, params.padding, + &out_rows, &pad_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input_cols, filter_cols, dilation_cols, stride_cols, params.padding, + &out_cols, &pad_cols)); + + dimensions->batch = batch; + dimensions->input_rows = input_rows; + dimensions->input_cols = input_cols; + dimensions->in_depth = in_depth; + dimensions->filter_rows = filter_rows; + dimensions->filter_cols = filter_cols; + dimensions->patch_depth = patch_depth; + dimensions->out_depth = out_depth; + dimensions->stride_rows = stride_rows; + dimensions->stride_cols = stride_cols; + dimensions->dilation_rows = dilation_rows; + dimensions->dilation_cols = dilation_cols; + dimensions->out_rows = out_rows; + dimensions->out_cols = out_cols; + dimensions->pad_rows = pad_rows; + dimensions->pad_cols = pad_cols; + + return Status::OK(); +} + +#undef TF_REQUIRES + template <typename Device, typename T> class Conv2DOp : public BinaryOp<T> { public: explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) { - OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_)); + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); - OP_REQUIRES(context, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); - const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); - const int64 stride_h = GetTensorDim(strides_, data_format_, 'H'); - const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); - OP_REQUIRES( - context, stride_n == 1 && stride_c == 1, - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES(context, stride_h > 0 && stride_w > 0, - errors::InvalidArgument( - "Row and column strides should be larger than 0.")); - - const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); - const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); - const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); - const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); - OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( - "Current implementation does not yet support " - "dilations in the batch and depth dimensions.")); - OP_REQUIRES( - context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } void Compute(OpKernelContext* context) override { // Input tensor is of the following dimensions: // [ batch, in_rows, in_cols, in_depth ] - const Tensor& input = context->input(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, in_depth, out_depth] const Tensor& filter = context->input(1); - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - for (int i = 0; i < 3; i++) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), - errors::InvalidArgument("filter too large")); - } + Conv2DDimensions dimensions; + OP_REQUIRES_OK(context, + ComputeConv2DDimension(params_, input, filter, &dimensions)); - // The last dimension for input is in_depth. It must be the same as the - // filter's in_depth or be evenly divisible by filter's in_depth. - const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - const int64 patch_depth = filter.dim_size(2); - OP_REQUIRES(context, in_depth % patch_depth == 0, - errors::InvalidArgument( - "input depth must be evenly divisible by filter depth: ", - in_depth, " vs ", patch_depth)); - - // The last dimension for filter is out_depth. - const int out_depth = static_cast<int>(filter.dim_size(3)); - - // The second dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input rows too large")); - const int input_rows = static_cast<int>(input_rows_raw); - const int filter_rows = static_cast<int>(filter.dim_size(0)); - - // The third dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input cols too large")); - const int input_cols = static_cast<int>(input_cols_raw); - const int filter_cols = static_cast<int>(filter.dim_size(1)); - - // The first dimension for input is batch. - const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES(context, - FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("batch is too large")); - const int batch = static_cast<int>(batch_raw); - - // For now we take the stride and dilation from the second and third - // dimensions only (we do not support striding or dilation on the batch or - // depth dimension). - const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - - const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); - const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( - input_rows, filter_rows, dilation_rows, - stride_rows, padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( - input_cols, filter_cols, dilation_cols, - stride_cols, padding_, &out_cols, &pad_cols)); - TensorShape out_shape = - ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + TensorShape out_shape = ShapeFromFormat( + params_.data_format, dimensions.batch, dimensions.out_rows, + dimensions.out_cols, dimensions.out_depth); // Output tensor is of the following dimensions: // [ in_batch, out_rows, out_cols, out_depth ] Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - VLOG(2) << "Conv2D: in_depth = " << in_depth - << ", patch_depth = " << patch_depth - << ", input_cols = " << input_cols - << ", filter_cols = " << filter_cols - << ", input_rows = " << input_rows - << ", filter_rows = " << filter_rows - << ", stride_rows = " << stride_rows - << ", stride_cols = " << stride_cols - << ", dilation_rows = " << dilation_rows - << ", dilation_cols = " << dilation_cols - << ", out_depth = " << out_depth; + VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth + << ", patch_depth = " << dimensions.patch_depth + << ", input_cols = " << dimensions.input_cols + << ", filter_cols = " << dimensions.filter_cols + << ", input_rows = " << dimensions.input_rows + << ", filter_rows = " << dimensions.filter_rows + << ", stride_rows = " << dimensions.stride_rows + << ", stride_cols = " << dimensions.stride_cols + << ", dilation_rows = " << dimensions.dilation_rows + << ", dilation_cols = " << dimensions.dilation_cols + << ", out_depth = " << dimensions.out_depth; // If there is nothing to compute, return. if (out_shape.num_elements() == 0) { @@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> { #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS if (LaunchXsmmConvOp<Device, T>::Run( - context, input, filter, batch, input_rows, input_cols, in_depth, - filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, - output, data_format_)) { + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols, + dimensions.out_rows, dimensions.out_cols, dimensions.out_depth, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, output, + params_.data_format)) { return; } #endif if (LaunchDeepConvOp<Device, T>::Run( - context, input, filter, batch, input_rows, input_cols, in_depth, - filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, - output, data_format_)) { + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols, + dimensions.out_rows, dimensions.out_cols, dimensions.out_depth, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, output, + params_.data_format)) { return; } launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, - dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, - output, data_format_); + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, params_.padding, + output, params_.data_format); } private: - std::vector<int32> dilations_; - std::vector<int32> strides_; + Conv2DParameters params_; bool use_cudnn_; - Padding padding_; - TensorFormat data_format_; - LaunchConv2DOp<Device, T> launcher_; bool cudnn_use_autotune_; + LaunchConv2DOp<Device, T> launcher_; + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); }; @@ -680,9 +733,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( TensorShape({filter.dim_size(3), filter.dim_size(2), filter.dim_size(0), filter.dim_size(1)}), &transformed_filter)); - functor::TransformFilter<GPUDevice, T, int, 4>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), + ctx->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 4>()), To32Bit(transformed_filter.tensor<T, 4>())); Tensor transformed_output; @@ -731,9 +784,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( if (cudnn_use_autotune && !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) { std::vector<AlgorithmDesc> algorithms; - CHECK(stream->parent()->GetConvolveAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()), - &algorithms)); + OP_REQUIRES( + ctx, + stream->parent()->GetConvolveAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( + stream->parent()), + &algorithms), + errors::Unknown("Failed to get convolution algorithm. This is probably " + "because cuDNN failed to initialize, so try looking to " + "see if a warning log message was printed above.")); ProfileResult best_result; ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { @@ -823,7 +882,8 @@ namespace functor { extern template struct MatMulConvFunctor<GPUDevice, T>; \ template <> \ void TransformFilter<GPUDevice, T, int, 4>::operator()( \ - const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 4, int>::ConstTensor in, \ typename TTypes<T, 4, int>::Tensor out); \ extern template struct TransformFilter<GPUDevice, T, int, 4>; \ template <> \ diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index adf4601b43..7ec878e0b2 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase { string DebugString() { return "Im2ColBufferResource"; } }; +// Convolution parameters specified by Op attributes. +struct Conv2DParameters { + std::vector<int32> dilations; + std::vector<int32> strides; + Padding padding; + TensorFormat data_format; +}; + +// Convolution dimensions inferred from parameters, input and filter tensors. +struct Conv2DDimensions { + int batch; + int input_rows; + int input_cols; + int in_depth; + + int filter_rows; + int filter_cols; + int patch_depth; + int out_depth; + + int stride_rows; + int stride_cols; + + int dilation_rows; + int dilation_cols; + + int64 out_rows; + int64 out_cols; + int64 pad_rows; + int64 pad_cols; +}; + +// Initializes and validates Conv2D parameters configured by OpKernel +// attributes. +Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params); + +// Computes and validates convolutions dimensions from Conv2D parameters. If +// parameters are valid, dimensions will be updated with derived convolution +// dimensions, otherwise error will be returned. +Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions); + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index a1eed4e68c..83df4dce38 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -386,7 +386,8 @@ struct LaunchConvOp<GPUDevice, T> { // filter: [x, y, z, in, out] // t_filter: [out, in, x, y, z] functor::TransformFilter<GPUDevice, T, int, 5>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), + ctx->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 5>()), To32Bit(transformed_filter.tensor<T, 5>())); Tensor transformed_output; @@ -434,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> { if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find( conv_parameters, &algorithm_config)) { std::vector<AlgorithmDesc> algorithms; - CHECK(stream->parent()->GetConvolveAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( - stream->parent()), - &algorithms)); + OP_REQUIRES(ctx, + stream->parent()->GetConvolveAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( + stream->parent()), + &algorithms), + errors::Unknown( + "Failed to get convolution algorithm. This is probably " + "because cuDNN failed to initialize, so try looking to " + "see if a warning log message was printed above.")); + ProfileResult best_result; ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { @@ -514,7 +521,8 @@ namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void TransformFilter<GPUDevice, T, int, 5>::operator()( \ - const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 5, int>::ConstTensor in, \ typename TTypes<T, 5, int>::Tensor out); \ template <> \ void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index afc611f277..21d135decd 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -142,8 +142,12 @@ class ConvParameters { template <typename T> bool ShouldIncludeWinogradNonfusedAlgo( se::StreamExecutor* stream_exec) const { + auto* dnn_support = stream_exec->AsDnn(); + if (!dnn_support) { + return false; + } // Skip this check for cuDNN 7 and newer. - auto version = stream_exec->AsDnn()->GetVersion(); + auto version = dnn_support->GetVersion(); if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index a5fa48f85e..46167db3a2 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -170,51 +170,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex( return tensor_index; } -// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor. -template <typename T, bool conjugate = false> -__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input, - Dimension<3> input_dims, - T* output) { - Dimension<3> output_dims; - output_dims[0] = input_dims[2]; - output_dims[1] = input_dims[1]; - output_dims[2] = input_dims[0]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int output_index = index; - - Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); - - Index<3> input_tensor_index; - input_tensor_index[0] = output_tensor_index[2]; - input_tensor_index[1] = output_tensor_index[1]; - input_tensor_index[2] = output_tensor_index[0]; - - int input_index = TensorIndexToFlat(input_tensor_index, input_dims); - - output[output_index] = - maybe_conj<T, conjugate>::run(ldg(input + input_index)); - } -} - -// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor. -template <typename T, bool conjugate = false> -__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input, - Dimension<3> input_dims, - T* output) { +// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to +// the given shuffle permutation in template parameters. Shuffle permutation +// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0, +// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1> +// will populate output so that input[x][y][z] is equal to (*output)[y][z][x]. +// +// Requires that nthreads is equal to the total number of elements in the input +// tensor. +template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> +__global__ void ShuffleInTensor3Simple(int nthreads, const T* input, + Dimension<3> input_dims, T* output) { Dimension<3> output_dims; - output_dims[0] = input_dims[0]; - output_dims[1] = input_dims[2]; - output_dims[2] = input_dims[1]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int output_index = index; + output_dims[sp0] = input_dims[0]; + output_dims[sp1] = input_dims[1]; + output_dims[sp2] = input_dims[2]; + + // Iterate over output as opposed to iterating over input for better + // performance. Iterating over output will generate sequential writes and + // random reads that performs better compared to sequential reads and random + // writes. + CUDA_1D_KERNEL_LOOP(output_index, nthreads) { Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); Index<3> input_tensor_index; - input_tensor_index[0] = output_tensor_index[0]; - input_tensor_index[1] = output_tensor_index[2]; - input_tensor_index[2] = output_tensor_index[1]; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; int input_index = TensorIndexToFlat(input_tensor_index, input_dims); @@ -439,7 +421,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input, template <typename T, int NDIMS> struct TransformFilter<GPUDevice, T, int, NDIMS> { typedef GPUDevice Device; - void operator()(const Device& d, + void operator()(const Device& d, FilterTensorFormat dst_filter_format, typename TTypes<T, NDIMS, int>::ConstTensor in, typename TTypes<T, NDIMS, int>::Tensor out) { Dimension<3> combined_dims; @@ -450,13 +432,18 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> { combined_dims[1] = in.dimension(NDIMS - 2); // input filters combined_dims[2] = in.dimension(NDIMS - 1); // output filters CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension0And2InTensor3Simple<T> + + CHECK(dst_filter_format == FORMAT_OIHW) + << "Unsupported output layout: " << ToString(dst_filter_format); + + ShuffleInTensor3Simple<T, 2, 1, 0> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in.data(), combined_dims, out.data()); } }; -// Converts Cudnn filter format back to TensorFlow filter format. +// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO. +// TODO(hinsu): Support reverse transformation from filter format OHWI as well. template <typename T, int NDIMS> struct ReverseTransformFilter<GPUDevice, T, NDIMS> { typedef GPUDevice Device; @@ -470,7 +457,7 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> { combined_dims[2] *= in.dimension(i); } CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension0And2InTensor3Simple<T> + ShuffleInTensor3Simple<T, 2, 1, 0> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in.data(), combined_dims, out.data()); } @@ -937,7 +924,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, } else { int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d); - SwapDimension1And2InTensor3Simple<T, conjugate> + ShuffleInTensor3Simple<T, 0, 2, 1, conjugate> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, input, input_dims, output); } @@ -969,7 +956,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> { static_cast<int>(combined_dims[2])}; size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d); - SwapDimension0And2InTensor3Simple<T, conjugate> + ShuffleInTensor3Simple<T, 2, 1, 0, conjugate> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in, input_dims, out); } diff --git a/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc new file mode 100644 index 0000000000..e4b21a66c6 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc new file mode 100644 index 0000000000..1e1b5a426e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc new file mode 100644 index 0000000000..6a6aec5e86 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_xdivy.cc @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double, + complex64, complex128); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \ + BinaryOp<SYCLDevice, functor::xdivy<TYPE>>); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double, + complex64, complex128); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc new file mode 100644 index 0000000000..e71a9109b2 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_xlogy.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double, + complex64, complex128); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \ + BinaryOp<SYCLDevice, functor::xlogy<TYPE>>); +REGISTER_SYCL_KERNEL(Eigen::half); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +REGISTER_SYCL_KERNEL(complex64); +REGISTER_SYCL_KERNEL(complex128); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double, + complex64, complex128); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 22eb66e979..66ba827a90 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -471,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> { enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; }; +// TODO(srvasude): Add packet versions of this operation. +template <typename Scalar> +struct xlogy_op { + EIGEN_EMPTY_STRUCT_CTOR(xlogy_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x * numext::log(y); + } +}; + +template <typename Scalar> +struct functor_traits<xlogy_op<Scalar>> { + enum { + Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost, + PacketAccess = false + }; +}; + +template <typename Scalar> +// TODO(srvasude): Add packet versions of this operation. +struct xdivy_op { + EIGEN_EMPTY_STRUCT_CTOR(xdivy_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x / y; + } +}; + +template <typename Scalar> +struct functor_traits<xdivy_op<Scalar>> { + enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false }; +}; + } // end namespace internal } // end namespace Eigen @@ -830,6 +869,12 @@ struct squared_difference Eigen::internal::scalar_difference_op<T>>> {}; template <typename T> +struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {}; + +template <typename T> +struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {}; + +template <typename T> struct less : base<T, Eigen::internal::less<T>, bool> {}; template <typename T> diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc index 980edffceb..8ad3b4d1fc 100644 --- a/tensorflow/core/kernels/cwise_ops_common.cc +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -20,9 +20,9 @@ namespace tensorflow { BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in) : OpKernel(ctx) { -#ifndef INTEL_MKL +#if !defined(INTEL_MKL) || !defined(ENABLE_MKL) OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); -#endif +#endif // !INTEL_MKL || !ENABLE_MKL } void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index e7b3d0c92f..87efdff789 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -51,6 +51,7 @@ cc_library( hdrs = ["captured_function.h"], deps = [ ":dataset", + ":single_threaded_executor", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -61,6 +62,42 @@ cc_library( ) cc_library( + name = "single_threaded_executor", + srcs = ["single_threaded_executor.cc"], + hdrs = ["single_threaded_executor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "single_threaded_executor_test", + srcs = ["single_threaded_executor_test.cc"], + deps = [ + ":single_threaded_executor", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + +cc_library( name = "window_dataset", srcs = ["window_dataset.cc"], hdrs = ["window_dataset.h"], @@ -591,6 +628,20 @@ tf_kernel_library( ) tf_kernel_library( + name = "multi_device_iterator_ops", + srcs = ["multi_device_iterator_ops.cc"], + deps = [ + ":dataset", + ":dataset_utils", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core/kernels:ops_util", + ], +) + +tf_kernel_library( name = "optional_ops", srcs = ["optional_ops.cc"], hdrs = ["optional_ops.h"], @@ -638,6 +689,19 @@ tf_kernel_library( ) tf_kernel_library( + name = "model_dataset_op", + srcs = ["model_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( name = "dataset_ops", srcs = ["dataset_ops.cc"], deps = [ @@ -671,6 +735,8 @@ tf_kernel_library( ":map_and_batch_dataset_op", ":map_dataset_op", ":map_defun_op", + ":model_dataset_op", + ":multi_device_iterator_ops", ":optimize_dataset_op", ":optional_ops", ":padded_batch_dataset_op", diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index f9b5353724..d1db1d7bec 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -241,5 +242,5 @@ REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU), BatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 6ca0bcd37d..34c6c86538 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level description of @@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr<IteratorBase>( - new FileIterator({this, strings::StrCat(prefix, "::FileIterator")})); + new FileIterator({this, strings::StrCat(prefix, "::FileCache")})); } const DataTypeVector& output_dtypes() const override { @@ -553,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr<IteratorBase>(new MemoryIterator( - {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_)); + {this, strings::StrCat(prefix, "::MemoryCache")}, cache_)); } const DataTypeVector& output_dtypes() const override { @@ -891,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU), CacheDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index abdf6ee4e8..0bb929b3ce 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -17,33 +17,101 @@ limitations under the License. #include <utility> #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { +namespace data { + +namespace { + +// Simplistic implementation of the `StepStatsCollectorInterface` that only +// cares about collecting the CPU time needed to execute a captured function. +class SimpleStepStatsCollector : public StepStatsCollectorInterface { + public: + void IncrementProcessingTime(int64 delta) { + mutex_lock l(mu_); + processing_time_ += delta; + } + + NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override { + return new SimpleNodeExecStats(this); + } + + string ReportAllocsOnResourceExhausted(const string& err) override { + return ""; + } + + int64 processing_time() { + tf_shared_lock l(mu_); + return processing_time_; + } + + private: + class SimpleNodeExecStats : public NodeExecStatsInterface { + public: + explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector) + : step_stats_collector_(step_stats_collector) {} + + void Done(const string& device) override { + step_stats_collector_->IncrementProcessingTime(end_time_ns_ - + start_time_ns_); + delete this; + } + + void RecordExecutorStarted() override { + start_time_ns_ = Env::Default()->NowNanos(); + } + + void RecordComputeStarted() override {} + + void RecordComputeEnded() override {} + + void RecordExecutorEnded() override { + end_time_ns_ = Env::Default()->NowNanos(); + } + + void SetMemory(OpKernelContext* ctx) override {} + + void SetOutput(int slot, const Tensor* tensor) override {} + + void SetReferencedTensors(const TensorReferenceVector& tensors) override {} + + void SetScheduled(int64 nanos) override {} + + private: + int64 start_time_ns_ = 0; + int64 end_time_ns_ = 0; + SimpleStepStatsCollector* step_stats_collector_; // Not owned. + }; + + mutex mu_; + int64 processing_time_ GUARDED_BY(mu_) = 0; +}; + +} // namespace /* static */ Status CapturedFunction::Create( - const NameAttrList& func, std::vector<Tensor> captured_inputs, + const NameAttrList& func, OpKernelContext* ctx, const string& argument, std::unique_ptr<CapturedFunction>* out_function) { - out_function->reset(new CapturedFunction(func, std::move(captured_inputs))); - return Status::OK(); + return CapturedFunction::Create(func, ctx, argument, true, out_function); } -/* static */ Status CapturedFunction::Create( const NameAttrList& func, OpKernelContext* ctx, const string& argument, + bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction>* out_function) { - OpInputList argument_inputs; - TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs)); - std::vector<Tensor> arguments_t; - arguments_t.reserve(argument_inputs.size()); - for (const Tensor& t : argument_inputs) { - arguments_t.push_back(t); - } - return CapturedFunction::Create(func, std::move(arguments_t), out_function); + OpInputList inputs; + TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs)); + std::vector<Tensor> arguments(inputs.begin(), inputs.end()); + *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments), + use_inter_op_parallelism)); + return Status::OK(); } CapturedFunction::~CapturedFunction() { @@ -272,6 +340,9 @@ Status CapturedFunction::Instantiate(IteratorContext* ctx) { inst_opts.overlay_lib = ctx->function_library().get(); inst_opts.state_handle = std::to_string(random::New64()); inst_opts.create_kernels_eagerly = true; + if (!use_inter_op_parallelism_) { + inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; + } Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_opts, &f_handle_)); TF_RETURN_IF_ERROR(s); @@ -345,7 +416,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args, void CapturedFunction::RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, - FunctionLibraryRuntime::DoneCallback done) { + FunctionLibraryRuntime::DoneCallback done, + const string& prefix) { // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may // be deleted before `done` is called. Take care not to capture `ctx` in any // code that may execute asynchronously in this function. @@ -355,13 +427,13 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, done(s); return; } - auto frame = + OwnedArgsCallFrame* frame = new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager(); - auto step_container = new ScopedStepContainer( + ScopedStepContainer* step_container = new ScopedStepContainer( f_opts.step_id, [resource_mgr](const string& name) { resource_mgr->Cleanup(name).IgnoreError(); }); @@ -376,32 +448,50 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, // (such as queue kernels) that depend on the non-nullness of // `OpKernelContext::cancellation_manager()`, but additional effort // will be required to plumb it through the `IteratorContext`. - auto c_mgr = new CancellationManager; + CancellationManager* c_mgr = new CancellationManager; f_opts.cancellation_manager = c_mgr; - - tf_shared_lock l(mu_); - ctx->lib()->Run(f_opts, handle, frame, - std::bind( - [rets, step_container, c_mgr, frame]( - FunctionLibraryRuntime::DoneCallback done, - // Begin unbound arguments. - Status s) { - delete step_container; - delete c_mgr; - if (s.ok()) { - s = frame->ConsumeRetvals(rets); - } - delete frame; - done(s); - }, - std::move(done), std::placeholders::_1)); + std::shared_ptr<SimpleStepStatsCollector> stats_collector; + if (ctx->model()) { + stats_collector = MakeUnique<SimpleStepStatsCollector>(); + } + f_opts.stats_collector = stats_collector.get(); + + auto callback = std::bind( + [rets, step_container, c_mgr, frame]( + const FunctionLibraryRuntime::DoneCallback& done, + const std::shared_ptr<model::Model>& model, const string& prefix, + const std::shared_ptr<SimpleStepStatsCollector>& stats_collector, + // Begin unbound arguments. + Status s) { + delete step_container; + delete c_mgr; + if (s.ok()) { + s = frame->ConsumeRetvals(rets); + } + delete frame; + if (model) { + model->AddProcessingTime(prefix, stats_collector->processing_time()); + model->RecordStart(prefix, false /* stop_output */); + } + done(s); + if (model) { + model->RecordStop(prefix, false /* start_output */); + } + }, + std::move(done), ctx->model(), prefix, std::move(stats_collector), + std::placeholders::_1); + + ctx->lib()->Run(f_opts, handle, frame, std::move(callback)); } CapturedFunction::CapturedFunction(const NameAttrList& func, - std::vector<Tensor> captured_inputs) + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism) : func_(func), lib_(nullptr), f_handle_(kInvalidHandle), - captured_inputs_(std::move(captured_inputs)) {} + captured_inputs_(std::move(captured_inputs)), + use_inter_op_parallelism_(use_inter_op_parallelism) {} +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index c95f2b1c01..a10376bf97 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -32,6 +32,8 @@ class Device; class OpKernelContext; class ResourceMgr; +namespace data { + // A `CapturedFunction` encapsulates a TensorFlow function and all of // the runtime support required to execute it. // @@ -40,18 +42,19 @@ class ResourceMgr; // context. class CapturedFunction { public: - // Creates a new instance from a list of named attributes and captured inputs. - // - // NOTE(mrry): The `captured_inputs` are passed by value. For - // efficiency, you are recommended to move this argument into the call. - static Status Create(const NameAttrList& func, - std::vector<Tensor> captured_inputs, + // Creates a new instance using a list of named attributes, fetching captured + // inputs from a context argument. + static Status Create(const NameAttrList& func, OpKernelContext* ctx, + const string& argument, std::unique_ptr<CapturedFunction>* out_function); // Creates a new instance using a list of named attributes, fetching captured // inputs from a context argument. + // + // If `use_inter_op_parallelism` is false, the runtime may use an executor + // that is optimized for small functions. static Status Create(const NameAttrList& func, OpKernelContext* ctx, - const string& argument, + const string& argument, bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction>* out_function); ~CapturedFunction(); @@ -93,7 +96,8 @@ class CapturedFunction { // in order to be able to deallocate them as early as possible. void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, - FunctionLibraryRuntime::DoneCallback done); + FunctionLibraryRuntime::DoneCallback done, + const string& prefix); // Returns the named list of function arguments. const NameAttrList& func() { return func_; } @@ -114,7 +118,8 @@ class CapturedFunction { private: CapturedFunction(const NameAttrList& func, - std::vector<Tensor> captured_inputs); + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism); Status GetHandle(IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle); @@ -126,10 +131,17 @@ class CapturedFunction { const std::vector<Tensor> captured_inputs_; DataTypeSlice ret_types_; std::function<void(std::function<void()>)> captured_runner_ = nullptr; + const bool use_inter_op_parallelism_; TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; +} // namespace data + +// TODO(b/114112161): Remove these aliases when all users have moved over to the +// `tensorflow::data` namespace. +using data::CapturedFunction; + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index c361a9adcb..a04f150e71 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU), ConcatenateDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index c71d027f23..bd1ccd5b5d 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -48,4 +49,5 @@ class DatasetToGraphOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU), DatasetToGraphOp); +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index d85ef1cbab..e10833f525 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" namespace tensorflow { - -namespace dataset { +namespace data { Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, @@ -45,6 +44,42 @@ Status MakeIteratorFromInputElement( ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator); } -} // namespace dataset +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != received[i]) { + return errors::InvalidArgument("Data type mismatch at component ", i, + ": expected ", DataTypeString(expected[i]), + " but got ", DataTypeString(received[i]), + "."); + } + } + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, + const std::vector<PartialTensorShape>& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].IsCompatibleWith(received[i])) { + return errors::InvalidArgument("Incompatible shapes at component ", i, + ": expected ", expected[i].DebugString(), + " but got ", received[i].DebugString(), + "."); + } + } + + return Status::OK(); +} +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6c4191c2be..6ec1350cd4 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -20,16 +20,24 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - -namespace dataset { +namespace data { Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator); -} // namespace dataset +// Returns Status::OK() if `expected` and `received` types match, +// errors::InvalidArgument otherwise. +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received); + +// Returns Status::OK() if `expected` and `received` shapes are compatible, +// errors::InvalidArgument otherwise. +Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, + const std::vector<PartialTensorShape>& received); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 9770bc025d..237511a07d 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -301,5 +301,5 @@ REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU), DenseToSparseBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc index ce577397c5..a7e3a56727 100644 --- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -166,5 +166,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU), FilterByLastComponentDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index bbce001eaf..00884314a9 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -14,14 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -37,14 +39,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - FunctionLibraryRuntime::Handle pred_handle; OP_REQUIRES_OK(ctx, ctx->function_library()->Instantiate( @@ -61,9 +55,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { Node* ret_node = pred_body->ret_nodes[0]; Node* ret_input_node; OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); + std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); if (ret_input_node->def().op() == "_Arg") { int32 index = -1; @@ -146,7 +141,13 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<FilterDatasetBase> { public: explicit Iterator(const Params& params) - : DatasetIterator<FilterDatasetBase>(params) {} + : DatasetIterator<FilterDatasetBase>(params), + filtered_elements_(0), + dropped_elements_(0) { + std::vector<string> components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -161,6 +162,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { // `input_impl_` and `f` are thread-safe. However, if multiple // threads enter this method, outputs may be observed in a // non-deterministic order. + auto stats_aggregator = ctx->stats_aggregator(); bool matched; do { { @@ -183,8 +185,34 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); + if (stats_aggregator) { + mutex_lock l(mu_); + dropped_elements_++; + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::dropped_elements"), + static_cast<float>((dropped_elements_))); + // TODO(shivaniagrawal): multiple pipelines would collect + // aggregated number of dropped elements for all the pipelines, + // exploit tagged_context here. + stats_aggregator->IncrementCounter( + prefix_end_, "dropped_elements", static_cast<float>(1)); + } } } while (!matched); + // TODO(shivaniagrawal): add ratio of dropped_elements and + // filtered_elements as a histogram. + if (stats_aggregator) { + mutex_lock l(mu_); + filtered_elements_++; + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::filtered_elements"), + static_cast<float>((filtered_elements_))); + // TODO(shivaniagrawal): multiple pipelines would collect aggregated + // number of filtered elements for all the pipelines, exploit + // tagged_context here. + stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements", + static_cast<float>(1)); + } *end_of_sequence = false; return Status::OK(); } @@ -197,6 +225,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { else TF_RETURN_IF_ERROR( writer->WriteScalar(full_name("input_impls_empty"), "")); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"), + filtered_elements_)); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"), + dropped_elements_)); return Status::OK(); } @@ -207,12 +239,19 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { input_impl_.reset(); else TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"), + &filtered_elements_)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"), + &dropped_elements_)); return Status::OK(); } private: mutex mu_; std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + int64 filtered_elements_ GUARDED_BY(mu_); + int64 dropped_elements_ GUARDED_BY(mu_); + string prefix_end_; }; const DatasetBase* const input_; @@ -280,5 +319,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU), FilterDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index b1eb2fd849..2fada22a21 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); - + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); } @@ -245,7 +236,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { private: Status BuildCurrentElementIteratorLocked(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return dataset::MakeIteratorFromInputElement( + return MakeIteratorFromInputElement( ctx, captured_func_inputs_, element_index_++, dataset()->captured_func_.get(), prefix(), ¤t_element_iterator_); @@ -285,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU), FlatMapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index ccee690d7e..b4367d5a11 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -85,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx)); TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx)); TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); return Status::OK(); } @@ -95,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { bool* end_of_sequence) override { mutex_lock l(mu_); + if (!initialized_) { + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + initialized_ = true; + } + if (finalized_) { *end_of_sequence = true; return Status::OK(); @@ -122,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { private: mutex mu_; + bool initialized_ GUARDED_BY(mu_) = false; bool finalized_ GUARDED_BY(mu_) = false; std::vector<Tensor> state_ GUARDED_BY(mu_); }; @@ -144,54 +150,31 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), - &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + init_func_, ctx, "init_func_other_args", &init_func)); + std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), - &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); + next_func_, ctx, "next_func_other_args", &next_func)); + + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx, + "finalize_func_other_args", + &finalize_func)); *output = new Dataset(ctx, std::move(init_func), std::move(next_func), std::move(finalize_func), output_types_, output_shapes_); } +namespace { REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), GeneratorDatasetOp); REGISTER_KERNEL_BUILDER( Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"), GeneratorDatasetOp); +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h index 8407543136..d23ed97ec3 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.h +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" namespace tensorflow { +namespace data { class GeneratorDatasetOp : public DatasetOpKernel { public: @@ -36,5 +37,6 @@ class GeneratorDatasetOp : public DatasetOpKernel { NameAttrList finalize_func_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index 130f04da3e..d6ee42a7c6 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -433,4 +434,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU), GroupByReducerDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 46a3185b49..8b417bb1c2 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -41,50 +42,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - // Get captured inputs for the key, reduce, and window_size functions. - OpInputList key_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments", - &key_func_other_argument_inputs)); - std::vector<Tensor> key_func_other_arguments; - key_func_other_arguments.reserve(key_func_other_argument_inputs.size()); - for (const Tensor& t : key_func_other_argument_inputs) { - key_func_other_arguments.push_back(t); - } - OpInputList reduce_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments", - &reduce_func_other_argument_inputs)); - std::vector<Tensor> reduce_func_other_arguments; - reduce_func_other_arguments.reserve( - reduce_func_other_argument_inputs.size()); - for (const Tensor& t : reduce_func_other_argument_inputs) { - reduce_func_other_arguments.push_back(t); - } - OpInputList window_size_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, - ctx->input_list("window_size_func_other_arguments", - &window_size_func_other_argument_inputs)); - std::vector<Tensor> window_size_func_other_arguments; - window_size_func_other_arguments.reserve( - window_size_func_other_argument_inputs.size()); - for (const Tensor& t : window_size_func_other_argument_inputs) { - window_size_func_other_arguments.push_back(t); - } - // TODO(mrry): Refactor CapturedFunction to share the runtime - // state between multiple functions? std::unique_ptr<CapturedFunction> captured_key_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - key_func_, std::move(key_func_other_arguments), - &captured_key_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, + "key_func_other_arguments", + &captured_key_func)); std::unique_ptr<CapturedFunction> captured_reduce_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(reduce_func_, - std::move(reduce_func_other_arguments), - &captured_reduce_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, + "reduce_func_other_arguments", + &captured_reduce_func)); std::unique_ptr<CapturedFunction> captured_window_size_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - window_size_func_, std::move(window_size_func_other_arguments), - &captured_window_size_func)); + OP_REQUIRES_OK(ctx, + CapturedFunction::Create(window_size_func_, ctx, + "window_size_func_other_arguments", + &captured_window_size_func)); *output = new Dataset( ctx, input, key_func_, reduce_func_, window_size_func_, @@ -549,4 +519,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU), GroupByWindowDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 716e040277..0aa802b874 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - const Tensor* cycle_length_t; OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t)); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()), @@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { errors::InvalidArgument("block_length must be greater than zero.")); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), cycle_length, @@ -201,7 +193,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(input_impl_->GetNext( ctx, &args_list_[cycle_index_], &end_of_input_)); if (!end_of_input_) { - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( ctx, args_list_[cycle_index_], cycle_index_, dataset()->captured_func_.get(), prefix(), ¤t_elements_[cycle_index_])); @@ -288,7 +280,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), &args_list_[idx][i])); } - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( ctx, args_list_[idx], idx, dataset()->captured_func_.get(), prefix(), ¤t_elements_[idx])); TF_RETURN_IF_ERROR( @@ -330,5 +322,5 @@ REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU), InterleaveDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 4e9b280968..c0bc507ec0 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -44,43 +44,6 @@ namespace { const char kIteratorVariantTypeName[] = "tensorflow::Iterator"; -Status VerifyTypesMatch(const DataTypeVector& expected, - const DataTypeVector& received) { - if (expected.size() != received.size()) { - return errors::InvalidArgument( - "Number of components does not match: expected ", expected.size(), - " types but got ", received.size(), "."); - } - for (size_t i = 0; i < expected.size(); ++i) { - if (expected[i] != received[i]) { - return errors::InvalidArgument("Data type mismatch at component ", i, - ": expected ", DataTypeString(expected[i]), - " but got ", DataTypeString(received[i]), - "."); - } - } - return Status::OK(); -} - -Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, - const std::vector<PartialTensorShape>& received) { - if (expected.size() != received.size()) { - return errors::InvalidArgument( - "Number of components does not match: expected ", expected.size(), - " shapes but got ", received.size(), "."); - } - for (size_t i = 0; i < expected.size(); ++i) { - if (!expected[i].IsCompatibleWith(received[i])) { - return errors::InvalidArgument("Incompatible shapes at component ", i, - ": expected ", expected[i].DebugString(), - " but got ", received[i].DebugString(), - "."); - } - } - - return Status::OK(); -} - } // namespace class IteratorResource : public ResourceBase { @@ -236,6 +199,8 @@ class IteratorResource : public ResourceBase { const std::vector<PartialTensorShape> output_shapes_; }; +namespace { + // Helper class for reading data from a VariantTensorData object. class VariantTensorDataReader : public IteratorStateReader { public: @@ -401,12 +366,12 @@ class IteratorStateVariant { } string TypeName() const { return kIteratorVariantTypeName; } void Encode(VariantTensorData* data) const { *data = *data_; } - bool Decode(const VariantTensorData& data) { + bool Decode(VariantTensorData data) { if (data.type_name() != TypeName()) { return false; } std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData); - *tensor_data = data; + std::swap(*tensor_data, data); std::unique_ptr<VariantTensorDataReader> reader( new VariantTensorDataReader(tensor_data.get())); status_ = reader->status(); @@ -443,6 +408,8 @@ class IteratorStateVariant { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, kIteratorVariantTypeName); +} // namespace + // Note that IteratorHandleOp holds a reference to the resource it creates. If // cleaning up resources with DestroyResourceOp is important, consider creating // resource containers with AnonymousIteratorHandleOp instead. @@ -622,6 +589,8 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } +namespace { + class ToSingleElementOp : public AsyncOpKernel { public: explicit ToSingleElementOp(OpKernelConstruction* ctx) @@ -887,6 +856,8 @@ class OneShotIteratorOp : public AsyncOpKernel { const int graph_def_version_; }; +} // namespace + void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { IteratorResource* iterator; OP_REQUIRES_OK_ASYNC( @@ -957,6 +928,8 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) { } } +namespace { + class IteratorGetNextAsOptionalOp : public AsyncOpKernel { public: explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) @@ -1037,6 +1010,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel { std::vector<PartialTensorShape> output_shapes_; }; +} // namespace + void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { const Tensor& resource_handle_t = ctx->input(0); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), @@ -1108,6 +1083,8 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { resource_handle_t->scalar<ResourceHandle>()() = resource_handle; } +namespace { + class SerializeIteratorOp : public OpKernel { public: explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -1202,4 +1179,7 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), DeserializeIteratorOp); +} // namespace + +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 723564286c..8a2b2639a7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" namespace tensorflow { +namespace data { class IteratorResource; @@ -142,6 +143,7 @@ class IteratorFromStringHandleOp : public OpKernel { std::vector<PartialTensorShape> output_shapes_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 8b0c9ad6b2..2bbf4af664 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #define EIGEN_USE_THREADS +#include <atomic> #include <utility> #include "tensorflow/core/common_runtime/function.h" @@ -26,10 +27,11 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -39,7 +41,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -49,14 +50,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 batch_size; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); OP_REQUIRES( @@ -77,7 +70,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { case 2: OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, + num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); break; @@ -92,8 +86,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, batch_size, num_parallel_calls, drop_remainder, output_types_, output_shapes_, func_, @@ -190,7 +184,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + : DatasetIterator<Dataset>(params), + num_parallel_calls_(params.dataset->num_parallel_calls_) {} ~Iterator() override { mutex_lock l(mu_); @@ -204,6 +199,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + AddTunableParameter(ctx, "parallelism", + &num_parallel_calls_ /* value */, 1 /* min */, + port::NumSchedulableCPUs() /* max */, &cond_var_); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -218,12 +223,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { EnsureRunnerThreadStarted(ctx); while (batch_results_.empty() || batch_results_.front()->num_calls > 0) { + RecordStop(ctx); cond_var_.wait(l); + RecordStart(ctx); } std::swap(result, batch_results_.front()); batch_results_.pop_front(); + cond_var_.notify_all(); } - cond_var_.notify_all(); return ProcessResult(ctx, result, out_tensors, end_of_sequence); } @@ -326,11 +333,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void CallCompleted(const std::shared_ptr<BatchResult>& result) LOCKS_EXCLUDED(mu_) { - { - mutex_lock l(mu_); - num_calls_--; - result->num_calls--; - } + mutex_lock l(mu_); + num_calls_--; + result->num_calls--; cond_var_.notify_all(); } @@ -365,7 +370,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ctx.get(), std::move(input_element), return_values.get(), [this, ctx, result, return_values, offset](Status status) { Callback(ctx, result, return_values, offset, status); - }); + }, + prefix()); }, ctx, std::move(input_element))); } @@ -422,11 +428,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { result->output_allocated = true; } - int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) / - dataset()->batch_size_; - } - Status ProcessResult(IteratorContext* ctx, const std::shared_ptr<BatchResult>& result, std::vector<Tensor>* out_tensors, @@ -475,26 +476,34 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) LOCKS_EXCLUDED(mu_) { std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; - new_calls.reserve(dataset()->num_parallel_calls_); + RecordStart(ctx.get()); + auto stop_cleanup = + gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); }); + new_calls.reserve(num_parallel_calls_); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_; + int64 max_batch_results = + (num_parallel_calls + dataset()->batch_size_ - 1) / + dataset()->batch_size_; + return num_calls_ >= num_parallel_calls || + (batch_results_.size() > max_batch_results || + (batch_results_.size() == max_batch_results && + call_counter_ % dataset()->batch_size_ == 0)); + }; while (true) { { mutex_lock l(mu_); - while (!cancelled_ && - (num_calls_ >= dataset()->num_parallel_calls_ || - batch_results_.size() > MaxBatchResults() || - (batch_results_.size() == MaxBatchResults() && - call_counter_ % dataset()->batch_size_ == 0))) { + while (!cancelled_ && busy()) { + RecordStop(ctx.get()); cond_var_.wait(l); + RecordStart(ctx.get()); } if (cancelled_) { return; } - while (num_calls_ < dataset()->num_parallel_calls_ && - (batch_results_.size() < MaxBatchResults() || - (batch_results_.size() == MaxBatchResults() && - call_counter_ % dataset()->batch_size_ != 0))) { + while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { batch_results_.emplace_back( new BatchResult(dataset()->batch_size_)); @@ -638,6 +647,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // user specified level of parallelism and there are slots available in // the `batch_results_` buffer. condition_variable cond_var_; + // Identifies the maximum number of parallel calls. + std::atomic<int64> num_parallel_calls_; // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(mu_) = 0; // Counts the total number of calls. @@ -661,7 +672,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const Eigen::ThreadPoolDevice* device_; // not owned }; - const int graph_def_version_; const int op_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; @@ -675,5 +685,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU), MapAndBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 7f8182d917..f112e1dc43 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -28,27 +28,20 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: - explicit MapDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + use_inter_op_parallelism_, + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); @@ -183,14 +176,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList func_; + bool use_inter_op_parallelism_; }; REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 607d0ca028..6657f2b2b3 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -18,18 +18,20 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/util/reffed_status_callback.h" namespace tensorflow { +namespace data { namespace { void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, bool always_collect_stats) { opts->step_id = ctx->step_id(); opts->rendezvous = ctx->rendezvous(); - opts->cancellation_manager = ctx->cancellation_manager(); if (always_collect_stats) { opts->stats_collector = ctx->stats_collector(); } @@ -60,103 +62,186 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - int64 batch_size = ctx->input(0).dim_size(0); - // Inputs - auto* args = new std::vector<Tensor>; - auto* arg_shapes = new std::vector<TensorShape>; - arg_shapes->reserve(ctx->num_inputs()); - args->reserve(ctx->num_inputs()); - + Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { + // Validates inputs and gets the size of their leading dimension. + *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args->push_back(ctx->input(i)); - arg_shapes->push_back(ctx->input(i).shape()); - arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension - OP_REQUIRES_ASYNC( - ctx, batch_size == ctx->input(i).dim_size(0), - errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size, - "."), - done); + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != *batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } } + return Status::OK(); + } - // Outputs - auto* output = new OpOutputList; - OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ComputeOptions* compute_opts = nullptr; - for (size_t i = 0; i < output_types().size(); ++i) { - Tensor* out = nullptr; - TensorShape output_shape = output_shapes_.at(i); - output_shape.InsertDim(0, batch_size); - OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done); - } + OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done); - SetRunOptions(ctx, &opts_, false); + Status s = SetupOutputs(ctx, compute_opts); + if (!s.ok()) delete compute_opts; + OP_REQUIRES_OK_ASYNC(ctx, s, done); + + FunctionLibraryRuntime::Options opts; + SetRunOptions(ctx, &opts, false); // Run loop StatusCallback callback = std::bind( - [](OpKernelContext* ctx, std::vector<Tensor>* args, - std::vector<TensorShape>* arg_shapes, OpOutputList* output, + [](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done, const Status& status) { - delete args; - delete arg_shapes; - delete output; + delete compute_opts; ctx->SetStatus(status); done(); }, - ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1); + ctx, compute_opts, std::move(done), std::placeholders::_1); auto* refcounted = new ReffedStatusCallback(std::move(callback)); - for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) { - // Start from i = 1 because refcounted is initialized with refcount = 1 + CancellationManager* parent_mgr = ctx->cancellation_manager(); + + for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) { + // We use a different cancellation manager each time the function is run + // to avoid the race condition between a function run error and other + // functions being cancelled as a result. + CancellationManager* c_mgr = new CancellationManager; + CancellationToken token = parent_mgr->get_cancellation_token(); + const bool success = parent_mgr->RegisterCallback( + token, [c_mgr]() { c_mgr->StartCancel(); }); + + opts.cancellation_manager = c_mgr; + if (!success) { + delete c_mgr; + refcounted->UpdateStatus(errors::Cancelled( + "MapDefunOp functions cancelled because parent graph cancelled")); + break; + } + + auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i); + refcounted->Ref(); + ctx->function_library()->Run(opts, func_handle_, call_frame, + [call_frame, refcounted, c_mgr, parent_mgr, + token](const Status& func_status) { + parent_mgr->DeregisterCallback(token); + delete c_mgr; + delete call_frame; + refcounted->UpdateStatus(func_status); + refcounted->Unref(); + }); } - for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) { - auto* call_frame = - new MapFunctionCallFrame(*args, *arg_shapes, output, this, i); - ctx->function_library()->Run( - opts_, func_handle_, call_frame, - [call_frame, refcounted](const Status& func_status) { - delete call_frame; - refcounted->UpdateStatus(func_status); - refcounted->Unref(); - }); - } + + // Unref 1 because refcounted is initialized with refcount = 1 + refcounted->Unref(); } private: FunctionLibraryRuntime::Handle func_handle_; - FunctionLibraryRuntime::Options opts_; - std::vector<TensorShape> output_shapes_; + std::vector<PartialTensorShape> output_shapes_; + + struct ComputeOptions { + // These vary per MapDefunOp::ComputeAsync call, but must persist until + // all calls to the function are complete. This struct also encapsulates + // all the components that need to be passed to each MapFunctionCallFrame. + + const std::vector<Tensor> args; + const std::vector<TensorShape> arg_shapes; + const int64 batch_size; + + // Output of a compute call + std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu); + OpOutputList output GUARDED_BY(mu); + mutex mu; + + // Create a copy of output_shapes because every `Compute` may expect a + // different output shape. + ComputeOptions(std::vector<Tensor> args, + std::vector<TensorShape> arg_shapes, int64 batch_size, + const std::vector<PartialTensorShape>& output_shapes_attr) + : args(std::move(args)), + arg_shapes(std::move(arg_shapes)), + batch_size(batch_size), + output_shapes(output_shapes_attr) {} + }; + + // Get inputs to Compute and check that they are valid. + Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { + int64 batch_size = + ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } + } + + std::vector<Tensor> args; + std::vector<TensorShape> arg_shapes; + args.reserve(ctx->num_inputs()); + arg_shapes.reserve(ctx->num_inputs()); + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + args.push_back(ctx->input(i)); + arg_shapes.push_back(ctx->input(i).shape()); + arg_shapes.at(i).RemoveDim(0); + } + + *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), + batch_size, output_shapes_); + return Status::OK(); + } + + Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { + mutex_lock l(opts->mu); + TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output)); + + for (size_t i = 0; i < output_types().size(); ++i) { + if (output_shapes_.at(i).IsFullyDefined()) { + Tensor* out = nullptr; + TensorShape output_shape; + output_shapes_.at(i).AsTensorShape(&output_shape); + output_shape.InsertDim(0, opts->batch_size); + TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out)); + } + } + return Status::OK(); + } class MapFunctionCallFrame : public CallFrameInterface { public: - MapFunctionCallFrame(const std::vector<Tensor>& args, - const std::vector<TensorShape>& arg_shapes, - OpOutputList* output, OpKernel* kernel, size_t iter) - : args_(args), - arg_shapes_(arg_shapes), - output_(output), - kernel_(kernel), - iter_(iter) {} + MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel, + size_t iter) + : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {} ~MapFunctionCallFrame() override {} - size_t num_args() const override { return args_.size(); } + size_t num_args() const override { return compute_opts_->args.size(); } + size_t num_retvals() const override { return static_cast<size_t>(kernel_->num_outputs()); } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= args_.size()) { + if (index < 0 || index >= compute_opts_->args.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } - bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1), - arg_shapes_.at(index)); + bool result = + val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); } else if (!val->IsAligned()) { @@ -179,18 +264,39 @@ class MapDefunOp : public AsyncOpKernel { "output: ", index); } - return batch_util::CopyElementToSlice(val, (*output_)[index], iter_); + { // Locking scope + mutex_lock l(compute_opts_->mu); + if (!compute_opts_->output_shapes.at(index).IsCompatibleWith( + val.shape())) { + return errors::InvalidArgument( + "Mismatch in function retval shape, ", val.shape(), + ", and expected output shape, ", + compute_opts_->output_shapes.at(index).DebugString(), "."); + } + if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) { + // Given val, we have new information about the output shape at + // this index. Store the shape and allocate the output accordingly. + compute_opts_->output_shapes.at(index) = val.shape(); + + Tensor* out = nullptr; + TensorShape actual_shape = val.shape(); + actual_shape.InsertDim(0, compute_opts_->batch_size); + TF_RETURN_IF_ERROR( + compute_opts_->output.allocate(index, actual_shape, &out)); + } + return batch_util::CopyElementToSlice( + val, (compute_opts_->output)[index], iter_); + } } private: - const std::vector<Tensor>& args_; - const std::vector<TensorShape>& arg_shapes_; - OpOutputList* output_; + ComputeOptions* const compute_opts_; // Not owned const OpKernel* kernel_; const size_t iter_; }; -}; // namespace +}; REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc new file mode 100644 index 0000000000..9aa505f4f1 --- /dev/null +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -0,0 +1,183 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" + +namespace tensorflow { +namespace data { +namespace { + +const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros; + +class ModelDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ModelDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Model")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "ModelDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + model_(std::make_shared<model::Model>()) {} + + ~Iterator() override { + // Signal the optimize thread to terminate it. We will then join that + // thread when we delete `this->optimize_thread_`. + mutex_lock l(mu_); + cancelled_ = true; + cond_var_.notify_all(); + } + + Status Initialize(IteratorContext* ctx) override { + IteratorContext ctx_with_model(CreateParams(ctx)); + return dataset()->input_->MakeIterator(&ctx_with_model, prefix(), + &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx)); + IteratorContext ctx_with_model(CreateParams(ctx)); + return input_impl_->GetNext(&ctx_with_model, out_tensors, + end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + IteratorContext::Params CreateParams(IteratorContext* ctx) { + IteratorContext::Params params = ctx->params(); + params.model = model_; + return params; + } + + private: + Status EnsureOptimizeThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!optimize_thread_) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + optimize_thread_.reset(ctx->env()->StartThread( + {}, "optimize_thread", + [this, new_ctx]() { OptimizeThread(new_ctx); })); + } + return Status::OK(); + } + + void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) { + int64 last_optimization_ms = 0; + int64 optimization_period_ms = 10; + while (true) { + { + mutex_lock l(mu_); + while (!cancelled_ && + last_optimization_ms + optimization_period_ms >= + ctx->env()->NowMicros() / EnvTime::kMillisToMicros) { + cond_var_.wait_for( + l, std::chrono::milliseconds( + last_optimization_ms + optimization_period_ms - + ctx->env()->NowMicros() / EnvTime::kMillisToMicros)); + } + if (cancelled_) return; + } + model_->Optimize(port::NumSchedulableCPUs()); + // Exponentially increase the period of running the optimization + // until a threshold is reached. + if (optimization_period_ms < kOptimizationPeriodThresholdMs) { + if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) { + optimization_period_ms <<= 1; + } else { + optimization_period_ms = kOptimizationPeriodThresholdMs; + } + } + last_optimization_ms = + ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + } + } + + mutex mu_; + condition_variable cond_var_; + std::shared_ptr<model::Model> model_; + std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_) = false; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* input_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU), + ModelDatasetOp); +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc new file mode 100644 index 0000000000..5f143967d9 --- /dev/null +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -0,0 +1,633 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <deque> + +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace data { +namespace { + +struct HostBufferElement { + Status status; + bool end_of_sequence; + std::vector<Tensor> value; +}; + +using MultiDeviceIteratorCallback = + std::function<void(const HostBufferElement&)>; + +class MultiDeviceIterator : public ResourceBase { + public: + MultiDeviceIterator(const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes, + const std::vector<string>& devices, + std::unique_ptr<FunctionLibraryDefinition> flib_def, + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, + FunctionLibraryRuntime* lib) + : output_types_(output_types), + output_shapes_(output_shapes), + devices_(devices), + flib_def_(std::move(flib_def)), + pflr_(std::move(pflr)), + lib_(lib) { + DCHECK(lib_ != nullptr); + } + + string DebugString() override { + return strings::StrCat("MultiDeviceIterator for ", devices_.size(), + " devices"); + } + + Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size, + int64* incarnation_id) { + if (iterator) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_types_, iterator->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, iterator->output_shapes())); + } + + mutex_lock l(mu_); + if (multi_device_buffer_) { + multi_device_buffer_->Reset(); + } + + ++incarnation_id_; + *incarnation_id = incarnation_id_; + + multi_device_buffer_.reset( + new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_, + std::move(iterator))); + return Status::OK(); + } + + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + if (lib_ != nullptr) { + ctx->set_lib(lib_); + } + tf_shared_lock l(mu_); + multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id, + std::move(callback)); + } + + const DataTypeVector& output_types() const { return output_types_; } + + const std::vector<PartialTensorShape>& output_shapes() const { + return output_shapes_; + } + + std::shared_ptr<const FunctionLibraryDefinition> function_library() { + tf_shared_lock l(mu_); + return lib_def_; + } + + FunctionLibraryRuntime* const lib() { + tf_shared_lock l(mu_); + return lib_; + } + + private: + // A private class that uses a background thread to keep a per device buffer + // full. + class MultiDeviceBuffer { + public: + MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id, + std::unique_ptr<IteratorBase> host_iterator) + : buffer_(size), + size_(size), + max_buffer_size_(max_buffer_size), + incarnation_id_(incarnation_id), + host_iterator_(std::move(host_iterator)) {} + + ~MultiDeviceBuffer() { + { + mutex_lock l(mu_); + if (!background_thread_started_) return; + } + Reset(); + } + + void Reset() LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (background_thread_finished_) { + return; + } + + cancelled_ = true; + // Wake up the background thread. + for (int i = 0; i < size_; ++i) { + buffer_[i].cond_var.notify_all(); + } + + // Make sure background thread has finished first. + while (!background_thread_finished_) { + shutdown_cond_var_.wait(l); + } + } + RunPendingCallbacks(); + } + + void GetNextFromShard(IteratorContext* ctx, int shard_num, + int64 incarnation_id, + MultiDeviceIteratorCallback callback) { + HostBufferElement elem; + if (incarnation_id_ != incarnation_id) { + elem.status = errors::InvalidArgument("Invalid incarnation id"); + callback(elem); + return; + } + + bool produced_output = false; + { + mutex_lock l(mu_); + if (cancelled_) { + elem.status = errors::Cancelled("Cancelled Multidevice iterator"); + callback(elem); + return; + } + + EnsureBackgroundThreadStarted(ctx); + + if (!buffer_[shard_num].data.empty()) { + produced_output = true; + std::swap(elem, buffer_[shard_num].data.front()); + buffer_[shard_num].data.pop_front(); + // Wake up background thread if it is blocked on this element. + if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) { + buffer_[shard_num].cond_var.notify_all(); + } + } else { + if (background_thread_finished_) { + produced_output = true; + elem.end_of_sequence = true; + } else { + buffer_[shard_num].callbacks.push_back(std::move(callback)); + callback = nullptr; + } + } + } + + if (produced_output) { + callback(elem); + } + } + + private: + void EnsureBackgroundThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!background_thread_) { + background_thread_.reset(ctx->env()->StartThread( + {}, "multi_device_iterator_background_thread", + std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread, + this, new IteratorContext(*ctx)))); + } + } + + void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) { + // Run all remaining callbacks. + std::vector<MultiDeviceIteratorCallback> cancellation_callbacks; + std::vector<HostBufferElement> cancellation_elements; + { + mutex_lock l(mu_); + + for (int i = 0; i < size_; ++i) { + while (!buffer_[i].callbacks.empty()) { + if (buffer_[i].data.empty()) { + HostBufferElement elem; + elem.status = + errors::Cancelled("Cancelled and buffer not filled."); + cancellation_elements.push_back(std::move(elem)); + } else { + cancellation_elements.push_back( + std::move(buffer_[i].data.front())); + buffer_[i].data.pop_front(); + } + cancellation_callbacks.push_back( + std::move(buffer_[i].callbacks.front())); + buffer_[i].callbacks.pop_front(); + } + } + } + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_elements[i]); + } + } + + void BackgroundThread(IteratorContext* ctx) { + { + mutex_lock l(mu_); + background_thread_started_ = true; + } + std::unique_ptr<IteratorContext> cleanup(ctx); + int shard_to_fetch = 0; + while (true) { + HostBufferElement elem; + MultiDeviceIteratorCallback callback = nullptr; + bool end_of_iterator = false; + + { + mutex_lock l(mu_); + while (!cancelled_ && + buffer_[shard_to_fetch].data.size() >= max_buffer_size_) { + buffer_[shard_to_fetch].cond_var.wait(l); + } + + if (cancelled_) { + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + return; + } + } + + elem.status = + host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence); + + if (elem.status.ok() && elem.end_of_sequence) { + end_of_iterator = true; + } + + { + mutex_lock l(mu_); + // Try to find a callback, else just push stuff into buffer. + if (!buffer_[shard_to_fetch].callbacks.empty()) { + callback = buffer_[shard_to_fetch].callbacks.front(); + buffer_[shard_to_fetch].callbacks.pop_front(); + } else { + buffer_[shard_to_fetch].data.push_back(std::move(elem)); + elem = HostBufferElement(); + } + } + + if (callback) { + (*ctx->runner())(std::bind(std::move(callback), std::move(elem))); + } + + // Finish off the thread if we reach the end of the iterator. Runs + // pending callbacks. + if (end_of_iterator) { + { + mutex_lock l(mu_); + background_thread_finished_ = true; + shutdown_cond_var_.notify_all(); + } + RunPendingCallbacks(); + return; + } + shard_to_fetch = (shard_to_fetch + 1) % size_; + } + } + + struct HostBuffer { + condition_variable cond_var; + std::deque<HostBufferElement> data; + std::deque<MultiDeviceIteratorCallback> callbacks; + }; + + mutex mu_; + std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_); + bool background_thread_finished_ GUARDED_BY(mu_) = false; + bool background_thread_started_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(mu_) = false; + condition_variable shutdown_cond_var_ GUARDED_BY(mu_); + + std::vector<HostBuffer> buffer_; + + const size_t size_; + const int64 max_buffer_size_; + const int64 incarnation_id_; + const std::unique_ptr<IteratorBase> host_iterator_; + }; + + mutex mu_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + const std::vector<string> devices_; + const std::unique_ptr<FunctionLibraryDefinition> flib_def_; + const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + FunctionLibraryRuntime* const lib_ = nullptr; // not owned. + std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_); + + int64 incarnation_id_ GUARDED_BY(mu_) = 0; + std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_); +}; + +// Just creates a MultiDeviceIterator and returns it. +class MultiDeviceIteratorHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_)); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. + ~MultiDeviceIteratorHandleOp() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete<MultiDeviceIterator>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + FunctionLibraryRuntime* lib; + std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr); + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr); + OP_REQUIRES_OK(context, context->function_library()->Clone( + &flib_def, &pflr, &lib)); + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + MultiDeviceIterator* resource; + OP_REQUIRES_OK( + context, + mgr->LookupOrCreate<MultiDeviceIterator>( + cinfo_.container(), cinfo_.name(), &resource, + [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new MultiDeviceIterator( + output_types_, output_shapes_, devices_, + std::move(flib_def), std::move(pflr), lib); + return Status::OK(); + })); + + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + resource_ = resource; + } + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<MultiDeviceIterator>())); + } + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(MultiDeviceIterator* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_types_, resource->output_types())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); + } + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + const int graph_def_version_; + string name_; + string container_; + std::vector<string> devices_; +}; + +REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU), + MultiDeviceIteratorHandleOp); + +// Calls init on the MultiDeviceIterator. +class MultiDeviceIteratorInitOp : public OpKernel { + public: + explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor* tensor_max_buffer_size; + OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size)); + int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()(); + + DatasetBase* dataset; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset)); + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 1), &resource)); + core::ScopedUnref unref(resource); + + std::unique_ptr<IteratorBase> iterator; + IteratorContext iter_ctx(ctx); + iter_ctx.set_lib(resource->lib()); + OP_REQUIRES_OK( + ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator)); + int64 incarnation_id; + OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size, + &incarnation_id)); + Tensor tensor_incarnation_id(DT_INT64, TensorShape({})); + tensor_incarnation_id.scalar<int64>()() = incarnation_id; + OP_REQUIRES_OK(ctx, + ctx->set_output("incarnation_id", tensor_incarnation_id)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU), + MultiDeviceIteratorInitOp); + +// Calls GetNextFromShard(shard) and returns a vector of Tensors as output. +// TODO(rohanj): Implement using BackgroundWorker that Derek built? +class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel { + public: + explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + thread_pool_(new thread::ThreadPool( + ctx->env(), ThreadOptions(), + strings::StrCat("multi_device_iterator_get_next_thread_", + SanitizeThreadSuffix(name())), + 1 /* num_threads */, false /* low_latency_hint */)) {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + const Tensor* tensor_shard_num; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done); + int32 shard_num = tensor_shard_num->scalar<int32>()(); + + const Tensor* tensor_incarnation_id; + OP_REQUIRES_OK_ASYNC( + ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done); + int64 incarnation_id = tensor_incarnation_id->scalar<int64>()(); + + MultiDeviceIterator* iterator; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done); + thread_pool_->Schedule(std::bind( + [ctx, iterator, shard_num, incarnation_id](DoneCallback done) { + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = *(ctx->runner()); + params.function_library = iterator->function_library(); + DeviceBase* device = ctx->function_library()->device(); + params.allocator_getter = [device](AllocatorAttributes attrs) { + return device->GetAllocator(attrs); + }; + IteratorContext iter_ctx(std::move(params)); + + MultiDeviceIteratorCallback callback = std::bind( + [ctx](const HostBufferElement& elem, DoneCallback done) { + // iterator->Unref(); + Status s = elem.status; + if (!s.ok()) { + ctx->SetStatus(s); + } else if (elem.end_of_sequence) { + ctx->SetStatus(errors::OutOfRange("End of sequence")); + } else { + for (int i = 0; i < elem.value.size(); ++i) { + ctx->set_output(i, elem.value[i]); + } + } + done(); + }, + std::placeholders::_1, std::move(done)); + + iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id, + callback); + iterator->Unref(); + }, + std::move(done))); + } + + private: + std::unique_ptr<thread::ThreadPool> thread_pool_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU), + MultiDeviceIteratorGetNextFromShardOp); + +class MultiDeviceIteratorToStringHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& resource_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), + errors::InvalidArgument("resource_handle must be a scalar")); + + // Validate that the handle corresponds to a real resource, and + // that it is an MultiDeviceIterator. + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &resource)); + resource->Unref(); + + Tensor* string_handle_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &string_handle_t)); + string_handle_t->scalar<string>()() = + resource_handle_t.scalar<ResourceHandle>()().SerializeAsString(); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU), + MultiDeviceIteratorToStringHandleOp); + +class MultiDeviceIteratorFromStringHandleOp : public OpKernel { + public: + explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES( + ctx, + output_types_.empty() || output_shapes_.empty() || + output_types_.size() == output_shapes_.size(), + errors::InvalidArgument("If both 'output_types' and 'output_shapes' " + "are set, they must have the same length.")); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor& string_handle_t = ctx->input(0); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()), + errors::InvalidArgument("string_handle must be a scalar")); + + ResourceHandle resource_handle; + OP_REQUIRES( + ctx, + resource_handle.ParseFromString(string_handle_t.scalar<string>()()), + errors::InvalidArgument( + "Could not parse string_handle as a valid ResourceHandle")); + + OP_REQUIRES( + ctx, resource_handle.device() == ctx->device()->attributes().name(), + errors::InvalidArgument("Attempted create an iterator on device \"", + ctx->device()->attributes().name(), + "\" from handle defined on device \"", + resource_handle.device(), "\"")); + + // Validate that the handle corresponds to a real resource, and + // that it is an MultiDeviceIterator. + MultiDeviceIterator* resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource)); + core::ScopedUnref unref_iterator(resource); + if (!output_types_.empty()) { + OP_REQUIRES_OK(ctx, + VerifyTypesMatch(output_types_, resource->output_types())); + } + if (!output_shapes_.empty()) { + OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_, + resource->output_shapes())); + } + + Tensor* resource_handle_t; + OP_REQUIRES_OK( + ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t)); + resource_handle_t->scalar<ResourceHandle>()() = resource_handle; + } + + private: + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER( + Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU), + MultiDeviceIteratorFromStringHandleOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 6263dc3cf8..d5b725eac9 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -270,4 +271,5 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), OptimizeDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index cfac45dbc7..2ab5c83082 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" namespace tensorflow { +namespace data { namespace { const char kOptionalVariantTypeName[] = "tensorflow::data::Optional"; @@ -107,11 +108,8 @@ class OptionalFromValueOp : public OpKernel { void Compute(OpKernelContext* ctx) override { OpInputList components_input; OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input)); - std::vector<Tensor> components; - components.reserve(components_input.size()); - for (const Tensor& component_t : components_input) { - components.push_back(component_t); - } + std::vector<Tensor> components(components_input.begin(), + components_input.end()); OP_REQUIRES_OK( ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components))); } @@ -215,6 +213,14 @@ static Status OptionalDeviceCopy( std::vector<Tensor> to_values; to_values.reserve(from_values.size()); for (const Tensor& t : from_values) { + if (t.dtype() == DT_VARIANT) { + // TODO(b/116349787): Implement support for nested variants. + return errors::Unimplemented( + "Support for copying nested variants to device has not yet been " + "implemented."); + } + } + for (const Tensor& t : from_values) { if (DMAHelper::CanUseDMA(&t)) { Tensor tmp(t.dtype()); TF_RETURN_IF_ERROR(copy(t, &tmp)); @@ -230,10 +236,9 @@ static Status OptionalDeviceCopy( return Status::OK(); } -#define REGISTER_OPTIONAL_COPY(DIRECTION) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - OptionalVariant, DIRECTION, kOptionalVariantTypeName, \ - OptionalDeviceCopy) +#define REGISTER_OPTIONAL_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + OptionalVariant, DIRECTION, OptionalDeviceCopy) REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); @@ -267,4 +272,5 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h index 6f25567678..2cbf2933f5 100644 --- a/tensorflow/core/kernels/data/optional_ops.h +++ b/tensorflow/core/kernels/data/optional_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" namespace tensorflow { +namespace data { // Stores a DT_VARIANT value representing an Optional with the given value // in the `output_index`^th output of the given kernel execution context. @@ -31,6 +32,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, // in the `output_index`^th output of the given kernel execution context. Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index be45eac46e..7b01c3b4e0 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -382,5 +383,5 @@ REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU), PaddedBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index f6b3fd97e3..2e6e0465f7 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <atomic> #include <deque> +#include <utility> #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -21,11 +23,12 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -34,8 +37,7 @@ namespace { class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -43,14 +45,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 cycle_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "cycle_length", &cycle_length)); @@ -82,8 +76,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - interleave_func_, std::move(other_arguments), &captured_func)); + ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, interleave_func_, std::move(captured_func), @@ -125,6 +119,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector<PartialTensorShape>& output_shapes() const override { return output_shapes_; } @@ -250,6 +245,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { + AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -349,11 +345,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (must_wait_for_input) { // Wait for elements to become available. + RecordStop(ctx); if (dataset()->sloppy_) { sloppy_cond_var_.wait(l); } else { workers_[interleave_indices_[next_index_]].cond_var.wait(l); } + RecordStart(ctx); } } return errors::Cancelled( @@ -482,10 +480,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("worker_threads_running"))) { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", - std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i))); + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } return Status::OK(); @@ -581,10 +579,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } workers_[i].SetInputs(s, std::move(args)); + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", - std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i))); + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -599,7 +597,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } // Produces elements into the worker's output buffers. - void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) { + void WorkerThread(const std::shared_ptr<IteratorContext>& ctx, + const int64 thread_index) { // Notes on checkpointing thread local state, i.e., `WorkerThreadState`: // // 1. Any local state that may need to be checkpointed should be kept @@ -620,10 +619,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // std::function arguments are copy-constructable, so we pass raw // pointers, and then immediately wrap them to ensure correct ownership. - std::unique_ptr<IteratorContext> ctx(ctx_ptr); - auto cleanup = gtl::MakeCleanup([this, thread_index] { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] { mutex_lock l(mu_); workers_[thread_index].cond_var.notify_all(); + RecordStop(ctx.get()); }); bool make_new_iterator; { @@ -649,9 +649,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // 1. Build a new iterator or use the existing one. if (make_new_iterator) { // 1a. Get new input tensors or use the exiting ones. - bool read_new_input; - { tf_shared_lock l(ckpt_mu_); // worker_thread_states_[thread_index].input will be non-empty @@ -663,7 +661,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (read_new_input) { mutex_lock l(mu_); while (!cancelled_ && !workers_[thread_index].is_producing) { + RecordStop(ctx.get()); workers_[thread_index].cond_var.wait(l); + RecordStart(ctx.get()); } if (cancelled_) return; // Copy the input tensors so that we do not need to block on `mu_` @@ -684,7 +684,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { { tf_shared_lock l(ckpt_mu_); worker_thread_states_[thread_index].iterator_creation_status = - dataset::MakeIteratorFromInputElement( + MakeIteratorFromInputElement( ctx.get(), worker_thread_states_[thread_index].input, thread_index, dataset()->captured_func_.get(), prefix(), &worker_thread_states_[thread_index].iterator); @@ -713,7 +713,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Wait for space in the prefetch queue. while (!cancelled_ && workers_[thread_index].outputs.size() == dataset()->buffer_output_elements_) { + RecordStop(ctx.get()); workers_[thread_index].cond_var.wait(l); + RecordStart(ctx.get()); } if (cancelled_) return; tf_shared_lock ckpt_l(ckpt_mu_); @@ -762,7 +764,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Wait for space in the prefetch queue. while (!cancelled_ && workers_[thread_index].outputs.size() == dataset()->buffer_output_elements_) { + RecordStop(ctx.get()); workers_[thread_index].cond_var.wait(l); + RecordStart(ctx.get()); } if (cancelled_) return; @@ -914,7 +918,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_thread_states_[index].iterator.reset(); } else { std::unique_ptr<IteratorBase> iterator; - Status s = dataset::MakeIteratorFromInputElement( + Status s = MakeIteratorFromInputElement( ctx, worker_thread_states_[index].input, index, dataset()->captured_func_.get(), prefix(), &iterator); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator)); @@ -1058,7 +1062,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList interleave_func_; @@ -1067,6 +1070,600 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), ParallelInterleaveDatasetOp); -} // namespace +// The motivation for creating an alternative implementation of parallel +// interleave is to decouple the degree of parallelism from the cycle length. +// This makes it possible to change the degree of parallelism (e.g. through +// auto-tuning) without changing the cycle length (which would change the order +// in which elements are produced). +// +// Furthermore, this class favors modularity over extended functionality. In +// particular, it refrains from implementing configurable buffering of output +// elements and prefetching of input iterators, relying on other parts of +// tf.data to provide this functionality if necessary. +// +// The above design choices were made with automated optimizations in mind, +// isolating the degree of parallelism as the single tunable knob of this +// implementation. +class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { + public: + explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 cycle_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + int64 num_parallel_calls; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", + &num_parallel_calls)); + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + OP_REQUIRES( + ctx, num_parallel_calls <= cycle_length, + errors::InvalidArgument( + "num_parallel_calls must less than or equal to cycle_length.")); + + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + &captured_func)); + + *output = new Dataset(ctx, input, interleave_func_, + std::move(captured_func), cycle_length, block_length, + num_parallel_calls, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, + int64 block_length, int64 num_parallel_calls, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + interleave_func_(func), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + num_parallel_calls_(num_parallel_calls), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::ParallelInterleaveV2")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ParallelInterleaveDatasetV2Op::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); + Node* input_node; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); + Node* cycle_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); + Node* block_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); + Node* num_parallel_calls_node; + TF_RETURN_IF_ERROR( + b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector<Node*> other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(interleave_func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {{0, input_node}, + {2, cycle_length_node}, + {3, block_length_node}, + {4, num_parallel_calls_node}}, + {{1, other_arguments}}, + {{"f", f}, {"Targuments", other_arguments_types_attr}}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + num_parallel_calls_(params.dataset->num_parallel_calls_), + args_list_(params.dataset->cycle_length_), + current_elements_(params.dataset->cycle_length_), + element_in_use_(params.dataset->cycle_length_, false), + thread_pool_(new thread::ThreadPool( + Env::Default(), ThreadOptions(), "parallel_interleave", + dataset()->cycle_length_ /* num_threads */, + false /* low_latency_hint */)) {} + + ~Iterator() override { + mutex_lock l(mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + AddTunableParameter(ctx, "parallelism", + &num_parallel_calls_ /* value */, 1 /* min */, + dataset()->cycle_length_ /* max */, &cond_var_); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } + AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr<InvocationResult> result; + do { + { + mutex_lock l(mu_); + EnsureRunnerThreadStarted(ctx); + while (invocation_results_.empty() && + (!end_of_input_ || num_open_ > 0)) { + RecordStop(ctx); + cond_var_.wait(l); + RecordStart(ctx); + } + if (!invocation_results_.empty()) { + std::swap(result, invocation_results_.front()); + invocation_results_.pop_front(); + } else { + *end_of_sequence = true; + return Status::OK(); + } + cond_var_.notify_all(); + } + RecordStop(ctx); + result->notification.WaitForNotification(); + RecordStart(ctx); + } while (result->skip); + + if (result->status.ok()) { + *out_tensors = std::move(result->return_values); + } + *end_of_sequence = false; + return result->status; + } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + CHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("invocation_results.size"), invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + std::shared_ptr<InvocationResult> result = invocation_results_[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat("invocation_results[", i, "][", j, "]")), + result->return_values[j])); + } + if (result->skip) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].skip")), + "")); + } + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cycle_index"), cycle_index_)); + if (end_of_input_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("end_of_input"), "")); + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("num_open"), num_open_)); + TF_RETURN_IF_ERROR(WriteCurrentElements(writer)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name("invocation_results.size"), &invocation_results_size)); + for (size_t i = 0; i < invocation_results_size; i++) { + std::shared_ptr<InvocationResult> result(new InvocationResult()); + invocation_results_.push_back(result); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + &size)); + num_return_values = static_cast<size_t>(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( + full_name( + strings::StrCat("invocation_results[", i, "].size")), + ": ", size, " is not a valid value of type size_t.")); + } + } + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + &result->return_values.back())); + } + result->skip = reader->Contains( + full_name(strings::StrCat("invocation_results[", i, "].skip"))); + result->notification.Notify(); + } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("cycle_index"), &cycle_index_)); + if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("num_open"), &num_open_)); + TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader)); + return Status::OK(); + } + + private: + struct InvocationResult { + Notification notification; // used for coordination with the consumer + Status status; // the invocation status + std::vector<Tensor> return_values; // the invocation result values + bool skip; // if set the result should be skipped + }; + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!runner_thread_) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + [this, new_ctx]() { RunnerThread(new_ctx); })); + } + } + + // Fetches up to `results.size()` outputs from the cycle element at + // position `cycle_index`. + // + // If end of input is encountered, the `skip` field of the invocation + // result is used to identify results that should be skipped. + void FetchOutputs( + const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, + const std::vector<std::shared_ptr<InvocationResult>>& results) + LOCKS_EXCLUDED(mu_) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); + bool end_of_input = false; + for (auto& result : results) { + if (!end_of_input) { + result->status = current_elements_[cycle_index]->GetNext( + ctx.get(), &result->return_values, &end_of_input); + } + if (end_of_input) { + result->skip = true; + } + result->notification.Notify(); + if (!result->status.ok()) { + break; + } + } + + // Release the ownership of the cycle element iterator, closing the + // iterator if end of input was encountered. + if (end_of_input) { + current_elements_[cycle_index].reset(); + } + mutex_lock l(mu_); + element_in_use_[cycle_index] = false; + num_calls_--; + if (end_of_input) { + args_list_[cycle_index].clear(); + num_open_--; + } + cond_var_.notify_all(); + } + + // Method responsible for 1) creating iterators out of input elements, 2) + // determining the order in which elements are fetched from the iterators, + // and 3) scheduling the fetching of the elements to a threadpool. + // + // This method runs in the `runner_thread` background thread. + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { + return element_in_use_[cycle_index_] || + num_calls_ >= num_parallel_calls_ || + invocation_results_.size() >= + dataset()->cycle_length_ * dataset()->block_length_; + }; + while (true) { + mutex_lock l(mu_); + // Wait until this thread is cancelled, the end of input has been + // reached, or the cycle element at the `cycle_index_` position is + // not in use and there is space in the `invocation_results_` queue. + while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) { + RecordStop(ctx.get()); + cond_var_.wait(l); + RecordStart(ctx.get()); + } + + if (cancelled_ || (end_of_input_ && num_open_ == 0)) { + return; + } + + while ((!end_of_input_ || num_open_ > 0) && !busy()) { + if (!current_elements_[cycle_index_]) { + // Try to create a new iterator from the next input element. + Status status = input_impl_->GetNext( + ctx.get(), &args_list_[cycle_index_], &end_of_input_); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + if (!end_of_input_) { + Status status = MakeIteratorFromInputElement( + ctx.get(), args_list_[cycle_index_], cycle_index_, + dataset()->captured_func_.get(), prefix(), + ¤t_elements_[cycle_index_]); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + ++num_open_; + } + } + if (current_elements_[cycle_index_]) { + // Pre-allocate invocation results for outputs to be fetched + // and then fetch the outputs asynchronously. + std::vector<std::shared_ptr<InvocationResult>> results; + results.reserve(dataset()->block_length_); + for (int i = 0; i < dataset()->block_length_; ++i) { + invocation_results_.emplace_back(new InvocationResult()); + results.push_back(invocation_results_.back()); + } + num_calls_++; + element_in_use_[cycle_index_] = true; + thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this, + ctx, cycle_index_, + std::move(results))); + } + cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; + } + cond_var_.notify_all(); + } + } + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].error_message")); + } + + Status WriteCurrentElements(IteratorStateWriter* writer) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (current_elements_[idx]) { + TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + args_list_[idx].size())); + for (int i = 0; i < args_list_[idx].size(); i++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + args_list_[idx][i])); + } + } + } + return Status::OK(); + } + + Status ReadCurrentElements(IteratorContext* ctx, + IteratorStateReader* reader) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (reader->Contains( + full_name(strings::StrCat("args_size[", idx, "]")))) { + int64 args_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + &args_size)); + args_list_[idx].resize(args_size); + for (int i = 0; i < args_size; i++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + &args_list_[idx][i])); + } + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, args_list_[idx], idx, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[idx])); + TF_RETURN_IF_ERROR( + RestoreInput(ctx, reader, current_elements_[idx])); + } else { + current_elements_[idx].reset(); + } + } + return Status::OK(); + } + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. + mutex mu_; + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. In particular, the runner thread should only + // schedule new calls when the number of in-flight calls is less than the + // user specified level of parallelism, there are slots available in the + // `invocation_results_` buffer, the current cycle element is not in use, + // and there are elements left to be fetched. + condition_variable cond_var_; + + // Identifies the maximum number of parallel calls. + std::atomic<int64> num_parallel_calls_; + + // Iterator for input elements. + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + + // Identifies current cycle element. + int64 cycle_index_ = 0; + + // Arguments for creating an iterator for cycle elements. + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + + // Iterators for the current cycle elements. Concurrent access is + // protected by `element_in_use_`. + std::vector<std::unique_ptr<IteratorBase>> current_elements_; + + // Identifies cycle elements that are in use by worker threads. + std::vector<bool> element_in_use_ GUARDED_BY(mu_); + + // Buffer for storing the invocation results. + std::deque<std::shared_ptr<InvocationResult>> invocation_results_ + GUARDED_BY(mu_); + + // Identifies whether end of input has been reached. + bool end_of_input_ GUARDED_BY(mu_) = false; + + // Identifies the number of open iterators. + int64 num_open_ GUARDED_BY(mu_) = 0; + + // Identifies the number of outstanding calls. + int64 num_calls_ GUARDED_BY(mu_) = 0; + + std::unique_ptr<thread::ThreadPool> thread_pool_; + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + + // Identifies whether background activity should be cancelled. + bool cancelled_ GUARDED_BY(mu_) = false; + }; + + const DatasetBase* const input_; + const NameAttrList interleave_func_; + const std::unique_ptr<CapturedFunction> captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const int64 num_parallel_calls_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList interleave_func_; +}; + +REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU), + ParallelInterleaveDatasetV2Op); + +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index bff54813d6..6abe6c8338 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -33,37 +33,32 @@ namespace { class ParallelMapDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelMapDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int32 num_parallel_calls; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + use_inter_op_parallelism_, + &captured_func)); *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, - output_shapes_, std::move(captured_func)); + output_shapes_, use_inter_op_parallelism_, + std::move(captured_func)); } private: @@ -73,6 +68,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, int32 num_parallel_calls, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, + bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction> captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), @@ -80,6 +76,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { num_parallel_calls_(num_parallel_calls), output_types_(output_types), output_shapes_(output_shapes), + use_inter_op_parallelism_(use_inter_op_parallelism), captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -92,16 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - auto map_func = [this](IteratorContext* ctx, + const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); + ParallelMapIteratorFunction map_func = + [this, new_prefix](IteratorContext* ctx, std::vector<Tensor> input_element, std::vector<Tensor>* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done)); - }; + captured_func_->RunAsync(ctx, std::move(input_element), result, + std::move(done), new_prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func]( + IteratorContext* ctx, std::vector<Tensor> input_element, + std::vector<Tensor>* result, StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), + result, std::move(done))); + }; + } - return NewParallelMapIterator( - {this, strings::StrCat(prefix, "::ParallelMap")}, input_, - std::move(init_func), std::move(map_func), num_parallel_calls_); + return NewParallelMapIterator({this, new_prefix}, input_, + std::move(init_func), std::move(map_func), + num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -167,12 +174,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const int32 num_parallel_calls_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const bool use_inter_op_parallelism_; const std::unique_ptr<CapturedFunction> captured_func_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; + bool use_inter_op_parallelism_; NameAttrList func_; }; @@ -180,5 +188,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU), ParallelMapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 61f8139b9e..ee20249bfe 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/data/parallel_map_iterator.h" +#include <atomic> #include <deque> #include <functional> #include <utility> #include <vector> +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/cpu_info.h" + namespace tensorflow { +namespace data { namespace { class ParallelMapIterator : public DatasetBaseIterator { @@ -36,11 +41,6 @@ class ParallelMapIterator : public DatasetBaseIterator { num_parallel_calls_(num_parallel_calls) {} ~ParallelMapIterator() override { - // TODO(mrry): Replace this cancellation logic with a - // CancellationManager. The syntax would be more heavyweight, - // but it would be possible to thread a cancellation manager - // through the IteratorContext to upstream, - // potentially-blocking iterators, when we add these. mutex_lock l(mu_); // Cancel the runner thread. cancelled_ = true; @@ -52,6 +52,17 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and + // use it here for the maximum. + AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */, + 1 /* min */, port::NumSchedulableCPUs() /* max */, + &cond_var_); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } TF_RETURN_IF_ERROR( input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); if (init_func_) { @@ -67,13 +78,17 @@ class ParallelMapIterator : public DatasetBaseIterator { mutex_lock l(mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty()) { + RecordStop(ctx); cond_var_.wait(l); + RecordStart(ctx); } std::swap(result, invocation_results_.front()); invocation_results_.pop_front(); + cond_var_.notify_all(); } - cond_var_.notify_all(); + RecordStop(ctx); result->notification.WaitForNotification(); + RecordStart(ctx); return ProcessResult(result, out_tensors, end_of_sequence); } @@ -86,9 +101,8 @@ class ParallelMapIterator : public DatasetBaseIterator { } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("invocation_results.size"), - invocation_results_.size())); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"), + invocation_results_.size())); for (size_t i = 0; i < invocation_results_.size(); i++) { std::shared_ptr<InvocationResult> result = invocation_results_[i]; TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); @@ -175,9 +189,9 @@ class ParallelMapIterator : public DatasetBaseIterator { { mutex_lock l(mu_); num_calls_--; + cond_var_.notify_all(); } result->notification.Notify(); - cond_var_.notify_all(); } void CallFunction(const std::shared_ptr<IteratorContext>& ctx, @@ -192,9 +206,8 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } - // Call `func_(input_element)`, store the result in - // `result->return_values`, and notify `result->notification` to unblock - // a consumer. + // Call `func_(input_element)`, store the result in `result->return_values`, + // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); @@ -204,8 +217,6 @@ class ParallelMapIterator : public DatasetBaseIterator { std::move(done)); } - int64 MaxInvocationResults() { return num_parallel_calls_; } - Status ProcessResult(const std::shared_ptr<InvocationResult>& result, std::vector<Tensor>* out_tensors, bool* end_of_sequence) { @@ -225,27 +236,33 @@ class ParallelMapIterator : public DatasetBaseIterator { } void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); std::vector<std::shared_ptr<InvocationResult>> new_calls; new_calls.reserve(num_parallel_calls_); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_; + return num_calls_ >= num_parallel_calls || + invocation_results_.size() >= num_parallel_calls; + }; while (true) { { mutex_lock l(mu_); - while (!cancelled_ && - (num_calls_ >= num_parallel_calls_ || - invocation_results_.size() >= MaxInvocationResults())) { + while (!cancelled_ && busy()) { + RecordStop(ctx.get()); cond_var_.wait(l); + RecordStart(ctx.get()); } if (cancelled_) { return; } - while (num_calls_ < num_parallel_calls_ && - invocation_results_.size() < MaxInvocationResults()) { + while (!busy()) { invocation_results_.emplace_back(new InvocationResult()); new_calls.push_back(invocation_results_.back()); num_calls_++; } + cond_var_.notify_all(); } - cond_var_.notify_all(); for (const auto& call : new_calls) { CallFunction(ctx, call); } @@ -294,7 +311,6 @@ class ParallelMapIterator : public DatasetBaseIterator { const DatasetBase* const input_dataset_; // Not owned. const std::function<Status(IteratorContext*)> init_func_; const ParallelMapIteratorFunction map_func_; - const int32 num_parallel_calls_; // Used for coordination between the main thread and the runner thread. mutex mu_; // Used for coordination between the main thread and the runner thread. In @@ -303,6 +319,8 @@ class ParallelMapIterator : public DatasetBaseIterator { // parallelism and there are slots available in the `invocation_results_` // buffer. condition_variable cond_var_; + // Identifies the maximum number of parallel calls. + std::atomic<int64> num_parallel_calls_; // Counts the number of outstanding calls. int64 num_calls_ GUARDED_BY(mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; @@ -333,4 +351,5 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( std::move(map_func), num_parallel_calls)); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index 7e6cc586f3..dc26c5cf25 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" namespace tensorflow { +namespace data { // A function that transforms elements of one dataset into another // asynchronously. The arguments are: @@ -47,6 +48,7 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, int32 num_parallel_calls); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_ diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 9057800d94..c28c06da62 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/example_proto_fast_parsing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "Expected len(dense_defaults) == len(dense_keys) but got: ", dense_default_tensors.size(), " vs. ", dense_keys_.size())); - std::vector<Tensor> dense_defaults; - dense_defaults.reserve(dense_default_tensors.size()); - for (const Tensor& dense_default_t : dense_default_tensors) { - dense_defaults.push_back(dense_default_t); - } + std::vector<Tensor> dense_defaults(dense_default_tensors.begin(), + dense_default_tensors.end()); for (int d = 0; d < dense_keys_.size(); ++d) { const Tensor& def_value = dense_defaults[d]; @@ -368,5 +365,5 @@ REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU), ParseExampleDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc index b3272f6bcd..da357339c9 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_autotuner.h" namespace tensorflow { +namespace data { PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) : buffer_limit_(initial_buffer_size) { @@ -25,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) } } +namespace { +// Determines what strategy to use for increasing the buffer size limit. For +// limits less than the threshold, an exponential increase is used, while for +// limits greater than or equal to the threshold, a linear increase is used. +size_t kBufferLimitThreshold = 2048; +} // namespace + void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { switch (mode_) { case Mode::kDisabled: @@ -36,11 +44,16 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { return; case Mode::kDownswing: if (current_buffer_size == 0) { - buffer_limit_ *= 2; // Increase the buffer size. + if (buffer_limit_ >= kBufferLimitThreshold) { + buffer_limit_ += kBufferLimitThreshold; + } else { + buffer_limit_ *= 2; + } mode_ = Mode::kUpswing; } return; } } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h index fa8a184072..8693205512 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.h +++ b/tensorflow/core/kernels/data/prefetch_autotuner.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { +namespace data { // PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator. // @@ -66,6 +67,7 @@ class PrefetchAutotuner { Mode mode_ = Mode::kDisabled; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_ diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc index 29a8cc50cd..cfc324fc7e 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace data { namespace { TEST(PrefetchAutotuner, Disabled) { @@ -79,4 +80,5 @@ TEST(PrefetchAutotuner, EnabledSteady) { } } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 50efbcbe2a..754ed772db 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -12,15 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <deque> - #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" +#include <deque> + #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -70,7 +74,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), - auto_tuner_(params.dataset->buffer_size_) {} + auto_tuner_(params.dataset->buffer_size_) { + std::vector<string> components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } ~Iterator() override { // Signal the prefetch thread to terminate it. We will then @@ -95,6 +103,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { + auto stats_aggregator = ctx->stats_aggregator(); { mutex_lock l(mu_); TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); @@ -103,7 +112,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ && auto_tuner_.buffer_limit() != 0) { auto_tuner_.RecordEmpty(); + RecordStop(ctx); cond_var_.wait(l); + RecordStart(ctx); } if (cancelled_) { @@ -112,7 +123,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } if (!buffer_.empty()) { - return Consume(out_tensors, end_of_sequence); + return Consume(out_tensors, end_of_sequence, stats_aggregator); } if (prefetch_thread_finished_) { @@ -125,6 +136,14 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex_lock parent_l(parent_mu_); mutex_lock l(mu_); + if (stats_aggregator) { + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::buffer_size"), + static_cast<float>(buffer_.size())); + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::buffer_capacity"), + static_cast<float>(auto_tuner_.buffer_limit())); + } return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); } @@ -200,14 +219,28 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { std::vector<Tensor> value; }; - Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence) + Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence, + const std::shared_ptr<StatsAggregator>& stats_aggregator) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (stats_aggregator) { + stats_aggregator->AddToHistogram( + strings::StrCat(prefix_end_, "::buffer_utilization"), + {static_cast<float>(buffer_.size()) / + static_cast<float>(auto_tuner_.buffer_limit())}); + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::buffer_size"), + static_cast<float>(buffer_.size())); + stats_aggregator->AddScalar( + strings::StrCat(prefix_end_, "::buffer_capacity"), + static_cast<float>(auto_tuner_.buffer_limit())); + } // A new element is available. Forward the status from computing it, and // (if we successfully got an element) the output values. Status s = buffer_.front().status; if (s.ok()) { *out_tensors = std::move(buffer_.front().value); } + auto_tuner_.RecordConsumption(buffer_.size()); buffer_.pop_front(); *end_of_sequence = false; @@ -223,10 +256,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status EnsurePrefetchThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!prefetch_thread_) { - prefetch_thread_.reset( - ctx->env()->StartThread({}, "prefetch_thread", - std::bind(&Iterator::PrefetchThread, this, - new IteratorContext(*ctx)))); + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + prefetch_thread_.reset(ctx->env()->StartThread( + {}, "prefetch_thread", + [this, new_ctx]() { PrefetchThread(new_ctx); })); } return Status::OK(); } @@ -235,8 +268,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { // buffer. // // It owns the iterator context passed to it. - void PrefetchThread(IteratorContext* ctx) { - std::unique_ptr<IteratorContext> cleanup(ctx); + void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) { + RecordStart(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); while (true) { std::vector<Tensor> value; @@ -244,7 +278,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { { mutex_lock l(mu_); while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) { + RecordStop(ctx.get()); cond_var_.wait(l); + RecordStart(ctx.get()); } if (cancelled_) { @@ -261,8 +297,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex_lock parent_l(parent_mu_); bool end_of_sequence; BufferElement buffer_element; - buffer_element.status = - input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence); + buffer_element.status = input_impl_->GetNext( + ctx.get(), &buffer_element.value, &end_of_sequence); if (buffer_element.status.ok() && end_of_sequence) { mutex_lock l(mu_); prefetch_thread_finished_ = true; @@ -324,6 +360,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex parent_mu_ ACQUIRED_BEFORE(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_); condition_variable cond_var_; + string prefix_end_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque<BufferElement> buffer_ GUARDED_BY(mu_); std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); @@ -346,6 +383,7 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, *output = new Dataset(ctx, input, buffer_size); } +namespace { REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU), PrefetchDatasetOp); REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") @@ -354,4 +392,7 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") .HostMemory("input_dataset") .HostMemory("handle"), PrefetchDatasetOp); +} // namespace + +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h index c40c4b00da..588fb25a06 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.h +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_autotuner.h" namespace tensorflow { +namespace data { class PrefetchDatasetOp : public UnaryDatasetOpKernel { public: @@ -34,6 +35,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { class Dataset; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index 7817170e73..044a791a3f 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -151,5 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU), RandomDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index aa38775125..89fbaae369 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -142,5 +142,5 @@ REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU), RangeDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index 086b552936..c474cb4773 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -691,5 +691,5 @@ REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU), TFRecordDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 299949b99f..94e96635ab 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -250,5 +250,5 @@ REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU), RepeatDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index fccad933d0..dbe31f37b8 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -45,23 +45,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { OpInputList initial_state_inputs; OP_REQUIRES_OK(ctx, ctx->input_list("initial_state", &initial_state_inputs)); - std::vector<Tensor> initial_state; - initial_state.reserve(initial_state_inputs.size()); - for (const Tensor& t : initial_state_inputs) { - initial_state.push_back(t); - } - - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } + std::vector<Tensor> initial_state(initial_state_inputs.begin(), + initial_state_inputs.end()); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(initial_state), std::move(captured_func), state_types_, output_types_, @@ -279,5 +268,5 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 93a4376836..66466d6a36 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { - +namespace data { namespace { const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. @@ -620,5 +620,5 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), ShuffleAndRepeatDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc new file mode 100644 index 0000000000..5b084a16f0 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -0,0 +1,380 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace data { +namespace { + +typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; +typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec; +typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; + +class SingleThreadedExecutorImpl : public Executor { + public: + explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params) + : params_(params) {} + + ~SingleThreadedExecutorImpl() override { + for (const KernelState& kernel_state : kernels_) { + params_.delete_kernel(kernel_state.kernel); + } + } + + Status Initialize(const Graph& graph) { + // Topologicially sort `graph` to get a sequence of OpKernels. + std::vector<Node*> ordered_nodes; + ordered_nodes.reserve(graph.num_nodes()); + GetReversePostOrder(graph, &ordered_nodes); + + if (ordered_nodes.size() != graph.num_nodes()) { + return errors::InvalidArgument("Graph had ", graph.num_nodes(), + " but reverse post-order had ", + ordered_nodes.size()); + } + + kernels_.resize(ordered_nodes.size()); + + std::unordered_map<Node*, size_t> node_to_index_map; + + // Create the kernel and input-related structures for each node in `graph`. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + node_to_index_map[n] = i; + + for (DataType dt : n->output_types()) { + if (IsRefType(dt)) { + return errors::Unimplemented( + "Single-threaded executor does not support reference-typed " + "edges."); + } + } + + if (n->IsControlFlow()) { + return errors::Unimplemented( + "Single-threaded executor does not support control flow."); + } + if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) { + return errors::Unimplemented( + "Single-threaded executor does not support partitioned graphs."); + } + if (n->IsCollective()) { + return errors::Unimplemented( + "Single-threaded executor does not support collective ops."); + } + + KernelState& kernel_state = kernels_[i]; + TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); + kernel_state.num_inputs = n->num_inputs(); + kernel_state.num_outputs = n->num_outputs(); + + if (i == 0) { + kernel_state.input_start_index = 0; + } else { + const KernelState& previous_kernel_state = kernels_[i - 1]; + kernel_state.input_start_index = + previous_kernel_state.input_start_index + + previous_kernel_state.num_inputs; + } + } + + // Build the mapping from each node output to the input slot for the + // corresponding destination node. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + KernelState& kernel_state = kernels_[i]; + kernel_state.output_locations.resize(kernel_state.num_outputs); + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge()) { + kernel_state.output_locations[e->src_output()].push_back( + kernels_[node_to_index_map[e->dst()]].input_start_index + + e->dst_input()); + } + } + + // Compute allocator attributes for each node output, and corresponding + // node input. + kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs); + AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data(); + + OpKernel* op_kernel = kernel_state.kernel; + for (int out = 0; out < n->num_outputs(); out++) { + DCHECK_LT(out, op_kernel->output_memory_types().size()); + bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; + if (on_host) { + AllocatorAttributes h; + h.set_on_host(on_host); + attrs[out].Merge(h); + } + } + } + + if (!kernels_.empty()) { + const KernelState& last_kernel_state = kernels_.back(); + total_num_inputs_ = + last_kernel_state.input_start_index + last_kernel_state.num_inputs; + input_alloc_attrs_.resize(total_num_inputs_); + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) { + for (size_t output_location : kernels_[i].output_locations[j]) { + input_alloc_attrs_[output_location] = + kernels_[i].output_alloc_attrs[j]; + } + } + } + } else { + total_num_inputs_ = 0; + } + return Status::OK(); + } + + // TODO(mrry): Consider specializing the implementation of Executor::Run() + // instead, to avoid unnecessary atomic operations in the callback when + // running synchronously. + void RunAsync(const Args& args, DoneCallback done) override { + // The inputs to each kernel are stored contiguously in `inputs`. + // + // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to + // determine the range of elements in this vector that correspond to + // the inputs of `kernels_[i]`. + // + // This vector has the following layout: + // + // * Kernel 0, input 0. + // * Kernel 0, input 1. + // * ... + // * Kernel 0, input `kernels_[0].num_inputs - 1`. + // * Kernel 1, input 0. + // * ... + // * Kernel 1, input `kernels_[1].num_inputs - 1`. + // * ... + // * Kernel `kernels_.size() - 1`, input 0. + // * ... + // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`. + // + // Note that kernels with zero inputs do not correspond to any elements in + // this vector. + // + // We use `ManualConstructor<Tensor>` to avoid the overhead of + // default-constructing an invalid `Tensor` for each slot at the beginning + // of execution: + // * Elements are initialized when the outputs of a kernel execution are + // propagated to the inputs of kernels that depend on them. + // * The elements corresponding to the inputs for kernel `i` are destroyed + // after kernel `i` executes. + // * In an error case (see below), we use the connectivity information in + // `KernelState::output_locations` to determine which locations have been + // initialized, and manually destroy them. + std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_); + + // TODO(mrry): Can we avoid copying into these vectors? Consider modifying + // OpKernelContext to take the TensorValueVec as a pointer into `inputs`. + TensorValueVec node_inputs; + DeviceContextVec input_device_contexts; + AllocatorAttributeVec input_alloc_attrs; + + // Prepare the parameters that will be the same for all kernels. + OpKernelContext::Params params; + params.step_id = args.step_id; + Device* device = params_.device; + params.device = device; + params.log_memory = false; // TODO(mrry): Too severe? + params.record_tensor_accesses = false; // TODO(mrry): Too severe? + params.rendezvous = args.rendezvous; + params.session_state = args.session_state; + params.tensor_store = args.tensor_store; + params.cancellation_manager = args.cancellation_manager; + // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor + // allocations that it performs. Consider specializing its handling in the + // executor. + params.call_frame = args.call_frame; + params.function_library = params_.function_library; + params.resource_manager = device->resource_manager(); + params.step_container = args.step_container; + params.slice_reader_cache = nullptr; // TODO(mrry): Too severe? + params.inputs = &node_inputs; + params.input_device_contexts = &input_device_contexts; + params.input_alloc_attrs = &input_alloc_attrs; + + Args::Runner runner_copy = args.runner; + params.runner = &runner_copy; + params.stats_collector = args.stats_collector; + + // NOTE(mrry): We are assuming that the graph is loopless and condless. + params.frame_iter = FrameAndIter(0, 0); + params.is_input_dead = false; + + // TODO(mrry): Add non-default device context inference. + params.op_device_context = nullptr; + // TODO(mrry): Consider implementing forwarding. + params.forward_from_array = nullptr; + + // Execute the kernels one-at-a-time in topological order. + for (size_t i = 0; i < kernels_.size(); ++i) { + const KernelState& kernel_state = kernels_[i]; + + // Prepare the per-kernel parameters. + const size_t input_start_index = kernel_state.input_start_index; + const size_t num_inputs = kernel_state.num_inputs; + const size_t num_outputs = kernel_state.num_outputs; + + node_inputs.clear(); + node_inputs.resize(num_inputs); + input_alloc_attrs.clear(); + input_alloc_attrs.resize(num_inputs); + for (size_t j = 0; j < num_inputs; ++j) { + auto t = inputs[input_start_index + j].get(); + node_inputs[j].tensor = t; + input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j]; + } + params.op_kernel = kernel_state.kernel; + input_device_contexts.clear(); + input_device_contexts.resize(num_inputs); + params.output_attr_array = kernel_state.output_alloc_attrs.data(); + OpKernelContext ctx(¶ms, num_outputs); + + // Actually execute the kernel. + device->Compute(kernel_state.kernel, &ctx); + + if (!ctx.status().ok()) { + // On failure, we must manually free all intermediate tensors. We have + // already freed all the inputs for kernels up to (but not including) + // the `i`th kernel. We scan through the previously executed kernels and + // destroy any tensors that were destined to be the input for a kernel + // that has not yet executed. + for (size_t j = 0; j < i; ++j) { + const KernelState& executed_kernel_state = kernels_[j]; + for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) { + for (size_t output_location : + executed_kernel_state.output_locations[k]) { + if (output_location >= input_start_index) { + // Only destroy an output location if it is an input to an + // operation that has not yet executed. + inputs[output_location].Destroy(); + } + } + } + } + done(ctx.status()); + return; + } + + // Free the inputs to the current kernel. + for (size_t j = 0; j < num_inputs; ++j) { + inputs[input_start_index + j].Destroy(); + } + + // Forward the outputs of the kernel to the inputs of subsequent kernels. + for (size_t j = 0; j < num_outputs; ++j) { + TensorValue val = ctx.release_output(j); + // TODO(mrry): Consider flattening the `output_locations` vector + // to improve the cache-friendliness of this loop. + for (size_t output_location : kernel_state.output_locations[j]) { + // TODO(mrry): Validate that the types match the expected values or + // ensure that the necessary validation has already happened. + inputs[output_location].Init(*val.tensor); + } + delete val.tensor; + } + } + done(Status::OK()); + } + + private: + const LocalExecutorParams params_; + + // All following members are read-only after Initialize(). + + // The sum of the number of inputs for each node in the graph. This determines + // the length of the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + size_t total_num_inputs_; + + // Represents cached graph structure state for each kernel. + struct KernelState { + // The kernel object. Not owned. + // + // This pointer is managed by `params_.create_kernel()` and + // `params_.delete_kernel()`. + OpKernel* kernel; + + // These fields determine the range of elements in `inputs` that corresponds + // to the inputs of `kernel`. + size_t input_start_index; + size_t num_inputs; + + size_t num_outputs; + + // For the `j`th output of `kernel`, `output_locations[j]` contains the + // locations in the flat `inputs` vector to which that output must be + // copied. See comment at the beginning of `RunAsync()` for details. + std::vector<std::vector<size_t>> + output_locations; // Length = `num_outputs`. + + // Memory space information for each output of `kernel`. + std::vector<AllocatorAttributes> + output_alloc_attrs; // Length = `num_outputs`. + }; + std::vector<KernelState> kernels_; + + // Memory space information for each input. This information is stored in the + // same order as the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + std::vector<AllocatorAttributes> + input_alloc_attrs_; // Length = `total_num_inputs_`. +}; + +class SingleThreadedExecutorRegistrar { + public: + SingleThreadedExecutorRegistrar() { + ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory()); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) override { + Executor* ret; + TF_RETURN_IF_ERROR( + NewSingleThreadedExecutor(params, std::move(graph), &ret)); + out_executor->reset(ret); + return Status::OK(); + } + }; +}; +static SingleThreadedExecutorRegistrar registrar; + +} // namespace + +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + Executor** executor) { + std::unique_ptr<SingleThreadedExecutorImpl> impl( + new SingleThreadedExecutorImpl(params)); + TF_RETURN_IF_ERROR(impl->Initialize(*graph)); + *executor = impl.release(); + return Status::OK(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h new file mode 100644 index 0000000000..e934352a1d --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ + +#include "tensorflow/core/common_runtime/executor.h" + +namespace tensorflow { +namespace data { + +// Creates a new `Executor` for executing `graph` synchronously on the caller +// thread. +// +// NOTE(mrry): The returned executor is optimized to impose low overhead on +// graphs that perform a small amount of work (e.g. <15us of work per graph on +// present architectures). It eschews concurrency, because issuing work to +// multiple threads can dominate the cost of executing small ops synchronously, +// and because contention in the executor data structures can reduce throughput +// (in terms of ops executed per unit time). +// +// However, the current implementation has the following limitations: +// +// 1. Reference-typed tensors are not supported and will not be supported in +// future. +// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not +// currently supported. The current plan is to extend support to "functional" +// control flow after the TensorFlow APIs transition to building graphs in +// that form (e.g. `tf.cond_v2()`). +// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported. +// The present implementation executes kernels one at a time in topological +// order, and cannot currently distinguish between disconnected subgraphs +// that are logically connected by subgraphs on a different device. +// 4. Memory logging is not currently supported. +// 5. Allocation forwarding is not currently supported. +// 6. Non-default device contexts are not currently supported. In effect, this +// limits the executor to CPU devices. +// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null +// are not currently supported. +// +// The single-threaded executor is primarily suitable for executing simple +// TensorFlow functions, such as one might find in a `tf.data` pipeline. +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + Executor** executor); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc new file mode 100644 index 0000000000..6244e287bb --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include <algorithm> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace data { +namespace { + +class ExecutorTest : public ::testing::Test { + protected: + ExecutorTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + + ~ExecutorTest() override { + // There should always be exactly one Ref left on the Rendezvous + // when the test completes. + CHECK(rendez_->Unref()); + delete exec_; + delete device_; + } + + // Resets executor_ with a new executor based on a graph 'gdef'. + void Create(std::unique_ptr<const Graph> graph) { + const int version = graph->versions().producer(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + delete exec_; + TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_)); + runner_ = [](std::function<void()> fn) { fn(); }; + rendez_ = NewLocalRendezvous(); + } + + Status Run(Rendezvous* rendez) { + Executor::Args args; + args.rendezvous = rendez; + args.runner = runner_; + return exec_->Run(args); + } + + Status Run(CallFrameInterface* call_frame) { + Executor::Args args; + args.call_frame = call_frame; + args.runner = runner_; + return exec_->Run(args); + } + + Device* device_ = nullptr; + Executor* exec_ = nullptr; + Executor::Args::Runner runner_; + Rendezvous* rendez_ = nullptr; +}; + +// A float val -> Tensor<float> +Tensor V(const float val) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar<float>()() = val; + return tensor; +} + +// A int32 val -> Tensor<int32> +Tensor VI(const int32 val) { + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar<int32>()() = val; + return tensor; +} + +// A bool val -> Tensor<bool> +Tensor VB(const bool val) { + Tensor tensor(DT_BOOL, TensorShape({})); + tensor.scalar<bool>()() = val; + return tensor; +} + +// A double val -> Tensor<double> +Tensor VD(const double val) { + Tensor tensor(DT_DOUBLE, TensorShape({})); + tensor.scalar<double>()() = val; + return tensor; +} + +// Tensor<float> -> a float val. +float V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_FLOAT); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar<float>()(); +} + +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + TF_CHECK_OK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result)); + return result; +} + +TEST_F(ExecutorTest, SimpleAdd) { + // c = a + b + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Retval(g.get(), 0, tmp); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0 +} + +TEST_F(ExecutorTest, SelfAdd) { + // v0 <- a + // v1 = v0 + v0 + // v2 = v1 + v1 + // ... ... + // v10 = v9 + v9 + // + // b <- v10 + // All nodes are executed by one thread. + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto v = test::graph::Arg(g.get(), 0, DT_FLOAT); + const int N = 10; + for (int i = 1; i <= N; ++i) { + v = test::graph::Add(g.get(), v, v); + } + // out <- v10 + test::graph::Retval(g.get(), 0, v); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + // a = 1.0 + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0 +} + +// Builds a graph which adds N copies of one variable "in". I.e., +// a + a + a + ... + a +// The returned graph is parenthesized ramdonly. I.e., +// a + ((a + a) + a) +// (a + a) + (a + a) +// ((a + a) + a) + a +// are all possibly generated. +void BuildTree(int N, Graph* g) { + CHECK_GT(N, 1); + // A single input node "in". + auto in = test::graph::Arg(g, 0, DT_FLOAT); + std::vector<Node*> nodes; + int i = 0; + // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... + for (; i < N; ++i) { + nodes.push_back(test::graph::Identity(g, in, 0)); + } + random::PhiloxRandom philox(0, 17); + random::SimplePhilox rnd(&philox); + while (nodes.size() > 1) { + // Randomly pick two from nodes and add them. The resulting node + // is named lik n10, n11, .... and is put back into "nodes". + int x = rnd.Uniform(nodes.size()); + auto in0 = nodes[x]; + nodes[x] = nodes.back(); + nodes.resize(nodes.size() - 1); + x = rnd.Uniform(nodes.size()); + auto in1 = nodes[x]; + // node = in0 + in1. + nodes[x] = test::graph::Add(g, in0, in1); + } + // The final output node "out". + test::graph::Retval(g, 0, nodes.back()); + FixupSourceAndSinkEdges(g); +} + +TEST_F(ExecutorTest, RandomTree) { + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(4096.0, V(retvals[0])); +} + +TEST_F(ExecutorTest, OpError) { + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto zero = test::graph::Constant(g.get(), V(0.0)); + auto inf = test::graph::Unary(g.get(), "Reciprocal", zero); + auto check = test::graph::CheckNumerics(g.get(), inf, "message"); + auto two = test::graph::Constant(g.get(), V(2.0)); + test::graph::Binary(g.get(), "Mul", check, two); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({}, {}); + // Fails due to invalid dtype. + EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame))); +} + +static void BM_executor(int iters, int width, int depth) { +#ifdef PLATFORM_GOOGLE + BenchmarkUseRealTime(); +#endif // PLATFORM_GOOGLE + Graph* g = new Graph(OpRegistry::Global()); + random::PhiloxRandom philox(1729, 17); + random::SimplePhilox rand(&philox); + uint64 cur = 0; + uint32 r = 1 + rand.Rand32() % width; + std::vector<Node*> ready_nodes; + for (int i = 0; i < r; ++i) { + ready_nodes.push_back(test::graph::NoOp(g, {})); + ++cur; + } + for (int i = 0; i < depth; ++i) { + std::random_shuffle(ready_nodes.begin(), ready_nodes.end()); + r = 1 + rand.Rand32() % (ready_nodes.size()); + std::vector<Node*> control_inputs; + for (int j = 0; j < r; ++j) { + control_inputs.push_back(ready_nodes.back()); + ready_nodes.pop_back(); + } + Node* n = test::graph::NoOp(g, control_inputs); + ++cur; + r = 1 + rand.Rand32() % width; + for (int j = 0; j < r; ++j) { + ready_nodes.push_back(test::graph::NoOp(g, {n})); + ++cur; + } + } + FixupSourceAndSinkEdges(g); +#ifdef PLATFORM_GOOGLE + SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); + SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters)); +#endif // PLATFORM_GOOGLE + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .Run(iters); +} + +// Tall skinny graphs +BENCHMARK(BM_executor)->ArgPair(16, 1024); +BENCHMARK(BM_executor)->ArgPair(32, 8192); + +// Short fat graphs +BENCHMARK(BM_executor)->ArgPair(1024, 16); +BENCHMARK(BM_executor)->ArgPair(8192, 32); + +// Tall fat graph +BENCHMARK(BM_executor)->ArgPair(1024, 1024); + +// TODO(mrry): This benchmark currently crashes with a use-after free, because +// test::Benchmark::RunWithArgs() assumes that the executor will take ownership +// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the +// duration of the benchmark. Since the single threaded executor does not retain +// a copy of the graph, this fails. +// +// TODO(mrry): Add support for Arg/Retval "function call convention" in +// `test::Benchmark::RunWithArgs()`. +#if 0 +#define ALICE "/job:j/replica:0/task:0/cpu:0" +#define BOB "/job:j/replica:0/task:0/gpu:0" + +static void BM_FeedInputFetchOutput(int iters) { + Graph* g = new Graph(OpRegistry::Global()); + // z = x + y: x and y are provided as benchmark inputs. z is the + // output of the benchmark. Conceptually, the caller is ALICE, the + // benchmark is BOB. + Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB); + Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB); + Node* sum = test::graph::Add(g, x, y); + Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE); + FixupSourceAndSinkEdges(g); + Tensor val(DT_FLOAT, TensorShape({})); + val.scalar<float>()() = 3.14; + SetBenchmarkItemsProcessed(static_cast<int64>(iters)); + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .RunWithArgs({{x, val}, {y, val}}, {z}, iters); +} +BENCHMARK(BM_FeedInputFetchOutput); +#endif + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index fe7ef38d5f..b8c7fb15f4 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -187,5 +187,5 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index 14df3a6801..1e73cfc753 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -293,5 +293,5 @@ REGISTER_KERNEL_BUILDER(Name("SlideDataset").Device(DEVICE_CPU), SlideDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index e526578701..85b1e50695 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/util/sparse/sparse_tensor.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -274,5 +274,5 @@ TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL); #undef REGISTER_DATASET_KERNEL } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc index ffabda1a8a..783d1e6cb2 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.cc +++ b/tensorflow/core/kernels/data/sql/driver_manager.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h" namespace tensorflow { - +namespace data { namespace sql { std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection( @@ -30,5 +30,5 @@ std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection( } } // namespace sql - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h index a34691b5a2..c5428f396b 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.h +++ b/tensorflow/core/kernels/data/sql/driver_manager.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/query_connection.h" namespace tensorflow { - +namespace data { namespace sql { // A factory class for creating `QueryConnection` instances. @@ -35,7 +35,7 @@ class DriverManager { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h index e9ffca202f..2fd229a9bf 100644 --- a/tensorflow/core/kernels/data/sql/query_connection.h +++ b/tensorflow/core/kernels/data/sql/query_connection.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +namespace data { class IteratorContext; @@ -63,7 +64,7 @@ class QueryConnection { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc index 7cd07bd8ec..5108e83976 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { - +namespace data { namespace sql { SqliteQueryConnection::SqliteQueryConnection() {} @@ -115,5 +115,5 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry( } } // namespace sql - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h index 81b19530b7..175492c49d 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { - +namespace data { namespace sql { class SqliteQueryConnection : public QueryConnection { @@ -50,7 +50,7 @@ class SqliteQueryConnection : public QueryConnection { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 2aa153fcfa..6bbe459332 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -24,8 +24,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { - +namespace data { namespace { + // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following ops. @@ -211,5 +212,5 @@ class SqlDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index 75af73df54..7e528a71be 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { @@ -33,16 +34,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { &stats_aggregator_resource)); core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); - *output = new Dataset(ctx, input, stats_aggregator_resource); + *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource); } private: class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, + const Tensor& resource_handle, StatsAggregatorResource* stats_aggregator_resource) : DatasetBase(DatasetContext(ctx)), input_(input), + resource_handle_(resource_handle), stats_aggregator_resource_(stats_aggregator_resource) { input_->Ref(); stats_aggregator_resource_->Ref(); @@ -74,8 +77,13 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* resource_handle_node = nullptr; + TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, resource_handle_node}, output)); + return Status::OK(); } private: @@ -128,6 +136,7 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; + const Tensor resource_handle_; StatsAggregatorResource* stats_aggregator_resource_; }; }; @@ -135,4 +144,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU), SetStatsAggregatorDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index b133cfab54..a7ded67876 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" namespace tensorflow { +namespace data { namespace { static mutex* get_counters_map_lock() { @@ -145,4 +146,5 @@ REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU), StatsAggregatorSummaryOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 8957f5d997..e9e42f05a1 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // This op defines a `Dataset` that passes through its input elements and @@ -248,4 +249,5 @@ REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU), BytesProducedStatsDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index e5c237dfaa..e5cdfdd732 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -174,5 +174,5 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index 1192fafc4c..ca4ea25b89 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -33,11 +33,7 @@ class TensorDatasetOp : public DatasetOpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs)); // TODO(mrry): Validate that the shapes of the "components" tensors match // the "shapes" attr.; - std::vector<Tensor> components; - components.reserve(inputs.size()); - for (const Tensor& t : inputs) { - components.push_back(t); - } + std::vector<Tensor> components(inputs.begin(), inputs.end()); *output = new Dataset(ctx, std::move(components)); } @@ -140,5 +136,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU), TensorDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index ccd5e60acc..2ed636a400 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a, @@ -648,5 +648,5 @@ REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU), EnqueueInQueueDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index dc32cd23e5..7dc64b0a75 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -168,5 +168,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU), TensorSliceDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 1a79f72b28..81c432b938 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -204,5 +204,5 @@ REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU), UnbatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 0ab6beabfc..2ad4711aab 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace data { namespace { class WindowDataset : public DatasetBase { @@ -107,4 +108,5 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements, return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h index 7bd31a0bc7..84cb3c7860 100644 --- a/tensorflow/core/kernels/data/window_dataset.h +++ b/tensorflow/core/kernels/data/window_dataset.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { +namespace data { // Creates a dataset representing an eagerly-collected window of elements. // @@ -43,6 +44,7 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements, std::vector<PartialTensorShape> output_shapes, DatasetBase** out_dataset); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 41bf9d43fe..ac44623ce2 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/window_dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { int64 window_size = 0; - OP_REQUIRES_OK( - ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size)); + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size)); OP_REQUIRES( ctx, window_size > 0, errors::InvalidArgument("Window size must be greater than zero.")); - *output = new Dataset(ctx, window_size, input); + int64 window_shift = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "shift", &window_shift)); + OP_REQUIRES( + ctx, window_shift > 0, + errors::InvalidArgument("Window shift must be greater than zero.")); + + int64 window_stride = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "stride", &window_stride)); + OP_REQUIRES( + ctx, window_stride > 0, + errors::InvalidArgument("Window stride must be greater than zero.")); + + bool drop_remainder; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder)); + + *output = new Dataset(ctx, input, window_size, window_shift, window_stride, + drop_remainder); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input) + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size, + int64 window_shift, int64 window_stride, bool drop_remainder) : DatasetBase(DatasetContext(ctx)), + input_(input), window_size_(window_size), - input_(input) { + window_shift_(window_shift), + window_stride_(window_stride), + drop_remainder_(drop_remainder) { input_->Ref(); } @@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } string DebugString() const override { - return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset"); + return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_, + window_stride_, drop_remainder_, ")::Dataset"); } protected: @@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* window_size = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size)); + Node* window_size_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node)); + Node* window_shift_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node)); + Node* window_stride_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node)); + Node* drop_remainder_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, window_size}, output)); + b->AddDataset(this, + {input_graph_node, window_size_node, window_shift_node, + window_stride_node, drop_remainder_node}, + output)); return Status::OK(); } @@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { - // Each row of `window_elements` is a tuple of tensors from the - // input iterator. + const int64 window_size = dataset()->window_size_; + const int64 window_shift = dataset()->window_shift_; + const int64 window_stride = dataset()->window_stride_; std::vector<std::vector<Tensor>> window_elements; + Status status = Status::OK(); { mutex_lock l(mu_); - if (!input_impl_) { + if (!input_impl_ && buffer_.empty()) { *end_of_sequence = true; return Status::OK(); } - window_elements.reserve(dataset()->window_size_); - *end_of_sequence = false; - for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence; - ++i) { - std::vector<Tensor> window_element_tuple; - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple, - end_of_sequence)); - if (!*end_of_sequence) { - window_elements.emplace_back(std::move(window_element_tuple)); - } else { - input_impl_.reset(); + + // Add elements to the buffer. + size_t target_size = TargetBufferSize(window_size, window_stride); + if (input_impl_) { + *end_of_sequence = false; + for (size_t i = buffer_.size(); + i < target_size && !*end_of_sequence; ++i) { + std::vector<Tensor> element; + Status status = + input_impl_->GetNext(ctx, &element, end_of_sequence); + if (!*end_of_sequence) { + buffer_.emplace_back(std::move(element), status); + } else { + input_impl_.reset(); + } + } + } + + // If there are not enough elements and `drop_remainder` is set, we do + // not wish to return a smaller window. + if (buffer_.empty() || + (dataset()->drop_remainder_ && buffer_.size() < target_size)) { + DCHECK(*end_of_sequence); + return Status::OK(); + } + + int num_elements = 1 + (buffer_.size() - 1) / window_stride; + window_elements.reserve(num_elements); + for (size_t i = 0; i < num_elements; ++i) { + status.Update(buffer_[window_stride * i].status); + if (!status.ok()) { + break; + } + window_elements.emplace_back(buffer_[window_stride * i].result); + } + + // Shift the window, discarding elements if necessary. + int buffer_size = buffer_.size(); + if (window_shift >= buffer_size) { + for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) { + bool end_of_input; + std::vector<Tensor> element; + // Ignore non-error status of discarded elements. + input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError(); + if (end_of_input) { + input_impl_.reset(); + } } + buffer_.clear(); + } else { + buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift); } } - if (window_elements.empty()) { - DCHECK(*end_of_sequence); - return Status::OK(); + if (!status.ok()) { + return status; } + // Construct output tensors. const size_t num_tuple_components = window_elements[0].size(); const int64 num_window_elements = window_elements.size(); + *end_of_sequence = false; for (size_t idx = 0; idx < num_tuple_components; ++idx) { DatasetBase* window_dataset; std::vector<std::vector<Tensor>> window_component_elements; @@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset, &out_tensors->back())); } - *end_of_sequence = false; return Status::OK(); } @@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } else { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); } + // Save buffer. + TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"), + buffer_.size())); + for (int64 i = 0; i < buffer_.size(); i++) { + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(strings::StrCat("buffer[", i, "].size"), + buffer_[i].result.size())); + for (int64 j = 0; j < buffer_[i].result.size(); j++) { + TF_RETURN_IF_ERROR( + writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"), + buffer_[i].result[j])); + } + } return Status::OK(); } @@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } else { input_impl_.reset(); } + // Restore buffer. + int64 buffer_size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); + buffer_.resize(buffer_size); + for (int64 i = 0; i < buffer_size; i++) { + int64 vector_size; + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status)); + TF_RETURN_IF_ERROR(reader->ReadScalar( + strings::StrCat("buffer[", i, "].size"), &vector_size)); + buffer_[i].result.resize(vector_size); + for (int64 j = 0; j < vector_size; j++) { + TF_RETURN_IF_ERROR( + reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"), + &buffer_[i].result[j])); + } + } return Status::OK(); } private: + struct InvocationResult { + InvocationResult() = default; + InvocationResult(std::vector<Tensor>&& result, const Status& status) + : result(result), status(status) {} + + std::vector<Tensor> result; + Status status; + }; + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name(strings::StrCat("buffer[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name(strings::StrCat("buffer[", index, "].error_message")); + } + + size_t TargetBufferSize(int64 window_size, int64 window_stride) { + return (window_size - 1) * window_stride + 1; + } + mutex mu_; + std::deque<InvocationResult> buffer_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); }; - const int64 window_size_; const DatasetBase* const input_; + const int64 window_size_; + const int64 window_shift_; + const int64 window_stride_; + const bool drop_remainder_; }; }; REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU), WindowDatasetOp); - } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 1c49874a6a..3f76695bb1 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/file_system.h" namespace tensorflow { - +namespace data { namespace { class ToTFRecordOp : public AsyncOpKernel { @@ -104,4 +104,5 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), ToTFRecordOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index e4306579ed..61a2078f46 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -175,5 +175,5 @@ class ZipDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 33ed5522d0..d705e82b0d 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -255,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp { TensorShape shape({1}); OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor)); output_tensor->vec<int64>()(0) = nan_count; - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } }; @@ -380,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp { bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 && positive_inf_count == 0; if (!mute) { - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } } diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index b4dcf0a74b..ae451be7e2 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument( "Number of channels must be 1, 3 or 4, was ", channels_)); - OP_REQUIRES(context, width > 0 && header_size >= 0, + OP_REQUIRES(context, width > 0, errors::InvalidArgument("Width must be positive")); + OP_REQUIRES(context, height != 0, + errors::InvalidArgument("Height must be nonzero")); OP_REQUIRES(context, header_size >= 0, errors::InvalidArgument("header size must be nonnegative")); @@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel { const int32 abs_height = abs(height); // there may be padding bytes when the width is not a multiple of 4 bytes - // 8 * channels == bits per pixel - const int row_size = (8 * channels_ * width + 31) / 32 * 4; + const int row_size = (channels_ * width + 3) / 4 * 4; const int64 last_pixel_offset = static_cast<int64>(header_size) + (abs_height - 1) * row_size + diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index 3eed847c16..6bfb5bd5bc 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults)); for (int i = 0; i < record_defaults.size(); ++i) { + OP_REQUIRES(ctx, record_defaults[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2, errors::InvalidArgument( "There should only be 1 default per field but field ", i, diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc index 2a25459194..76afd6f18c 100644 --- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc +++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc @@ -17,7 +17,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "external/cub_archive/cub/util_ptx.cuh" +#include "third_party/cub/util_ptx.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/kernels/depthwise_conv_op.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc index 862a97723f..e7882acc80 100644 --- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc +++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc @@ -35,10 +35,10 @@ limitations under the License. #define EIGEN_USE_GPU -#include "external/cub_archive/cub/device/device_radix_sort.cuh" -#include "external/cub_archive/cub/device/device_reduce.cuh" -#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh" -#include "external/cub_archive/cub/thread/thread_operators.cuh" +#include "third_party/cub/device/device_radix_sort.cuh" +#include "third_party/cub/device/device_reduce.cuh" +#include "third_party/cub/iterator/constant_input_iterator.cuh" +#include "third_party/cub/thread/thread_operators.cuh" #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index b01db91720..fb2a4cc8ef 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { data.shaped<T, 2>({indices_vec.dimension(0), slice_size}); if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { - T* merged_base = &merged_flat(0, 0); - const T* data_base = &data_flat(0, 0); + T* merged_base = merged_flat.data(); + const T* data_base = data_flat.data(); for (int i = 0; i < indices_vec.size(); i++) { int32 index = internal::SubtleMustCopy(indices_vec(i)); OP_REQUIRES( diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index e13e548f86..8edf7d4a2c 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -51,48 +51,55 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, - const TensorReshapingOp< + IndexPair<typename internal::traits<OutputBackward>::Index>, 1>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, - const TensorReverseOp<const array<bool, 5>, const Kernel> >, + 2>, + const TensorShufflingOp< + const array< + typename internal::traits<OutputBackward>::Index, 5>, + const TensorReverseOp<const Eigen::array<bool, 5>, + const Kernel>>>>, const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, + 2>, const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> > > >, + const OutputBackward>>>>, TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, + IndexPair<typename internal::traits<OutputBackward>::Index>, 1>, const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, + 2>, const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> >, - const TensorReshapingOp< + const OutputBackward>>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, - const TensorReverseOp<const array<bool, 5>, - const Kernel> > > > >::type + 2>, + const TensorShufflingOp< + const array< + typename internal::traits<OutputBackward>::Index, 5>, + const TensorReverseOp<const Eigen::array<bool, 5>, + const Kernel>>>>>>>::type CuboidConvolutionBackwardInput( const Kernel& kernel, const OutputBackward& output_backward, typename internal::traits<OutputBackward>::Index inputPlanes, typename internal::traits<OutputBackward>::Index inputRows, typename internal::traits<OutputBackward>::Index inputCols, - const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1, - const DenseIndex strideCols = 1) { + const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1) { typedef typename internal::traits<OutputBackward>::Index TensorIndex; const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, - internal::traits<Kernel>::Layout, TensorIndex> > + internal::traits<Kernel>::Layout, TensorIndex>> kern(kernel); const TensorRef< const Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, - internal::traits<OutputBackward>::Layout, TensorIndex> > + internal::traits<OutputBackward>::Layout, TensorIndex>> out(output_backward); EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout == @@ -125,58 +132,45 @@ CuboidConvolutionBackwardInput( const TensorIndex outputCols = isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4]; - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi<TensorIndex>( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi<TensorIndex>( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi<TensorIndex>( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; - } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; - } - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Computing the forward padding. + const TensorIndex forward_pad_top_z = numext::maxi<Index>( + 0, + ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2); + const TensorIndex forward_pad_top = numext::maxi<Index>( + 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2); + const TensorIndex forward_pad_left = numext::maxi<Index>( + 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2); + + const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z; + const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; + const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left; + + const TensorIndex padding_bottom_z = inputPlanes - + (outputPlanes - 1) * plane_stride - 2 - + padding_top_z + kernelPlanesEff; + const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride - + 2 - padding_top + kernelRowsEff; + const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride - + 2 - padding_left + kernelColsEff; + + eigen_assert(padding_top_z >= 0); eigen_assert(padding_top >= 0); eigen_assert(padding_left >= 0); + eigen_assert(padding_bottom_z >= 0); eigen_assert(padding_bottom >= 0); eigen_assert(padding_right >= 0); - // The kernel has dimensions filters X channels X patch_planes X patch_rows X - // patch_cols. + // The kernel has dimensions : + // filters x channels x patch_planes x patch_rows x patch_cols. // We need to reverse the kernel along the spatial dimensions. - array<bool, 5> kernel_reverse; + Eigen::array<bool, 5> kernel_reverse; if (isColMajor) { kernel_reverse[0] = false; kernel_reverse[1] = false; @@ -191,15 +185,35 @@ CuboidConvolutionBackwardInput( kernel_reverse[4] = false; } - DSizes<TensorIndex, 3> kernel_dims; + // Reorder the dimensions to: + // filters x patch_planes x patch_rows x patch_cols x channels + array<TensorIndex, 5> kernel_shuffle; if (isColMajor) { - kernel_dims[0] = kernelFilters; - kernel_dims[1] = kernelChannels; - kernel_dims[2] = kernelRows * kernelCols * kernelPlanes; + // From: filters x channels x planes x rows x cols + // To: filters x planes x rows x cols x channels + kernel_shuffle[0] = 0; + kernel_shuffle[1] = 2; + kernel_shuffle[2] = 3; + kernel_shuffle[3] = 4; + kernel_shuffle[4] = 1; } else { - kernel_dims[0] = kernelRows * kernelCols * kernelPlanes; + // From: cols x rows x planes x channels x filters + // To: channels x cols x rows x planes x filters + kernel_shuffle[0] = 3; + kernel_shuffle[1] = 0; + kernel_shuffle[2] = 1; + kernel_shuffle[3] = 2; + kernel_shuffle[4] = 4; + } + + // Collapse the dims + DSizes<TensorIndex, 2> kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols; kernel_dims[1] = kernelChannels; - kernel_dims[2] = kernelFilters; + } else { + kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols; + kernel_dims[0] = kernelChannels; } // The output_backward has dimensions out_depth X out_planes X out_rows X @@ -208,36 +222,32 @@ CuboidConvolutionBackwardInput( // dimensions: // out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes * // input_rows * input_cols * OTHERS) - DSizes<TensorIndex, 3> pre_contract_dims; + DSizes<TensorIndex, 2> pre_contract_dims; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; + pre_contract_dims[0] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = inputPlanes * inputRows * inputCols; for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[2] *= out.dimension(i); + pre_contract_dims[1] *= out.dimension(i); } } else { - pre_contract_dims[2] = kernelFilters; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = inputRows * inputCols * inputPlanes; + pre_contract_dims[1] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[0] = inputPlanes * inputRows * inputCols; for (int i = 0; i < NumDims - 4; ++i) { pre_contract_dims[0] *= out.dimension(i); } } - // We will contract along dimensions (0, 2) in kernel and (0, 1) in - // output_backward, if this is col-major, and - // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this - // row-major. - array<IndexPair<TensorIndex>, 2> contract_dims; + // We will contract along the collapsed dimension that contains the + // kernelFilters, kernelPlanes, kernelRows and kernelCols. + array<IndexPair<TensorIndex>, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) contract_dims[0] = IndexPair<TensorIndex>(0, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 1); } else { // row-major: output.patches.contract(kernel) - contract_dims[0] = IndexPair<TensorIndex>(1, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 2); + contract_dims[0] = IndexPair<TensorIndex>(1, 1); } // Post contraction, the dimensions of the input_backprop is @@ -261,40 +271,31 @@ CuboidConvolutionBackwardInput( } } - DSizes<TensorIndex, NumDims> strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } - return choose( Cond<internal::traits<OutputBackward>::Layout == ColMajor>(), kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) .reshape(kernel_dims) + .eval() .contract(output_backward .extract_volume_patches( kernelPlanes, kernelRows, kernelCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, padding_ztop, - padding_zbottom, padding_top, padding_bottom, + plane_stride, row_stride, col_stride, padding_top_z, + padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims), contract_dims) .reshape(post_contract_dims), output_backward .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, + plane_stride, row_stride, col_stride, + padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims) - .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims), + .contract(kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) + .reshape(kernel_dims) + .eval(), contract_dims) .reshape(post_contract_dims)); } @@ -322,48 +323,69 @@ CuboidConvolutionBackwardInput( */ template <typename OutputBackward, typename Input> EIGEN_ALWAYS_INLINE static const typename internal::conditional< - internal::traits<OutputBackward>::Layout == ColMajor, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, 5>, - const TensorReverseOp< - const array<bool, 5>, - const TensorReshapingOp< - const DSizes<typename internal::traits<OutputBackward>::Index, - 5>, + internal::traits<Input>::Layout == ColMajor, + const TensorReverseOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorShufflingOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<Input>::Index>, 2>, - const TensorReshapingOp< + IndexPair<typename internal::traits<Input>::Index>, 1>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, - 3>, - const Input>, + 2>, + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const OutputBackward>>>, const TensorReshapingOp< - const DSizes< - typename internal::traits<OutputBackward>::Index, - 4>, + const DSizes<typename internal::traits<Input>::Index, + 2>, const TensorVolumePatchOp< Dynamic, Dynamic, Dynamic, - const OutputBackward> > > > > >, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, 5>, - const TensorReverseOp< - const array<bool, 5>, - const TensorReshapingOp< - const DSizes<typename internal::traits<OutputBackward>::Index, - 5>, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Input>>>>>>>>, + const TensorReverseOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorShufflingOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<Input>::Index>, 2>, - const TensorReshapingOp< - const DSizes< - typename internal::traits<OutputBackward>::Index, - 4>, - const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> >, + IndexPair<typename internal::traits<Input>::Index>, 1>, const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, - 3>, - const Input> > > > > >::type + 2>, + const TensorVolumePatchOp< + Dynamic, Dynamic, Dynamic, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Input>>>>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, + 2>, + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const OutputBackward>>>>>>>>::type CuboidConvolutionBackwardKernel( const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelPlanes, @@ -374,11 +396,11 @@ CuboidConvolutionBackwardKernel( typedef typename internal::traits<Input>::Index TensorIndex; TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, - internal::traits<Input>::Layout, TensorIndex> > + internal::traits<Input>::Layout, TensorIndex>> in(input); TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, - internal::traits<OutputBackward>::Layout, TensorIndex> > + internal::traits<OutputBackward>::Layout, TensorIndex>> out(output_backward); EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == @@ -392,6 +414,13 @@ CuboidConvolutionBackwardKernel( internal::traits<OutputBackward>::NumDimensions, YOU_MADE_A_PROGRAMMING_MISTAKE); + // We do not support higher dimensional backward convolutions, or convolutions + // without batch dimension. + // TODO(ezhulenev): Relax this constraint, and turn on tests without batch + // dimension in eigen_backward_cuboid_convolutions_test.cc. + EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5, + YOU_MADE_A_PROGRAMMING_MISTAKE); + const TensorIndex inputPlanes = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); const TensorIndex inputRows = @@ -406,213 +435,174 @@ CuboidConvolutionBackwardKernel( const TensorIndex outputCols = isColMajor ? out.dimension(3) : out.dimension(NumDims - 4); + // Number of filters. This is the same as the output depth. const TensorIndex kernelFilters = isColMajor ? out.dimension(0) : out.dimension(NumDims - 1); + // Number of channels. This is the same as the input depth. const TensorIndex kernelChannels = isColMajor ? in.dimension(0) : in.dimension(NumDims - 1); - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi<TensorIndex>( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi<TensorIndex>( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi<TensorIndex>( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; + // Number of batches in the input tensor. + const TensorIndex batch = + isColMajor ? in.dimension(4) : in.dimension(NumDims - 5); + + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Compute forward padding from input and output_backward dimensions. + const TensorIndex padPlanes = numext::maxi<Index>( + 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes); + const TensorIndex padRows = numext::maxi<Index>( + 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows); + const TensorIndex padCols = numext::maxi<Index>( + 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols); + + const TensorIndex padding_top_z = padPlanes / 2; + const TensorIndex padding_top = padRows / 2; + const TensorIndex padding_left = padCols / 2; + + // Compute paddings for output_backward before extracting patches. + const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1; + const auto expanded_out_rows = (outputRows - 1) * strideRows + 1; + const auto expanded_out_cols = (outputCols - 1) * strideCols + 1; + const auto padded_out_planes = inputPlanes + kernelPlanes - 1; + const auto padded_out_rows = inputRows + kernelRows - 1; + const auto padded_out_cols = inputCols + kernelCols - 1; + const auto top_pad_planes = kernelPlanes - 1 - padding_top_z; + const auto top_pad_rows = kernelRows - 1 - padding_top; + const auto left_pad_cols = kernelCols - 1 - padding_left; + const auto bottom_pad_planes = + padded_out_planes - expanded_out_planes - top_pad_planes; + const auto bottom_pad_rows = + padded_out_rows - expanded_out_rows - top_pad_rows; + const auto right_pad_cols = + padded_out_cols - expanded_out_cols - left_pad_cols; + + // Reorder output_backward dimensions. + array<TensorIndex, 5> output_backward_shuffle; + if (isColMajor) { + // From: [out_depth, out_planes, out_rows, out_cols, batch] + // To: [batch, out_planes, out_rows, out_cols, out_depth] + output_backward_shuffle = {4, 1, 2, 3, 0}; } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; + // From: [batch, out_cols, out_rows, out_planes, out_depth] + // To: [out_depth, out_cols, out_rows, out_planes, batch] + output_backward_shuffle = {4, 1, 2, 3, 0}; } - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); - eigen_assert(padding_top >= 0); - eigen_assert(padding_left >= 0); - eigen_assert(padding_bottom >= 0); - eigen_assert(padding_right >= 0); - - // The output_backward has dimensions out_depth X out_plaens X out_rows X - // out_cols X OTHERS - // When we extract the image patches from output_backward (with input as the - // kernel), it will have dimensions - // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes * - // kernel_rows * kernel_cols) X OTHERS - DSizes<TensorIndex, 4> pre_contract_dims; + // Reorder input dimensions. + array<TensorIndex, 5> input_shuffle; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = inputRows * inputCols * inputPlanes; - pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[3] = 1; - for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[3] *= out.dimension(i); - } + // From: [in_depth, in_planes, in_rows, in_cols, batch] + // To: [in_depth, batch, in_planes, in_rows, in_cols] + input_shuffle = {0, 4, 1, 2, 3}; } else { - pre_contract_dims[3] = kernelFilters; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = 1; - for (int i = 0; i < NumDims - 4; ++i) { - pre_contract_dims[0] *= out.dimension(i); - } + // From: [batch, in_cols, in_rows, in_planes, in_depth] + // To: [in_cols, in_rows, in_planes, batch, in_depth] + input_shuffle = {1, 2, 3, 0, 4}; } - // The input has dimensions in_depth X (input_planes * input_rows * - // input_cols) X OTHERS - DSizes<TensorIndex, 3> input_dims; + // Input is playing the role of a "kernel" in this convolution. + DSizes<TensorIndex, 2> input_dims; if (isColMajor) { input_dims[0] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[2] = 1; - for (int i = 4; i < NumDims; ++i) { - input_dims[2] *= in.dimension(i); - } - eigen_assert(input_dims[2] == pre_contract_dims[3]); + input_dims[1] = batch * inputPlanes * inputRows * inputCols; } else { - input_dims[2] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[0] = 1; - for (int i = 0; i < NumDims - 4; ++i) { - input_dims[0] *= in.dimension(i); - } - eigen_assert(input_dims[0] == pre_contract_dims[0]); + input_dims[1] = kernelChannels; + input_dims[0] = inputCols * inputRows * inputPlanes * batch; } - // We will contract along dimensions (1, 2) in and (1, 3) in out, if - // this is col-major. - // For row-major, it's dimensions (0, 1) in and (0, 2) in out. - array<IndexPair<TensorIndex>, 2> contract_dims; + // Molds the output of the patch extraction result into a 2D tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes<TensorIndex, 2> pre_contract_dims; if (isColMajor) { - // col-major: in.contract(output.patches) - contract_dims[0] = IndexPair<TensorIndex>(1, 1); - contract_dims[1] = IndexPair<TensorIndex>(2, 3); + pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols; + pre_contract_dims[1] = + kernelPlanes * kernelRows * kernelCols * kernelFilters; } else { - // row-major: output.patches.contract(in) - contract_dims[0] = IndexPair<TensorIndex>(0, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 1); + pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch; + pre_contract_dims[0] = + kernelFilters * kernelCols * kernelRows * kernelPlanes; } - // After the contraction, the kernel will have dimension - // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols - // We will need to shuffle the first two dimensions and reverse the spatial - // dimensions. - // The end shape is: - // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols + // We will contract along the collapsed dimension that contains the + // batch, inputPlanes, inputRows and inputCols. + array<IndexPair<TensorIndex>, 1> contract_dims; + contract_dims[0] = IndexPair<TensorIndex>(1, 0); - // This is the shape of the kernel *before* the shuffling. - DSizes<TensorIndex, 5> kernel_dims; + // Dimensions after contraction. + DSizes<TensorIndex, NumDims> post_contract_dims; if (isColMajor) { - kernel_dims[0] = kernelChannels; - kernel_dims[1] = kernelFilters; - kernel_dims[2] = kernelPlanes; - kernel_dims[3] = kernelRows; - kernel_dims[4] = kernelCols; + post_contract_dims[0] = kernelChannels; + post_contract_dims[1] = kernelPlanes; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelCols; + post_contract_dims[4] = kernelFilters; } else { - kernel_dims[0] = kernelCols; - kernel_dims[1] = kernelRows; - kernel_dims[2] = kernelPlanes; - kernel_dims[3] = kernelFilters; - kernel_dims[4] = kernelChannels; + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = kernelCols; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelPlanes; + post_contract_dims[4] = kernelChannels; } - // Flip filters and channels. + // Reorder output of contraction to valid filter shape. array<TensorIndex, 5> kernel_shuffle; if (isColMajor) { - kernel_shuffle[0] = 1; - kernel_shuffle[1] = 0; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 3; - kernel_shuffle[4] = 4; + // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth] + // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols] + kernel_shuffle = {4, 0, 1, 2, 3}; } else { - kernel_shuffle[0] = 0; - kernel_shuffle[1] = 1; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 4; - kernel_shuffle[4] = 3; + // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth] + // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth] + kernel_shuffle = {1, 2, 3, 4, 0}; } - // Reverse the spatial dimensions. - array<bool, 5> kernel_reverse; + // Reverse kernel backprop dimensions. + array<TensorIndex, 5> kernel_reverse; if (isColMajor) { - kernel_reverse[0] = false; - kernel_reverse[1] = false; - kernel_reverse[2] = true; - kernel_reverse[3] = true; - kernel_reverse[4] = true; + kernel_reverse = {false, false, true, true, true}; } else { - kernel_reverse[0] = true; - kernel_reverse[1] = true; - kernel_reverse[2] = true; - kernel_reverse[3] = false; - kernel_reverse[4] = false; + kernel_reverse = {true, true, true, false, false}; } - DSizes<TensorIndex, NumDims> strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } - return choose( - Cond<internal::traits<Input>::Layout == ColMajor>(), - input.reshape(input_dims) - .contract(output_backward + // Create convolution input (aka source of patches) from output backward + // tensor by shuffling dimensions. + const auto the_input = + output_backward.shuffle(output_backward_shuffle).eval(); + + // Create convolution kernel (aka filter) from input by shuffling and + // reshaping. + const auto the_kernel = + input.shuffle(input_shuffle).reshape(input_dims).eval(); + + return choose(Cond<internal::traits<Input>::Layout == ColMajor>(), + the_kernel.contract( + the_input .extract_volume_patches( inputPlanes, inputRows, inputCols, 1, 1, 1, stridePlanes, strideRows, strideCols, - - padding_ztop, padding_zbottom, padding_top, - padding_bottom, padding_left, padding_right) + top_pad_planes, bottom_pad_planes, top_pad_rows, + bottom_pad_rows, left_pad_cols, right_pad_cols) .reshape(pre_contract_dims), - contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle), - output_backward - .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, - padding_bottom, padding_left, padding_right) - .reshape(pre_contract_dims) - .contract(input.reshape(input_dims), contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle)); + contract_dims), + the_input + .extract_volume_patches( + inputPlanes, inputRows, inputCols, 1, 1, 1, + stridePlanes, strideRows, strideCols, top_pad_planes, + bottom_pad_planes, top_pad_rows, bottom_pad_rows, + left_pad_cols, right_pad_cols) + .reshape(pre_contract_dims) + .contract(the_kernel, contract_dims)) + .reshape(post_contract_dims) + .shuffle(kernel_shuffle) + .reverse(kernel_reverse); } } // end namespace Eigen diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index cb0a76dac4..960920c55b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -189,14 +189,19 @@ SpatialConvolutionBackwardInput( } #endif - // Reorder the dimensions to filters X patch_rows X patch_cols X channels + // Reorder the dimensions to: + // filters x patch_rows x patch_cols x channels array<TensorIndex, 4> kernel_shuffle; if (isColMajor) { + // From: filters x channels x rows x cols + // To: filters x rows x cols x channels kernel_shuffle[0] = 0; kernel_shuffle[1] = 2; kernel_shuffle[2] = 3; kernel_shuffle[3] = 1; } else { + // From: cols x rows x channels x filters + // To: channels x cols x rows x filters kernel_shuffle[0] = 2; kernel_shuffle[1] = 0; kernel_shuffle[2] = 1; @@ -233,8 +238,8 @@ SpatialConvolutionBackwardInput( } } - // We will contract along the fused dimension that contains the kernelFilters, - // the kernelRows and the kernelCols. + // We will contract along the collapsed dimension that contains the + // kernelFilters, the kernelRows and the kernelCols. array<IndexPair<TensorIndex>, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) @@ -327,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>, const OutputBackward>, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, - 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, 2>, - const TensorImagePatchOp<Dynamic, Dynamic, - const Input> > > > >, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >, TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 4>, const TensorContractionOp< const array<IndexPair<typename internal::traits<Input>::Index>, 1>, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, - 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, 2>, - const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>, const OutputBackward> > > >::type @@ -451,12 +449,16 @@ SpatialConvolutionBackwardKernel( eigen_assert(output_dims[0] == pre_contract_dims[0]); } - array<TensorIndex, 2> shuffle_dims; - shuffle_dims[0] = 1; - shuffle_dims[1] = 0; - + // We will contract along the collapsed dimension that contains the + // outputCols, outputRows and OTHERS. array<IndexPair<TensorIndex>, 1> contract_dims; - contract_dims[0] = IndexPair<TensorIndex>(1, 0); + if (isColMajor) { + // col-major: output_backward.contract(input.patches) + contract_dims[0] = IndexPair<TensorIndex>(1, 1); + } else { + // row-major: input.patches.contract(output_backward) + contract_dims[0] = IndexPair<TensorIndex>(0, 0); + } // After the contraction, the kernel will have the desired shape // out_depth X in_shape X kernel_rows X kernel_cols @@ -482,8 +484,7 @@ SpatialConvolutionBackwardKernel( kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride, 1, 1, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) - .reshape(pre_contract_dims) - .shuffle(shuffle_dims), + .reshape(pre_contract_dims), contract_dims) .reshape(kernel_dims), input @@ -492,7 +493,6 @@ SpatialConvolutionBackwardKernel( padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) .reshape(pre_contract_dims) - .shuffle(shuffle_dims) .contract(output_backward.reshape(output_dims), contract_dims) .reshape(kernel_dims)); } diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc index 2229ec9659..673ec1458b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc @@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest, const int output_cols = input_cols - patch_cols + 1; const int output_planes = input_planes - patch_planes + 1; - Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols); + // TODO(ezhulenev): Support backward kernel convolution without batch + // dimension. + Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols, + /*num_batches*/ 1); Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows, patch_cols); - Tensor<float, 4> output_backward(output_depth, output_planes, output_rows, - output_cols); + Tensor<float, 5> output_backward(output_depth, output_planes, output_rows, + output_cols, /*num_batches*/ 1); output_backward = output_backward.constant(11.0f) + output_backward.random(); input = input.constant(2.0f) + input.random(); @@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest, if (output_i >= 0 && output_i < output_planes && output_j >= 0 && output_j < output_rows && output_k >= 0 && output_k < output_cols) { - expected += - input(id, i, j, k) * - output_backward(od, output_i, output_j, output_k); + expected += input(id, i, j, k, /*batch*/ 0) * + output_backward(od, output_i, output_j, + output_k, /*batch*/ 0); } } } @@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest, const int output_cols = input_cols - patch_cols + 1; const int output_planes = input_planes - patch_planes + 1; - Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes, - input_depth); + // TODO(ezhulenev): Support backward kernel convolution without batch + // dimension. + Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows, + input_planes, input_depth); Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes, input_depth, output_depth); - Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows, - output_planes, output_depth); + Tensor<float, 5, RowMajor> output_backward( + /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth); output_backward = output_backward.constant(11.0f) + output_backward.random(); input = input.constant(2.0f) + input.random(); @@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest, if (output_i >= 0 && output_i < output_planes && output_j >= 0 && output_j < output_rows && output_k >= 0 && output_k < output_cols) { - expected += - input(k, j, i, id) * - output_backward(output_k, output_j, output_i, od); + expected += input(/*batch*/ 0, k, j, i, id) * + output_backward(/*batch*/ 0, output_k, output_j, + output_i, od); } } } diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h index 46ad38fb77..87e41b89b3 100644 --- a/tensorflow/core/kernels/eigen_benchmark.h +++ b/tensorflow/core/kernels/eigen_benchmark.h @@ -76,6 +76,9 @@ class SpatialConvolutionBenchmarksSuite { void SpatialConvolutionBackwardInput(Dimensions input_dims, Dimensions filter_dims) { + using OutputBackward = TTypes<float, 4>::ConstTensor; + using InputBackward = TTypes<float, 4>::Tensor; + Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height input_dims[2], // input_width @@ -85,37 +88,37 @@ class SpatialConvolutionBenchmarksSuite { Eigen::Index input_rows = input_dims[1]; Eigen::Index input_cols = input_dims[2]; - Scalar* input_data = - static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* filter_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); - Scalar* output_data = + Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); - device_.memset(input_data, 123, BufferSize(input_dims)); device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); - Input input(input_data, input_dims); Filter filter(filter_data, filter_dims); - Output output(output_data, output_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - output.device(device_) = Eigen::SpatialConvolutionBackwardInput( - filter, input, input_rows, input_cols); - tensorflow::testing::DoNotOptimize(output); + input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); } ::tensorflow::testing::StopTiming(); - device_.deallocate(input_data); device_.deallocate(filter_data); - device_.deallocate(output_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); } void SpatialConvolutionBackwardKernel(Dimensions input_dims, Dimensions filter_dims) { using OutputBackward = TTypes<float, 4>::ConstTensor; - using FilterGrad = TTypes<float, 4>::Tensor; + using FilterBackward = TTypes<float, 4>::Tensor; Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height @@ -130,7 +133,7 @@ class SpatialConvolutionBenchmarksSuite { static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); - Scalar* filter_data = + Scalar* filter_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); device_.memset(input_data, 123, BufferSize(input_dims)); @@ -138,19 +141,19 @@ class SpatialConvolutionBenchmarksSuite { Input input(input_data, input_dims); OutputBackward output_backward(output_backward_data, input_dims); - FilterGrad filter_grad(filter_data, filter_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - filter_grad.device(device_) = Eigen::SpatialConvolutionBackwardKernel( + filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel( input, output_backward, filter_rows, filter_cols); - tensorflow::testing::DoNotOptimize(filter_grad); + tensorflow::testing::DoNotOptimize(filter_backward); } ::tensorflow::testing::StopTiming(); device_.deallocate(input_data); device_.deallocate(output_backward_data); - device_.deallocate(filter_data); + device_.deallocate(filter_backward_data); } private: @@ -215,42 +218,45 @@ class CuboidConvolutionBenchmarksSuite { input_dims[3], // input_planes filter_dims[4]); // filter_count + using OutputBackward = TTypes<float, 5>::ConstTensor; + using InputBackward = TTypes<float, 5>::Tensor; + // Assuming that the convolution had SAME padding. Eigen::Index input_rows = input_dims[1]; Eigen::Index input_cols = input_dims[2]; Eigen::Index input_planes = input_dims[3]; - Scalar* input_data = - static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* filter_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); - Scalar* output_data = + Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); - device_.memset(input_data, 123, BufferSize(input_dims)); device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); - Input input(input_data, input_dims); Filter filter(filter_data, filter_dims); - Output output(output_data, output_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - output.device(device_) = Eigen::CuboidConvolutionBackwardInput( - filter, input, input_planes, input_rows, input_cols); - tensorflow::testing::DoNotOptimize(output); + input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, input_planes, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); } ::tensorflow::testing::StopTiming(); - device_.deallocate(input_data); device_.deallocate(filter_data); - device_.deallocate(output_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); } void CuboidConvolutionBackwardKernel(Dimensions input_dims, Dimensions filter_dims) { using OutputBackward = TTypes<float, 5>::ConstTensor; - using FilterGrad = TTypes<float, 5>::Tensor; + using FilterBackward = TTypes<float, 5>::Tensor; Dimensions output_dims(input_dims[0], // batch input_dims[1], // input_height @@ -267,7 +273,7 @@ class CuboidConvolutionBenchmarksSuite { static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); Scalar* output_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); - Scalar* filter_data = + Scalar* filter_backward_data = static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); device_.memset(input_data, 123, BufferSize(input_dims)); @@ -275,19 +281,19 @@ class CuboidConvolutionBenchmarksSuite { Input input(input_data, input_dims); OutputBackward output_backward(output_backward_data, output_dims); - FilterGrad filter_grad(filter_data, filter_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); ::tensorflow::testing::StartTiming(); for (int i = 0; i < iters_; ++i) { - filter_grad.device(device_) = Eigen::CuboidConvolutionBackwardKernel( + filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel( input, output_backward, filter_planes, filter_rows, filter_cols); - tensorflow::testing::DoNotOptimize(filter_grad); + tensorflow::testing::DoNotOptimize(filter_backward); } ::tensorflow::testing::StopTiming(); device_.deallocate(input_data); device_.deallocate(output_backward_data); - device_.deallocate(filter_data); + device_.deallocate(filter_backward_data); } private: diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc index 2a8308ef9a..ec949ddc84 100644 --- a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -48,8 +48,10 @@ void SpatialConvolution(int iters, int num_threads, benchmark.SpatialConvolution(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * (input_depth * filter_height * filter_width); + auto num_computed_elements = + (input_dims.TotalSize() / input_depth) * filter_count; + auto flops = + num_computed_elements * (input_depth * filter_height * filter_width); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -75,8 +77,9 @@ void SpatialConvolutionBackwardInput(int iters, int num_threads, benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * (input_depth * filter_height * filter_width); + auto num_computed_elements = input_dims.TotalSize(); + auto flops = + num_computed_elements * (input_depth * filter_height * filter_width); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -102,8 +105,9 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims); - auto filter_size = filter_dims.TotalSize(); - auto flops = filter_size * (input_batches * input_height * input_width); + auto num_computed_elements = filter_dims.TotalSize(); + auto flops = + num_computed_elements * (input_batches * input_height * input_width); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -123,6 +127,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \ FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW)) @@ -130,6 +135,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \ FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK( \ @@ -138,6 +144,7 @@ void SpatialConvolutionBackwardKernel(int iters, int num_threads, #define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \ } \ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ @@ -263,8 +270,9 @@ void CuboidConvolution(int iters, int num_threads, benchmark.CuboidConvolution(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * + auto num_computed_elements = + (input_dims.TotalSize() / input_depth) * filter_count; + auto flops = num_computed_elements * (input_depth * filter_height * filter_width * filter_planes); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -292,8 +300,8 @@ void CuboidConvolutionBackwardInput(int iters, int num_threads, benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims); - auto output_size = input_dims.TotalSize(); - auto flops = output_size * + auto num_computed_elements = input_dims.TotalSize(); + auto flops = num_computed_elements * (input_depth * filter_height * filter_width * filter_planes); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -321,9 +329,9 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims); - auto filter_size = filter_dims.TotalSize(); - auto flops = - filter_size * (input_batches * input_height * input_width * input_planes); + auto num_computed_elements = filter_dims.TotalSize(); + auto flops = num_computed_elements * + (input_batches * input_height * input_width * input_planes); ::tensorflow::testing::ItemsProcessed(flops * iters); } @@ -348,6 +356,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, #define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \ FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK( \ @@ -356,6 +365,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, #define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ @@ -365,6 +375,7 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, LABEL) \ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \ FC, FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ } \ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \ @@ -395,8 +406,17 @@ void CuboidConvolutionBackwardKernel(int iters, int num_threads, BM_CuboidConvolutions(8, // batch size 25, 25, 25, 4, // input: height, width, panes, depth 16, 5, 5, 5, // filter: count, height, width, panes - "conv3d"); - -BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d"); - -BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d"); + "conv3d_depth4"); +BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); + +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); + +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h index 62e9f9123d..6a9a2accd8 100644 --- a/tensorflow/core/kernels/eigen_cuboid_convolution.h +++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -21,6 +21,1412 @@ limitations under the License. namespace Eigen { +namespace internal { + +// WARNING: Most of the code here implicitly assumes that the matrix is in +// ColMajor layout. This is guaranteed by the tensor contraction (see +// TensorContraction.h). +// +// Inside Eigen a tensor contraction is represented by a matrix multiplication. +// We don't want to actually extract volume patches and reshape the result into +// a matrix (this involves allocating huge extra memory), so the patch +// extraction and reshape operations are implicit. +// +// TensorContractionInputMapper takes a matrix index and returns the coefficient +// (or the packet) of the "virtual tensor", that would be at that index if we +// were to actually reshape the result of patch extraction. +// +// TensorContractionSubMapper provides a similar view into the "virtual matrix" +// at the given vertical and horizontal offsets. +// +// "Virtual matrix" dimensions: +// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols +// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...) +// +// *) extracted patches are continuous in memory (innermost dimension assuming +// col major layout) +// +// With this dimensions: +// row - offset within a single patch (in code: patchId) +// col - index of the extracted patch (in code: patchIndex) +// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) +// +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar_, + typename Index, typename nocontract_t, typename contract_t, int Side, + int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, + int Alignment> +class TensorContractionInputMapper< + Scalar_, Index, Side, + TensorEvaluator<const TensorReshapingOp<NewDimension, + const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef Scalar_ Scalar; + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper VectorMapper; + typedef SubMapper LinearMapper; + typedef typename packet_traits<Scalar>::type Packet; + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper( + const TensorEvaluator< + const TensorReshapingOp< + NewDimension, + const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >, + Device>& tensor, + const nocontract_t&, const nocontract_t&, const contract_t&, + const contract_t&) + : m_impl(tensor.impl().impl()) { + if (internal::traits<ArgType>::Layout == ColMajor) { + m_patch_depth = tensor.impl().dimensions()[0]; + m_patch_planes = tensor.impl().dimensions()[1]; + m_patch_rows = tensor.impl().dimensions()[2]; + m_patch_cols = tensor.impl().dimensions()[3]; + m_num_patches = tensor.impl().dimensions()[4]; + } else { + const int NumDims = tensor.impl().dimensions().size(); + m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; + m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; + m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; + m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; + m_num_patches = tensor.impl().dimensions()[NumDims - 5]; + } + + // Strides for navigating through the single patch. + m_patch_plane_stride = m_patch_depth; + m_patch_row_stride = m_patch_planes * m_patch_plane_stride; + m_patch_col_stride = m_patch_rows * m_patch_row_stride; + + // Strides for the output tensor. + // IMPORTANT: These strides are used to locate an element in a patch at a + // depth zero (channel), which is not quite the same as "traditional" + // stride. + m_rowStride = m_patch_planes; + m_colStride = m_patch_rows * m_rowStride; + m_patchStride = m_colStride * m_patch_cols * m_patch_depth; + m_otherStride = m_patchStride * m_num_patches; + + m_outputPlanes = tensor.impl().outputPlanes(); + m_outputRows = tensor.impl().outputRows(); + m_outputCols = tensor.impl().outputCols(); + + m_outputPlanesRows = m_outputPlanes * m_outputRows; + + m_plane_strides = tensor.impl().userPlaneStride(); + m_row_strides = tensor.impl().userRowStride(); + m_col_strides = tensor.impl().userColStride(); + + m_in_plane_strides = tensor.impl().userInPlaneStride(); + m_in_row_strides = tensor.impl().userInRowStride(); + m_in_col_strides = tensor.impl().userInColStride(); + + m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); + m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); + m_patch_col_inflate_strides = tensor.impl().colInflateStride(); + + if (internal::traits<ArgType>::Layout == ColMajor) { + m_inputDepth = tensor.impl().impl().dimensions()[0]; + m_inputPlanes = tensor.impl().impl().dimensions()[1]; + m_inputRows = tensor.impl().impl().dimensions()[2]; + m_inputCols = tensor.impl().impl().dimensions()[3]; + } else { + const int NumDims = tensor.impl().impl().dimensions().size(); + m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; + m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; + m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; + m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; + } + + // Strides for navigating through the input tensor. + m_planeInputStride = m_inputDepth; + m_rowInputStride = m_inputDepth * m_inputPlanes; + m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; + m_patchInputStride = + m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; + + m_planePaddingTop = tensor.impl().planePaddingTop(); + m_rowPaddingTop = tensor.impl().rowPaddingTop(); + m_colPaddingLeft = tensor.impl().colPaddingLeft(); + + m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); + + m_fastPatchPlaneStride = + internal::TensorIntDivisor<Index>(m_patch_plane_stride); + m_fastPatchRowStride = + internal::TensorIntDivisor<Index>(m_patch_row_stride); + m_fastPatchColStride = + internal::TensorIntDivisor<Index>(m_patch_col_stride); + + m_fastInputPlaneStride = + internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides); + m_fastInputRowStride = + internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); + m_fastInputColStride = + internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); + + m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); + m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); + + m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth); + m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); + m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); + m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); + m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols); + + m_fastOutputPlanesRows = + internal::TensorIntDivisor<Index>(m_outputPlanesRows); + } + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) + : m_impl(base_mapper.m_impl) { + m_patch_depth = base_mapper.m_patch_depth; + m_patch_planes = base_mapper.m_patch_planes; + m_patch_rows = base_mapper.m_patch_rows; + m_patch_cols = base_mapper.m_patch_cols; + m_num_patches = base_mapper.m_num_patches; + + m_patch_plane_stride = base_mapper.m_patch_plane_stride; + m_patch_row_stride = base_mapper.m_patch_row_stride; + m_patch_col_stride = base_mapper.m_patch_col_stride; + + m_rowStride = base_mapper.m_rowStride; + m_colStride = base_mapper.m_colStride; + m_patchStride = base_mapper.m_patchStride; + m_otherStride = base_mapper.m_otherStride; + + m_planeInputStride = base_mapper.m_planeInputStride; + m_rowInputStride = base_mapper.m_rowInputStride; + m_colInputStride = base_mapper.m_colInputStride; + m_patchInputStride = base_mapper.m_patchInputStride; + m_otherInputStride = base_mapper.m_otherInputStride; + + m_inputDepth = base_mapper.m_inputDepth; + m_inputPlanes = base_mapper.m_inputPlanes; + m_inputRows = base_mapper.m_inputRows; + m_inputCols = base_mapper.m_inputCols; + + m_outputPlanes = base_mapper.m_outputPlanes; + m_outputRows = base_mapper.m_outputRows; + m_outputCols = base_mapper.m_outputCols; + + m_plane_strides = base_mapper.m_plane_strides; + m_row_strides = base_mapper.m_row_strides; + m_col_strides = base_mapper.m_col_strides; + + m_in_plane_strides = base_mapper.m_in_plane_strides; + m_in_row_strides = base_mapper.m_in_row_strides; + m_in_col_strides = base_mapper.m_in_col_strides; + + m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; + m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; + m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; + + m_planePaddingTop = base_mapper.m_planePaddingTop; + m_rowPaddingTop = base_mapper.m_rowPaddingTop; + m_colPaddingLeft = base_mapper.m_colPaddingLeft; + + m_outputPlanesRows = base_mapper.m_outputPlanesRows; + + m_fastNumPatches = base_mapper.m_fastNumPatches; + m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride; + m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; + m_fastPatchColStride = base_mapper.m_fastPatchColStride; + m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; + m_fastInputRowStride = base_mapper.m_fastInputRowStride; + m_fastInputColStride = base_mapper.m_fastInputColStride; + m_fastRowStride = base_mapper.m_fastRowStride; + m_fastColStride = base_mapper.m_fastColStride; + m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; + m_fastOutputRows = base_mapper.m_fastOutputRows; + m_fastOutputCols = base_mapper.m_fastOutputCols; + m_fastDimZero = base_mapper.m_fastDimZero; + m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; + } + + // If true, turns off some optimizations for loading packets since the image + // patches are "non-standard" such as there are non-trivial strides or + // inflations in the input. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_in_plane_strides != 1 || m_in_row_strides != 1 || + m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || + m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the coefficient at the patchIndex location instead of the usual + // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the + // gpu code. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the packet at the patchIndex location instead of the usual m_rowIndex, + // m_colIndex, m_otherIndex. This is currently only used by the gpu code. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { + return m_impl; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_impl.template packet<Unaligned>(inputIndex); + } + + private: + friend class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>; + + // Load coefficient from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset * m_in_col_strides; + const Index origInputCol = + (m_patch_col_inflate_strides == 1) + ? inputCol + : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); + + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset * m_in_row_strides; + const Index origInputRow = + (m_patch_row_inflate_strides == 1) + ? inputRow + : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); + + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; + const Index origInputPlane = + (m_patch_plane_inflate_strides == 1) + ? inputPlane + : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); + + if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || + origInputCol >= m_inputCols || origInputRow >= m_inputRows || + origInputPlane >= m_inputPlanes || + (inputCol != origInputCol * m_patch_col_inflate_strides) || + (inputRow != origInputRow * m_patch_row_inflate_strides) || + (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + origInputPlane * m_planeInputStride + + origInputRow * m_rowInputStride + + origInputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // This is the same as loadCoeff(...), but optimized for all `inflate_strides` + // and `in_strides` equal to 1 (template specialization without templates). + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + eigen_assert(!nonStandardPatches()); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || + inputRow >= m_inputRows || inputPlane < 0 || + inputPlane >= m_inputPlanes) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // Load packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits<Packet>::size; + + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + if (nonStandardPatches()) { + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits<Packet>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } else { + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + + const Index patchOffsets[2] = { + patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; + + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + eigen_assert(colOffsets[0] <= colOffsets[1]); + + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + return internal::pset1<Packet>(Scalar(0)); + } + + if (inputCols[0] == inputCols[1]) { + const Index rowOffsets[2] = { + (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, + (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; + eigen_assert(rowOffsets[0] <= rowOffsets[1]); + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { + return internal::pset1<Packet>(Scalar(0)); + } + + if (inputRows[0] == inputRows[1]) { + const Index planeOffsets[2] = { + patchOffsets[0] - colOffsets[0] * m_colStride - + rowOffsets[0] * m_rowStride, + patchOffsets[1] - colOffsets[1] * m_colStride - + rowOffsets[1] * m_rowStride}; + eigen_assert(planeOffsets[0] <= planeOffsets[1]); + const Index inputPlanes[2] = {planeIndex + planeOffsets[0], + planeIndex + planeOffsets[1]}; + + if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { + return internal::pset1<Packet>(Scalar(0)); + } + + if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = + depth + inputPlanes[0] * m_planeInputStride + + inputRows[0] * m_rowInputStride + + inputCols[0] * m_colInputStride + otherIndex; + return m_impl.template packet<Unaligned>(inputIndex); + } + } + } + } + + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits<Packet>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + eigen_assert(!nonStandardPatches()); + eigen_assert((patchDepth() % packetSize) == 0); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); + + const Index colOffset = patchOffset / m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + + const Index inputCol = colIndex + colOffset; + const Index inputRow = rowIndex + rowOffset; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || + inputCol >= m_inputCols || inputRow >= m_inputRows || + inputPlane >= m_inputPlanes) { + return internal::pset1<Packet>(Scalar(0)); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template packet<Unaligned>(inputIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, + Index colIndex, Index otherIndex) const { + const int packetSize = internal::unpacket_traits<Packet>::size; + EIGEN_ALIGN_MAX + typename internal::remove_const<Scalar>::type values[packetSize]; + for (int i = 0; i < packetSize; ++i) { + values[i] = + loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); + } + Packet rslt = internal::pload<Packet>(values); + return rslt; + } + + // Precompute the indices (plane, row, col, other) of the first element of + // the given patch index, within the output tensor of the TensorVolumePatchOp. + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( + Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, + Index& otherIndex) const { + const size_t NumInputDims = array_size< + typename TensorEvaluator<ArgType, Device>::Dimensions>::value; + + // Check if patchIndex might contain batch and other dimensions. + otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; + + // Compute index of the patch within the batch (and other dimensions). + const Index patch3DIndex = (NumInputDims == 4) + ? patchIndex + : (patchIndex - otherIndex * m_num_patches); + + otherIndex *= m_patchInputStride; + + colIndex = patch3DIndex / m_fastOutputPlanesRows; + rowIndex = + (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; + planeIndex = + patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; + + colIndex = colIndex * m_col_strides - m_colPaddingLeft; + rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; + planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; + } + + Index m_patch_depth; // number of channels in the patch + Index m_patch_planes; // number of planes in the patch + Index m_patch_rows; // number of rows in the patch + Index m_patch_cols; // number of columns in the patch + Index m_num_patches; // number of patches to extract + + // Strides for navigating through the single patch. + Index m_patch_plane_stride; + Index m_patch_row_stride; + Index m_patch_col_stride; + + // Strides for the output tensor (depth is not the part of the stride). + Index m_rowStride; + Index m_colStride; + Index m_patchStride; + Index m_otherStride; + + Index m_planeInputStride; // Plane stride in the input tensor + Index m_rowInputStride; // Row stride in the input tensor + Index m_colInputStride; // Col stride in the input tensor + Index m_patchInputStride; // Patch stride in the input tensor + Index m_otherInputStride; + + Index m_inputDepth; // Depth of the input tensor + Index m_inputPlanes; // Number of planes in the input tensor + Index m_inputRows; // Number of rows in the input tensor + Index m_inputCols; // Number of cols in the input tensor + + Index m_outputPlanes; // Number of output planes + Index m_outputRows; // Number of output rows + Index m_outputCols; // Number of output cols + Index m_outputPlanesRows; // Cached outputPlanes * outputRows. + + Index m_plane_strides; // User specified plane stride + Index m_row_strides; // User specified row stride + Index m_col_strides; // User specified col stride + + // User specified plane/row/col atrous convolution strides. + Index m_in_plane_strides; + Index m_in_row_strides; + Index m_in_col_strides; + + // User specified plane/row/col inflation strides in the image patch. + Index m_patch_plane_inflate_strides; + Index m_patch_row_inflate_strides; + Index m_patch_col_inflate_strides; + + Index m_planePaddingTop; // Plane padding + Index m_rowPaddingTop; // Row padding + Index m_colPaddingLeft; // Column padding + + // Fast representation of various divisors. + internal::TensorIntDivisor<Index> m_fastNumPatches; + + internal::TensorIntDivisor<Index> m_fastPatchPlaneStride; + internal::TensorIntDivisor<Index> m_fastPatchRowStride; + internal::TensorIntDivisor<Index> m_fastPatchColStride; + + internal::TensorIntDivisor<Index> m_fastInputPlaneStride; + internal::TensorIntDivisor<Index> m_fastInputRowStride; + internal::TensorIntDivisor<Index> m_fastInputColStride; + + internal::TensorIntDivisor<Index> m_fastRowStride; + internal::TensorIntDivisor<Index> m_fastColStride; + + internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth + internal::TensorIntDivisor<Index> m_fastOutputPlanes; + internal::TensorIntDivisor<Index> m_fastOutputRows; + internal::TensorIntDivisor<Index> m_fastOutputCols; + internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; + + const TensorEvaluator<ArgType, Device> m_impl; +}; + +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, int Side, + int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, + int Alignment> +class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp<NewDimension, + const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef typename packet_traits<Scalar>::type Packet; + typedef typename packet_traits<Scalar>::half HalfPacket; + + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + ParentMapper; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef Self LinearMapper; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper), + m_depth_offset(vert_offset), + m_col_offset(horiz_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const Self& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper.m_base_mapper), + m_depth_offset(vert_offset + base_mapper.m_depth_offset), + m_col_offset(horiz_offset + base_mapper.m_col_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, + Index j) const { + return m_base_mapper(i + m_depth_offset, j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, + Index j) const { + return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, + j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar + loadCoeffStandard(Index i) const { + return m_base_mapper.loadCoeffStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { + return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + loadPacketStandard(Index i) const { + return m_base_mapper.loadPacketStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + template <typename Packet> + EIGEN_DEVICE_FUNC bool aligned(Index) const { + return false; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_base_mapper.nonStandardPatches(); + } + + // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row, + // plane and depth index respectively that fits into the peeled_k elements + // starting at m_depth_offset. + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { + const Index max_col = + fastPatchColStride().divide(m_depth_offset + peeled_k); + return std::min<Index>(1 + max_col, patchCols()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, + const Index col) const { + const Index max_row = fastPatchRowStride().divide( + m_depth_offset + peeled_k - col * patchColStride()); + return std::min<Index>(1 + max_row, patchRows()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col, + const Index row) const { + const Index max_plane = fastPatchPlaneStride().divide( + m_depth_offset + peeled_k - col * patchColStride() - + row * patchRowStride()); + return std::min<Index>(1 + max_plane, patchPlanes()); + } + + // MaxDepth uses only the remaining number of elements in the peeled_k. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, + const Index start_depth) const { + return std::min<Index>(start_depth + num_elements, patchDepth()); + } + + // Every register matters in this code, so sometimes to prevent register + // spilling, instead of the variable that you would expect to see, we use + // another one, that is guaranteed to have the same value. E.g. patch depth is + // always the same as input depth, and it's also the same as input plane + // stride. Bunch of other parameters have similar relations. + + typedef internal::TensorIntDivisor<Index> IndexDivisor; + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { + eigen_assert(m_base_mapper.m_patch_depth == + m_base_mapper.m_planeInputStride && + "Patch depth must be equal to plane input stride."); + return m_base_mapper.m_planeInputStride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { + eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride && + "Patch planes must be equal to row stride."); + return m_base_mapper.m_rowStride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { + return m_base_mapper.m_patch_rows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { + return m_base_mapper.m_patch_cols; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlaneStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && + "Patch depth must be equal to patch plane stride."); + return patchDepth(); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRowStride() const { + return m_base_mapper.m_patch_row_stride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchColStride() const { + return m_base_mapper.m_patch_col_stride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride && + "Patch depth must be equal to patch plane stride."); + return m_base_mapper.m_fastDimZero; // patch_depth + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { + return m_base_mapper.m_fastPatchRowStride; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { + return m_base_mapper.m_fastPatchColStride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { + const Index p = m_planeIndex + plane; + return p < 0 || p >= m_base_mapper.m_inputPlanes; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { + const Index r = m_rowIndex + row; + return r < 0 || r >= m_base_mapper.m_inputRows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { + const Index c = m_colIndex + col; + return c < 0 || c >= m_base_mapper.m_inputCols; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, + const Index col) const { + const Index p = m_planeIndex + plane; + const Index r = m_rowIndex + row; + const Index c = m_colIndex + col; + return p * m_base_mapper.m_planeInputStride + + r * m_base_mapper.m_rowInputStride + + c * m_base_mapper.m_colInputStride + m_otherIndex; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index planeOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + const Index planeOffset = patchOffset - + colOffset * m_base_mapper.m_colStride - + rowOffset * m_base_mapper.m_rowStride; + return planeOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index rowOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + return rowOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index colOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + return colOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index depthOffset() const { + return m_depth_offset % patchDepth(); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper + getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); + } + + private: + const ParentMapper& m_base_mapper; + Index m_depth_offset; // First row in the input matrix + Index m_col_offset; // First col in the input matrix + + // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base + // indices for the first element in a patch specified by col_offset + // (see computeBaseIndices(...) for details). + Index m_planeIndex; + Index m_rowIndex; + Index m_colIndex; + Index m_otherIndex; +}; + +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted volume patches) in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 +// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 +// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 +// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 +// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 +// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 +// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 +// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 +// A8 ... +// ... +// +// *) A, B, C, ... - patches extracted from the original input. +// *) A0, A1, A2 ... - values from the same patch at different offsets. +// +// The traversal (packed rhs memory) order (B0 besides A0 in memory): +// A0 B0 C0 D0 A1 B1 C1 D1 ... +// E0 F0 G0 H0 E1 F1 G1 H1 ... +// ... +// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) +// +// This traversal order must be the same as in default gemm_pack_rhs defined in +// GeneralBlockPanelKernel.h. +// +// *) nr - number of registers along the 'n' dimension. +// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix +// Multiplication" paper. +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, + int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, + int Alignment, int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + + typedef SubMapper DataMapper; + typedef typename packet_traits<Scalar>::type Packet; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if ((packet_size % 4) == 0 && !non_standard_patches) { + // FAST PATH: + // Iterate over patch columns, rows and planes if we know that a single + // packet do not span across multiple planes, rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_plane = ((c == start_col) && (r == start_row)) + ? rhs.planeOffset() + : 0; + const Index max_plane = rhs.maxPlane(peeled_k, c, r); + + const bool pad_row0 = pad_col0 || dm0.padRow(r); + const bool pad_row1 = pad_col1 || dm1.padRow(r); + const bool pad_row2 = pad_col2 || dm2.padRow(r); + const bool pad_row3 = pad_col3 || dm3.padRow(r); + + for (Index p = start_plane; p < max_plane; ++p) { + eigen_assert(k <= peeled_k); + + const bool pad0 = pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index start_depth = + ((c == start_col) && (r == start_row) && (p == start_plane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock<Packet, 4> kernel; + kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + } else { + // Packet can span multiple planes, rows or columns, so we have to go + // though the slower "standard" path. + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 4> kernel; + kernel.packet[0] = dm0.loadPacketStandard(k); + kernel.packet[1] = dm1.loadPacketStandard(k); + kernel.packet[2] = dm2.loadPacketStandard(k); + kernel.packet[3] = dm3.loadPacketStandard(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } + } + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!non_standard_patches) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Template specialization for packet_size = 2. We must special-case packet +// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, + int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + typedef typename packet_traits<Scalar>::type Packet; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const int packet_size = 2; + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if (!non_standard_patches) { + // FAST PATH: + // Iterate over patch columns, rows and planes if we know that a single + // packet do not span across multiple planes, rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + + const Index start_plane = ((c == start_col) && (r == start_row)) + ? rhs.planeOffset() + : 0; + const Index max_plane = rhs.maxPlane(peeled_k, c, r); + + const bool pad_row0 = dm0.padRow(r); + const bool pad_row1 = dm1.padRow(r); + const bool pad_row2 = dm2.padRow(r); + const bool pad_row3 = dm3.padRow(r); + + for (Index p = start_plane; p < max_plane; ++p) { + eigen_assert(k <= peeled_k); + + const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index start_depth = + ((c == start_col) && (r == start_row) && (p == start_plane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = dm0.loadPacketStandard(k); + kernel0.packet[1] = dm1.loadPacketStandard(k); + kernel1.packet[0] = dm2.loadPacketStandard(k); + kernel1.packet[1] = dm3.loadPacketStandard(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } + } + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Special case for non-vectorized types such as float16 (packet_size = 1). +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, + int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + const Index packet_cols4 = (cols / 4) * 4; + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + if (!rhs.nonStandardPatches()) { + for (Index k = 0; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (Index k = 0; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // Copy the remaining columns one at a time (nr==1). + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +} // namespace internal + /** CuboidConvolution * \ingroup CXX11_NeuralNetworks_Module * @@ -98,7 +1504,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel, isColMajor ? kern.dimensions()[1] : kern.dimensions()[3]; // Spatial size of the kernel. - const TensorIndex kernelDepth = + const TensorIndex kernelPlanes = isColMajor ? kern.dimensions()[2] : kern.dimensions()[2]; const TensorIndex kernelRows = isColMajor ? kern.dimensions()[3] : kern.dimensions()[1]; @@ -118,27 +1524,27 @@ CuboidConvolution(const Input& input, const Kernel& kernel, const TensorIndex inputCols = isColMajor ? in.dimension(3) : in.dimension(NumDims - 4); - TensorIndex out_depth; + TensorIndex out_planes; TensorIndex out_height; TensorIndex out_width; switch (padding_type) { case PADDING_VALID: - out_depth = Eigen::divup(inputPlanes - kernelDepth + 1, - static_cast<TensorIndex>(stridePlanes)); + out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1, + static_cast<TensorIndex>(stridePlanes)); out_height = Eigen::divup(inputRows - kernelRows + 1, static_cast<TensorIndex>(strideRows)); out_width = Eigen::divup(inputCols - kernelCols + 1, static_cast<TensorIndex>(strideCols)); break; case PADDING_SAME: - out_depth = + out_planes = Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); out_height = Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); break; default: - out_depth = 0; + out_planes = 0; out_height = 0; out_width = 0; eigen_assert(false && "unexpected padding"); @@ -147,9 +1553,9 @@ CuboidConvolution(const Input& input, const Kernel& kernel, DSizes<TensorIndex, 2> kernel_dims; if (isColMajor) { kernel_dims[0] = kernelFilters; - kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols; + kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols; } else { - kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols; + kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols; kernel_dims[1] = kernelFilters; } @@ -160,15 +1566,15 @@ CuboidConvolution(const Input& input, const Kernel& kernel, DSizes<TensorIndex, 2> pre_contract_dims; if (isColMajor) { pre_contract_dims[0] = - kernelChannels * kernelDepth * kernelRows * kernelCols; - pre_contract_dims[1] = out_depth * out_height * out_width; + kernelChannels * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = out_planes * out_height * out_width; for (int i = 4; i < NumDims; ++i) { pre_contract_dims[1] *= in.dimension(i); } } else { pre_contract_dims[1] = - kernelChannels * kernelDepth * kernelRows * kernelCols; - pre_contract_dims[0] = out_depth * out_height * out_width; + kernelChannels * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[0] = out_planes * out_height * out_width; for (int i = 0; i < NumDims - 4; ++i) { pre_contract_dims[0] *= in.dimension(i); } @@ -187,7 +1593,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel, DSizes<TensorIndex, NumDims> post_contract_dims; if (isColMajor) { post_contract_dims[0] = kernelFilters; - post_contract_dims[1] = out_depth; + post_contract_dims[1] = out_planes; post_contract_dims[2] = out_height; post_contract_dims[3] = out_width; for (int i = 4; i < NumDims; ++i) { @@ -195,7 +1601,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel, } } else { post_contract_dims[NumDims - 1] = kernelFilters; - post_contract_dims[NumDims - 2] = out_depth; + post_contract_dims[NumDims - 2] = out_planes; post_contract_dims[NumDims - 3] = out_height; post_contract_dims[NumDims - 4] = out_width; for (int i = 0; i < NumDims - 4; ++i) { @@ -208,13 +1614,13 @@ CuboidConvolution(const Input& input, const Kernel& kernel, kernel.reshape(kernel_dims) .contract(input .extract_volume_patches( - kernelDepth, kernelRows, kernelCols, stridePlanes, + kernelPlanes, kernelRows, kernelCols, stridePlanes, strideRows, strideCols, padding_type) .reshape(pre_contract_dims), contract_dims) .reshape(post_contract_dims), input - .extract_volume_patches(kernelDepth, kernelRows, kernelCols, + .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, stridePlanes, strideRows, strideCols, padding_type) .reshape(pre_contract_dims) diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h index a4dff4b91c..e926d73f87 100644 --- a/tensorflow/core/kernels/eigen_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h @@ -22,8 +22,36 @@ namespace Eigen { namespace internal { -// TODO: Consolidate this part of the code with the image patch extraction code -// since they are both very similar. +// WARNING: Most of the code here implicitly assumes that the matrix is in +// ColMajor layout. This is guaranteed by the tensor contraction (see +// TensorContraction.h). +// +// Inside Eigen a tensor contraction is represented by a matrix multiplication. +// We don't want to actually extract image patches and reshape the result into +// a matrix (this involves allocating huge extra memory), so the patch +// extraction and reshape operations are implicit. +// +// TensorContractionInputMapper takes a matrix index and returns the coefficient +// (or the packet) of the "virtual tensor", that would be at that index if we +// were to actually reshape the result of patch extraction. +// +// TensorContractionSubMapper provides a similar view into the "virtual matrix" +// at the given vertical and horizontal offsets. +// +// "Virtual matrix" dimensions: +// *0: kernelChannels * kernelRows * kernelCols; +// 1: out_height * out_width; * OTHERS (e.g batches, etc...) +// +// *) extracted patches are continuous in memory (innermost dimension assuming +// col major layout) +// +// With this dimensions: +// row - offset within a single patch (in code: patchId) +// col - index of the extracted patch (in code: patchIndex) +// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) +// +// TODO(ezhulenev): Consolidate this part of the code with the image patch +// extraction code since they are both very similar. template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device, typename Scalar_, typename Index, typename nocontract_t, typename contract_t, int Side, int packet_size, @@ -77,12 +105,17 @@ class TensorContractionInputMapper< m_patch_cols = tensor.impl().dimensions()[2]; m_num_patches = tensor.impl().dimensions()[3]; } else { - const int NumDims = tensor.impl().dimensions().size(); + const size_t NumDims = tensor.impl().dimensions().size(); patch_depth = tensor.impl().dimensions()[NumDims - 1]; patch_rows = tensor.impl().dimensions()[NumDims - 2]; m_patch_cols = tensor.impl().dimensions()[NumDims - 3]; m_num_patches = tensor.impl().dimensions()[NumDims - 4]; } + + // Strides for navigating through the single patch. + m_patch_row_stride = patch_depth; + m_patch_col_stride = patch_rows * m_patch_row_stride; + m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); m_patch_col_inflate_strides = tensor.impl().colInflateStride(); @@ -111,6 +144,10 @@ class TensorContractionInputMapper< m_rowPaddingTop = tensor.impl().rowPaddingTop(); m_colPaddingLeft = tensor.impl().colPaddingLeft(); + m_fastPatchRowStride = + internal::TensorIntDivisor<Index>(m_patch_row_stride); + m_fastPatchColStride = + internal::TensorIntDivisor<Index>(m_patch_col_stride); m_fastInputRowStride = internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); m_fastInputColStride = @@ -126,6 +163,10 @@ class TensorContractionInputMapper< : m_impl(base_mapper.m_impl) { m_patch_cols = base_mapper.m_patch_cols; m_num_patches = base_mapper.m_num_patches; + + m_patch_row_stride = base_mapper.m_patch_row_stride; + m_patch_col_stride = base_mapper.m_patch_col_stride; + m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; @@ -148,6 +189,8 @@ class TensorContractionInputMapper< m_rowPaddingTop = base_mapper.m_rowPaddingTop; m_colPaddingLeft = base_mapper.m_colPaddingLeft; + m_fastPatchRowStride = base_mapper.m_fastPatchRowStride; + m_fastPatchColStride = base_mapper.m_fastPatchColStride; m_fastInputRowStride = base_mapper.m_fastInputRowStride; m_fastInputColStride = base_mapper.m_fastInputColStride; m_fastNumPatches = base_mapper.m_fastNumPatches; @@ -238,6 +281,8 @@ class TensorContractionInputMapper< nocontract_t, contract_t, packet_size, inner_dim_contiguous, inner_dim_reordered, Alignment>; + // Load coefficient from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex, Index colIndex, Index otherIndex) const { @@ -250,6 +295,7 @@ class TensorContractionInputMapper< (m_patch_col_inflate_strides == 1) ? inputCol : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); + const Index rowOffset = patchOffset - colOffset * m_colStride; const Index inputRow = rowIndex + rowOffset * m_in_row_strides; const Index origInputRow = @@ -268,6 +314,8 @@ class TensorContractionInputMapper< return m_impl.coeff(inputIndex); } + // This is the same as loadCoeff(...), but optimized for all `inflate_strides` + // and `in_strides` equal to 1 (template specialization without templates). EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex, Index colIndex, @@ -276,10 +324,9 @@ class TensorContractionInputMapper< // Find the offset of the element wrt the location of the first element. const Index patchOffset = patchId / m_fastDimZero; - const Index colOffset = patchOffset / m_fastColStride; - const Index inputCol = colIndex + colOffset; const Index rowOffset = patchOffset - colOffset * m_colStride; + const Index inputCol = colIndex + colOffset; const Index inputRow = rowIndex + rowOffset; if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || inputRow >= m_inputRows) { @@ -291,6 +338,8 @@ class TensorContractionInputMapper< return m_impl.coeff(inputIndex); } + // Load packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex, Index colIndex, @@ -318,12 +367,14 @@ class TensorContractionInputMapper< if ((patchDepth() % packetSize) == 0) { return loadPacketFast(patchId, rowIndex, colIndex, otherIndex); } else { + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + const Index patchOffsets[2] = { patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, patchOffsets[1] / m_fastColStride}; - const Index inputCols[2] = {colIndex + colOffsets[0], colIndex + colOffsets[1]}; if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { @@ -371,8 +422,8 @@ class TensorContractionInputMapper< eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); const Index colOffset = patchOffset / m_fastColStride; - const Index inputCol = colIndex + colOffset; const Index rowOffset = patchOffset - colOffset * m_colStride; + const Index inputCol = colIndex + colOffset; const Index inputRow = rowIndex + rowOffset; if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols || inputRow >= m_inputRows) { @@ -401,7 +452,7 @@ class TensorContractionInputMapper< EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( Index patchIndex, Index& rowIndex, Index& colIndex, Index& otherIndex) const { - const int NumInputDims = array_size< + const size_t NumInputDims = array_size< typename TensorEvaluator<ArgType, Device>::Dimensions>::value; otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches; const Index patch2DIndex = (NumInputDims == 3) @@ -414,8 +465,15 @@ class TensorContractionInputMapper< rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; } - Index m_patch_cols; // number of colums in the patch - Index m_num_patches; // number of patches to extract. + Index m_patch_cols; // number of columns in the patch + Index m_num_patches; // number of patches to extract. + + // Strides for navigating through the single patch. + Index m_patch_row_stride; + Index m_patch_col_stride; + internal::TensorIntDivisor<Index> m_fastPatchRowStride; + internal::TensorIntDivisor<Index> m_fastPatchColStride; + Index m_patch_row_inflate_strides; // the strides for row inflation in the // image patch Index m_patch_col_inflate_strides; // the strides for col inflation in the @@ -549,6 +607,40 @@ class TensorContractionSubMapper< return m_base_mapper.nonStandardPatches(); } + // Max(Col|Row|Depth): compute the upper limit for the column, row and depth + // index respectively that fits into the peeled_k elements starting at + // m_depth_offset. + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const { + const Index max_col = + fastPatchColStride().divide(m_depth_offset + peeled_k); + return std::min<Index>(1 + max_col, patchCols()); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k, + const Index col) const { + const Index max_row = fastPatchRowStride().divide( + m_depth_offset + peeled_k - col * patchColStride()); + return std::min<Index>(1 + max_row, patchRows()); + } + + // MaxDepth uses only the remaining number of elements in the peeled_k. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements, + const Index start_depth) const { + return std::min<Index>(start_depth + num_elements, patchDepth()); + } + + // Every register matters in this code, so sometimes to prevent register + // spilling, instead of the variable that you would expect to see, we use + // another one, that is guaranteed to have the same value. E.g. patch depth is + // always the same as input depth, and it's also the same as input row stride. + // Bunch of other parameters have similar relations. + + typedef internal::TensorIntDivisor<Index> IndexDivisor; + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_base_mapper.m_rowInputStride; @@ -563,6 +655,28 @@ class TensorContractionSubMapper< } EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRowStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && + "Patch depth must be equal to patch row stride."); + return patchDepth(); + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchColStride() const { + return m_base_mapper.m_patch_col_stride; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const { + eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride && + "Patch depth must be equal to patch row stride."); + return m_base_mapper.m_fastDimZero; // patch_depth + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const { + return m_base_mapper.m_fastPatchColStride; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, const Index baseIndex) const { const Index inputIndex = depth + baseIndex; @@ -603,8 +717,7 @@ class TensorContractionSubMapper< EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Index depthOffset() const { - const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth(); - return patchOffset; + return m_depth_offset % patchDepth(); } EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper @@ -617,12 +730,44 @@ class TensorContractionSubMapper< Index m_depth_offset; // First row in the input matrix Index m_col_offset; // First col in the input matrix - Index m_rowIndex; // precomputed row index corresponding to the col offset - Index m_colIndex; // precomputed col index corresponding to the col offset - Index - m_otherIndex; // precomputed other index corresponding to the col offset + // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base + // indices for the first element in a patch specified by col_offset + // (see computeBaseIndices(...) for details). + Index m_rowIndex; + Index m_colIndex; + Index m_otherIndex; }; +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted image patches) in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0 +// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1 +// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2 +// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3 +// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4 +// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5 +// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6 +// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7 +// A8 ... +// ... +// +// *) A, B, C, ... - patches extracted from the original input. +// *) A0, A1, A2 ... - values from the same patch at different offsets. +// +// The traversal (packed rhs memory) order (B0 besides A0 in memory): +// A0 B0 C0 D0 A1 B1 C1 D1 ... +// E0 F0 G0 H0 E1 F1 G1 H1 ... +// ... +// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4) +// +// This traversal order must be the same as in default gemm_pack_rhs defined in +// GeneralBlockPanelKernel.h. +// +// *) nr - number of registers along the 'n' dimension. +// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix +// Multiplication" paper. template <typename NewDimension, DenseIndex Rows, DenseIndex Cols, typename ArgType, typename Device, typename Scalar, typename Index, typename nocontract_t, typename contract_t, int packet_size, @@ -649,9 +794,9 @@ struct gemm_pack_rhs< inner_dim_reordered, Alignment> SubMapper; typedef SubMapper DataMapper; + typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC - static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, @@ -660,9 +805,6 @@ struct gemm_pack_rhs< eigen_assert(stride == 0); eigen_assert(offset == 0); - EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); - typedef typename packet_traits<Scalar>::type Packet; - const Index packet_cols4 = (cols / 4) * 4; const Index peeled_k = (depth / packet_size) * packet_size; const bool non_standard_patches = rhs.nonStandardPatches(); @@ -675,30 +817,27 @@ struct gemm_pack_rhs< Index k = 0; if ((packet_size % 4) == 0 && !non_standard_patches) { - const Index patch_depth = rhs.patchDepth(); - if ((patch_depth % packet_size) == 0) { - const Index patch_cols = rhs.patchCols(); - const Index patch_rows = rhs.patchRows(); - - const Index startCol = rhs.colOffset(); - const Index max_cols = std::min<Index>( - ceil_div(peeled_k, patch_rows * patch_depth) + startCol, - patch_cols); - - for (Index c = startCol; c < max_cols; ++c) { - eigen_assert(k < peeled_k); - const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; - const Index max_rows = std::min<Index>( - ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) + - startRow, - patch_rows); + // FAST PATH: + // Iterate over patch columns and rows, if we know that a single + // packet do not span across multiple rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); const bool pad_col0 = dm0.padCol(c); const bool pad_col1 = dm1.padCol(c); const bool pad_col2 = dm2.padCol(c); const bool pad_col3 = dm3.padCol(c); - for (Index r = startRow; r < max_rows; ++r) { - eigen_assert(k < peeled_k); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + const bool pad0 = pad_col0 || dm0.padRow(r); const bool pad1 = pad_col1 || dm1.padRow(r); const bool pad2 = pad_col2 || dm2.padRow(r); @@ -709,14 +848,13 @@ struct gemm_pack_rhs< const Index idx2 = dm2.baseIndex(r, c); const Index idx3 = dm3.baseIndex(r, c); - const Index startDepth = - ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0; - const Index max_depth = - std::min<Index>(peeled_k - c * patch_rows * patch_depth - - r * patch_depth + startDepth, - patch_depth); - eigen_assert((max_depth - startDepth) % packet_size == 0); - for (Index d = startDepth; d < max_depth; d += packet_size) { + const Index start_depth = ((c == start_col) && (r == start_row)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { eigen_assert(k < peeled_k); PacketBlock<Packet, 4> kernel; kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) @@ -738,19 +876,9 @@ struct gemm_pack_rhs< } } - for (; k < peeled_k; k += packet_size) { - PacketBlock<Packet, 4> kernel; - kernel.packet[0] = dm0.loadPacketFast(k); - kernel.packet[1] = dm1.loadPacketFast(k); - kernel.packet[2] = dm2.loadPacketFast(k); - kernel.packet[3] = dm3.loadPacketFast(k); - ptranspose(kernel); - pstoreu(block + 0 * packet_size, kernel.packet[0]); - pstoreu(block + 1 * packet_size, kernel.packet[1]); - pstoreu(block + 2 * packet_size, kernel.packet[2]); - pstoreu(block + 3 * packet_size, kernel.packet[3]); - block += 4 * packet_size; - } + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + } else { for (; k < peeled_k; k += packet_size) { PacketBlock<Packet, 4> kernel; @@ -767,6 +895,8 @@ struct gemm_pack_rhs< } } } + + // Copy the remaining coefficients of the column block after the peeled_k. if (!rhs.nonStandardPatches()) { for (; k < depth; k++) { block[0] = dm0.loadCoeffStandard(k); @@ -824,9 +954,9 @@ struct gemm_pack_rhs< Alignment> SubMapper; typedef SubMapper DataMapper; + typedef typename packet_traits<Scalar>::type Packet; - EIGEN_DEVICE_FUNC - static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, @@ -835,9 +965,6 @@ struct gemm_pack_rhs< eigen_assert(stride == 0); eigen_assert(offset == 0); - EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); - typedef typename packet_traits<Scalar>::type Packet; - const int packet_size = 2; const Index packet_cols4 = (cols / 4) * 4; const Index peeled_k = (depth / packet_size) * packet_size; @@ -851,30 +978,27 @@ struct gemm_pack_rhs< Index k = 0; if (!non_standard_patches) { - const Index patch_depth = rhs.patchDepth(); - if ((patch_depth % packet_size) == 0) { - const Index patch_cols = rhs.patchCols(); - const Index patch_rows = rhs.patchRows(); - - const Index startCol = rhs.colOffset(); - const Index max_cols = std::min<Index>( - ceil_div(peeled_k, patch_rows * patch_depth) + startCol, - patch_cols); - - for (Index c = startCol; c < max_cols; ++c) { - eigen_assert(k < peeled_k); - const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; - const Index max_rows = std::min<Index>( - ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) + - startRow, - patch_rows); + // FAST PATH: + // Iterate over patch columns and rows if we know that a single + // packet do not span across multiple rows or columns. + if ((rhs.patchDepth() % packet_size) == 0) { + const Index start_col = rhs.colOffset(); + const Index max_col = rhs.maxCol(peeled_k); + + for (Index c = start_col; c < max_col; ++c) { + eigen_assert(k <= peeled_k); + + const Index start_row = (c == start_col) ? rhs.rowOffset() : 0; + const Index max_row = rhs.maxRow(peeled_k, c); const bool pad_col0 = dm0.padCol(c); const bool pad_col1 = dm1.padCol(c); const bool pad_col2 = dm2.padCol(c); const bool pad_col3 = dm3.padCol(c); - for (Index r = startRow; r < max_rows; ++r) { - eigen_assert(k < peeled_k); + + for (Index r = start_row; r < max_row; ++r) { + eigen_assert(k <= peeled_k); + const bool pad0 = pad_col0 || dm0.padRow(r); const bool pad1 = pad_col1 || dm1.padRow(r); const bool pad2 = pad_col2 || dm2.padRow(r); @@ -885,14 +1009,13 @@ struct gemm_pack_rhs< const Index idx2 = dm2.baseIndex(r, c); const Index idx3 = dm3.baseIndex(r, c); - const Index startDepth = - ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0; - const Index max_depth = - std::min<Index>(peeled_k - c * patch_rows * patch_depth - - r * patch_depth + startDepth, - patch_depth); - eigen_assert((max_depth - startDepth) % packet_size == 0); - for (Index d = startDepth; d < max_depth; d += packet_size) { + const Index start_depth = ((c == start_col) && (r == start_row)) + ? rhs.depthOffset() + : 0; + const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth); + eigen_assert((max_depth - start_depth) % packet_size == 0); + + for (Index d = start_depth; d < max_depth; d += packet_size) { eigen_assert(k < peeled_k); PacketBlock<Packet, 2> kernel0; PacketBlock<Packet, 2> kernel1; @@ -916,22 +1039,12 @@ struct gemm_pack_rhs< } } - for (; k < peeled_k; k += packet_size) { - PacketBlock<Packet, 2> kernel0; - PacketBlock<Packet, 2> kernel1; - kernel0.packet[0] = dm0.loadPacketFast(k); - kernel0.packet[1] = dm1.loadPacketFast(k); - kernel1.packet[0] = dm2.loadPacketFast(k); - kernel1.packet[1] = dm3.loadPacketFast(k); - ptranspose(kernel0); - ptranspose(kernel1); - pstoreu(block + 0 * packet_size, kernel0.packet[0]); - pstoreu(block + 1 * packet_size, kernel1.packet[0]); - pstoreu(block + 2 * packet_size, kernel0.packet[1]); - pstoreu(block + 3 * packet_size, kernel1.packet[1]); - block += 4 * packet_size; - } + // The loop above should fill peeled_k elements. + eigen_assert(peeled_k == k); + } else { + // Packet can span multiple rows or columns, so we have to go + // though the slower "standard" path. for (; k < peeled_k; k += packet_size) { PacketBlock<Packet, 2> kernel0; PacketBlock<Packet, 2> kernel1; @@ -949,7 +1062,9 @@ struct gemm_pack_rhs< } } } - if (!rhs.nonStandardPatches()) { + + // Copy the remaining coefficients of the column block after the peeled_k. + if (!non_standard_patches) { for (; k < depth; k++) { block[0] = dm0.loadCoeffStandard(k); block[1] = dm1.loadCoeffStandard(k); @@ -968,7 +1083,7 @@ struct gemm_pack_rhs< } } - // copy the remaining columns one at a time (nr==1) + // Copy the remaining columns one at a time (nr==1). for (Index j2 = packet_cols4; j2 < cols; ++j2) { const SubMapper dm0 = rhs.getLinearMapper(0, j2); for (Index k = 0; k < depth; k++) { @@ -1006,8 +1121,7 @@ struct gemm_pack_rhs< SubMapper; typedef SubMapper DataMapper; - EIGEN_DEVICE_FUNC - static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; } + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, @@ -1016,8 +1130,6 @@ struct gemm_pack_rhs< eigen_assert(stride == 0); eigen_assert(offset == 0); - EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); - const Index packet_cols4 = (cols / 4) * 4; for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { @@ -1045,7 +1157,7 @@ struct gemm_pack_rhs< } } - // copy the remaining columns one at a time (nr==1) + // Copy the remaining columns one at a time (nr==1). for (Index j2 = packet_cols4; j2 < cols; ++j2) { const SubMapper dm0 = rhs.getLinearMapper(0, j2); for (Index k = 0; k < depth; k++) { diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h index a3d795813d..80ab745bfe 100644 --- a/tensorflow/core/kernels/eigen_volume_patch.h +++ b/tensorflow/core/kernels/eigen_volume_patch.h @@ -43,6 +43,7 @@ struct CustomTensorEvaluator { IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = false, + PreferBlockAccess = false, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = NumDims == 6, RawAccess = false diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc new file mode 100644 index 0000000000..52cd078a35 --- /dev/null +++ b/tensorflow/core/kernels/extract_volume_patches_op.cc @@ -0,0 +1,197 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +/* +See extract_image_patches_op* files and docs for extract_image_patches in +../ops/image_ops.cc. + +Rates are not supported as of now, but the comments hint how to edit the code +when rates are to be added. +*/ + +#define USE_EIGEN_TENSOR +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/extract_volume_patches_op.h" +#include <vector> +#include "tensorflow/core/framework/numeric_op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/util/tensor_format.h" + +namespace tensorflow { + +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +static inline void ParseAttributeVec5(OpKernelConstruction* context, + const string& attr_name, + std::vector<int32>* attr) { + OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr)); + OP_REQUIRES( + context, (*attr)[0] == 1 && (*attr)[4] == 1, + errors::Unimplemented("Only support ", attr_name, " across space.")); + OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1, + errors::OutOfRange(attr_name, " is out of range.")); +} + +template <typename Device, typename T> +class ExtractVolumePatchesOp : public UnaryOp<T> { + public: + explicit ExtractVolumePatchesOp(OpKernelConstruction* context) + : UnaryOp<T>(context) { + ParseAttributeVec5(context, "ksizes", &ksizes_); + ParseAttributeVec5(context, "strides", &strides_); + // ParseAttributeVec5(context, "rates", &rates_); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + // Input tensor is of the following dimensions: + // [ batch, in_planes, in_rows, in_cols, channels ] + const Tensor& input = context->input(0); + OP_REQUIRES(context, input.dims() == 5, + errors::InvalidArgument("input must be 5-dimensional", + input.shape().DebugString())); + + const int batch = input.dim_size(0); + const int in_planes = input.dim_size(1); + const int in_rows = input.dim_size(2); + const int in_cols = input.dim_size(3); + const int depth = input.dim_size(4); + + const int ksize_planes = ksizes_[1]; + const int ksize_rows = ksizes_[2]; + const int ksize_cols = ksizes_[3]; + + const int stride_planes = strides_[1]; + const int stride_rows = strides_[2]; + const int stride_cols = strides_[3]; + + /* + // TODO(hsgkim): enable rates + // Rates are disabled as of now due to Eigen's definitions of + // `extract_volume_patch` functions; none of them accept rates + // as its argument and rates are fixed to (1, 1, 1, 1, 1). A + // workaround has to be found for this. + // In order to enable rates, uncomment the following lines and use + // ksize_*_eff instead of ksize_* for the second argument of + // GetWindowedOutputSize calls. + + const int rate_planes = rates_[1]; + const int rate_rows = rates_[2]; + const int rate_cols = rates_[3]; + + const int ksize_planes_eff = ksize_planes + + (ksize_planes - 1) * (rate_planes - 1); + const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); + const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); + */ + + int64 out_planes = 0, out_rows = 0, out_cols = 0; + int64 pad_planes = 0, pad_rows = 0, pad_cols = 0; + OP_REQUIRES_OK(context, + GetWindowedOutputSize(in_planes, ksize_planes, stride_planes, + padding_, &out_planes, &pad_planes)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(in_rows, ksize_rows, stride_rows, + padding_, &out_rows, &pad_rows)); + OP_REQUIRES_OK(context, + GetWindowedOutputSize(in_cols, ksize_cols, stride_cols, + padding_, &out_cols, &pad_cols)); + + const std::vector<int64> out_sizes = { + batch, out_planes, out_rows, out_cols, + ksize_planes * ksize_rows * ksize_cols * depth}; + TensorShape out_shape(out_sizes); + + Tensor* output = nullptr; + OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); + + // If there is nothing to compute, return. + if (out_shape.num_elements() == 0) { + return; + } + + functor::ExtractVolumePatchesForward<Device, T>()( + context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes, + ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols, + /* rate_planes, rate_rows, rate_cols, */ + BrainPadding2EigenPadding(padding_), output->tensor<T, 5>()); + } + + private: + std::vector<int32> ksizes_; + std::vector<int32> strides_; + // std::vector<int32> rates_; + + Padding padding_; + + TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp); +}; + +// Registration of the CPU implementations. +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + ExtractVolumePatchesOp<CPUDevice, T>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER); + +#undef REGISTER + +#if GOOGLE_CUDA + +// Forward declarations of the functor specializations for GPU. +namespace functor { + +// clang-format off +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \ + const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \ + int patch_planes, int patch_rows, int patch_cols, \ + int stride_planes, int stride_rows, int stride_cols, \ + /* int rate_planes, int rate_rows, int rate_cols, */ \ + const Eigen::PaddingType& padding, \ + typename TTypes<T, 5>::Tensor output); \ + extern template struct ExtractVolumePatchesForward<GPUDevice, T>; +// clang-format on + +TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC); + +#undef DECLARE_GPU_SPEC + +} // namespace functor + +// Registration of the GPU implementations. +#define REGISTER(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \ + ExtractVolumePatchesOp<GPUDevice, T>); + +TF_CALL_GPU_NUMBER_TYPES(REGISTER); + +#undef REGISTER + +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h new file mode 100644 index 0000000000..7e0502b770 --- /dev/null +++ b/tensorflow/core/kernels/extract_volume_patches_op.h @@ -0,0 +1,58 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_ +#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_ + +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_volume_patch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +namespace tensorflow { +namespace functor { + +template <typename Device, typename T> +struct ExtractVolumePatchesForward { + void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input, + int patch_planes, int patch_rows, int patch_cols, + int stride_planes, int stride_rows, int stride_cols, + /* int rate_planes, int rate_rows, int rate_cols, */ + const Eigen::PaddingType& padding, + typename TTypes<T, 5>::Tensor output) { + const int64 N = std::max(input.size(), output.size()); + if (N <= std::numeric_limits<Index32>::max()) { + auto output_32bit = To32Bit(output); + output_32bit.device(d) = + To32Bit(input) + .extract_volume_patches(patch_cols, patch_rows, patch_planes, + stride_cols, stride_rows, stride_planes, + padding) + .reshape(output_32bit.dimensions()); + } else { + output.device(d) = + input + .extract_volume_patches(patch_cols, patch_rows, patch_planes, + stride_cols, stride_rows, stride_planes, + padding) + .reshape(output.dimensions()); + } + } +}; + +} // namespace functor +} // namespace tensorflow + +#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_ diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc new file mode 100644 index 0000000000..c636493602 --- /dev/null +++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/extract_volume_patches_op.h" +#include "tensorflow/core/framework/register_types.h" + +namespace tensorflow { + +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { + +#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>; + +TF_CALL_GPU_NUMBER_TYPES(REGISTER); + +#undef REGISTER + +} // end namespace functor +} // end namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD index 8bfa40304e..f2e0b2558f 100644 --- a/tensorflow/core/kernels/fuzzing/BUILD +++ b/tensorflow/core/kernels/fuzzing/BUILD @@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing") tf_ops_fuzz_target_lib("parse_tensor_op") +tf_ops_fuzz_target_lib("decode_compressed") + tf_ops_fuzz_target_lib("decode_json_example") diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc new file mode 100644 index 0000000000..0a56f4b63f --- /dev/null +++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" + +namespace tensorflow { +namespace fuzzing { + +class FuzzDecodeCompressed : public FuzzStringInputOp { + void BuildGraph(const Scope& scope) override { + auto input = + tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING); + auto d1 = tensorflow::ops::DecodeCompressed( + scope.WithOpName("d1"), input, + tensorflow::ops::DecodeCompressed::CompressionType("")); + auto d2 = tensorflow::ops::DecodeCompressed( + scope.WithOpName("d2"), input, + tensorflow::ops::DecodeCompressed::CompressionType("ZLIB")); + auto d3 = tensorflow::ops::DecodeCompressed( + scope.WithOpName("d3"), input, + tensorflow::ops::DecodeCompressed::CompressionType("GZIP")); + Scope grouper = + scope.WithControlDependencies(std::vector<tensorflow::Operation>{ + d1.output.op(), d2.output.op(), d3.output.op()}); + (void)tensorflow::ops::NoOp(grouper.WithOpName("output")); + } +}; + +STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed); + +} // namespace fuzzing +} // namespace tensorflow diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc index c90ad2cfeb..ada1235449 100644 --- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc @@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession { } void FuzzImpl(const uint8_t* data, size_t size) final { + // We need to be sure that we don't request too many elements (i.e., we + // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large + // number of elements, up to the limit of the memory available to the OS. + // However, due to the tracing done in ASAN, after 2^32 bytes of requested + // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's + // try parsing the proto here, check that the size (if valid) is below a + // maximum threshold (using 2^20 for convenience), and then run the + // remainder of the fuzzer testing. Of course, this duplicates some work + // but it's better than repeating the investigation whenever Autofuzz + // detects another similar OOM. + string as_string = string(reinterpret_cast<const char*>(data), size); + TensorProto proto; + if (!ParseProtoUnlimited(&proto, as_string)) { + LOG(WARNING) << "Unable to parse proto of tensor\n"; + return; + } + if (!TensorShape::IsValid(proto.tensor_shape())) { + LOG(WARNING) << "Invalid tensor shape\n"; + return; + } + TensorShape shape(proto.tensor_shape()); + const int64 num_elements = shape.num_elements(); + const int64 max_num_elements = 1 << 20; + if (num_elements > max_num_elements) { + LOG(WARNING) << "Requiring a tensor with too many elements\n"; + return; + } + + // Now we can do the actual fuzz implementation Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); - input_tensor.scalar<string>()() = - string(reinterpret_cast<const char*>(data), size); + input_tensor.scalar<string>()() = as_string; // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! RunOneInput(input_tensor).IgnoreError(); } diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h index cd2873bdca..7710cf93d6 100644 --- a/tensorflow/core/kernels/gather_functor.h +++ b/tensorflow/core/kernels/gather_functor.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/prefetch.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index 66ae7f0894..1c78de253e 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator( slice_size, Tindices, Tparams, Tout, &error_loc); -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) // Eigen implementation below is not highly performant. gather_nd_generator // does not seem to be called in parallel, leading to very poor performance. // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it @@ -123,15 +123,15 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { // is considerably more efficient. #pragma omp parallel for for (Eigen::DenseIndex i = 0; i < batch_size; i++) { - const Eigen::array<Eigen::DenseIndex, 1> loc = i; + const Eigen::array<Eigen::DenseIndex, 1> loc{i}; gather_nd_generator(loc); } -#else +#else // INTEL_MKL && ENABLE_MKL Tscratch.device(d) = Tscratch.reshape(reshape_dims) .broadcast(broadcast_dims) .generate(gather_nd_generator) .sum(); -#endif +#endif // INTEL_MKL && ENABLE_MKL // error_loc() returns -1 if there's no out-of-bounds index, // otherwise it returns the location of an OOB index in Tindices. diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index c7dbefa0b4..86146f75f4 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -123,8 +123,7 @@ class AutoTuneMap { string GetActionSummary(StringPiece action, const Parameters& params, const Config& config) { return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(), - std::string(action).c_str(), - params.ToString().c_str(), + string(action).c_str(), params.ToString().c_str(), config.ToString().c_str()); } diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc index a88e9b0ddc..374a05850e 100644 --- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc +++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc @@ -18,7 +18,7 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "external/cub_archive/cub/device/device_histogram.cuh" +#include "third_party/cub/device/device_histogram.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index bca1cff41c..2088c13586 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -77,9 +77,9 @@ static Status TensorListDeviceCopy( return Status::OK(); } -#define REGISTER_LIST_COPY(DIRECTION) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy) +#define REGISTER_LIST_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \ + TensorListDeviceCopy) REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); @@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) { return Status::OK(); } -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName, - TensorListShape); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape); bool TensorList::Decode(const VariantTensorData& data) { tensors = data.tensors(); @@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16); #undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - TensorList, TensorList::kTypeName, + TensorList, TensorListBinaryAdd<CPUDevice>); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, TensorList, - TensorList::kTypeName, TensorListZerosLike<CPUDevice>); } // namespace tensorflow diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc index c591226b76..a00bf700ca 100644 --- a/tensorflow/core/kernels/list_kernels.cu.cc +++ b/tensorflow/core/kernels/list_kernels.cu.cc @@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool); #undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - TensorList, TensorList::kTypeName, + TensorList, TensorListBinaryAdd<GPUDevice>); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, TensorList, - TensorList::kTypeName, TensorListZerosLike<GPUDevice>); } // namespace tensorflow diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 066a1d603b..72581c9293 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -374,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, y->tensors.reserve(x.tensors.size()); for (const Tensor& t : x.tensors) { Tensor out_tensor; - TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor)); + AllocatorAttributes attr; + if (t.dtype() == DT_VARIANT) { + attr.set_on_host(true); + } + TF_RETURN_IF_ERROR( + c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr)); switch (out_tensor.dtype()) { #define DTYPE_CASE(dtype) \ case DataTypeToEnum<dtype>::value: \ @@ -385,6 +390,20 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, TF_CALL_POD_TYPES(DTYPE_CASE) #undef DTYPE_CASE + + case DataTypeToEnum<Variant>::value: { + const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>(); + if (inner_x == nullptr) { + return errors::InvalidArgument("Input handle is not a list. Saw: '", + t.scalar<Variant>()().DebugString(), + "'"); + } + TensorList inner_y; + TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y)); + out_tensor.scalar<Variant>()() = std::move(inner_y); + break; + } + default: return errors::InvalidArgument( "Trying to compute zeros_like for unsupported dtype ", diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc index 6b6a14e9a7..1ded012f3c 100644 --- a/tensorflow/core/kernels/logging_ops.cc +++ b/tensorflow/core/kernels/logging_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <iostream> +#include "absl/strings/str_split.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -90,6 +91,59 @@ class PrintOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp); +class PrintV2Op : public OpKernel { + public: + explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_)); + + auto output_stream_index = + std::find(std::begin(valid_output_streams_), + std::end(valid_output_streams_), output_stream_); + + if (output_stream_index == std::end(valid_output_streams_)) { + string error_msg = strings::StrCat( + "Unknown output stream: ", output_stream_, ", Valid streams are:"); + for (auto valid_stream : valid_output_streams_) { + strings::StrAppend(&error_msg, " ", valid_stream); + } + OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg)); + } + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_)); + const string& msg = input_->scalar<string>()(); + + if (output_stream_ == "stdout") { + std::cout << msg << std::endl; + } else if (output_stream_ == "stderr") { + std::cerr << msg << std::endl; + } else if (output_stream_ == "log(info)") { + LOG(INFO) << msg << std::endl; + } else if (output_stream_ == "log(warning)") { + LOG(WARNING) << msg << std::endl; + } else if (output_stream_ == "log(error)") { + LOG(ERROR) << msg << std::endl; + } else { + string error_msg = strings::StrCat( + "Unknown output stream: ", output_stream_, ", Valid streams are:"); + for (auto valid_stream : valid_output_streams_) { + strings::StrAppend(&error_msg, " ", valid_stream); + } + OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg)); + } + } + + const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)", + "log(warning)", "log(error)"}; + + private: + string output_stream_; +}; + +REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op); + class TimestampOp : public OpKernel { public: explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {} diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc index 5e6958f364..a259d995fa 100644 --- a/tensorflow/core/kernels/logging_ops_test.cc +++ b/tensorflow/core/kernels/logging_ops_test.cc @@ -23,11 +23,33 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/ops_testutil.h" #include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { namespace { +class PrintingV2GraphTest : public OpsTestBase { + protected: + Status Init(const string& output_stream = "log(warning)") { + TF_CHECK_OK(NodeDefBuilder("op", "PrintV2") + .Input(FakeInput(DT_STRING)) + .Attr("output_stream", output_stream) + .Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(PrintingV2GraphTest, StringSuccess) { + TF_ASSERT_OK(Init()); + AddInputFromArray<string>(TensorShape({}), {"bar"}); + TF_ASSERT_OK(RunOpKernel()); +} + +TEST_F(PrintingV2GraphTest, InvalidOutputStream) { + ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream"))); +} + class PrintingGraphTest : public OpsTestBase { protected: Status Init(DataType input_type1, DataType input_type2, string msg = "", diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 2e8d9c623c..a495758861 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface { MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface { const auto key_values = key.flat<K>(); auto value_values = value->flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); @@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); Tensor* keys; @@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface { } private: - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; std::unordered_map<K, V> table_ GUARDED_BY(mu_); }; @@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface { auto value_values = value->flat_inner_dims<V, 2>(); int64 value_dim = value_shape_.dim_size(0); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { ValueArray* value_vec = gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i))); @@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); int64 value_dim = value_shape_.dim_size(0); @@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface { private: TensorShape value_shape_; - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; typedef gtl::InlinedVector<V, 4> ValueArray; std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_); @@ -335,7 +333,7 @@ class MutableDenseHashTable final : public LookupInterface { } size_t size() const override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return num_entries_; } @@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface { auto value_matrix = value->shaped<V, 2>({num_elements, value_size}); const auto default_flat = default_value.flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); const auto key_buckets_matrix = key_buckets_.AccessTensor(ctx)->template matrix<K>(); const auto value_buckets_matrix = @@ -451,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx); Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx); TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor)); @@ -493,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface { TensorShape value_shape() const override { return value_shape_; } int64 MemoryUsed() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() + value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes(); } diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index bdc3b5778f..dd89597369 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -410,8 +410,9 @@ class StagingMap : public ResourceBase { copy_or_move_tensors(&it->second, *key, *indices, tuple)); // Remove entry if all the values have been consumed - if (!std::any_of(it->second.begin(), it->second.end(), - std::mem_fn(&OptionalTensor::has_value))) { + if (!std::any_of( + it->second.begin(), it->second.end(), + [](const OptionalTensor& tensor) { return tensor.has_value(); })) { map_.erase(it); } @@ -444,8 +445,9 @@ class StagingMap : public ResourceBase { *key = it->first; // Remove entry if all the values have been consumed - if (!std::any_of(it->second.begin(), it->second.end(), - std::mem_fn(&OptionalTensor::has_value))) { + if (!std::any_of( + it->second.begin(), it->second.end(), + [](const OptionalTensor& tensor) { return tensor.has_value(); })) { map_.erase(it); } diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 79967aab38..4ad390a411 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> { .Label("cublas"), \ MatMulOp<GPUDevice, T, true /* cublas */>) -#if defined(INTEL_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) // MKL does not support half, bfloat16 and int32 types for // matrix-multiplication, so register the kernel to use default Eigen based @@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU_EIGEN); TF_CALL_complex128(REGISTER_CPU_EIGEN); TF_CALL_double(REGISTER_CPU_EIGEN); -#endif +#endif // INTEL_MKL_DNN_ONLY -#else // INTEL MKL +#else // INTEL_MKL && ENABLE_MKL TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_half(REGISTER_CPU); @@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU); TF_CALL_int32(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); -#endif +#endif // INTEL_MKL && ENABLE_MKL #if GOOGLE_CUDA TF_CALL_float(REGISTER_GPU); diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc index 10e468ce46..693ed8a8f0 100644 --- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc +++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc @@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase { // Exercises "delete_old_dirs". for (int i = 0; i < 2; ++i) { int directory_found = - Env::Default() - ->IsDirectory(std::string(io::Dirname(prefixes[i]))) - .code(); + Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code(); if (delete_old_dirs) { EXPECT_EQ(error::NOT_FOUND, directory_found); } else { diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h index cc4b6941b9..62aa7d5c29 100644 --- a/tensorflow/core/kernels/mirror_pad_op.h +++ b/tensorflow/core/kernels/mirror_pad_op.h @@ -103,6 +103,7 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>, IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = false, + PreferBlockAccess = false, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = true, RawAccess = false diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index 0841395dc3..bc135de11e 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel { Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ BatchMatMulMkl<CPUDevice, TYPE>) +#ifdef ENABLE_MKL TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); TF_CALL_double(REGISTER_BATCH_MATMUL_MKL); TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL); TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL); +#endif // ENABLE_MKL } // end namespace tensorflow #endif diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 9b10c3f3d6..184e0cb003 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -1083,7 +1083,7 @@ class MklConvOp : public OpKernel { #endif // Register 2D operations -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_CPU_2D(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ @@ -1100,16 +1100,16 @@ class MklConvOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklDummyOp<CPUDevice, T>); -TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_float(REGISTER_MKL_CPU_2D); // Register 3D operations -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_CPU_3D(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConvOp<CPUDevice, T, false>); -TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_float(REGISTER_MKL_CPU_3D); } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc new file mode 100644 index 0000000000..a055351337 --- /dev/null +++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc @@ -0,0 +1,407 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +#if defined(INTEL_MKL_DNN_ONLY) +#include "third_party/intel_mkl_dnn/include/mkldnn.h" +#include "tensorflow/core/util/mkl_util.h" +#endif + +// TODO(ezhulenev): Add numerical tests that will compare results of default +// (aka Eigen) convolutions with MKL convolutions. + +// -------------------------------------------------------------------------- // +// Performance Benchmarks. // +// -------------------------------------------------------------------------- // + +// Compare performance of default Tensorflow convolution kernels (Eigen) with +// MKL kernels on CPU. + +// Before running these benchmarks configure OpenMP environment variables: +// export KMP_BLOCKTIME=0 +// export OMP_NUM_THREADS=${num_threads} + +namespace tensorflow { + +struct Conv2DDimensions { + Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw) + : input_batches(n), + input_height(h), + input_width(w), + input_depth(c), + filter_count(fc), + filter_height(fh), + filter_width(fw) {} + + int input_batches; + int input_height; + int input_width; + int input_depth; + int filter_count; + int filter_height; + int filter_width; +}; + +static Tensor GetRandomTensor(const TensorShape& shape) { + Tensor tensor(DT_FLOAT, TensorShape(shape)); + tensor.flat<float>() = tensor.flat<float>().setRandom(); + return tensor; +} + +// Get a random Tensor for the Conv2D input. +static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) { + return GetRandomTensor({dims.input_batches, dims.input_height, + dims.input_width, dims.input_depth}); +} + +// Get a random Tensor for the Conv2D filter. +static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) { + return GetRandomTensor({dims.filter_height, dims.filter_width, + dims.input_depth, dims.filter_count}); +} + +// Get a random Tensor for the Conv2D output (assuming SAME padding). +static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) { + return GetRandomTensor({dims.input_batches, dims.input_height, + dims.input_width, dims.filter_count}); +} + +// Get a Tensor encoding Conv2D input shape. +static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) { + return test::AsTensor<int32>({dims.input_batches, dims.input_height, + dims.input_width, dims.input_depth}); +} + +// Get a Tensor encoding Conv2D filter shape. +static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) { + return test::AsTensor<int32>({dims.filter_height, dims.filter_width, + dims.input_depth, dims.filter_count}); +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Tensor NonMklTensor() { + MklDnnShape non_mkl_shape; + non_mkl_shape.SetMklTensor(false); + + auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize()); + Tensor tensor(DT_UINT8, {size}); + + non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(), + size * sizeof(uint8)); + return tensor; +} +#endif + +static Graph* DefaultConv2D(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + + Node* conv2d; + TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D") + .Input(input) + .Input(filter) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Finalize(graph, &conv2d)); + + return graph; +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Graph* MklConv2D(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + + Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl"); + + Node* conv2d; + TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D") + .Input(input) + .Input(filter) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "MklOp") + .Finalize(graph, &conv2d)); + + return graph; +} +#endif + +static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_sizes_t = GetInputSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input_sizes = + test::graph::Constant(graph, input_sizes_t, "input_sizes"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* conv2d_bwd_input; + TF_CHECK_OK( + NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput") + .Input(input_sizes) + .Input(filter) + .Input(out_backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Finalize(graph, &conv2d_bwd_input)); + + return graph; +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_sizes_t = GetInputSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input_sizes = + test::graph::Constant(graph, input_sizes_t, "input_sizes"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl"); + + Node* conv2d_bwd_input; + TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"), + "_MklConv2DBackpropInput") + .Input(input_sizes) + .Input(filter) + .Input(out_backprop) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "MklOp") + .Finalize(graph, &conv2d_bwd_input)); + + return graph; +} +#endif + +static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_sizes_t = GetFilterSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter_sizes = + test::graph::Constant(graph, filter_sizes_t, "filter_sizes"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* conv2d_bwd_filter; + TF_CHECK_OK( + NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter") + .Input(input) + .Input(filter_sizes) + .Input(out_backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Finalize(graph, &conv2d_bwd_filter)); + + return graph; +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) { + Graph* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_sizes_t = GetFilterSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter_sizes = + test::graph::Constant(graph, filter_sizes_t, "filter_sizes"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl"); + + Node* conv2d_bwd_filter; + TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"), + "_MklConv2DBackpropFilter") + .Input(input) + .Input(filter_sizes) + .Input(out_backprop) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "MklOp") + .Finalize(graph, &conv2d_bwd_filter)); + + return graph; +} +#endif + +// Macro arguments names: --------------------------------------------------- // +// N: batch size +// H: height +// W: width +// C: channels +// FC: filter count +// FH: filter height +// FW: filter width + +#define BM_CONCAT(a, b) a##b + +#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \ + BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW) + +// Flops computation in these benchmarks are the same as in +// eigen_benchmark_cpu_test.cc. + +#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (N) * (H) * (W) * (FC); \ + int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ + testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW)) + +#if defined(INTEL_MKL_DNN_ONLY) +#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ + BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); +#else +#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); +#endif + +#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (N) * (H) * (W) * (C); \ + int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ + testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW)) + +#if defined(INTEL_MKL_DNN_ONLY) +#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ + BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); +#else +#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); +#endif + +#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \ + int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \ + testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW)) + +#if defined(INTEL_MKL_DNN_ONLY) +#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ + BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); +#else +#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); +#endif + +// ImageNet Convolutions ---------------------------------------------------- // + +BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3"); +BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5"); +BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3"); +BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5"); +BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3"); +BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5"); +BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3"); + +BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3"); +BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5"); +BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3"); +BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5"); +BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3"); +BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5"); +BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3"); + +BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3"); +BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5"); +BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3"); +BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5"); +BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3"); +BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5"); +BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3"); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 077d62ce32..f4788f4851 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel { reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, reinterpret_cast<MKL_Complex16*>(c), ldc); } -#endif +#endif // !INTEL_MKL_DNN_ONLY }; #define REGISTER_CPU(T) \ @@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel { Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); +#ifdef ENABLE_MKL // TODO(inteltf) Consider template specialization when adding/removing // additional types TF_CALL_float(REGISTER_CPU); @@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); -#endif +#endif // !INTEL_MKL_DNN_ONLY +#endif // ENABLE_MKL } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index ec6d241e17..5398e6113f 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -34,11 +34,11 @@ using mkldnn::prop_kind; template <typename T> void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { - if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg && - fwdParams.alg_kind != pooling_avg_include_padding && - fwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(fwdParams.alg_kind == pooling_max || + fwdParams.alg_kind == pooling_avg || + fwdParams.alg_kind == pooling_avg_include_padding || + fwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = fwdParams.alg_kind; // create memory desc @@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, static_cast<void*>(const_cast<T*>(src_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } context_.fwd_stream->submit(context_.fwd_primitives); @@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } @@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>; template <typename T> void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { - if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg && - bwdParams.alg_kind != pooling_avg_include_padding && - bwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(bwdParams.alg_kind == pooling_max || + bwdParams.alg_kind == pooling_avg || + bwdParams.alg_kind == pooling_avg_include_padding || + bwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = bwdParams.alg_kind; // check whether it is 2d or 3d @@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, static_cast<void*>(const_cast<T*>(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast<void*>(ws_data)); } @@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, context_.diff_dst_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index f4cfc48af5..84385356e1 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -40,7 +40,6 @@ using mkldnn::memory; #include "mkl_dnn.h" #include "mkl_dnn_types.h" #endif -#include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index 04d8a1bdeb..cfab529662 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -88,6 +88,7 @@ class MklSoftmaxOp : public OpKernel { break; default: OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1")); + return; } // Create softmax memory for src, dst: both are defined in mkl_util.h, // they are wrapper diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc index 7a64788448..82dfece4a2 100644 --- a/tensorflow/core/kernels/multinomial_op.cc +++ b/tensorflow/core/kernels/multinomial_op.cc @@ -75,7 +75,7 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> { // lambda. Since we want to let each worker have its own copy, we pass // "gen" by reference and explicitly do a copy assignment here. random::PhiloxRandom gen_copy = gen; - // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to + // Skip takes units of 128 bits. +3 is so rounding doesn't lead to // us using the same state in different batches. gen_copy.Skip(start_row * (num_samples + 3) / 4); random::SimplePhilox simple_philox(&gen_copy); diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 5d9257e20b..81ce6d6e95 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context, } // Return intersection-over-union overlap between boxes i and j -static inline float IOUGreaterThanThreshold( - typename TTypes<float, 2>::ConstTensor boxes, int i, int j, - float iou_threshold) { - const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2)); - const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3)); - const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2)); - const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3)); - const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2)); - const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3)); - const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2)); - const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3)); - const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); - const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); - if (area_i <= 0 || area_j <= 0) return 0.0; - const float intersection_ymin = std::max<float>(ymin_i, ymin_j); - const float intersection_xmin = std::max<float>(xmin_i, xmin_j); - const float intersection_ymax = std::min<float>(ymax_i, ymax_j); - const float intersection_xmax = std::min<float>(xmax_i, xmax_j); - const float intersection_area = - std::max<float>(intersection_ymax - intersection_ymin, 0.0) * - std::max<float>(intersection_xmax - intersection_xmin, 0.0); - const float iou = intersection_area / (area_i + area_j - intersection_area); +template <typename T> +static inline bool IOUGreaterThanThreshold( + typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) { + const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2)); + const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3)); + const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2)); + const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3)); + const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2)); + const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3)); + const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2)); + const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3)); + const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); + const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); + if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0; + const T intersection_ymin = std::max<T>(ymin_i, ymin_j); + const T intersection_xmin = std::max<T>(xmin_i, xmin_j); + const T intersection_ymax = std::min<T>(ymax_i, ymax_j); + const T intersection_xmax = std::min<T>(xmax_i, xmax_j); + const T intersection_area = + std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) * + std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0)); + const T iou = intersection_area / (area_i + area_j - intersection_area); return iou > iou_threshold; } @@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold( return overlaps(i, j) > overlap_threshold; } +template <typename T> static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn( const Tensor& boxes, float threshold) { - typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>(); - return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1, - std::placeholders::_2, threshold); + typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>(); + return std::bind(&IOUGreaterThanThreshold<T>, boxes_data, + std::placeholders::_1, std::placeholders::_2, + static_cast<T>(threshold)); } static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( @@ -121,6 +123,7 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } +template <typename T> void DoNonMaxSuppressionOp( OpKernelContext* context, const Tensor& scores, int num_boxes, const Tensor& max_output_size, const float score_threshold, @@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp( bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = max_output_size.scalar<int>()(); - std::vector<float> scores_data(num_boxes); - std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); + std::vector<T> scores_data(num_boxes); + std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin()); // Data structure for selection candidate in NMS. struct Candidate { int box_index; - float score; + T score; }; auto cmp = [](const Candidate bs_i, const Candidate bs_j) { @@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp( std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)> candidate_priority_queue(cmp); for (int i = 0; i < scores_data.size(); ++i) { - if (scores_data[i] > score_threshold) { + if (static_cast<float>(scores_data[i]) > score_threshold) { candidate_priority_queue.emplace(Candidate({i, scores_data[i]})); } } std::vector<int> selected; - std::vector<float> selected_scores; + std::vector<T> selected_scores; Candidate next_candidate; while (selected.size() < output_size && !candidate_priority_queue.empty()) { @@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp( int num_valid_outputs = selected.size(); if (pad_to_max_output_size) { selected.resize(output_size, 0); - selected_scores.resize(output_size, 0); + selected_scores.resize(output_size, static_cast<T>(0)); } if (ptr_num_valid_outputs) { *ptr_num_valid_outputs = num_valid_outputs; @@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_); + auto suppress_check_fn = + CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_); const float score_threshold_val = std::numeric_limits<float>::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } private: float iou_threshold_; }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV2Op : public OpKernel { public: explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) @@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); + auto suppress_check_fn = + CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val); const float score_threshold_val = std::numeric_limits<float>::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } }; @@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel { float score_threshold_val_; }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) @@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_); - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn); + DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn); } }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) @@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_); int num_valid_outputs; - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn, - pad_to_max_output_size_, &num_valid_outputs); + DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); // Allocate scalar output tensor for number of indices computed. Tensor* num_outputs_t = nullptr; @@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel { auto suppress_check_fn = CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } }; REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp<CPUDevice>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), - NonMaxSuppressionV2Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice, Eigen::half>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), - NonMaxSuppressionV3Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV3Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV3Op<CPUDevice, Eigen::half>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), - NonMaxSuppressionV4Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 876a1704c7..fc1c9003aa 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" @@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel { for (auto d : lib->device_mgr()->ListDevices()) { device_set.AddDevice(d); } - Placer placer(graph.get(), &device_set); - OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); - - std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; - OP_REQUIRES_OK_ASYNC( - ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), - done); // The FunctionLibraryRuntime's library cannot be mutated from within // an OpKernel, so functions are instantiated in an overlay library. @@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel { new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition()); overlay_libs_.emplace(lib, overlay_lib); + GraphOptimizationPassOptions optimization_options; + // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or + // make it possible to specify the relevant options via attributes. + SessionOptions session_options; + session_options.env = ctx->env(); + optimization_options.session_options = &session_options; + optimization_options.graph = &graph; + optimization_options.flib_def = overlay_lib; + optimization_options.device_set = &device_set; + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::PRE_PLACEMENT, optimization_options), + done); + Placer placer(graph.get(), &device_set); + OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PLACEMENT, optimization_options), + done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, + optimization_options), + done); + + std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; + OP_REQUIRES_OK_ASYNC( + ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), + done); + optimization_options.graph = nullptr; + optimization_options.device_set = nullptr; + optimization_options.partition_graphs = &subgraphs; + OP_REQUIRES_OK_ASYNC(ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, + optimization_options), + done); + auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>(); for (const auto& pair : subgraphs) { // TODO(akshayka): Fail gracefully if the set of devices corresponds @@ -175,7 +210,7 @@ class PartitionedCallOp : public AsyncOpKernel { TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value)); DataType dtype = attr_value->type(); if (dtype == DT_RESOURCE) { - ResourceHandle handle = args[index].flat<ResourceHandle>()(0); + const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0); node->set_assigned_device_name(handle.device()); } } diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h index 5fb1c92f94..272aa3b4f5 100644 --- a/tensorflow/core/kernels/queue_base.h +++ b/tensorflow/core/kernels/queue_base.h @@ -19,6 +19,7 @@ limitations under the License. #include <deque> #include <vector> +#include "absl/base/macros.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/queue_interface.h" #include "tensorflow/core/framework/tensor.h" @@ -82,6 +83,9 @@ class QueueBase : public QueueInterface { // NOTE(mrry): This method is deprecated. Use // `tensorflow::batch_util::CopySliceToElement()` defined in // "./batch_util.h" instead. + ABSL_DEPRECATED( + "Use `tensorflow::batch_util::CopySliceToElement()` defined in " + "\"./batch_util.h\" instead.") static Status CopyElementToSlice(const Tensor& element, Tensor* parent, int64 index); diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index c4d404259b..97ddc852f7 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel { } void Compute(OpKernelContext* context) override { - ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0); + const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0); handle_.AccessTensor(context)->flat<string>()(0) = ref.container(); handle_.AccessTensor(context)->flat<string>()(1) = ref.name(); context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index e37232539f..04a53697c0 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel { errors::InvalidArgument("maxval must be 0-D, got shape ", maxval.shape().DebugString())); - // Verify that minval < maxval + // Allocate output, and exit early if possible + Tensor* output; + OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output)); + if (output->NumElements() == 0) return; + + // Verify that minval < maxval. This check intentionally happens after the + // early exit for empty output. Zero impossible things are fine. IntType lo = minval.scalar<IntType>()(); IntType hi = maxval.scalar<IntType>()(); OP_REQUIRES( @@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel { Distribution; Distribution dist(lo, hi); - Tensor* output; - OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output)); auto output_flat = output->flat<IntType>(); functor::FillPhiloxRandom<Device, Distribution>()( ctx, ctx->eigen_device<Device>(), diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h index 88b3c2ac76..bb8254eaac 100644 --- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h +++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h @@ -21,11 +21,11 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "external/cub_archive/cub/device/device_reduce.cuh" -#include "external/cub_archive/cub/device/device_segmented_reduce.cuh" -#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" -#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" -#include "external/cub_archive/cub/warp/warp_reduce.cuh" +#include "third_party/cub/device/device_reduce.cuh" +#include "third_party/cub/device/device_segmented_reduce.cuh" +#include "third_party/cub/iterator/counting_input_iterator.cuh" +#include "third_party/cub/iterator/transform_input_iterator.cuh" +#include "third_party/cub/warp/warp_reduce.cuh" #include "cuda/include/cuComplex.h" #include "tensorflow/core/kernels/reduction_ops.h" #include "tensorflow/core/lib/core/bits.h" diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc index 9cf953f4bf..8bfa44b2d0 100644 --- a/tensorflow/core/kernels/reduction_ops_max.cc +++ b/tensorflow/core/kernels/reduction_ops_max.cc @@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .TypeConstraint<int64>("Tidx") \ .HostMemory("reduction_indices"), \ ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>); + +REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); REGISTER_GPU_KERNELS(int64); diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc index 5863a2c8e4..7edaaad8f7 100644 --- a/tensorflow/core/kernels/regex_full_match_op.cc +++ b/tensorflow/core/kernels/regex_full_match_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU), RegexFullMatchOp); +class StaticRegexFullMatchOp : public OpKernel { + public: + explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string pattern; + OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern)); + re_ = MakeUnique<RE2>(pattern); + OP_REQUIRES(ctx, re_->ok(), + errors::InvalidArgument("Invalid pattern: ", pattern, + ", error: ", re_->error())); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat<string>(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<bool>(); + for (size_t i = 0; i < input_flat.size(); ++i) { + output_flat(i) = RE2::FullMatch(input_flat(i), *re_); + } + } + + private: + std::unique_ptr<RE2> re_; +}; + +REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU), + StaticRegexFullMatchOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 194a711d98..26f107f940 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts( std::unordered_set<string> retval; for (const string& node_name_and_port : node_names_and_ports) { const TensorId tid = ParseTensorName(node_name_and_port); - retval.emplace(std::string(tid.first)); + retval.emplace(tid.first); } return retval; } @@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) { const NodeDef* FindNodeDefByName(const string& input, const GraphDef& graph_def) { const TensorId tid = ParseTensorName(input); - const string name = std::string(tid.first); + const string name = string(tid.first); for (const NodeDef& node_def : graph_def.node()) { if (node_def.name() == name) { return &node_def; @@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( std::vector<DataType> data_types; std::vector<TensorShape> shapes; const TensorId tid = ParseTensorName(name_and_port); - const string node_name = std::string(tid.first); + const string node_name(tid.first); const int port = tid.second; const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); CHECK_NOTNULL(node_def); @@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType( const TensorShapeMap& tensor_shape_map, const string& node_name) { if (node_name.find(':') != string::npos) { const TensorId tid = ParseTensorName(node_name); - return GetTensorShapeType(tensor_shape_map, std::string(tid.first), - tid.second); + return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second); } else { return GetTensorShapeType(tensor_shape_map, node_name, 0); } @@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( const TensorId tid = ParseTensorName(name); CHECK_EQ(tensor_shape_map->count(name), 0); tensor_shape_map->emplace( - std::string(tid.first), + string(tid.first), std::make_pair(tid.second, std::make_pair(tensor.dtype(), tensor.shape()))); } @@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::vector<NodeBuilder::NodeOut> node_out_list; for (const string& input : inputs) { const TensorId tid = ParseTensorName(input); - Node* node = FindMutableNodeByName(std::string(tid.first), graph); + Node* node = FindMutableNodeByName(string(tid.first), graph); CHECK_NOTNULL(node); node_out_list.emplace_back(node, tid.second); } @@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (const string& subgraph_input : std::get<1>(cluster)) { const TensorId tid = ParseTensorName(subgraph_input); - const string subgraph_input_name = std::string(tid.first); + const string subgraph_input_name(tid.first); const int subgraph_input_port = tid.second; const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); CHECK_NOTNULL(node_def); @@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::deque<const Node*> queue; for (const string& output : border_outputs) { const TensorId tid = ParseTensorName(output); - const string& output_node_name = std::string(tid.first); + const string output_node_name(tid.first); for (const Node* node : graph.nodes()) { if (output_node_name == node->name()) { queue.push_back(node); @@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (int j = 0; j < border_outputs.size(); ++j) { const string& output = border_outputs.at(j); const TensorId tid = ParseTensorName(output); - const string output_name = std::string(tid.first); + const string output_name(tid.first); Node* src_node = edge->src(); if (src_node != nullptr && src_node->name() == output_name && edge->src_output() == tid.second) { @@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( // RemoteFusedGraphExecuteOpNode for (const string& output : outputs) { const TensorId output_tid = ParseTensorName(output); - const string output_name = std::string(output_tid.first); + const string output_name(output_tid.first); for (size_t i = 0; i < border_outputs.size(); ++i) { const TensorId subgraph_output_tid = ParseTensorName(border_outputs.at(i)); - const string& subgraph_output_name = - std::string(subgraph_output_tid.first); + const string subgraph_output_name(subgraph_output_tid.first); if (output_name == subgraph_output_name) { LOG(INFO) << "As graph output and subgraph output are same, " << "the graph output node is replaced by identity node"; @@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( GraphDef* graph_def) { const TensorId tid = ParseTensorName(input); CHECK_EQ(0, tid.second); - const string node_name = std::string(tid.first); + const string node_name(tid.first); for (NodeDef& node : *graph_def->mutable_node()) { if (node.name() != node_name) { continue; diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index ebcfb673d1..427044ca67 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -51,7 +51,9 @@ limitations under the License. #define EIGEN_USE_GPU #endif -#include "tensorflow/core/kernels/resource_variable_ops.h" +#include <memory> +#include <vector> + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -60,10 +62,12 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/scatter_functor.h" #include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -72,6 +76,8 @@ limitations under the License. namespace tensorflow { REGISTER_RESOURCE_HANDLE_KERNEL(Var); +REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU), + ResourceHandlesOp<Var>); ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); @@ -79,7 +85,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { void ReadVariableOp::Compute(OpKernelContext* ctx) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, 0); + const ResourceHandle& handle = HandleFromInput(ctx, 0); const auto status = LookupResource(ctx, handle, &variable); OP_REQUIRES(ctx, status.ok(), errors::FailedPrecondition( @@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, t); } +ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) { + int n; + OP_REQUIRES_OK(c, c->GetAttr("N", &n)); + OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES(c, n == dtypes_.size(), + errors::InvalidArgument( + "Mismatched number of arguments to ReadVariablesOp (", n, + " vs. ", dtypes_.size(), ")")); +} + +void ReadVariablesOp::Compute(OpKernelContext* ctx) { + std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables( + dtypes_.size()); + std::vector<const ResourceHandle*> handles(dtypes_.size()); + for (size_t i = 0; i < dtypes_.size(); ++i) { + handles[i] = &HandleFromInput(ctx, i); + } + const auto status = LookupResources(ctx, handles, &variables); + OP_REQUIRES(ctx, status.ok(), + errors::FailedPrecondition( + "Error while reading resource variable. This could mean that " + "the variable was uninitialized. ", + status.ToString())); + + for (size_t i = 0; i < dtypes_.size(); ++i) { + // We're acquiring a reference to the underlying buffer while + // holding a shared lock to guarantee ordering of reads and + // writes. + tf_shared_lock ml(*variables[i]->mu()); + const Tensor& t = *variables[i]->tensor(); + OP_REQUIRES(ctx, dtypes_[i] == t.dtype(), + errors::InvalidArgument( + "Trying to read variable ", handles[i]->name(), + " from Container: ", handles[i]->container(), + " with wrong dtype. Expected ", DataTypeString(dtypes_[i]), + " got ", DataTypeString(t.dtype()))); + ctx->set_output(i, t); + } +} + REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU), ReadVariableOp); +REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU), + ReadVariablesOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"), ReadVariableOp); +REGISTER_KERNEL_BUILDER( + Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"), + ReadVariablesOp); #define REGISTER_GPU_KERNELS(type) \ namespace functor { \ @@ -121,7 +172,12 @@ REGISTER_KERNEL_BUILDER( .Device(DEVICE_GPU) \ .HostMemory("resource") \ .TypeConstraint<type>("dtype"), \ - ResourceHandleOp<Var>) + ResourceHandleOp<Var>) \ + REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") \ + .Device(DEVICE_GPU) \ + .HostMemory("resources") \ + .TypeConstraint<type>("dtypes"), \ + ResourceHandlesOp<Var>) TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h index 9b60106f13..cffb732c38 100644 --- a/tensorflow/core/kernels/resource_variable_ops.h +++ b/tensorflow/core/kernels/resource_variable_ops.h @@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel { DataType dtype_; }; +class ReadVariablesOp : public OpKernel { + public: + explicit ReadVariablesOp(OpKernelConstruction* c); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + + private: + DataTypeVector dtypes_; +}; + class DestroyResourceOp : public OpKernel { public: explicit DestroyResourceOp(OpKernelConstruction* ctx); diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index 15a707a9c6..cded417986 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, "), ", "(", seq_lens.NumElements(), - " vs. ", input.dim_size(batch_dim))); + " vs. ", input.dim_size(batch_dim), ")")); for (size_t d = 0; d < seq_lens_vec.size(); ++d) { OP_REQUIRES(context, seq_lens_vec[d] >= 0, @@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, "), ", "(", seq_lens.NumElements(), - " vs. ", input.dim_size(batch_dim))); + " vs. ", input.dim_size(batch_dim), ")")); } template <> @@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel { auto seq_lens_t = seq_lens.vec<Tlen>(); CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_); + if (!context->status().ok()) return; const int input_dims = input.dims(); diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index e335e38bdc..82546d581a 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context, // If we cannot find a cached reader we will allocate our own. std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader; - const checkpoint::TensorSliceReader* reader = - context->slice_reader_cache()->GetReader(file_pattern, open_func, - preferred_shard); + const checkpoint::TensorSliceReader* reader = nullptr; + + if (context->slice_reader_cache()) { + reader = context->slice_reader_cache()->GetReader(file_pattern, open_func, + preferred_shard); + } if (!reader) { allocated_reader.reset(new checkpoint::TensorSliceReader( file_pattern, open_func, preferred_shard)); diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index ab4de6c815..180eb3ca34 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel { context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); if (delete_old_dirs_) { - const string& merged_dir = std::string(io::Dirname(merged_prefix)); + const string merged_dir(io::Dirname(merged_prefix)); for (const string& input_prefix : input_prefixes) { - const string& dirname = std::string(io::Dirname(input_prefix)); + const string dirname(io::Dirname(input_prefix)); if (dirname == merged_dir) continue; Status status = env->DeleteDir(dirname); // For sharded save, only the first delete will go through and all diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index e0194605ce..2f8aede427 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -145,6 +145,7 @@ class ScatterNdUpdateOp : public OpKernel { if (dtype_ == DT_RESOURCE) { Var* v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); + core::ScopedUnref scoped_unref(v); mutex_lock m(*v->mu()); DoCompute(c); } else if (use_exclusive_lock_) { diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc new file mode 100644 index 0000000000..dc627ac77a --- /dev/null +++ b/tensorflow/core/kernels/searchsorted_op.cc @@ -0,0 +1,249 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/kernels/searchsorted_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +typedef Eigen::ThreadPoolDevice CPUDevice; +typedef Eigen::GpuDevice GPUDevice; + +namespace functor { +template <typename T, typename OutType> +struct UpperBoundFunctor<CPUDevice, T, OutType> { + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& sorted_inputs, + const typename TTypes<T, 1>::ConstTensor& values, + int batch_size, int num_inputs, int num_values, + typename TTypes<OutType, 1>::Tensor* output) { + // TODO(eriche): If anyone ever needs this to be faster, we can multithread. + for (int b = 0; b < batch_size; ++b) { + const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs; + OutType* output_ptr = output->data() + b * num_values; + for (int i = 0; i < num_values; ++i) { + output_ptr[i] = + std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs, + values(i + b * num_values)) - + sorted_inputs_ptr; + } + } + + return Status::OK(); + } +}; + +template <typename T, typename OutType> +struct LowerBoundFunctor<CPUDevice, T, OutType> { + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& sorted_inputs, + const typename TTypes<T, 1>::ConstTensor& values, + int batch_size, int num_inputs, int num_values, + typename TTypes<OutType, 1>::Tensor* output) { + // TODO(eriche): If anyone ever needs this to be faster, we can multithread. + for (int b = 0; b < batch_size; ++b) { + const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs; + OutType* output_ptr = output->data() + b * num_values; + for (int i = 0; i < num_values; ++i) { + output_ptr[i] = + std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs, + values(i + b * num_values)) - + sorted_inputs_ptr; + } + } + + return Status::OK(); + } +}; +} // namespace functor + +template <typename Device, typename T, typename OutType> +class UpperBoundOp : public OpKernel { + public: + explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& sorted_inputs_t = ctx->input(0); + const Tensor& values_t = ctx->input(1); + + // must have same batch dim_size for both + OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0), + Status(error::INVALID_ARGUMENT, + "Leading dim_size of both tensors must match.")); + + // this is required because we do indexing in int32 on the GPU + OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(), + Status(error::INVALID_ARGUMENT, + "values tensor size must less than INT_MAX")); + + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t)); + + if (output_t->dtype() == DT_INT32) { + OP_REQUIRES(ctx, + FastBoundsCheck(sorted_inputs_t.dim_size(1), + std::numeric_limits<int>::max()), + errors::InvalidArgument("trailing dim_size must less than " + "INT_MAX for int32 output type, was ", + sorted_inputs_t.dim_size(1))); + } + + auto output = output_t->template flat<OutType>(); + const auto sorted_inputs = sorted_inputs_t.template flat<T>(); + const auto values = values_t.template flat<T>(); + OP_REQUIRES_OK( + ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute( + ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0), + sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output)); + } +}; + +template <typename Device, typename T, typename OutType> +class LowerBoundOp : public OpKernel { + public: + explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + const Tensor& sorted_inputs_t = ctx->input(0); + const Tensor& values_t = ctx->input(1); + + // must have same batch dim_size for both + OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0), + Status(error::INVALID_ARGUMENT, + "Leading dim_size of both tensors must match.")); + + // this is required because we do indexing in int32 on the GPU + OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(), + Status(error::INVALID_ARGUMENT, + "values tensor size must less than INT_MAX")); + + Tensor* output_t; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t)); + + if (output_t->dtype() == DT_INT32) { + OP_REQUIRES(ctx, + FastBoundsCheck(sorted_inputs_t.dim_size(1), + std::numeric_limits<int>::max()), + errors::InvalidArgument("trailing dim_size must less than " + "INT_MAX for int32 output type, was ", + sorted_inputs_t.dim_size(1))); + } + + auto output = output_t->template flat<OutType>(); + const auto sorted_inputs = sorted_inputs_t.template flat<T>(); + const auto values = values_t.template flat<T>(); + OP_REQUIRES_OK( + ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute( + ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0), + sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output)); + } +}; + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("UpperBound") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("out_type"), \ + UpperBoundOp<CPUDevice, type, int32>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("UpperBound") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_type"), \ + UpperBoundOp<CPUDevice, type, int64>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("UpperBound") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("out_type"), \ + UpperBoundOp<GPUDevice, type, int32>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("UpperBound") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_type"), \ + UpperBoundOp<GPUDevice, type, int64>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#endif // GOOGLE_CUDA + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("LowerBound") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("out_type"), \ + LowerBoundOp<CPUDevice, type, int32>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("LowerBound") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_type"), \ + LowerBoundOp<CPUDevice, type, int64>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#if GOOGLE_CUDA + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("LowerBound") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int32>("out_type"), \ + LowerBoundOp<GPUDevice, type, int32>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("LowerBound") \ + .Device(DEVICE_GPU) \ + .TypeConstraint<type>("T") \ + .TypeConstraint<int64>("out_type"), \ + LowerBoundOp<GPUDevice, type, int64>); + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +#endif // GOOGLE_CUDA +} // namespace tensorflow diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h new file mode 100644 index 0000000000..f075bf0fa2 --- /dev/null +++ b/tensorflow/core/kernels/searchsorted_op.h @@ -0,0 +1,52 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_ +#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace functor { + +template <typename Device, typename T, typename OutType> +struct UpperBoundFunctor { + // Searches for values in sorted_inputs and returns the greatest possible + // index where they maintain sorted order. + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& sorted_inputs, + const typename TTypes<T, 1>::ConstTensor& values, + int batch_size, int num_inputs, int num_values, + typename TTypes<OutType, 1>::Tensor* output); +}; + +template <typename Device, typename T, typename OutType> +struct LowerBoundFunctor { + // Searches for values in sorted_inputs and returns the lowest possible + // index where they maintain sorted order. + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& sorted_inputs, + const typename TTypes<T, 1>::ConstTensor& values, + int batch_size, int num_inputs, int num_values, + typename TTypes<OutType, 1>::Tensor* output); +}; +} // namespace functor + +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_ diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc new file mode 100644 index 0000000000..263b5bf298 --- /dev/null +++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc @@ -0,0 +1,126 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#define EIGEN_USE_GPU + +#include "tensorflow/core/kernels/searchsorted_op.h" + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + +namespace tensorflow { +typedef Eigen::GpuDevice GPUDevice; + +namespace { +template <typename T, typename OutType> +__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size, + int sorted_inputs_size, int values_size, + const T* values, OutType* outputs) { + CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) { + int bid = work_unit_id / values_size; + T value = values[work_unit_id]; + outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>( + sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value); + } +} + +template <typename T, typename OutType> +__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size, + int sorted_inputs_size, int values_size, + const T* values, OutType* outputs) { + CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) { + int bid = work_unit_id / values_size; + T value = values[work_unit_id]; + outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>( + sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value); + } +} +} // namespace + +namespace functor { +template <typename T, typename OutType> +struct UpperBoundFunctor<GPUDevice, T, OutType> { + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& sorted_inputs, + const typename TTypes<T, 1>::ConstTensor& values, + int batch_size, int num_inputs, int num_values, + typename TTypes<OutType, 1>::Tensor* output) { + const cudaStream_t& stream = GetCudaStream(context); + CudaLaunchConfig config = + GetCudaLaunchConfig(values.size(), context->eigen_gpu_device()); + + UpperBoundKernel<T> + <<<config.block_count, config.thread_per_block, 0, stream>>>( + sorted_inputs.data(), batch_size, num_inputs, num_values, + values.data(), output->data()); + + return Status::OK(); + } +}; + +template <typename T, typename OutType> +struct LowerBoundFunctor<GPUDevice, T, OutType> { + static Status Compute(OpKernelContext* context, + const typename TTypes<T, 1>::ConstTensor& sorted_inputs, + const typename TTypes<T, 1>::ConstTensor& values, + int batch_size, int num_inputs, int num_values, + typename TTypes<OutType, 1>::Tensor* output) { + const cudaStream_t& stream = GetCudaStream(context); + CudaLaunchConfig config = + GetCudaLaunchConfig(values.size(), context->eigen_gpu_device()); + + LowerBoundKernel<T> + <<<config.block_count, config.thread_per_block, 0, stream>>>( + sorted_inputs.data(), batch_size, num_inputs, num_values, + values.data(), output->data()); + + return Status::OK(); + } +}; +} // namespace functor + +#define REGISTER_GPU_SPEC(type) \ + template struct functor::UpperBoundFunctor<GPUDevice, type, int32>; + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC); +#undef REGISTER_GPU_SPEC + +#define REGISTER_GPU_SPEC(type) \ + template struct functor::UpperBoundFunctor<GPUDevice, type, int64>; + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC); +#undef REGISTER_GPU_SPEC + +#define REGISTER_GPU_SPEC(type) \ + template struct functor::LowerBoundFunctor<GPUDevice, type, int32>; + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC); +#undef REGISTER_GPU_SPEC + +#define REGISTER_GPU_SPEC(type) \ + template struct functor::LowerBoundFunctor<GPUDevice, type, int64>; + +TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC); +#undef REGISTER_GPU_SPEC +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc index 9cd590ae61..30cb1e0a7f 100644 --- a/tensorflow/core/kernels/shape_op_test.cc +++ b/tensorflow/core/kernels/shape_op_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/abi.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE"); -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE", - GetShapeFromKnownVecSize); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize); static void ExpectHasError(const Status& s, StringPiece substr) { EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) @@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) { Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs); EXPECT_FALSE(s.ok()); ExpectHasError( - s, - "No unary variant shape function found for Variant type_name: " - "NO KNOWN SHAPE"); + s, strings::StrCat( + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name()))); } { diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 77594479cb..97f77e45b6 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -411,7 +411,7 @@ class MklSliceOp : public OpKernel { context->input(0).tensor<T, NDIM>(), indices, sizes); } }; -#endif +#endif // INTEL_MKL // Forward declarations of the functor specializations for declared in the // sharded source files. @@ -440,18 +440,14 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N); #undef DECLARE_CPU_SPEC } // namespace functor -#ifndef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #define REGISTER_SLICE(type) \ REGISTER_KERNEL_BUILDER(Name("Slice") \ .Device(DEVICE_CPU) \ .TypeConstraint<type>("T") \ .HostMemory("begin") \ .HostMemory("size"), \ - SliceOp<CPUDevice, type>) - -TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); -TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -#undef REGISTER_SLICE + MklSliceOp<CPUDevice, type>) #else #define REGISTER_SLICE(type) \ REGISTER_KERNEL_BUILDER(Name("Slice") \ @@ -459,12 +455,12 @@ TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); .TypeConstraint<type>("T") \ .HostMemory("begin") \ .HostMemory("size"), \ - MklSliceOp<CPUDevice, type>) + SliceOp<CPUDevice, type>) +#endif // INTEL_MKL && ENABLE_MKL TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); #undef REGISTER_SLICE -#endif // INTEL_MKL #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index 11149c4d16..a4453bd7ab 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -50,10 +50,10 @@ class SparseConditionalAccumulator public: SparseConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) + const string& name, const string& reduction_type) : TypedConditionalAccumulatorBase< std::tuple<const Tensor*, const Tensor*, const Tensor*>>( - dtype, shape, name) { + dtype, shape, name, reduction_type) { accum_idx_vec_ = nullptr; count_element_ = nullptr; accum_val_ = nullptr; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index 80bc1f1934..1e542a26a7 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { SparseConditionalAccumulator<Device, T>* accumulator = - new SparseConditionalAccumulator<Device, T>(dtype_, shape_, - cinfo_.name()); + new SparseConditionalAccumulator<Device, T>( + dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc index 393818730b..a4a59dbcbc 100644 --- a/tensorflow/core/kernels/split_lib_gpu.cu.cc +++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc @@ -54,6 +54,7 @@ void SplitCustom<Device, T>::operator()( TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS); TF_CALL_complex64(DEFINE_GPU_KERNELS); TF_CALL_complex128(DEFINE_GPU_KERNELS); +TF_CALL_int64(DEFINE_GPU_KERNELS); TF_CALL_bfloat16(DEFINE_GPU_KERNELS); #undef DEFINE_GPU_KERNELS diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 7cc3c532c9..11db72bfa3 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel { void ComputeEasyCases(OpKernelContext* context, bool* done) { const Tensor& input = context->input(1); const TensorShape& input_shape = input.shape(); - const int32 split_dim_orig = context->input(0).flat<int32>()(0); + const Tensor& split_dim_tensor = context->input(0); + OP_REQUIRES( + context, split_dim_tensor.shape().dims() == 0, + errors::InvalidArgument("split_dim must be a scalar but has rank ", + split_dim_tensor.shape().dims())); + const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0); const int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; const int32 num_split = num_outputs(); diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index 65296f61fd..add4afafc9 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -131,10 +131,8 @@ class Stack : public ResourceBase { }; Status GetStack(OpKernelContext* ctx, Stack** stack) { - string key; if (ctx->input_dtype(0) == DT_RESOURCE) { - auto resource = ctx->input(0).flat<ResourceHandle>()(0); - key = resource.name(); + return LookupResource(ctx, HandleFromInput(ctx, 0), stack); } else { Tensor Tstack_handle = ctx->mutable_input(0, false); if (Tstack_handle.NumElements() != 2) { @@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) { } const string& container = Tstack_handle.flat<string>()(0); const string& stack_name = Tstack_handle.flat<string>()(1); - key = strings::StrCat(container, stack_name); - } - ResourceMgr* rm = ctx->resource_manager(); - if (rm == nullptr) { - return errors::Internal("No resource manager."); - } - auto* step_container = ctx->step_container(); - if (step_container == nullptr) { - return errors::Internal("No step container."); + string key = strings::StrCat(container, stack_name); + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) { + return errors::Internal("No resource manager."); + } + auto* step_container = ctx->step_container(); + if (step_container == nullptr) { + return errors::Internal("No step container."); + } + TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); + return Status::OK(); } - TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); - return Status::OK(); } std::atomic<int64> Stack::stack_counter{0}; diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 7b537fef5b..f0575de4d9 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -306,6 +306,7 @@ class StridedSliceAssignOp : public OpKernel { Var* v; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &v)); + core::ScopedUnref scoped_unref(v); mutex_lock ml(*v->mu()); OP_REQUIRES_OK(context, PrepareToUpdateVariable<Device, T>(context, v->tensor())); diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc new file mode 100644 index 0000000000..e4a1887f8d --- /dev/null +++ b/tensorflow/core/kernels/string_format_op.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <iostream> +#include "absl/strings/str_split.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +class StringFormatOp : public OpKernel { + public: + explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string template_; + OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_)); + + split_template_ = absl::StrSplit(template_, placeholder_); + int64 num_placeholders = split_template_.size() - 1; + OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders, + errors::InvalidArgument(strings::StrCat( + "num placeholders in template and num inputs must match: ", + num_placeholders, " vs. ", ctx->num_inputs()))); + } + + void Compute(OpKernelContext* ctx) override { + Tensor* formatted_string = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &formatted_string)); + + string msg; + strings::StrAppend(&msg, split_template_[0].c_str()); + for (int i = 0; i < ctx->num_inputs(); ++i) { + strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true)); + strings::StrAppend(&msg, split_template_[i + 1].c_str()); + } + + formatted_string->scalar<string>()() = msg; + } + + private: + int32 summarize_ = 0; + string placeholder_; + std::vector<std::string> split_template_; +}; + +REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU), + StringFormatOp); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc new file mode 100644 index 0000000000..13130a5797 --- /dev/null +++ b/tensorflow/core/kernels/string_format_op_test.cc @@ -0,0 +1,66 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace tensorflow { +namespace { + +class StringFormatGraphTest : public OpsTestBase { + protected: + Status Init(int num_inputs, DataType input_type, + const string& template_ = "%s", const string& placeholder = "%s", + int summarize = 3) { + TF_CHECK_OK(NodeDefBuilder("op", "StringFormat") + .Input(FakeInput(num_inputs, input_type)) + .Attr("template", template_) + .Attr("placeholder", placeholder) + .Attr("summarize", summarize) + .Finalize(node_def())); + return InitOp(); + } +}; + +TEST_F(StringFormatGraphTest, Int32Success_7) { + TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s")); + + AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({})); + test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"}); + test::ExpectTensorEqual<string>(expected, *GetOutput(0)); +} + +TEST_F(StringFormatGraphTest, Int32Success_3_3) { + TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1)); + + AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9}); + TF_ASSERT_OK(RunOpKernel()); + Tensor expected(allocator(), DT_STRING, TensorShape({})); + test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..." + "\n [7 ... 9]]"}); + test::ExpectTensorEqual<string>(expected, *GetOutput(0)); +} + +} // end namespace +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc index a6829b29d9..435a7abdca 100644 --- a/tensorflow/core/kernels/string_length_op.cc +++ b/tensorflow/core/kernels/string_length_op.cc @@ -14,13 +14,18 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/string_util.h" namespace tensorflow { namespace { class StringLengthOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); @@ -32,10 +37,22 @@ class StringLengthOp : public OpKernel { auto src = input.flat<string>(); auto dst = output->flat<int32>(); - for (int n = 0; n < src.size(); ++n) { - dst(n) = src(n).size(); + switch (unit_) { + case CharUnit::BYTE: + for (int n = 0; n < src.size(); ++n) { + dst(n) = src(n).size(); + } + break; + case CharUnit::UTF8_CHAR: + for (int n = 0; n < src.size(); ++n) { + dst(n) = UTF8StrLen(src(n)); + } + break; } } + + private: + CharUnit unit_ = CharUnit::BYTE; }; REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index 2aeafa28c4..544dca96ba 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -43,7 +43,7 @@ class StringStripOp : public OpKernel { for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); str_util::RemoveWhitespaceContext(&entry); - output(i) = std::string(entry); + output(i) = string(entry); } } }; diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc new file mode 100644 index 0000000000..3a9803a052 --- /dev/null +++ b/tensorflow/core/kernels/string_util.cc @@ -0,0 +1,63 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/string_util.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace { +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } +} // namespace + +namespace tensorflow { + +// Sets unit value based on str. +Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) { + if (str == "UTF8") { + *encoding = UnicodeEncoding::UTF8; + } else { + return errors::InvalidArgument(strings::StrCat( + "Invalid encoding \"", str, "\": Should be one of: BYTE")); + } + return Status::OK(); +} + +// Sets unit value based on str. +Status ParseCharUnit(const string& str, CharUnit* unit) { + if (str == "BYTE") { + *unit = CharUnit::BYTE; + } else if (str == "UTF8_CHAR") { + *unit = CharUnit::UTF8_CHAR; + } else { + return errors::InvalidArgument(strings::StrCat( + "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR")); + } + return Status::OK(); +} + +// Return the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& string) { + const int32 byte_size = string.size(); + const char* const end = string.data() + byte_size; + const char* ptr = string.data(); + int32 skipped_count = 0; + while (ptr < end) { + skipped_count += IsTrailByte(*ptr++) ? 1 : 0; + } + const int32 result = byte_size - skipped_count; + return result; +} + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h new file mode 100644 index 0000000000..390cf57702 --- /dev/null +++ b/tensorflow/core/kernels/string_util.h @@ -0,0 +1,45 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ + +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Enumeration for unicode encodings. Used by ops such as +// tf.strings.unicode_encode and tf.strings.unicode_decode. +// TODO(edloper): Add support for: +// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE +enum class UnicodeEncoding { UTF8 }; + +// Enumeration for character units. Used by string such as +// tf.strings.length and tf.substr. +// TODO(edloper): Add support for: UTF32_CHAR, etc. +enum class CharUnit { BYTE, UTF8_CHAR }; + +// Sets `encoding` based on `str`. +Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); + +// Sets `unit` value based on `str`. +Status ParseCharUnit(const string& str, CharUnit* unit); + +// Returns the number of Unicode characters in a UTF-8 string. +// Result may be incorrect if the input string is not valid UTF-8. +int32 UTF8StrLen(const string& string); + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 22e45918a0..07f1d6e767 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <cstddef> +#include <cstdlib> #include <string> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -25,6 +27,8 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { @@ -64,26 +68,28 @@ class SubstrOp : public OpKernel { const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); + StringPiece in(input(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size() + 1), + context, FastBoundsCheck(std::abs(pos), in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } } else { // Perform Op element-wise with tensor pos/len auto pos_flat = pos_tensor.flat<T>(); auto len_flat = len_tensor.flat<T>(); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); + StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size() + 1), + context, FastBoundsCheck(std::abs(pos), in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } } } else { @@ -142,14 +148,16 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { - string in = input_bcast(i); + StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, input_bcast(i).size() + 1), + context, + FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } break; } @@ -192,16 +200,18 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { for (int j = 0; j < output_shape.dim_size(1); ++j) { - string in = input_bcast(i, j); + StringPiece in(input_bcast(i, j)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1), - errors::InvalidArgument( - "pos ", pos, " out of range for ", "string b'", - in, "' at index (", i, ", ", j, ")")); - output(i, j) = in.substr(pos, len); + OP_REQUIRES( + context, FastBoundsCheck(std::abs(pos), in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", i, + ", ", j, ")")); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i, j).assign(sub_in.data(), sub_in.size()); } } break; @@ -213,6 +223,16 @@ class SubstrOp : public OpKernel { } } } + + private: + // This adjusts the requested position. Note it does not perform any bound + // checks. + T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + if (pos_requested < 0) { + return s.size() + pos_requested; + } + return pos_requested; + } }; #define REGISTER_SUBSTR(type) \ diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc new file mode 100644 index 0000000000..2e07050260 --- /dev/null +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Test data from the TensorFlow README.md. +const char* lines[] = { + "**TensorFlow** is an open source software library for numerical " + "computation using data flow graphs.", + "The graph nodes represent mathematical operations, while the graph edges " + "represent the multidimensional data arrays (tensors) that flow between " + "them.", + "This flexible architecture enables you to deploy computation to one or " + "more CPUs or GPUs in a desktop, server, or mobile device without " + "rewriting code.", + "TensorFlow also includes " + "[TensorBoard](https://www.tensorflow.org/guide/" + "summaries_and_tensorboard), a data visualization toolkit.", + "TensorFlow was originally developed by researchers and engineers working " + "on the Google Brain team within Google's Machine Intelligence Research " + "organization for the purposes of conducting machine learning and deep " + "neural networks research.", + "The system is general enough to be applicable in a wide variety of other " + "domains, as well.", + "TensorFlow provides stable Python API and C APIs as well as without API " + "backwards compatibility guarantee like C++, Go, Java, JavaScript and " + "Swift."}; + +Tensor GetTestTensor(int batch) { + const int sz = TF_ARRAYSIZE(lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat<string>(); + for (int i = 0; i < batch; ++i) { + s(i) = lines[i % sz]; + } + return t; +} + +Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor position(DT_INT32, TensorShape({})); + position.flat<int32>().setConstant(pos); + Tensor length(DT_INT32, TensorShape({})); + length.flat<int32>().setConstant(len); + + TF_CHECK_OK(NodeBuilder("substr_op", "Substr") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, position)) + .Input(test::graph::Constant(g, length)) + .Finalize(g, nullptr /* node */)); + return g; +} + +void BM_Substr(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters)); + testing::UseRealTime(); + Tensor input = GetTestTensor(batch_size); + Graph* g = SetupSubstrGraph(input, 3, 30); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg( + 256); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc index 765467bc1e..0e6c0ddccc 100644 --- a/tensorflow/core/kernels/tensor_array.cc +++ b/tensorflow/core/kernels/tensor_array.cc @@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU); } #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T) -TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU) +TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU); +TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU); #undef TENSOR_ARRAY_SET_ZERO_CPU #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h index e8dc4fad21..384a63e945 100644 --- a/tensorflow/core/kernels/tensor_array.h +++ b/tensorflow/core/kernels/tensor_array.h @@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) { Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value); #define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T) -TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU) +TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU); +TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU); #undef TENSOR_ARRAY_SET_ZERO_CPU #if GOOGLE_CUDA diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 632b65e9b6..a97a71b344 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -259,6 +259,7 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU), TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); REGISTER_GPU(bfloat16); #undef REGISTER_GPU @@ -290,14 +291,14 @@ class TensorArrayGradOp : public TensorArrayCreationOp { } } else { container = "_tensor_arrays"; - auto resource = ctx->input(0).flat<ResourceHandle>()(0); + const auto& resource = ctx->input(0).flat<ResourceHandle>()(0); if (StringPiece(resource.name()).substr(0, container.size()) != container) { return errors::InvalidArgument("Wrong input container. ", resource.name()); } tensor_array_name = - std::string(StringPiece(resource.name()).substr(container.size())); + string(StringPiece(resource.name()).substr(container.size())); } auto output_handle = tensor_array_output_handle->flat<string>(); @@ -576,6 +577,7 @@ TF_CALL_ALL_TYPES(REGISTER_READ) TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); REGISTER_GPU(bfloat16); #undef REGISTER_GPU @@ -1218,6 +1220,7 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +TF_CALL_int64(REGISTER_GPU); #undef REGISTER_GPU #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc index ca296d5aa0..2fbe1fe7cb 100644 --- a/tensorflow/core/kernels/topk_op_gpu.cu.cc +++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc @@ -20,9 +20,9 @@ limitations under the License. #include <cmath> #include <vector> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh" -#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" -#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" +#include "third_party/cub/device/device_segmented_radix_sort.cuh" +#include "third_party/cub/iterator/counting_input_iterator.cuh" +#include "third_party/cub/iterator/transform_input_iterator.cuh" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc index d3c4f62071..4262a5404b 100644 --- a/tensorflow/core/kernels/training_op_helpers.cc +++ b/tensorflow/core/kernels/training_op_helpers.cc @@ -15,13 +15,16 @@ limitations under the License. #include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/util/ptr_util.h" + namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource) { + *maybe_resource = nullptr; if (ctx->input_dtype(input) == DT_RESOURCE) { - Var* var; - if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) { - return var->mu(); + if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { + return (*maybe_resource)->mu(); } else { ctx->CtxFailureWithWarning( errors::Internal("Invalid variable reference.")); @@ -32,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { } // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes -// in address order to mitigate deadlock. Returns a vector of acquired mutexes. -// Safe to pass duplicates - will only lock each distinct mutex once. If -// do_lock is false, returns immediately. Note that this silently doesn't lock -// mutexes for invalid variable references; in all usages this is followed by -// GetInputTensor which will signal a failure. -std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( +// in address order to mitigate deadlock. Returns a structure that, when +// deleted, will release the acquired mutexes. Safe to pass duplicates - will +// only lock each distinct mutex once. If do_lock is false, returns +// immediately. Note that this silently doesn't lock mutexes for invalid +// variable references; in all usages this is followed by GetInputTensor which +// will signal a failure. +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) { bool any_resource = false; for (auto i : input_ids) { @@ -46,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( break; } } - std::vector<mutex_lock> locks; if (!do_lock && !any_resource) { - return locks; + return VariableInputLockHolder({}, {}); } + std::vector<Var*> vars; std::vector<mutex*> mutexes; std::vector<int> acquire_order; for (auto input : input_ids) { - mutex* mutex = GetTrainingVariableMutex(ctx, input); + Var* var; + mutex* mutex = GetTrainingVariableMutex(ctx, input, &var); + if (var) vars.push_back(var); // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { acquire_order.push_back(mutexes.size()); @@ -63,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( std::sort(acquire_order.begin(), acquire_order.end(), [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); + std::unique_ptr<std::vector<mutex_lock>> locks = + MakeUnique<std::vector<mutex_lock>>(); + locks->reserve(acquire_order.size()); + for (auto input : acquire_order) { - mutex* mu = GetTrainingVariableMutex(ctx, input); + Var* var; + mutex* mu = GetTrainingVariableMutex(ctx, input, &var); + core::ScopedUnref scoped_unref(var); if (mu != nullptr) { - locks.emplace_back(*mu); + locks->emplace_back(*mu); } } - return locks; + return VariableInputLockHolder(std::move(vars), std::move(locks)); } void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index 071cb371a7..9f173a80f7 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -23,9 +23,42 @@ limitations under the License. namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input); +// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. +// +// If `input` corresponds to a `DT_RESOURCE`-type variable input, +// `*maybe_resource` will be updated to contain the underlying resource, and the +// caller will be responsible for calling `Unref()` on that resource. +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource); -std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( +// Utility structure that releases a sequence of borrowed mutexes when it is +// deleted. +struct VariableInputLockHolder { + public: + VariableInputLockHolder(std::vector<Var*> vars, + std::unique_ptr<std::vector<mutex_lock>> locks) + : vars_(std::move(vars)), locks_(std::move(locks)) {} + + VariableInputLockHolder(VariableInputLockHolder&& other) + : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {} + + ~VariableInputLockHolder() { + // Release the locks before unreffing the Vars, because each lock + // is potentially borrowed from a Var in vars_. + locks_.reset(); + for (Var* var : vars_) { + var->Unref(); + } + } + + private: + std::vector<Var*> vars_; + // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, + // because a `std::vector<mutex_lock>` is not movable on all platforms. + std::unique_ptr<std::vector<mutex_lock>> locks_; +}; + +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids); void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 9a07ded17d..acf162deec 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -561,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - mutex* mu = GetTrainingVariableMutex(ctx, 0); + Var* resource; + mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource); + core::ScopedUnref scoped_unref(resource); if (use_exclusive_lock_ && mu != nullptr) { mutex_lock l1(*mu); // Don't try to acquire a lock on the second ref as they share the same @@ -710,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - mutex* mu = GetTrainingVariableMutex(ctx, 0); + Var* var; + mutex* mu = GetTrainingVariableMutex(ctx, 0, &var); + core::ScopedUnref scoped_unref(var); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. // diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 0f0f65c5a3..48e392c070 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, perm, out); } -#if defined(INTEL_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #define REGISTER(T) \ REGISTER_KERNEL_BUILDER(Name("Transpose") \ .Device(DEVICE_CPU) \ @@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, .TypeConstraint<T>("T") \ .HostMemory("perm"), \ MklConjugateTransposeCpuOp); -TF_CALL_ALL_TYPES(REGISTER); -#undef REGISTER - -#else // INTEL_MKL +#else // INTEL_MKL && ENABLE_MKL #define REGISTER(T) \ REGISTER_KERNEL_BUILDER(Name("Transpose") \ .Device(DEVICE_CPU) \ @@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER); .TypeConstraint<T>("T") \ .HostMemory("perm"), \ ConjugateTransposeCpuOp); +#endif // INTEL_MKL && ENABLE_MKL + TF_CALL_ALL_TYPES(REGISTER) #undef REGISTER -#endif // INTEL_MKL #if GOOGLE_CUDA Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h index 9dedb618f9..ca341e511e 100644 --- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h +++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h @@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { public: TypedConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : ConditionalAccumulatorBase(dtype, shape, name) {} + const string& name, + const string& reduction_type) + : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} /** * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc index 62e814ff77..8d839ba85a 100644 --- a/tensorflow/core/kernels/unravel_index_op.cc +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -97,10 +97,12 @@ class UnravelIndexOp : public OpKernel { auto output = output_tensor->matrix<Tidx>(); - Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}}; - Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()}); - Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}}; - Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1}); + Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}}; + Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()}); + Eigen::array<Eigen::Index, 2> indices_reshape{ + {1, indices_tensor.NumElements()}}; + Eigen::array<Eigen::Index, 2> indices_bcast( + {dims_tensor.NumElements(), 1}); output = indices_tensor.vec<Tidx>() .reshape(indices_reshape) diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h index 8879d9dd4c..2255597651 100644 --- a/tensorflow/core/kernels/where_op_gpu.cu.h +++ b/tensorflow/core/kernels/where_op_gpu.cu.h @@ -21,10 +21,10 @@ limitations under the License. #define EIGEN_USE_GPU #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "external/cub_archive/cub/device/device_reduce.cuh" -#include "external/cub_archive/cub/device/device_select.cuh" -#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh" -#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh" +#include "third_party/cub/device/device_reduce.cuh" +#include "third_party/cub/device/device_select.cuh" +#include "third_party/cub/iterator/counting_input_iterator.cuh" +#include "third_party/cub/iterator/transform_input_iterator.cuh" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/bounds_check.h" diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index ed2bf3e8e2..1bf46b5e46 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel { "Contents tensor must be scalar, but had shape: ", contents_input->shape().DebugString())); const string& filename = filename_input->scalar<string>()(); - const string dir = std::string(io::Dirname(filename)); + const string dir(io::Dirname(filename)); if (!context->env()->FileExists(dir).ok()) { OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); } diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h index 982901a39c..d5cbe6c616 100644 --- a/tensorflow/core/lib/core/errors.h +++ b/tensorflow/core/lib/core/errors.h @@ -136,11 +136,9 @@ string FormatNodeNamesForError(const T& names) { ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s)); }); } -// TODO(b/113350742): Consolidate the two different formats `{{key value}}` and -// `^^key:value^^` in a follow-on CL. // LINT.IfChange inline string FormatColocationNodeForError(const string& name) { - return strings::StrCat("^^colocation_node:", name, "^^"); + return strings::StrCat("{{colocation_node ", name, "}}"); } // LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) template <typename T> diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h index 49f74ff47f..eb0ff555a5 100644 --- a/tensorflow/core/lib/core/status.h +++ b/tensorflow/core/lib/core/status.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index e7b17c9b36..6edff139ae 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -26,13 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ #define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ -#include <assert.h> -#include <stddef.h> -#include <string.h> -#include <iosfwd> -#include <string> #include "absl/strings/string_view.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc index 99684ae47b..9ccd911b0e 100644 --- a/tensorflow/core/lib/core/threadpool.cc +++ b/tensorflow/core/lib/core/threadpool.cc @@ -17,6 +17,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/denormal.h" #include "tensorflow/core/platform/logging.h" @@ -120,6 +121,54 @@ void ThreadPool::Schedule(std::function<void()> fn) { impl_->Schedule(std::move(fn)); } +int ThreadPool::NumShardsUsedByTransformRangeConcurrently( + const int64 block_size, const int64 total) { + if (block_size <= 0 || total <= 1 || total <= block_size || + NumThreads() == 1) { + return 1; + } + return (total + block_size - 1) / block_size; +} + +// This functionality is similar to parallelFor, except that reasoning about +// the number of shards used is significantly easier. +void ThreadPool::TransformRangeConcurrently( + const int64 block_size, const int64 total, + const std::function<void(int64, int64)>& fn) { + const int num_shards_used = + NumShardsUsedByTransformRangeConcurrently(block_size, total); + if (num_shards_used == 1) { + fn(0, total); + return; + } + + // Adapted from Eigen's parallelFor implementation. + BlockingCounter counter(num_shards_used); + std::function<void(int64, int64)> handle_range = + [=, &handle_range, &counter, &fn](int64 first, int64 last) { + while (last - first > block_size) { + // Find something near the midpoint which is a multiple of block size. + const int64 mid = first + ((last - first) / 2 + block_size - 1) / + block_size * block_size; + Schedule([=, &handle_range]() { handle_range(mid, last); }); + last = mid; + } + // Single block or less, execute directly. + fn(first, last); + counter.DecrementCount(); // The shard is done. + }; + if (num_shards_used <= NumThreads()) { + // Avoid a thread hop by running the root of the tree and one block on the + // main thread. + handle_range(0, total); + } else { + // Execute the root in the thread pool to avoid running work on more than + // numThreads() threads. + Schedule([=, &handle_range]() { handle_range(0, total); }); + } + counter.Wait(); +} + void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit, std::function<void(int64, int64)> fn) { impl_->ParallelFor(total, cost_per_unit, std::move(fn)); diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h index 74df7c84a4..e14ad7ac64 100644 --- a/tensorflow/core/lib/core/threadpool.h +++ b/tensorflow/core/lib/core/threadpool.h @@ -59,6 +59,20 @@ class ThreadPool { // Schedules fn() for execution in the pool of threads. void Schedule(std::function<void()> fn); + // Requires 0 < block_size <= total. + // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the + // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total) + // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...). + // Note that when there aren't enough threads in the pool to achieve full + // parallelism, function calls will be automatically queued. + void TransformRangeConcurrently(const int64 block_size, const int64 total, + const std::function<void(int64, int64)>& fn); + + // Returns the number of threads spawned by calling TransformRangeConcurrently + // with these parameters. + int NumShardsUsedByTransformRangeConcurrently(const int64 block_size, + const int64 total); + // ParallelFor shards the "total" units of work assuming each unit of work // having roughly "cost_per_unit" cost, in cycles. Each unit of work is // indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc index 320f3ebb83..db996b783f 100644 --- a/tensorflow/core/lib/core/threadpool_test.cc +++ b/tensorflow/core/lib/core/threadpool_test.cc @@ -61,6 +61,67 @@ TEST(ThreadPool, DoWork) { } } +void RunSharding(int64 block_size, int64 total, ThreadPool* threads) { + mutex mu; + int64 num_shards = 0; + int64 num_done_work = 0; + std::vector<bool> work(total, false); + threads->TransformRangeConcurrently( + block_size, total, + [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) { + VLOG(1) << "Shard [" << start << "," << end << ")"; + EXPECT_GE(start, 0); + EXPECT_LE(end, total); + mutex_lock l(mu); + ++num_shards; + for (; start < end; ++start) { + EXPECT_FALSE(work[start]); // No duplicate + ++num_done_work; + work[start] = true; + } + }); + LOG(INFO) << block_size << " " << total; + const int64 num_workers = (total + block_size - 1) / block_size; + EXPECT_EQ(num_done_work, total); + if (num_workers < threads->NumThreads()) { + // If the intention is to limit the parallelism explicitly, we'd + // better honor it. Ideally, even if per_thread_max_parallelism > + // num_workers, we should expect that Shard() implementation do + // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor + // tends to over-shard. + EXPECT_LE(num_shards, 1 + num_workers); + } +} + +// Adapted from work_sharder_test.cc +TEST(SparseUtilsTest, TransformRangeConcurrently) { + ThreadPool threads(Env::Default(), "test", 16); + for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) { + for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) { + const int64 total = block_size + diff; + RunSharding(block_size, total, &threads); + } + } +} + +TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) { + ThreadPool threads(Env::Default(), "test", 16); + EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 3 /* total */)); + EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 4 /* total */)); + EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 5 /* total */)); + EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 6 /* total */)); + EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently( + 3 /* block_size */, 7 /* total */)); + EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently( + 1 /* block_size */, 7 /* total */)); + EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently( + 0 /* block_size */, 7 /* total */)); +} + TEST(ThreadPool, ParallelFor) { Context outer_context(ContextKind::kThread); // Make ParallelFor use as many threads as possible. diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index c18dc9ad1a..2d622dc229 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -13,674 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage -// for sequences of length <= N are provided inline without requiring -// any heap allocation. Typically N is very small (e.g., 4) so that -// sequences that are expected to be short do not require allocations. -// -// Only some of the std::vector<> operations are currently implemented. -// Other operations may be added as needed to facilitate migrating -// code that uses std::vector<> to InlinedVector<>. -// -// NOTE: If you want an inlined version to replace use of a -// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS> -// in util/bitmap/inlined_bitvector.h -// -// TODO(billydonahue): change size_t to size_type where appropriate. - #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ -#include <stddef.h> -#include <stdlib.h> -#include <string.h> -#include <sys/types.h> -#include <algorithm> -#include <cstddef> -#include <iterator> -#include <memory> -#include <type_traits> -#include <vector> - -#include "tensorflow/core/lib/gtl/manual_constructor.h" -#include "tensorflow/core/platform/byte_order.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" +#include "absl/container/inlined_vector.h" +// TODO(kramerb): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include <initializer_list> // NOLINT(build/include_order) - namespace tensorflow { namespace gtl { -template <typename T, int N> -class InlinedVector { - public: - typedef T value_type; - typedef T* pointer; - typedef const T* const_pointer; - typedef T& reference; - typedef const T& const_reference; - typedef size_t size_type; - typedef std::ptrdiff_t difference_type; - typedef pointer iterator; - typedef const_pointer const_iterator; - - // Create an empty vector - InlinedVector(); - - // Create a vector with n copies of value_type(). - explicit InlinedVector(size_t n); - - // Create a vector with n copies of elem - InlinedVector(size_t n, const value_type& elem); - - // Create and initialize with the elements [range_start .. range_end). - // The unused enable_if argument restricts this constructor so that it is - // elided when value_type is an integral type. This prevents ambiguous - // interpretation between a call to this constructor with two integral - // arguments and a call to the preceding (n, elem) constructor. - template <typename InputIterator> - InlinedVector( - InputIterator range_start, InputIterator range_end, - typename std::enable_if<!std::is_integral<InputIterator>::value>::type* = - NULL) { - InitRep(); - AppendRange(range_start, range_end); - } - - InlinedVector(std::initializer_list<value_type> init) { - InitRep(); - AppendRange(init.begin(), init.end()); - } - - InlinedVector(const InlinedVector& v); - - ~InlinedVector() { clear(); } - - InlinedVector& operator=(const InlinedVector& v) { - // Optimized to avoid reallocation. - // Prefer reassignment to copy construction for elements. - const size_t s = size(); - const size_t vs = v.size(); - if (s < vs) { // grow - reserve(vs); - if (s) std::copy(v.begin(), v.begin() + s, begin()); - std::copy(v.begin() + s, v.end(), std::back_inserter(*this)); - } else { // maybe shrink - erase(begin() + vs, end()); - std::copy(v.begin(), v.end(), begin()); - } - return *this; - } - - size_t size() const { return size_internal(); } - - bool empty() const { return (size() == 0); } - - // Return number of elements that can be stored in vector - // without requiring a reallocation of underlying memory - size_t capacity() const { - if (is_inline()) { - return kFit; - } else { - return static_cast<size_t>(1) << u_.data[kSize - 2]; - } - } - - // Return a pointer to the underlying array. - // Only result[0,size()-1] are defined. - pointer data() { - if (is_inline()) { - return reinterpret_cast<T*>(u_.data); - } else { - return outofline_pointer(); - } - } - const_pointer data() const { - return const_cast<InlinedVector<T, N>*>(this)->data(); - } - - // Remove all elements - void clear() { - DiscardStorage(); - u_.data[kSize - 1] = 0; - } - - // Return the ith element - // REQUIRES: 0 <= i < size() - const value_type& at(size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - const value_type& operator[](size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - - // Return a non-const reference to the ith element - // REQUIRES: 0 <= i < size() - value_type& at(size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - value_type& operator[](size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - - value_type& back() { - DCHECK(!empty()); - return at(size() - 1); - } - - const value_type& back() const { - DCHECK(!empty()); - return at(size() - 1); - } - - value_type& front() { - DCHECK(!empty()); - return at(0); - } - - const value_type& front() const { - DCHECK(!empty()); - return at(0); - } - - // Append a T constructed with args to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - template <typename... Args> - void emplace_back(Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - if (s < capacity()) { - new (data() + s) T(std::forward<Args>(args)...); - set_size_internal(s + 1); - } else { - EmplaceBackSlow(std::forward<Args>(args)...); - } - } - - // Append t to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - void push_back(const value_type& t) { emplace_back(t); } - void push_back(value_type&& t) { emplace_back(std::move(t)); } - - inline void pop_back() { - DCHECK(!empty()); - const size_t s = size(); - Destroy(data() + s - 1, 1); - set_size_internal(s - 1); - } - - // Resizes the vector to contain "n" elements. - // If "n" is smaller than the initial size, extra elements are destroyed. - // If "n" is larger than the initial size, enough copies of "elem" - // are appended to increase the size to "n". If "elem" is omitted, - // new elements are value-initialized. - void resize(size_t n) { Resize<ValueInit>(n, nullptr); } - void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); } - - iterator begin() { return data(); } - const_iterator begin() const { return data(); } - - iterator end() { return data() + size(); } - const_iterator end() const { return data() + size(); } - - iterator insert(iterator pos, const value_type& v); - - iterator erase(iterator pos) { - DCHECK_LT(pos, end()); - DCHECK_GE(pos, begin()); - std::copy(pos + 1, end(), pos); - pop_back(); - return pos; - } - - iterator erase(iterator first, iterator last); - - // Enlarges the underlying representation so it can hold at least - // "n" elements without reallocation. - // Does not change size() or the actual contents of the vector. - void reserve(size_t n) { - if (n > capacity()) { - // Make room for new elements - Grow<Move>(n); - } - } - - // Swap the contents of *this with other. - // REQUIRES: value_type is swappable and copyable. - void swap(InlinedVector& other); - - private: - // Representation can either be inlined or out-of-line. - // In either case, at least sizeof(void*) + 8 bytes are available. - // - // Inlined: - // Last byte holds the length. - // First (length*sizeof(T)) bytes stores the elements. - // Outlined: - // Last byte holds kSentinel. - // Second-last byte holds lg(capacity) - // Preceding 6 bytes hold size. - // First sizeof(T*) bytes hold pointer. - - // Compute rep size. - static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag - static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align - - // See how many fit T we can fit inside kSize, but no more than 254 - // since 255 is used as sentinel tag for out-of-line allocation. - static const unsigned int kSentinel = 255; - static const size_t kFit1 = (kSize - 1) / sizeof(T); - static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1; - - union { - unsigned char data[kSize]; - // Force data to be aligned enough for a pointer. - T* unused_aligner; - } u_; - - inline void InitRep() { u_.data[kSize - 1] = 0; } - inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; } - - inline T* outofline_pointer() const { - T* ptr; - memcpy(&ptr, &u_.data[0], sizeof(ptr)); - return ptr; - } - - inline void set_outofline_pointer(T* p) { - memcpy(&u_.data[0], &p, sizeof(p)); - } - - inline uint64_t outofline_word() const { - uint64_t word; - memcpy(&word, &u_.data[kSize - 8], sizeof(word)); - return word; - } - - inline void set_outofline_word(uint64_t w) { - memcpy(&u_.data[kSize - 8], &w, sizeof(w)); - } - - inline size_t size_internal() const { - uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]); - if (s != kSentinel) { - return static_cast<size_t>(s); - } else { - const uint64_t word = outofline_word(); - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - return static_cast<size_t>(word & 0xffffffffffffull); - } else { - // The sentinel and capacity bits are least-significant bits in word. - return static_cast<size_t>(word >> 16); - } - } - } - - void set_size_internal(size_t n) { - if (is_inline()) { - DCHECK_LT(n, kSentinel); - u_.data[kSize - 1] = static_cast<unsigned char>(n); - } else { - uint64_t word; - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - word = (static_cast<uint64_t>(n) | - (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) | - (static_cast<uint64_t>(kSentinel) << 56)); - } else { - // The sentinel and capacity bits are least-significant bits in word. - word = ((static_cast<uint64_t>(n) << 16) | - (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) | - (static_cast<uint64_t>(kSentinel))); - } - set_outofline_word(word); - DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n; - } - } - - void DiscardStorage() { - T* base = data(); - size_t n = size(); - Destroy(base, n); - if (!is_inline()) { - port::Free(base); - } - } - - template <typename... Args> - void EmplaceBackSlow(Args&&... args) { - const size_t s = size(); - DCHECK_EQ(s, capacity()); - Grow<Move, Construct>(s + 1, std::forward<Args>(args)...); - set_size_internal(s + 1); - } - - // Movers for Grow - // Does nothing. - static void Nop(T* src, size_t n, T* dst) {} - - // Moves srcs[0,n-1] contents to dst[0,n-1]. - static void Move(T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(std::move(*(src + i))); - } - } - - // Initializers for Resize. - // Initializes dst[0,n-1] with empty constructor. - static void ValueInit(const T*, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(); - } - } - - // Initializes dst[0,n-1] with copies of *src. - static void Fill(const T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(*src); - } - } - - void Destroy(T* src, int n) { - if (!std::is_trivially_destructible<T>::value) { - for (int i = 0; i < n; i++) { - (src + i)->~T(); - } - } - } - - // Initialization methods for Grow. - // 1) Leave uninitialized memory. - struct Uninitialized { - void operator()(T*) const {} - }; - // 2) Construct a T with args at not-yet-initialized memory pointed by dst. - struct Construct { - template <class... Args> - void operator()(T* dst, Args&&... args) const { - new (dst) T(std::forward<Args>(args)...); - } - }; - - // Grow so that capacity >= n. Uses Mover to move existing elements - // to new buffer, and possibly initialize the new element according - // to InitType. - // We pass the InitType and Mover as template arguments so that - // this code compiles even if T does not support copying or default - // construction. - template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized, - class... Args> - void Grow(size_t n, Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - - // Compute new capacity by repeatedly doubling current capacity - size_t target = 1; - size_t target_lg = 0; - while (target < kFit || target < n) { - // TODO(psrc): Check and avoid overflow? - target_lg++; - target <<= 1; - } - - T* src = data(); - T* dst = static_cast<T*>(port::Malloc(target * sizeof(T))); - - // Need to copy elem before discarding src since it might alias src. - InitType{}(dst + s, std::forward<Args>(args)...); - Mover(src, s, dst); - DiscardStorage(); - - u_.data[kSize - 1] = kSentinel; - u_.data[kSize - 2] = static_cast<unsigned char>(target_lg); - set_size_internal(s); - DCHECK_EQ(capacity(), target); - set_outofline_pointer(dst); - } - - // Resize to size n. Any new elements are initialized by passing - // elem and the destination to Initializer. We pass the Initializer - // as a template argument so that this code compiles even if T does - // not support copying. - template <void(Initializer)(const T*, size_t, T*)> - void Resize(size_t n, const T* elem) { - size_t s = size(); - if (n <= s) { - Destroy(data() + n, s - n); - set_size_internal(n); - return; - } - reserve(n); - DCHECK_GE(capacity(), n); - set_size_internal(n); - Initializer(elem, n - s, data() + s); - } - - template <typename Iter> - void AppendRange(Iter first, Iter last, std::input_iterator_tag); - - // Faster path for forward iterators. - template <typename Iter> - void AppendRange(Iter first, Iter last, std::forward_iterator_tag); - - template <typename Iter> - void AppendRange(Iter first, Iter last); -}; - -// Provide linkage for constants. -template <typename T, int N> -const size_t InlinedVector<T, N>::kSizeUnaligned; -template <typename T, int N> -const size_t InlinedVector<T, N>::kSize; -template <typename T, int N> -const unsigned int InlinedVector<T, N>::kSentinel; -template <typename T, int N> -const size_t InlinedVector<T, N>::kFit1; -template <typename T, int N> -const size_t InlinedVector<T, N>::kFit; - -template <typename T, int N> -inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) { - a.swap(b); -} - -template <typename T, int N> -inline bool operator==(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); -} - -template <typename T, int N> -inline bool operator!=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(a == b); -} - -template <typename T, int N> -inline bool operator<(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); -} - -template <typename T, int N> -inline bool operator>(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return b < a; -} - -template <typename T, int N> -inline bool operator<=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(b < a); -} - -template <typename T, int N> -inline bool operator>=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(a < b); -} - -// ======================================== -// Implementation - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector() { - InitRep(); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(size_t n) { - InitRep(); - if (n > capacity()) { - Grow<Nop>(n); // Must use Nop in case T is not copyable - } - set_size_internal(n); - ValueInit(nullptr, n, data()); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) { - InitRep(); - if (n > capacity()) { - Grow<Nop>(n); // Can use Nop since we know we have nothing to copy - } - set_size_internal(n); - Fill(&elem, n, data()); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) { - InitRep(); - *this = v; -} - -template <typename T, int N> -typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert( - iterator pos, const value_type& v) { - DCHECK_GE(pos, begin()); - DCHECK_LE(pos, end()); - if (pos == end()) { - push_back(v); - return end() - 1; - } - size_t s = size(); - size_t idx = std::distance(begin(), pos); - if (s == capacity()) { - Grow<Move>(s + 1); - } - CHECK_LT(s, capacity()); - pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator. - Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1] - std::copy_backward(pos, data() + s - 1, data() + s); - *pos = v; - - set_size_internal(s + 1); - return pos; -} - -template <typename T, int N> -typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase( - iterator first, iterator last) { - DCHECK_LE(begin(), first); - DCHECK_LE(first, last); - DCHECK_LE(last, end()); - - size_t s = size(); - ptrdiff_t erase_gap = std::distance(first, last); - std::copy(last, data() + s, first); - Destroy(data() + s - erase_gap, erase_gap); - set_size_internal(s - erase_gap); - return first; -} - -template <typename T, int N> -void InlinedVector<T, N>::swap(InlinedVector& other) { - using std::swap; // Augment ADL with std::swap. - if (&other == this) { - return; - } - - InlinedVector* a = this; - InlinedVector* b = &other; - - const bool a_inline = a->is_inline(); - const bool b_inline = b->is_inline(); - - if (!a_inline && !b_inline) { - // Just swap the top-level representations. - T* aptr = a->outofline_pointer(); - T* bptr = b->outofline_pointer(); - a->set_outofline_pointer(bptr); - b->set_outofline_pointer(aptr); - - uint64_t aword = a->outofline_word(); - uint64_t bword = b->outofline_word(); - a->set_outofline_word(bword); - b->set_outofline_word(aword); - return; - } - - // Make a the larger of the two to reduce number of cases. - size_t a_size = a->size(); - size_t b_size = b->size(); - if (a->size() < b->size()) { - swap(a, b); - swap(a_size, b_size); - } - DCHECK_GE(a_size, b_size); - - if (b->capacity() < a_size) { - b->Grow<Move>(a_size); - } - - // One is inline and one is not. - // 'a' is larger. Swap the elements up to the smaller array size. - std::swap_ranges(a->data(), a->data() + b_size, b->data()); - std::uninitialized_copy(a->data() + b_size, a->data() + a_size, - b->data() + b_size); - Destroy(a->data() + b_size, a_size - b_size); - a->set_size_internal(b_size); - b->set_size_internal(a_size); - DCHECK_EQ(b->size(), a_size); - DCHECK_EQ(a->size(), b_size); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last, - std::input_iterator_tag) { - std::copy(first, last, std::back_inserter(*this)); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last, - std::forward_iterator_tag) { - typedef typename std::iterator_traits<Iter>::difference_type Length; - Length length = std::distance(first, last); - size_t s = size(); - reserve(s + length); - std::uninitialized_copy_n(first, length, data() + s); - set_size_internal(s + length); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) { - typedef typename std::iterator_traits<Iter>::iterator_category IterTag; - AppendRange(first, last, IterTag()); -} +using absl::InlinedVector; } // namespace gtl } // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc deleted file mode 100644 index 2721885c4a..0000000000 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ /dev/null @@ -1,898 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/gtl/inlined_vector.h" - -#include <list> -#include <memory> -#include <string> -#include <vector> - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -typedef tensorflow::gtl::InlinedVector<int, 8> IntVec; - -// A type that counts number of live occurrences of the type -static int64 instances = 0; -class Instance { - public: - int value_; - explicit Instance(int x) : value_(x) { instances++; } - Instance(const Instance& x) : value_(x.value_) { instances++; } - ~Instance() { instances--; } - - friend inline void swap(Instance& a, Instance& b) { - using std::swap; - swap(a.value_, b.value_); - } - - friend std::ostream& operator<<(std::ostream& o, const Instance& v) { - return o << "[value:" << v.value_ << "]"; - } -}; - -typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec; - -// A simple reference counted class to make sure that the proper elements are -// destroyed in the erase(begin, end) test. -class RefCounted { - public: - RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); } - - RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) { - VLOG(5) << "[RefCounted: copy" - << " from count @" << v.count_ << "]"; - Ref(); - } - - ~RefCounted() { - Unref(); - count_ = nullptr; - } - - friend void swap(RefCounted& a, RefCounted& b) { - using std::swap; - swap(a.value_, b.value_); - swap(a.count_, b.count_); - } - - RefCounted& operator=(RefCounted v) { - using std::swap; - swap(*this, v); - return *this; - } - - void Ref() const { - CHECK(count_ != nullptr); - ++(*count_); - VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - void Unref() const { - --(*count_); - CHECK_GE(*count_, 0); - VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - int count() const { return *count_; } - - friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) { - return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]"; - } - - int value_; - int* count_; -}; - -typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec; - -// A class with a vtable pointer -class Dynamic { - public: - virtual ~Dynamic() {} - - friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) { - return o << "[Dynamic]"; - } -}; - -typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec; - -// Append 0..len-1 to *v -static void Fill(IntVec* v, int len, int offset = 0) { - for (int i = 0; i < len; i++) { - v->push_back(i + offset); - } -} - -static IntVec Fill(int len, int offset = 0) { - IntVec v; - Fill(&v, len, offset); - return v; -} - -TEST(IntVec, SimpleOps) { - for (int len = 0; len < 20; len++) { - IntVec v; - const IntVec& cv = v; // const alias - - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - EXPECT_EQ(v.begin(), v.data()); - EXPECT_EQ(cv.begin(), cv.data()); - - int counter = 0; - for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - counter = 0; - for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - if (len > 0) { - EXPECT_EQ(0, v.front()); - EXPECT_EQ(len - 1, v.back()); - v.pop_back(); - EXPECT_EQ(len - 1, v.size()); - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, Erase) { - for (int len = 1; len < 20; len++) { - for (int i = 0; i < len; ++i) { - IntVec v; - Fill(&v, len); - v.erase(v.begin() + i); - EXPECT_EQ(len - 1, v.size()); - for (int j = 0; j < i; ++j) { - EXPECT_EQ(j, v[j]); - } - for (int j = i; j < len - 1; ++j) { - EXPECT_EQ(j + 1, v[j]); - } - } - } -} - -// At the end of this test loop, the elements between [erase_begin, erase_end) -// should have reference counts == 0, and all others elements should have -// reference counts == 1. -TEST(RefCountedVec, EraseBeginEnd) { - for (int len = 1; len < 20; ++len) { - for (int erase_begin = 0; erase_begin < len; ++erase_begin) { - for (int erase_end = erase_begin; erase_end <= len; ++erase_end) { - std::vector<int> counts(len, 0); - RefCountedVec v; - for (int i = 0; i < len; ++i) { - v.push_back(RefCounted(i, &counts[i])); - } - - int erase_len = erase_end - erase_begin; - - v.erase(v.begin() + erase_begin, v.begin() + erase_end); - - EXPECT_EQ(len - erase_len, v.size()); - - // Check the elements before the first element erased. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(i, v[i].value_); - } - - // Check the elements after the first element erased. - for (size_t i = erase_begin; i < v.size(); ++i) { - EXPECT_EQ(i + erase_len, v[i].value_); - } - - // Check that the elements at the beginning are preserved. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(1, counts[i]); - } - - // Check that the erased elements are destroyed - for (int i = erase_begin; i < erase_end; ++i) { - EXPECT_EQ(0, counts[i]); - } - - // Check that the elements at the end are preserved. - for (int i = erase_end; i < len; ++i) { - EXPECT_EQ(1, counts[i]); - } - } - } - } -} - -struct NoDefaultCtor { - explicit NoDefaultCtor(int) {} -}; -struct NoCopy { - NoCopy() {} - NoCopy(const NoCopy&) = delete; -}; -struct NoAssign { - NoAssign() {} - NoAssign& operator=(const NoAssign&) = delete; -}; -struct MoveOnly { - MoveOnly() {} - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; -}; -TEST(InlinedVectorTest, NoDefaultCtor) { - tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2)); - (void)v; -} -TEST(InlinedVectorTest, NoCopy) { - tensorflow::gtl::InlinedVector<NoCopy, 1> v(10); - (void)v; -} -TEST(InlinedVectorTest, NoAssign) { - tensorflow::gtl::InlinedVector<NoAssign, 1> v(10); - (void)v; -} -TEST(InlinedVectorTest, MoveOnly) { - gtl::InlinedVector<MoveOnly, 2> v; - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); -} - -TEST(IntVec, Insert) { - for (int len = 0; len < 20; len++) { - for (int pos = 0; pos <= len; pos++) { - IntVec v; - Fill(&v, len); - v.insert(v.begin() + pos, 9999); - EXPECT_EQ(v.size(), len + 1); - for (int i = 0; i < pos; i++) { - EXPECT_EQ(v[i], i); - } - EXPECT_EQ(v[pos], 9999); - for (size_t i = pos + 1; i < v.size(); i++) { - EXPECT_EQ(v[i], i - 1); - } - } - } -} - -TEST(RefCountedVec, InsertConstructorDestructor) { - // Make sure the proper construction/destruction happen during insert - // operations. - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - for (int pos = 0; pos <= len; pos++) { - SCOPED_TRACE(pos); - std::vector<int> counts(len, 0); - int inserted_count = 0; - RefCountedVec v; - for (int i = 0; i < len; ++i) { - SCOPED_TRACE(i); - v.push_back(RefCounted(i, &counts[i])); - } - - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - - RefCounted insert_element(9999, &inserted_count); - EXPECT_EQ(1, inserted_count); - v.insert(v.begin() + pos, insert_element); - EXPECT_EQ(2, inserted_count); - // Check that the elements at the end are preserved. - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - EXPECT_EQ(2, inserted_count); - } - } -} - -TEST(IntVec, Resize) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - // Try resizing up and down by k elements - static const int kResizeElem = 1000000; - for (int k = 0; k < 10; k++) { - // Enlarging resize - v.resize(len + k, kResizeElem); - EXPECT_EQ(len + k, v.size()); - EXPECT_LE(len + k, v.capacity()); - for (int i = 0; i < len + k; i++) { - if (i < len) { - EXPECT_EQ(i, v[i]); - } else { - EXPECT_EQ(kResizeElem, v[i]); - } - } - - // Shrinking resize - v.resize(len, kResizeElem); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, InitWithLength) { - for (int len = 0; len < 20; len++) { - IntVec v(len, 7); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(7, v[i]); - } - } -} - -TEST(IntVec, CopyConstructorAndAssignment) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - IntVec v2(v); - EXPECT_EQ(v, v2); - - for (int start_len = 0; start_len < 20; start_len++) { - IntVec v3; - Fill(&v3, start_len, 99); // Add dummy elements that should go away - v3 = v; - EXPECT_EQ(v, v3); - } - } -} - -TEST(OverheadTest, Storage) { - // Check for size overhead. - using tensorflow::gtl::InlinedVector; - EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>)); - EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>)); - - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>)); - EXPECT_EQ(2 * sizeof(char*), - sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>)); - EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>)); -} - -TEST(IntVec, Clear) { - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - IntVec v; - Fill(&v, len); - v.clear(); - EXPECT_EQ(0, v.size()); - EXPECT_EQ(v.begin(), v.end()); - } -} - -TEST(IntVec, Reserve) { - for (size_t len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - for (size_t newlen = 0; newlen < 100; newlen++) { - const int* start_rep = v.data(); - v.reserve(newlen); - const int* final_rep = v.data(); - if (newlen <= len) { - EXPECT_EQ(start_rep, final_rep); - } - EXPECT_LE(newlen, v.capacity()); - - // Filling up to newlen should not change rep - while (v.size() < newlen) { - v.push_back(0); - } - EXPECT_EQ(final_rep, v.data()); - } - } -} - -template <typename T> -static std::vector<typename T::value_type> Vec(const T& src) { - std::vector<typename T::value_type> result; - for (const auto& elem : src) { - result.push_back(elem); - } - return result; -} - -TEST(IntVec, SelfRefPushBack) { - std::vector<string> std_v; - tensorflow::gtl::InlinedVector<string, 4> v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(std_v, Vec(v)); - - v.push_back(v.back()); - std_v.push_back(std_v.back()); - } - EXPECT_EQ(std_v, Vec(v)); -} - -TEST(IntVec, SelfRefPushBackWithMove) { - std::vector<string> std_v; - gtl::InlinedVector<string, 4> v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(v.back(), std_v.back()); - - v.push_back(std::move(v.back())); - std_v.push_back(std::move(std_v.back())); - } - EXPECT_EQ(v.back(), std_v.back()); -} - -TEST(IntVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - SCOPED_TRACE(l1); - for (int l2 = 0; l2 < 20; l2++) { - SCOPED_TRACE(l2); - IntVec a = Fill(l1, 0); - IntVec b = Fill(l2, 100); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(i, b[i]); - } - for (int i = 0; i < l2; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(100 + i, a[i]); - } - } - } -} - -TEST(InstanceVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - for (int l2 = 0; l2 < 20; l2++) { - InstanceVec a, b; - for (int i = 0; i < l1; i++) a.push_back(Instance(i)); - for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i)); - EXPECT_EQ(l1 + l2, instances); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1 + l2, instances); - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - EXPECT_EQ(i, b[i].value_); - } - for (int i = 0; i < l2; i++) { - EXPECT_EQ(100 + i, a[i].value_); - } - } - } -} - -TEST(IntVec, EqualAndNotEqual) { - IntVec a, b; - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - a.push_back(3); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b.push_back(3); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b.push_back(7); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.push_back(6); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.clear(); - b.clear(); - for (int i = 0; i < 100; i++) { - a.push_back(i); - b.push_back(i); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b[i] = b[i] + 1; - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b[i] = b[i] - 1; // Back to before - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - } -} - -TEST(IntVec, RelationalOps) { - IntVec a, b; - EXPECT_FALSE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_FALSE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_TRUE(b <= a); - EXPECT_TRUE(a >= b); - EXPECT_TRUE(b >= a); - b.push_back(3); - EXPECT_TRUE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_TRUE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_FALSE(b <= a); - EXPECT_FALSE(a >= b); - EXPECT_TRUE(b >= a); -} - -TEST(InstanceVec, CountConstructorsDestructors) { - const int start = instances; - for (int len = 0; len < 20; len++) { - InstanceVec v; - for (int i = 0; i < len; i++) { - v.push_back(Instance(i)); - } - EXPECT_EQ(start + len, instances); - - { // Copy constructor should create 'len' more instances. - InstanceVec v_copy(v); - EXPECT_EQ(start + len + len, instances); - } - EXPECT_EQ(start + len, instances); - - // Enlarging resize() must construct some objects - v.resize(len + 10, Instance(100)); - EXPECT_EQ(start + len + 10, instances); - - // Shrinking resize() must destroy some objects - v.resize(len, Instance(100)); - EXPECT_EQ(start + len, instances); - - // reserve() must not increase the number of initialized objects - v.reserve(len + 1000); - EXPECT_EQ(start + len, instances); - - // pop_back() and erase() must destroy one object - if (len > 0) { - v.pop_back(); - EXPECT_EQ(start + len - 1, instances); - if (!v.empty()) { - v.erase(v.begin()); - EXPECT_EQ(start + len - 2, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) { - const int start = instances; - for (int len = 0; len < 20; len++) { - for (int longorshort = 0; longorshort <= 1; ++longorshort) { - InstanceVec longer, shorter; - for (int i = 0; i < len; i++) { - longer.push_back(Instance(i)); - shorter.push_back(Instance(i)); - } - longer.push_back(Instance(len)); - EXPECT_EQ(start + len + len + 1, instances); - - if (longorshort) { - shorter = longer; - EXPECT_EQ(start + (len + 1) + (len + 1), instances); - } else { - longer = shorter; - EXPECT_EQ(start + len + len, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(RangedConstructor, SimpleType) { - std::vector<int> source_v = {4, 5, 6, 7}; - // First try to fit in inline backing - tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end()); - tensorflow::gtl::InlinedVector<int, 4> empty4; - EXPECT_EQ(4, v.size()); - EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(4, v[0]); - EXPECT_EQ(5, v[1]); - EXPECT_EQ(6, v[2]); - EXPECT_EQ(7, v[3]); - - // Now, force a re-allocate - tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector<int, 2> empty2; - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty2.capacity(), realloc_v.capacity()); - EXPECT_EQ(4, realloc_v[0]); - EXPECT_EQ(5, realloc_v[1]); - EXPECT_EQ(6, realloc_v[2]); - EXPECT_EQ(7, realloc_v[3]); -} - -TEST(RangedConstructor, ComplexType) { - // We also use a list here to pass a different flavor of iterator (e.g. not - // random-access). - std::list<Instance> source_v = {Instance(0)}; - - // First try to fit in inline backing - tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector<Instance, 1> empty1; - EXPECT_EQ(1, v.size()); - EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(0, v[0].value_); - - std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2), - Instance(3)}; - // Now, force a re-allocate - tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(), - source_v2.end()); - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty1.capacity(), realloc_v.capacity()); - EXPECT_EQ(0, realloc_v[0].value_); - EXPECT_EQ(1, realloc_v[1].value_); - EXPECT_EQ(2, realloc_v[2].value_); - EXPECT_EQ(3, realloc_v[3].value_); -} - -TEST(RangedConstructor, ElementsAreConstructed) { - std::vector<string> source_v = {"cat", "dog"}; - - // Force expansion and re-allocation of v. Ensures that when the vector is - // expanded that new elements are constructed. - tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end()); - EXPECT_EQ("cat", v[0]); - EXPECT_EQ("dog", v[1]); -} - -TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) { - auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_EQ(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) { - auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_LE(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, DisparateTypesInList) { - EXPECT_EQ((std::vector<int>{-7, 8}), - Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL})); - - EXPECT_EQ( - (std::vector<string>{"foo", "bar"}), - Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")})); -} - -TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) { - tensorflow::gtl::InlinedVector<Instance, 1> empty; - auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)}; - EXPECT_EQ(1, vec.size()); - EXPECT_EQ(empty.capacity(), vec.capacity()); - EXPECT_EQ(0, vec[0].value_); -} - -TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) { - auto vec = - tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)}; - EXPECT_EQ(2, vec.size()); - EXPECT_LE(2, vec.capacity()); - EXPECT_EQ(0, vec[0].value_); - EXPECT_EQ(1, vec[1].value_); -} - -TEST(DynamicVec, DynamicVecCompiles) { - DynamicVec v; - (void)v; -} - -static void BM_InlinedVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - IntVec v; - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024); - -static void BM_InlinedVectorFillRange(int iters, int len) { - std::unique_ptr<int[]> ia(new int[len]); - for (int j = 0; j < len; j++) { - ia[j] = j; - } - for (int i = 0; i < iters; i++) { - IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len); - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); - -static void BM_StdVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - std::vector<int> v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_StdVectorFill)->Range(0, 1024); - -bool StringRepresentedInline(string s) { - const char* chars = s.data(); - string s1 = std::move(s); - return s1.data() != chars; -} - -static void BM_InlinedVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - gtl::InlinedVector<string, 8> v; - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); -} -BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024); - -static void BM_StdVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - std::vector<string> v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); - // The purpose of the benchmark is to verify that inlined vector is - // efficient when moving is more efficient than copying. To do so, we - // use strings that are larger than the small string optimization. - CHECK(!StringRepresentedInline(strings[0])); -} -BENCHMARK(BM_StdVectorFillString)->Range(0, 1024); - -namespace { -struct Buffer { // some arbitrary structure for benchmarking. - char* base; - int length; - int capacity; - void* user_data; -}; -} // anonymous namespace - -static void BM_InlinedVectorTenAssignments(int iters, int len) { - typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec; - - BufferVec src; - src.resize(len); - - iters *= 10; - BufferVec dst; - for (int i = 0; i < iters; i++) { - dst = src; - } -} -BENCHMARK(BM_InlinedVectorTenAssignments) - ->Arg(0) - ->Arg(1) - ->Arg(2) - ->Arg(3) - ->Arg(4) - ->Arg(20); - -static void BM_CreateFromInitializerList(int iters) { - for (; iters > 0; iters--) { - tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3}; - (void)x[0]; - } -} -BENCHMARK(BM_CreateFromInitializerList); - -namespace { - -struct LargeSwappable { - LargeSwappable() : d_(1024, 17) {} - ~LargeSwappable() {} - LargeSwappable(const LargeSwappable& o) : d_(o.d_) {} - - friend void swap(LargeSwappable& a, LargeSwappable& b) { - using std::swap; - swap(a.d_, b.d_); - } - - LargeSwappable& operator=(LargeSwappable o) { - using std::swap; - swap(*this, o); - return *this; - } - - std::vector<int> d_; -}; - -} // namespace - -static void BM_LargeSwappableElements(int iters, int len) { - typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec; - Vec a(len); - Vec b; - while (--iters >= 0) { - using std::swap; - swap(a, b); - } -} -BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024); - -} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h index e2927689d2..117b6a0bb8 100644 --- a/tensorflow/core/lib/io/block_builder.h +++ b/tensorflow/core/lib/io/block_builder.h @@ -20,6 +20,7 @@ limitations under the License. #include <stdint.h> #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h index e3649fd0c9..38fb0c5d86 100644 --- a/tensorflow/core/lib/io/path.h +++ b/tensorflow/core/lib/io/path.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_LIB_IO_PATH_H_ #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index c24628be57..e22adcd569 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -108,10 +108,60 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) { return Status::OK(); } -Status RecordReader::ReadRecord(uint64* offset, string* record) { - static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); - static const size_t kFooterSize = sizeof(uint32); +Status RecordReader::GetMetadata(Metadata* md) { + if (!md) { + return errors::InvalidArgument( + "Metadata object call to GetMetadata() was null"); + } + + // Compute the metadata of the TFRecord file if not cached. + if (!cached_metadata_) { + TF_RETURN_IF_ERROR(input_stream_->Reset()); + + int64 data_size = 0; + int64 entries = 0; + + // Within the loop, we always increment offset positively, so this + // loop should be guaranteed to either return after reaching EOF + // or encountering an error. + uint64 offset = 0; + string record; + while (true) { + // Read header, containing size of data. + Status s = ReadChecksummed(offset, sizeof(uint64), &record); + if (!s.ok()) { + if (errors::IsOutOfRange(s)) { + // We should reach out of range when the record file is complete. + break; + } + return s; + } + + // Read the length of the data. + const uint64 length = core::DecodeFixed64(record.data()); + + // Skip reading the actual data since we just want the number + // of records and the size of the data. + TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize)); + offset += kHeaderSize + length + kFooterSize; + + // Increment running stats. + data_size += length; + ++entries; + } + cached_metadata_.reset(new Metadata()); + cached_metadata_->stats.entries = entries; + cached_metadata_->stats.data_size = data_size; + cached_metadata_->stats.file_size = + data_size + (kHeaderSize + kFooterSize) * entries; + } + + md->stats = cached_metadata_->stats; + return Status::OK(); +} + +Status RecordReader::ReadRecord(uint64* offset, string* record) { // Position the input stream. int64 curr_pos = input_stream_->Tell(); int64 desired_pos = static_cast<int64>(*offset); diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index c05f9e1b36..17444660d4 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -58,6 +58,26 @@ class RecordReaderOptions { // Note: this class is not thread safe; external synchronization required. class RecordReader { public: + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + + // Statistics (sizes are in units of bytes) + struct Stats { + int64 file_size = -1; + int64 data_size = -1; + int64 entries = -1; // Number of values + }; + + // Metadata for the TFRecord file. + struct Metadata { + Stats stats; + }; + // Create a reader that will return log records from "*file". // "*file" must remain live while this Reader is in use. explicit RecordReader( @@ -71,6 +91,17 @@ class RecordReader { // OUT_OF_RANGE for end of file, or something else for an error. Status ReadRecord(uint64* offset, string* record); + // Return the metadata of the Record file. + // + // The current implementation scans the file to completion, + // skipping over the data regions, to extract the metadata once + // on the first call to GetStats(). An improved implementation + // would change RecordWriter to write the metadata into TFRecord + // so that GetMetadata() could be a const method. + // + // 'metadata' must not be nullptr. + Status GetMetadata(Metadata* md); + private: Status ReadChecksummed(uint64 offset, size_t n, string* result); @@ -78,6 +109,8 @@ class RecordReader { std::unique_ptr<InputStreamInterface> input_stream_; bool last_read_failed_; + std::unique_ptr<Metadata> cached_metadata_; + TF_DISALLOW_COPY_AND_ASSIGN(RecordReader); }; diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc index 13bea1f8f1..a88d34d293 100644 --- a/tensorflow/core/lib/io/record_reader_writer_test.cc +++ b/tensorflow/core/lib/io/record_reader_writer_test.cc @@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) { EXPECT_EQ("abc", record); TF_CHECK_OK(reader.ReadRecord(&offset, &record)); EXPECT_EQ("defg", record); + + io::RecordReader::Metadata md; + TF_ASSERT_OK(reader.GetMetadata(&md)); + EXPECT_EQ(2, md.stats.entries); + EXPECT_EQ(7, md.stats.data_size); + // Two entries have 16 bytes of header/footer each. + EXPECT_EQ(39, md.stats.file_size); } } } diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 6e71d23e71..2c6db2487e 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() { } } -static uint32 MaskedCrc(const char* data, size_t n) { - return crc32c::Mask(crc32c::Value(data, n)); -} - Status RecordWriter::WriteRecord(StringPiece data) { if (dest_ == nullptr) { return Status(::tensorflow::error::FAILED_PRECONDITION, @@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) { // uint32 masked crc of length // byte data[length] // uint32 masked crc of data - char header[sizeof(uint64) + sizeof(uint32)]; - core::EncodeFixed64(header + 0, data.size()); - core::EncodeFixed32(header + sizeof(uint64), - MaskedCrc(header, sizeof(uint64))); - char footer[sizeof(uint32)]; - core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); - + char header[kHeaderSize]; + char footer[kFooterSize]; + PopulateHeader(header, data.data(), data.size()); + PopulateFooter(footer, data.data(), data.size()); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); TF_RETURN_IF_ERROR(dest_->Append(data)); return dest_->Append(StringPiece(footer, sizeof(footer))); diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 2f6afa5487..1212e1fafb 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ #define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ +#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/crc32c.h" #if !defined(IS_SLIM_BUILD) #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" @@ -41,12 +43,20 @@ class RecordWriterOptions { // Options specific to zlib compression. #if !defined(IS_SLIM_BUILD) - ZlibCompressionOptions zlib_options; + tensorflow::io::ZlibCompressionOptions zlib_options; #endif // IS_SLIM_BUILD }; class RecordWriter { public: + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + // Create a writer that will append data to "*dest". // "*dest" must be initially empty. // "*dest" must remain live while this Writer is in use. @@ -72,13 +82,35 @@ class RecordWriter { // are invalid. Status Close(); + // Utility method to populate TFRecord headers. Populates record-header in + // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1]. + inline static void PopulateHeader(char* header, const char* data, size_t n); + + // Utility method to populate TFRecord footers. Populates record-footer in + // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1]. + inline static void PopulateFooter(char* footer, const char* data, size_t n); + private: WritableFile* dest_; RecordWriterOptions options_; + inline static uint32 MaskedCrc(const char* data, size_t n) { + return crc32c::Mask(crc32c::Value(data, n)); + } + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); }; +void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) { + core::EncodeFixed64(header + 0, n); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); +} + +void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) { + core::EncodeFixed32(footer, MaskedCrc(data, n)); +} + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc index da514bd21c..946d7188d3 100644 --- a/tensorflow/core/lib/io/recordio_test.cc +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -58,7 +58,7 @@ class StringDest : public WritableFile { Status Close() override { return Status::OK(); } Status Flush() override { return Status::OK(); } Status Sync() override { return Status::OK(); } - Status Append(const StringPiece& slice) override { + Status Append(StringPiece slice) override { contents_->append(slice.data(), slice.size()); return Status::OK(); } diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc index 877ac40f1c..9cebbf40c6 100644 --- a/tensorflow/core/lib/io/table_test.cc +++ b/tensorflow/core/lib/io/table_test.cc @@ -98,7 +98,7 @@ class StringSink : public WritableFile { Status Flush() override { return Status::OK(); } Status Sync() override { return Status::OK(); } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { contents_.append(data.data(), data.size()); return Status::OK(); } diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc index 84b47c171f..cba139e6ad 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc @@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() { return Status::OK(); } -Status ZlibOutputBuffer::Append(const StringPiece& data) { +Status ZlibOutputBuffer::Append(StringPiece data) { // If there is sufficient free space in z_stream_input_ to fit data we // add it there and return. // If there isn't enough space we deflate the existing contents of diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index 3d86d89a99..ccad2fda44 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -62,7 +62,7 @@ class ZlibOutputBuffer : public WritableFile { // to file when the buffer is full. // // To immediately write contents to file call `Flush()`. - Status Append(const StringPiece& data) override; + Status Append(StringPiece data) override; // Deflates any cached input and writes all output to file. Status Flush() override; diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc index 50ed8bdb3b..f7a359eb5b 100644 --- a/tensorflow/core/lib/jpeg/jpeg_mem.cc +++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc @@ -152,7 +152,9 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { cinfo.scale_denom = ratio; cinfo.dct_method = flags.dct_method; - jpeg_start_decompress(&cinfo); + // Determine the output image size before attempting decompress to prevent + // OOM'ing doing the decompress + jpeg_calc_output_dimensions(&cinfo); int64 total_size = static_cast<int64>(cinfo.output_height) * static_cast<int64>(cinfo.output_width); @@ -170,6 +172,8 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) { return nullptr; } + jpeg_start_decompress(&cinfo); + JDIMENSION target_output_width = cinfo.output_width; JDIMENSION target_output_height = cinfo.output_height; JDIMENSION skipped_scanlines = 0; diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index c204d52cfe..9e4e1989dd 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace monitoring { diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index 756e5c2af8..bc4365e439 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace monitoring { diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h index bb5d20fb68..c876c5156a 100644 --- a/tensorflow/core/lib/png/png_io.h +++ b/tensorflow/core/lib/png/png_io.h @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/png.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace png { diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 36d939e061..c536b5688e 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -232,6 +232,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string, "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format); } TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset)); + if (*channel_count < 1) { + return errors::InvalidArgument( + "Bad number of channels for WAV: Expected at least 1, but got ", + *channel_count); + } TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset)); uint32 bytes_per_second; TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset)); diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 15cafaddf4..c9f80df5e4 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -2503,6 +2503,116 @@ REGISTER_OP("ExtractImagePatches") // -------------------------------------------------------------------------- +// To enable rates, uncomment all lines commented below and use ksize_*_eff +// as the second parameter of all GetWindowedOutputSizeVerbose calls instead +// of ksize_*. +REGISTER_OP("ExtractVolumePatches") + .Input("input: T") + .Output("patches: T") + .Attr("ksizes: list(int) >= 5") + .Attr("strides: list(int) >= 5") + /* .Attr("rates: list(int) >= 5") */ + .Attr("T: realnumbertype") + .Attr(GetPaddingAttrString()) + .SetShapeFn([](InferenceContext* c) { + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); + + std::vector<int32> ksizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes)); + if (ksizes.size() != 5) { + return errors::InvalidArgument( + "ExtractVolumePatches requires the ksizes attribute to contain 5 " + "values, but got: ", + ksizes.size()); + } + + std::vector<int32> strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 5) { + return errors::InvalidArgument( + "ExtractVolumePatches requires the stride attribute to contain 5 " + "values, but got: ", + strides.size()); + } + + /* + // TODO(hsgkim): Enable rates. + // See extract_volume_patches_op.cc for why rates are disabled now. + + std::vector<int32> rates; + TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates)); + if (rates.size() != 5) { + return errors::InvalidArgument( + "ExtractVolumePatches requires the rates attribute to contain 5 " + "values, but got: ", + rates.size()); + } + */ + + int32 ksize_planes = ksizes[1]; + int32 ksize_rows = ksizes[2]; + int32 ksize_cols = ksizes[3]; + + int32 stride_planes = strides[1]; + int32 stride_rows = strides[2]; + int32 stride_cols = strides[3]; + + /* + int32 rate_planes = rates[1]; + int32 rate_rows = rates[2]; + int32 rate_cols = rates[3]; + + int32 ksize_planes_eff = ksize_planes + + (ksize_planes - 1) * (rate_planes - 1); + int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1); + int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1); + */ + + DimensionHandle batch_size_dim = c->Dim(input_shape, 0); + DimensionHandle in_planes_dim = c->Dim(input_shape, 1); + DimensionHandle in_rows_dim = c->Dim(input_shape, 2); + DimensionHandle in_cols_dim = c->Dim(input_shape, 3); + DimensionHandle output_depth_dim; + TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4), + ksize_planes * ksize_rows * ksize_cols, + &output_depth_dim)); + + if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || + !c->ValueKnown(in_cols_dim)) { + ShapeHandle output_shape = + c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, + InferenceContext::kUnknownDim, output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); + } + auto in_planes = c->Value(in_planes_dim); + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + int64 output_planes, output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_planes, ksize_planes, stride_planes, padding, &output_planes, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, ksize_rows, stride_rows, padding, &output_rows, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, ksize_cols, stride_cols, padding, &output_cols, + &padding_before, &padding_after)); + ShapeHandle output_shape = + c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, + output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); + }); + +// -------------------------------------------------------------------------- + REGISTER_OP("Bitcast") .Input("input: T") .Output("output: type") @@ -2824,6 +2934,34 @@ Status ScatterNdShape(InferenceContext* c) { } // namespace +REGISTER_OP("UpperBound") + .Input("sorted_inputs: T") + .Input("values: T") + .Output("output: out_type") + .Attr("T: type") + .Attr("out_type: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); + c->set_output(0, c->input(1)); + return Status::OK(); + }); + +REGISTER_OP("LowerBound") + .Input("sorted_inputs: T") + .Input("values: T") + .Output("output: out_type") + .Attr("T: type") + .Attr("out_type: {int32, int64} = DT_INT32") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape)); + c->set_output(0, c->input(1)); + return Status::OK(); + }); + REGISTER_OP("ScatterNd") .Input("indices: Tindices") .Input("updates: T") diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 01452b3e85..b8cf538554 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -22,6 +22,10 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource); REGISTER_OP("IsBoostedTreesEnsembleInitialized") @@ -176,6 +180,8 @@ REGISTER_OP("BoostedTreesMakeStatsSummary") return Status::OK(); }); +// TODO(nponomareva): when/if creating the new op for unbucketized data, rename +// bucketized_features to features. REGISTER_OP("BoostedTreesPredict") .Input("tree_ensemble_handle: resource") .Input("bucketized_features: num_bucketized_features * int32") @@ -354,4 +360,125 @@ REGISTER_OP("BoostedTreesCenterBias") return Status::OK(); }); +REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource); + +REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized") + .Input("quantile_stream_resource_handle: resource") + .Output("is_initialized: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesCreateQuantileStreamResource") + .Attr("max_elements: int = 1099511627776") // 1 << 40 + .Input("quantile_stream_resource_handle: resource") + .Input("epsilon: float") + .Input("num_streams: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesMakeQuantileSummaries") + .Attr("num_features: int >= 0") + .Input("float_values: num_features * float") + .Input("example_weights: float") + .Input("epsilon: float") + .Output("summaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + ShapeHandle example_weights_shape; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(num_features), 1, &example_weights_shape)); + for (int i = 0; i < num_features; ++i) { + ShapeHandle feature_shape; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(example_weights_shape, 0), + &unused_dim)); + // the columns are value, weight, min_rank, max_rank. + c->set_output(i, c->MakeShape({c->UnknownDim(), 4})); + } + // epsilon must be a scalar. + ShapeHandle unused_input; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(num_features + 1), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries") + .Attr("num_features: int >= 0") + .Input("quantile_stream_resource_handle: resource") + .Input("summaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + // resource handle must be a scalar. + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + // each summary must be rank 2. + for (int i = 1; i < num_features + 1; i++) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input)); + } + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceFlush") + .Attr("generate_quantiles: bool = False") + .Input("quantile_stream_resource_handle: resource") + .Input("num_buckets: int64") + .SetShapeFn([](InferenceContext* c) { + // All the inputs are scalars. + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries") + .Attr("num_features: int >= 0") + .Input("quantile_stream_resource_handle: resource") + .Output("bucket_boundaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + shape_inference::ShapeHandle unused_input; + // resource handle must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + for (int i = 0; i < num_features; i++) { + c->set_output(i, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesBucketize") + .Attr("num_features: int >= 0") + .Input("float_values: num_features * float") + .Input("bucket_boundaries: num_features * float") + .Output("buckets: num_features * int32") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + ShapeHandle feature_shape; + DimensionHandle unused_dim; + for (int i = 0; i < num_features; i++) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(c->input(0), 0), &unused_dim)); + } + // Bucketized result should have same dimension as input. + for (int i = 0; i < num_features; i++) { + c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1})); + } + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index cb0cb46752..cac4259356 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -11360,6 +11360,29 @@ op { is_commutative: true } op { + name: "BoostedTreesBucketize" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + output_arg { + name: "buckets" + type: DT_INT32 + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesCalculateBestGainsPerFeature" input_arg { name: "node_id_range" @@ -11469,6 +11492,29 @@ op { is_stateful: true } op { + name: "BoostedTreesCreateQuantileStreamResource" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + input_arg { + name: "num_streams" + type: DT_INT64 + } + attr { + name: "max_elements" + type: "int" + default_value { + i: 1099511627776 + } + } + is_stateful: true +} +op { name: "BoostedTreesDeserializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -11562,6 +11608,32 @@ op { is_stateful: true } op { + name: "BoostedTreesMakeQuantileSummaries" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "example_weights" + type: DT_FLOAT + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + output_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesMakeStatsSummary" input_arg { name: "node_ids" @@ -11631,6 +11703,83 @@ op { is_stateful: true } op { + name: "BoostedTreesQuantileStreamResourceAddSummaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceFlush" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "num_buckets" + type: DT_INT64 + } + attr { + name: "generate_quantiles" + type: "bool" + default_value { + b: false + } + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceHandleOp" + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { name: "BoostedTreesSerializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -13070,6 +13219,71 @@ op { is_stateful: true } op { + name: "ConditionalAccumulator" + output_arg { + name: "handle" + type: DT_STRING + is_ref: true + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } + is_stateful: true +} +op { name: "Conj" input_arg { name: "input" @@ -21688,6 +21902,59 @@ op { } } op { + name: "ExtractVolumePatches" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "patches" + type_attr: "T" + } + attr { + name: "ksizes" + type: "list(int)" + has_minimum: true + minimum: 5 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 5 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } +} +op { name: "FFT" input_arg { name: "input" @@ -23838,6 +24105,85 @@ op { } } op { + name: "FusedBatchNorm" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "T" + } + input_arg { + name: "offset" + type_attr: "T" + } + input_arg { + name: "mean" + type_attr: "T" + } + input_arg { + name: "variance" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "batch_mean" + type_attr: "T" + } + output_arg { + name: "batch_variance" + type_attr: "T" + } + output_arg { + name: "reserve_space_1" + type_attr: "T" + } + output_arg { + name: "reserve_space_2" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { name: "FusedBatchNormGrad" input_arg { name: "y_backprop" @@ -23911,6 +24257,168 @@ op { } } op { + name: "FusedBatchNormGrad" + input_arg { + name: "y_backprop" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "T" + } + input_arg { + name: "reserve_space_1" + type_attr: "T" + } + input_arg { + name: "reserve_space_2" + type_attr: "T" + } + output_arg { + name: "x_backprop" + type_attr: "T" + } + output_arg { + name: "scale_backprop" + type_attr: "T" + } + output_arg { + name: "offset_backprop" + type_attr: "T" + } + output_arg { + name: "reserve_space_3" + type_attr: "T" + } + output_arg { + name: "reserve_space_4" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { + name: "FusedBatchNormGradV2" + input_arg { + name: "y_backprop" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type: DT_FLOAT + } + input_arg { + name: "reserve_space_1" + type_attr: "U" + } + input_arg { + name: "reserve_space_2" + type_attr: "U" + } + output_arg { + name: "x_backprop" + type_attr: "T" + } + output_arg { + name: "scale_backprop" + type_attr: "U" + } + output_arg { + name: "offset_backprop" + type_attr: "U" + } + output_arg { + name: "reserve_space_3" + type_attr: "U" + } + output_arg { + name: "reserve_space_4" + type_attr: "U" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { + name: "U" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { name: "FusedBatchNormGradV2" input_arg { name: "y_backprop" @@ -23958,6 +24466,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT } } @@ -24068,6 +24577,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -24245,6 +24760,96 @@ op { } } op { + name: "FusedBatchNormV2" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "U" + } + input_arg { + name: "offset" + type_attr: "U" + } + input_arg { + name: "mean" + type_attr: "U" + } + input_arg { + name: "variance" + type_attr: "U" + } + output_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "batch_mean" + type_attr: "U" + } + output_arg { + name: "batch_variance" + type_attr: "U" + } + output_arg { + name: "reserve_space_1" + type_attr: "U" + } + output_arg { + name: "reserve_space_2" + type_attr: "U" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + } + } + } + attr { + name: "U" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { name: "FusedPadConv2D" input_arg { name: "input" @@ -27127,6 +27732,18 @@ op { is_stateful: true } op { + name: "IsBoostedTreesQuantileStreamResourceInitialized" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "is_initialized" + type: DT_BOOL + } + is_stateful: true +} +op { name: "IsFinite" input_arg { name: "x" @@ -29162,6 +29779,38 @@ op { } } op { + name: "LowerBound" + input_arg { + name: "sorted_inputs" + type_attr: "T" + } + input_arg { + name: "values" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "MakeIterator" input_arg { name: "dataset" @@ -29381,6 +30030,49 @@ op { } } op { + name: "MapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "MapDefun" input_arg { name: "arguments" @@ -34842,6 +35534,29 @@ op { } } op { + name: "ModelDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "Mul" input_arg { name: "x" @@ -34949,6 +35664,134 @@ op { is_commutative: true } op { + name: "MultiDeviceIterator" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "devices" + type: "list(string)" + has_minimum: true + minimum: 1 + } + attr { + name: "shared_name" + type: "string" + } + attr { + name: "container" + type: "string" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorFromStringHandle" + input_arg { + name: "string_handle" + type: DT_STRING + } + output_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + attr { + name: "output_types" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { + name: "output_shapes" + type: "list(shape)" + default_value { + list { + } + } + has_minimum: true + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorGetNextFromShard" + input_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + input_arg { + name: "shard_num" + type: DT_INT32 + } + input_arg { + name: "incarnation_id" + type: DT_INT64 + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorInit" + input_arg { + name: "dataset" + type: DT_VARIANT + } + input_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + input_arg { + name: "max_buffer_size" + type: DT_INT64 + } + output_arg { + name: "incarnation_id" + type: DT_INT64 + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorToStringHandle" + input_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + output_arg { + name: "string_handle" + type: DT_STRING + } + is_stateful: true +} +op { name: "Multinomial" input_arg { name: "logits" @@ -35682,6 +36525,42 @@ op { } } op { + name: "NonMaxSuppressionV2" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } +} +op { name: "NonMaxSuppressionV3" input_arg { name: "boxes" @@ -35709,6 +36588,46 @@ op { } } op { + name: "NonMaxSuppressionV3" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } +} +op { name: "NonMaxSuppressionV4" input_arg { name: "boxes" @@ -35747,6 +36666,57 @@ op { } } op { + name: "NonMaxSuppressionV4" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + output_arg { + name: "valid_outputs" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { + name: "pad_to_max_output_size" + type: "bool" + default_value { + b: false + } + } +} +op { name: "NonMaxSuppressionWithOverlaps" input_arg { name: "overlaps" @@ -37037,6 +38007,54 @@ op { } } op { + name: "ParallelInterleaveDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -37118,6 +38136,53 @@ op { } } op { + name: "ParallelMapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "num_parallel_calls" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "ParameterizedTruncatedNormal" input_arg { name: "shape" @@ -38366,6 +39431,30 @@ op { is_stateful: true } op { + name: "PrintV2" + input_arg { + name: "input" + type: DT_STRING + } + attr { + name: "output_stream" + type: "string" + default_value { + s: "stderr" + } + allowed_values { + list { + s: "stdout" + s: "stderr" + s: "log(info)" + s: "log(warning)" + s: "log(error)" + } + } + } + is_stateful: true +} +op { name: "PriorityQueue" output_arg { name: "handle" @@ -59334,6 +60423,29 @@ op { } } op { + name: "Softplus" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "SoftplusGrad" input_arg { name: "gradients" @@ -59470,6 +60582,33 @@ op { } } op { + name: "SoftplusGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "Softsign" input_arg { name: "features" @@ -59590,6 +60729,29 @@ op { } } op { + name: "Softsign" + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "activations" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "SoftsignGrad" input_arg { name: "gradients" @@ -59726,6 +60888,33 @@ op { } } op { + name: "SoftsignGrad" + input_arg { + name: "gradients" + type_attr: "T" + } + input_arg { + name: "features" + type_attr: "T" + } + output_arg { + name: "backprops" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } +} +op { name: "SpaceToBatch" input_arg { name: "input" @@ -64500,6 +65689,71 @@ op { is_stateful: true } op { + name: "SparseConditionalAccumulator" + output_arg { + name: "handle" + type: DT_STRING + is_ref: true + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } + is_stateful: true +} +op { name: "SparseCross" input_arg { name: "indices" @@ -69293,6 +70547,21 @@ op { } } op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } +} +op { name: "StaticRegexReplace" input_arg { name: "input" @@ -69594,6 +70863,43 @@ op { } } op { + name: "StringFormat" + input_arg { + name: "inputs" + type_list_attr: "T" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "template" + type: "string" + default_value { + s: "%s" + } + } + attr { + name: "placeholder" + type: "string" + default_value { + s: "%s" + } + } + attr { + name: "summarize" + type: "int" + default_value { + i: 3 + } + } +} +op { name: "StringJoin" input_arg { name: "inputs" @@ -69630,6 +70936,30 @@ op { } } op { + name: "StringLength" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_INT32 + } + attr { + name: "unit" + type: "string" + default_value { + s: "BYTE" + } + allowed_values { + list { + s: "BYTE" + s: "UTF8_CHAR" + } + } + } +} +op { name: "StringSplit" input_arg { name: "input" @@ -74673,6 +76003,38 @@ op { is_stateful: true } op { + name: "UpperBound" + input_arg { + name: "sorted_inputs" + type_attr: "T" + } + input_arg { + name: "values" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "VarHandleOp" output_arg { name: "resource" @@ -75008,9 +76370,21 @@ op { type: DT_VARIANT } input_arg { - name: "window_size" + name: "size" + type: DT_INT64 + } + input_arg { + name: "shift" + type: DT_INT64 + } + input_arg { + name: "stride" type: DT_INT64 } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } output_arg { name: "handle" type: DT_VARIANT @@ -75247,6 +76621,62 @@ op { is_stateful: true } op { + name: "Xdivy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { + name: "Xlogy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { name: "ZerosLike" input_arg { name: "x" diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc index f78f7a897a..f84142c992 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops.cc @@ -37,7 +37,6 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; - REGISTER_OP("CudnnRNNParamsSize") .Input("num_layers: int32") .Input("num_units: int32") @@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize") .Attr("seed2: int = 0") .Output("params_size: S") .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + // num_layers, num_units, and input_size should be scalars. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + c->set_output(0, c->Vector(1)); return Status::OK(); }); - REGISTER_OP("CudnnRNN") .Input("input: T") .Input("input_h: T") @@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical") return Status::OK(); }); - REGISTER_OP("CudnnRNNCanonicalToParams") .Input("num_layers: int32") .Input("num_units: int32") diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc index 2dd867561b..13c3b933f4 100644 --- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc +++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc @@ -26,7 +26,16 @@ namespace tensorflow { TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) { ShapeInferenceTestOp op("CudnnRNNParamsSize"); - INFER_OK(op, "[1];[1];[1]", "[1]"); + INFER_OK(op, "[];[];[]", "[1]"); + INFER_OK(op, "?;[];[]", "[1]"); + INFER_OK(op, "[];?;[]", "[1]"); + INFER_OK(op, "[];[];?", "[1]"); + INFER_OK(op, "[];?;?", "[1]"); + INFER_OK(op, "?;?;?", "[1]"); + + INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]"); + INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]"); + INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]"); } TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) { diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index eed0bce174..ffab8ad661 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); @@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f03639e833..1ada623cf5 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -198,6 +198,7 @@ REGISTER_OP("MapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ParallelMapDataset") @@ -209,6 +210,7 @@ REGISTER_OP("ParallelMapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("MapAndBatchDataset") @@ -325,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ParallelInterleaveDatasetV2") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("num_parallel_calls: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("GroupByReducerDataset") .Input("input_dataset: variant") .Input("key_func_other_arguments: Tkey_func_other_arguments") @@ -381,14 +396,20 @@ REGISTER_OP("FilterByLastComponentDataset") REGISTER_OP("WindowDataset") .Input("input_dataset: variant") - .Input("window_size: int64") + .Input("size: int64") + .Input("shift: int64") + .Input("stride: int64") + .Input("drop_remainder: bool") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; - // batch_size should be a scalar. + // size, shift, stride, and drop_remainder should be scalars. TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); return shape_inference::ScalarShape(c); }); @@ -858,6 +879,13 @@ REGISTER_OP("IteratorGetNextAsOptional") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ModelDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("MapDefun") .Input("arguments: Targuments") .Output("output: output_types") @@ -866,7 +894,7 @@ REGISTER_OP("MapDefun") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector<TensorShape> output_shapes; + std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( @@ -876,6 +904,10 @@ REGISTER_OP("MapDefun") int64 dim_zero = -1; for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + if (c->Rank(c->input(i)) == 0) { + return errors::InvalidArgument( + "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { if (dim_zero == -1) { @@ -900,4 +932,41 @@ REGISTER_OP("MapDefun") return Status::OK(); }); +REGISTER_OP("MultiDeviceIterator") + .Output("handle: resource") + .Attr("devices: list(string) >= 1") + .Attr("shared_name: string") + .Attr("container: string") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("MultiDeviceIteratorInit") + .Input("dataset: variant") + .Input("multi_device_iterator: resource") + .Input("max_buffer_size: int64") + .Output("incarnation_id: int64") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("MultiDeviceIteratorGetNextFromShard") + .Input("multi_device_iterator: resource") + .Input("shard_num: int32") + .Input("incarnation_id: int64") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(IteratorGetNextShapeFn); + +REGISTER_OP("MultiDeviceIteratorToStringHandle") + .Input("multi_device_iterator: resource") + .Output("string_handle: string") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("MultiDeviceIteratorFromStringHandle") + .Input("string_handle: string") + .Output("multi_device_iterator: resource") + .Attr("output_types: list(type) >= 0 = []") + .Attr("output_shapes: list(shape) >= 0 = []") + .SetShapeFn(shape_inference::ScalarShape); + } // namespace tensorflow diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 11ca0bd259..5427275284 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression") }); REGISTER_OP("NonMaxSuppressionV2") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; @@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2") }); REGISTER_OP("NonMaxSuppressionV3") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn(NMSShapeFn); REGISTER_OP("NonMaxSuppressionV4") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") .Output("valid_outputs: int32") + .Attr("T: {half, float} = DT_FLOAT") .Attr("pad_to_max_output_size: bool = false") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(NMSShapeFn(c)); diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc index 639d211767..2034d3601b 100644 --- a/tensorflow/core/ops/logging_ops.cc +++ b/tensorflow/core/ops/logging_ops.cc @@ -20,6 +20,8 @@ limitations under the License. namespace tensorflow { +using shape_inference::InferenceContext; + REGISTER_OP("Assert") .Input("condition: bool") .Input("data: T") @@ -44,6 +46,23 @@ REGISTER_OP("Print") WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print"); +REGISTER_OP("PrintV2") + .Input("input: string") + .SetIsStateful() + .Attr( + "output_stream: {'stdout', 'stderr', 'log(info)', " + "'log(warning)', 'log(error)'} = 'stderr'") + .SetShapeFn([](InferenceContext* c) { + // Make sure that the input is a scalar. + if (c->Rank(c->input(0)) != 0) { + return errors::InvalidArgument("input must be a scalar, but has rank: ", + c->Rank(c->input(0))); + } + return Status::OK(); + }); + +WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2"); + // ---------------------------------------------------------------------------- // Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as // inputs or outputs in various ways. diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index 07f876cb90..55dcc50325 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -549,6 +549,40 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Pow", PowGrad); +Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"zeros"}, "ZerosLike", {"x"}}, + {{"is_x_zero"}, "NotEqual", {"x", "zeros"}}, + {{"is_zero_cast"}, "Cast", {"is_x_zero"}, + {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}}, + {{"xlogygrad"}, "Xdivy", {"x", "y"}}, + {{"gx"}, "Mul", {"safe_logy", "dz"}}, + {{"gy"}, "Mul", {"xlogygrad", "dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Xlogy", XlogyGrad); + +Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"zeros"}, "ZerosLike", {"x"}}, + {{"is_x_zero"}, "NotEqual", {"x", "zeros"}}, + {{"is_zero_cast"}, "Cast", {"is_x_zero"}, + {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}}, + {{"y2"}, "Square", {"y"}}, + {{"negy2"}, "Neg", {"y2"}}, + {{"xdivygrad"}, "Xdivy", {"x", "negy2"}}, + {{"gx"}, "Mul", {"safe_divy", "dz"}}, + {{"gy"}, "Mul", {"xdivygrad", "dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Xdivy", XdivyGrad); + Status MaximumMinimumGradHelper(const string& comparator, const AttrSlice& attrs, FunctionDef* g) { // clang-format off diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 5ee79809ac..9fc6b34147 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -909,6 +909,46 @@ TEST_F(MathGradTest, ComplexPow) { } #endif // TENSORFLOW_USE_SYCL +TEST_F(MathGradTest, Xlogy) { + auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f}, + TensorShape({2, 3})); + auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1})); + Tensor dx; + Tensor dy; + auto g = [](float x, float y) -> float { return x == 0. ? 0. : std::log(y); }; + auto h = [](float x, float y) -> float { return x == 0. ? 0. : x / y; }; + SymGrad("Xlogy", x, y, &dx, &dy); + test::ExpectClose( + dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f), + g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)}, + TensorShape({2, 3}))); + test::ExpectClose( + dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f), + h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)}, + TensorShape({2, 1}))); +} + +TEST_F(MathGradTest, Xdivy) { + auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f}, + TensorShape({2, 3})); + auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1})); + Tensor dx; + Tensor dy; + auto g = [](float x, float y) -> float { return x == 0. ? 0. : 1 / y; }; + auto h = [](float x, float y) -> float { + return x == 0. ? 0. : -x / (y * y); + }; + SymGrad("Xdivy", x, y, &dx, &dy); + test::ExpectClose( + dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f), + g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)}, + TensorShape({2, 3}))); + test::ExpectClose( + dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f), + h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)}, + TensorShape({2, 1}))); +} + TEST_F(MathGradTest, Maximum) { auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, TensorShape({2, 3})); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 717263a9b0..3eff728f03 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -429,6 +429,20 @@ Returns (x - y)(x - y) element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("Xlogy") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {half, float, double, complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + +REGISTER_OP("Xdivy") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {half, float, double, complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + #undef BINARY_FEWER #undef BINARY_MORE diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 2485fa4717..d1d81b27cc 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -178,7 +178,7 @@ REGISTER_OP("FusedBatchNorm") .Output("reserve_space_2: T") .Attr("T: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); @@ -196,7 +196,7 @@ REGISTER_OP("FusedBatchNormV2") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); @@ -213,7 +213,7 @@ REGISTER_OP("FusedBatchNormGrad") .Output("reserve_space_4: T") .Attr("T: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormGradShape); @@ -231,7 +231,7 @@ REGISTER_OP("FusedBatchNormGradV2") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormGradShape); @@ -1009,32 +1009,30 @@ REGISTER_OP("SeluGrad") .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); -// TODO(b/111515541): change T to {half, bfloat16, float, double} REGISTER_OP("Softplus") .Input("features: T") .Output("activations: T") - .Attr("T: realnumbertype") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("SoftplusGrad") .Input("gradients: T") .Input("features: T") .Output("backprops: T") - .Attr("T: realnumbertype") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); -// TODO(b/111515541): change T to {half, bfloat16, float, double} REGISTER_OP("Softsign") .Input("features: T") .Output("activations: T") - .Attr("T: realnumbertype") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("SoftsignGrad") .Input("gradients: T") .Input("features: T") .Output("backprops: T") - .Attr("T: realnumbertype") + .Attr("T: {half, bfloat16, float, double}") .SetShapeFn(shape_inference::MergeBothInputsShapeFn); // -------------------------------------------------------------------------- diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 4419f93d0c..e173c2d072 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4272,6 +4272,29 @@ op { is_commutative: true } op { + name: "BoostedTreesBucketize" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + output_arg { + name: "buckets" + type: DT_INT32 + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesCalculateBestGainsPerFeature" input_arg { name: "node_id_range" @@ -4381,6 +4404,29 @@ op { is_stateful: true } op { + name: "BoostedTreesCreateQuantileStreamResource" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + input_arg { + name: "num_streams" + type: DT_INT64 + } + attr { + name: "max_elements" + type: "int" + default_value { + i: 1099511627776 + } + } + is_stateful: true +} +op { name: "BoostedTreesDeserializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -4474,6 +4520,32 @@ op { is_stateful: true } op { + name: "BoostedTreesMakeQuantileSummaries" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "example_weights" + type: DT_FLOAT + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + output_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesMakeStatsSummary" input_arg { name: "node_ids" @@ -4543,6 +4615,83 @@ op { is_stateful: true } op { + name: "BoostedTreesQuantileStreamResourceAddSummaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceFlush" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "num_buckets" + type: DT_INT64 + } + attr { + name: "generate_quantiles" + type: "bool" + default_value { + b: false + } + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceHandleOp" + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { name: "BoostedTreesSerializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -5592,6 +5741,19 @@ op { s: "" } } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } is_stateful: true } op { @@ -10025,6 +10187,59 @@ op { } } op { + name: "ExtractVolumePatches" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "patches" + type_attr: "T" + } + attr { + name: "ksizes" + type: "list(int)" + has_minimum: true + minimum: 5 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 5 + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_INT64 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } +} +op { name: "FFT" input_arg { name: "input" @@ -11244,6 +11459,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -11317,6 +11538,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -11401,6 +11628,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -11485,6 +11718,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -13149,6 +13388,18 @@ op { is_stateful: true } op { + name: "IsBoostedTreesQuantileStreamResourceInitialized" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "is_initialized" + type: DT_BOOL + } + is_stateful: true +} +op { name: "IsFinite" input_arg { name: "x" @@ -14362,6 +14613,38 @@ op { } } op { + name: "LowerBound" + input_arg { + name: "sorted_inputs" + type_attr: "T" + } + input_arg { + name: "values" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "MakeIterator" input_arg { name: "dataset" @@ -14542,6 +14825,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "MapDefun" @@ -16540,6 +16830,29 @@ op { } } op { + name: "ModelDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "Mul" input_arg { name: "x" @@ -16576,6 +16889,134 @@ op { is_commutative: true } op { + name: "MultiDeviceIterator" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "devices" + type: "list(string)" + has_minimum: true + minimum: 1 + } + attr { + name: "shared_name" + type: "string" + } + attr { + name: "container" + type: "string" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorFromStringHandle" + input_arg { + name: "string_handle" + type: DT_STRING + } + output_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + attr { + name: "output_types" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { + name: "output_shapes" + type: "list(shape)" + default_value { + list { + } + } + has_minimum: true + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorGetNextFromShard" + input_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + input_arg { + name: "shard_num" + type: DT_INT32 + } + input_arg { + name: "incarnation_id" + type: DT_INT64 + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorInit" + input_arg { + name: "dataset" + type: DT_VARIANT + } + input_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + input_arg { + name: "max_buffer_size" + type: DT_INT64 + } + output_arg { + name: "incarnation_id" + type: DT_INT64 + } + is_stateful: true +} +op { + name: "MultiDeviceIteratorToStringHandle" + input_arg { + name: "multi_device_iterator" + type: DT_RESOURCE + } + output_arg { + name: "string_handle" + type: DT_STRING + } + is_stateful: true +} +op { name: "Multinomial" input_arg { name: "logits" @@ -17078,11 +17519,11 @@ op { name: "NonMaxSuppressionV2" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17096,16 +17537,29 @@ op { name: "selected_indices" type: DT_INT32 } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } } op { name: "NonMaxSuppressionV3" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17123,16 +17577,29 @@ op { name: "selected_indices" type: DT_INT32 } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } } op { name: "NonMaxSuppressionV4" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17155,6 +17622,19 @@ op { type: DT_INT32 } attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { name: "pad_to_max_output_size" type: "bool" default_value { @@ -18192,6 +18672,54 @@ op { } } op { + name: "ParallelInterleaveDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -18230,6 +18758,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "ParameterizedTruncatedNormal" @@ -19223,6 +19758,30 @@ op { is_stateful: true } op { + name: "PrintV2" + input_arg { + name: "input" + type: DT_STRING + } + attr { + name: "output_stream" + type: "string" + default_value { + s: "stderr" + } + allowed_values { + list { + s: "stdout" + s: "stderr" + s: "log(info)" + s: "log(warning)" + s: "log(error)" + } + } + } + is_stateful: true +} +op { name: "PriorityQueue" output_arg { name: "handle" @@ -28179,18 +28738,10 @@ op { type: "type" allowed_values { list { + type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_INT64 - type: DT_BFLOAT16 - type: DT_UINT16 - type: DT_HALF - type: DT_UINT32 - type: DT_UINT64 } } } @@ -28214,18 +28765,10 @@ op { type: "type" allowed_values { list { + type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_INT64 - type: DT_BFLOAT16 - type: DT_UINT16 - type: DT_HALF - type: DT_UINT32 - type: DT_UINT64 } } } @@ -28245,18 +28788,10 @@ op { type: "type" allowed_values { list { + type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_INT64 - type: DT_BFLOAT16 - type: DT_UINT16 - type: DT_HALF - type: DT_UINT32 - type: DT_UINT64 } } } @@ -28280,18 +28815,10 @@ op { type: "type" allowed_values { list { + type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE - type: DT_INT32 - type: DT_UINT8 - type: DT_INT16 - type: DT_INT8 - type: DT_INT64 - type: DT_BFLOAT16 - type: DT_UINT16 - type: DT_HALF - type: DT_UINT32 - type: DT_UINT64 } } } @@ -29610,6 +30137,19 @@ op { s: "" } } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } is_stateful: true } op { @@ -32108,6 +32648,21 @@ op { } } op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } +} +op { name: "StaticRegexReplace" input_arg { name: "input" @@ -32409,6 +32964,43 @@ op { } } op { + name: "StringFormat" + input_arg { + name: "inputs" + type_list_attr: "T" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "template" + type: "string" + default_value { + s: "%s" + } + } + attr { + name: "placeholder" + type: "string" + default_value { + s: "%s" + } + } + attr { + name: "summarize" + type: "int" + default_value { + i: 3 + } + } +} +op { name: "StringJoin" input_arg { name: "inputs" @@ -32443,6 +33035,19 @@ op { name: "output" type: DT_INT32 } + attr { + name: "unit" + type: "string" + default_value { + s: "BYTE" + } + allowed_values { + list { + s: "BYTE" + s: "UTF8_CHAR" + } + } + } } op { name: "StringSplit" @@ -35628,6 +36233,38 @@ op { is_stateful: true } op { + name: "UpperBound" + input_arg { + name: "sorted_inputs" + type_attr: "T" + } + input_arg { + name: "values" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "out_type" + } + attr { + name: "T" + type: "type" + } + attr { + name: "out_type" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "VarHandleOp" output_arg { name: "resource" @@ -35873,9 +36510,21 @@ op { type: DT_VARIANT } input_arg { - name: "window_size" + name: "size" + type: DT_INT64 + } + input_arg { + name: "shift" type: DT_INT64 } + input_arg { + name: "stride" + type: DT_INT64 + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } output_arg { name: "handle" type: DT_VARIANT @@ -36112,6 +36761,62 @@ op { is_stateful: true } op { + name: "Xdivy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { + name: "Xlogy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { name: "ZerosLike" input_arg { name: "x" diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 79ca96d249..eff453241d 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -343,10 +343,11 @@ REGISTER_OP("DecodeCSV") // Validate the record_defaults inputs. for (int i = 1; i < c->num_inputs(); ++i) { ShapeHandle v; - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v)); - if (c->Value(c->Dim(v, 0)) > 1) { + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { return errors::InvalidArgument( - "Shape of a default must be a length-0 or length-1 vector"); + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); } } diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc index c65e66d1a8..ba594e400c 100644 --- a/tensorflow/core/ops/parsing_ops_test.cc +++ b/tensorflow/core/ops/parsing_ops_test.cc @@ -52,9 +52,12 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) { INFER_OK(op, "[1,2,?,4];?;?", "in0;in0"); INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0"); + // Scalar defaults are ok + INFER_OK(op, "?;?;[]", "in0;in0"); + // Check errors in the record_defaults inputs. - INFER_ERROR("must be rank 1", op, "?;?;[]"); - INFER_ERROR("must be rank 1", op, "?;[];?"); + INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]"); + INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?"); INFER_ERROR("Shape of a default must be", op, "?;?;[2]"); INFER_ERROR("Shape of a default must be", op, "?;[2];?"); } diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 26499540f1..adc9cd1486 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -19,6 +19,7 @@ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" using ::tensorflow::shape_inference::InferenceContext; using ::tensorflow::shape_inference::ShapeAndType; @@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) { return Status::OK(); } +Status ReadVariablesShapeFn(InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + DataTypeVector value_dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes)); + if (n != value_dtypes.size()) { + return errors::InvalidArgument( + "Mismatched number of arguments to ReadVariablesOp"); + } + for (int i = 0; i < n; ++i) { + ShapeAndType shape_and_type; + auto* handle_data = c->input_handle_shapes_and_types(i); + if (handle_data == nullptr || handle_data->empty()) { + shape_and_type.shape = c->UnknownShape(); + shape_and_type.dtype = DT_INVALID; + } else { + shape_and_type = (*handle_data)[0]; + if (shape_and_type.dtype != value_dtypes[i]) { + return errors::InvalidArgument( + "Trying to read variable with wrong dtype. " + "Expected ", + DataTypeString(shape_and_type.dtype), " got ", + DataTypeString(value_dtypes[i])); + } + } + c->set_output(i, shape_and_type.shape); + } + return Status::OK(); +} + } // namespace REGISTER_OP("VarHandleOp") @@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp") return Status::OK(); }); +REGISTER_OP("_VarHandlesOp") + .Attr("containers: list(string)") + .Attr("shared_names: list(string)") + .Attr("N: int >= 0") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Output("resources: N * resource") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + DataTypeVector dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes)); + std::vector<PartialTensorShape> shapes; + TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); + if (dtypes.size() != n) { + return errors::InvalidArgument("Mismatched number of dtypes (n=", n, + ", num dtypes=", dtypes.size(), ")"); + } + if (shapes.size() != n) { + return errors::InvalidArgument("Mismatched number of shapes (n=", n, + ", num shapes=", shapes.size(), ")"); + } + for (int i = 0; i < n; ++i) { + c->set_output(i, c->Scalar()); + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s)); + c->set_output_handle_shapes_and_types( + i, std::vector<ShapeAndType>{{s, dtypes[i]}}); + } + + return Status::OK(); + }); + REGISTER_OP("ReadVariableOp") .Input("resource: resource") .Output("value: dtype") .Attr("dtype: type") .SetShapeFn(ReadVariableShapeFn); +REGISTER_OP("_ReadVariablesOp") + .Attr("N: int >= 0") + .Input("resources: N * resource") + .Output("values: dtypes") + .Attr("dtypes: list(type)") + .SetShapeFn(ReadVariablesShapeFn); + Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FunctionDefHelper::Define( diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 7aa1e71809..da1d2a6432 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "absl/strings/str_split.h" #include "tensorflow/core/framework/common_shape_fns.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/shape_inference.h" @@ -56,6 +57,12 @@ REGISTER_OP("RegexFullMatch") return Status::OK(); }); +REGISTER_OP("StaticRegexFullMatch") + .Input("input: string") + .Attr("pattern: string") + .Output("output: bool") + .SetShapeFn(shape_inference::UnchangedShape); + REGISTER_OP("StringToHashBucketFast") .Input("input: string") .Output("output: int64") @@ -96,6 +103,32 @@ REGISTER_OP("AsString") .Attr("fill: string = ''") .SetShapeFn(shape_inference::UnchangedShape); +REGISTER_OP("StringFormat") + .Input("inputs: T") + .Output("output: string") + .Attr("T: list(type) >= 0") + .Attr("template: string = '%s'") + .Attr("placeholder: string = '%s'") + .Attr("summarize: int = 3") + .SetShapeFn([](InferenceContext* c) { + string template_; + string placeholder; + TF_RETURN_IF_ERROR(c->GetAttr("template", &template_)); + TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder)); + + std::vector<std::string> split_template; + split_template = absl::StrSplit(template_, placeholder); + int64 num_placeholders = split_template.size() - 1; + if (c->num_inputs() != num_placeholders) { + return errors::InvalidArgument(strings::StrCat( + "num placeholders in template and num inputs must match: ", + num_placeholders, " vs. ", c->num_inputs())); + } + + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + REGISTER_OP("StringJoin") .Input("inputs: N * string") .Attr("N: int") @@ -170,6 +203,7 @@ REGISTER_OP("StringStrip") REGISTER_OP("StringLength") .Input("input: string") .Output("output: int32") + .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") .SetShapeFn(shape_inference::UnchangedShape); REGISTER_OP("EncodeBase64") diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc index e597a490d6..d7a13a3528 100644 --- a/tensorflow/core/platform/abi.cc +++ b/tensorflow/core/platform/abi.cc @@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name, namespace tensorflow { namespace port { -std::string MaybeAbiDemangle(const char* name) { +string MaybeAbiDemangle(const char* name) { #if defined(_MSC_VER) std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc, std::free, static_cast<unsigned short>(0))}; - return std::string(demangled.get() != nullptr ? demangled.get() : name); + return string(demangled.get() != nullptr ? demangled.get() : name); #else int status = 0; std::unique_ptr<char, void (*)(void*)> res{ diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h index 591e83b0c4..d1498a6a64 100644 --- a/tensorflow/core/platform/abi.h +++ b/tensorflow/core/platform/abi.h @@ -17,11 +17,12 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_ABI_H_ #include <string> +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace port { -std::string MaybeAbiDemangle(const char* name); +string MaybeAbiDemangle(const char* name); } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index a1be4aacce..5e1eabee5b 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, .StopCapture() .OneLiteral(": ") .GetResult(&value, &name)) { - string str_value = std::string(value); + string str_value(value); str_util::StripTrailingWhitespace(&str_value); - that->response_headers_[std::string(name)] = str_value; + that->response_headers_[string(name)] = str_value; } return size * nmemb; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 9d33787bd5..83ea8539ed 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -25,6 +25,7 @@ limitations under the License. #ifdef _WIN32 #include <io.h> // for _mktemp #endif +#include "absl/base/macros.h" #include "include/json/json.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -63,7 +64,7 @@ constexpr int kGetChildrenDefaultPageSize = 1000; // The HTTP response code "308 Resume Incomplete". constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308; // The environment variable that overrides the size of the readahead buffer. -// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead. +ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.") constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES"; // The environment variable that disables the GCS block cache for reads. // This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and @@ -179,13 +180,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, return errors::InvalidArgument("GCS path doesn't start with 'gs://': ", fname); } - *bucket = std::string(bucketp); + *bucket = string(bucketp); if (bucket->empty() || *bucket == ".") { return errors::InvalidArgument("GCS path doesn't contain a bucket name: ", fname); } str_util::ConsumePrefix(&objectp, "/"); - *object = std::string(objectp); + *object = string(objectp); if (!empty_object_ok && object->empty()) { return errors::InvalidArgument("GCS path doesn't contain an object name: ", fname); @@ -224,7 +225,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) { for (const string& path : paths) { StringPiece subpath = io::Dirname(path); while (!subpath.empty()) { - result.emplace(std::string(subpath)); + result.emplace(string(subpath)); subpath = io::Dirname(subpath); } } @@ -371,7 +372,7 @@ class GcsWritableFile : public WritableFile { ~GcsWritableFile() override { Close().IgnoreError(); } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { TF_RETURN_IF_ERROR(CheckWritable()); sync_needed_ = true; outfile_ << data; @@ -723,7 +724,7 @@ GcsFileSystem::GcsFileSystem() { if (!header_name.empty() && !header_value.empty()) { additional_header_.reset(new std::pair<const string, const string>( - std::string(header_name), std::string(header_value))); + string(header_name), string(header_value))); VLOG(1) << "GCS additional header ENABLED. " << "Name: " << additional_header_->first << ", " @@ -1229,7 +1230,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Find the fixed prefix by looking for the first wildcard. const string& fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - const string& dir = std::string(io::Dirname(fixed_prefix)); + const string dir(io::Dirname(fixed_prefix)); if (dir.empty()) { return errors::InvalidArgument( "A GCS pattern doesn't have a bucket name: ", pattern); @@ -1326,7 +1327,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, " doesn't match the prefix ", object_prefix)); } if (!relative_path.empty() || include_self_directory_marker) { - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); } if (++retrieved_results >= max_results) { return Status::OK(); @@ -1354,7 +1355,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, "Unexpected response: the returned folder name ", prefix_str, " doesn't match the prefix ", object_prefix); } - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); if (++retrieved_results >= max_results) { return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index ee6ba7b041..9b85cae9b9 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson( // Send the request to the Google OAuth 2.0 server to get the token. std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); std::vector<char> response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); @@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson( std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); std::vector<char> response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 4ffa72288b..1cd0641cd3 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", grant_type); - int last_dot = std::string(assertion).find_last_of("."); - string header_dot_claim = std::string(assertion.substr(0, last_dot)); - string signature_encoded = std::string(assertion.substr(last_dot + 1)); + int last_dot = assertion.rfind('.'); + string header_dot_claim(assertion.substr(0, last_dot)); + string signature_encoded(assertion.substr(last_dot + 1)); // Check that 'signature' signs 'header_dot_claim'. diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index 92aa72be89..941ab7ad65 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -177,7 +177,7 @@ class RetryingWritableFile : public WritableFile { Close().IgnoreError(); } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { return RetryingUtils::CallWithRetries( [this, &data]() { return base_file_->Append(data); }, initial_delay_microseconds_); diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index ec2c470db7..5910fef1d2 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -72,7 +72,7 @@ class MockRandomAccessFile : public RandomAccessFile { class MockWritableFile : public WritableFile { public: explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {} - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { return calls_.ConsumeNextCall("Append"); } Status Close() override { return calls_.ConsumeNextCall("Close"); } diff --git a/tensorflow/core/platform/cord.h b/tensorflow/core/platform/cord.h new file mode 100644 index 0000000000..7c5c6655be --- /dev/null +++ b/tensorflow/core/platform/cord.h @@ -0,0 +1,26 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_ +#define TENSORFLOW_CORE_PLATFORM_CORD_H_ + +// Include appropriate platform-dependent implementations +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/core/platform/google/cord.h" +#else +#include "tensorflow/core/platform/default/cord.h" +#endif + +#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 07b2e3426b..bb841aeab7 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -625,6 +625,7 @@ def tf_additional_lib_deps(): """Additional dependencies needed to build TF libraries.""" return [ "@com_google_absl//absl/base:base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", "@com_google_absl//absl/types:optional", ] + if_static( diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl index 3a012c23fd..37475feebe 100644 --- a/tensorflow/core/platform/default/build_config_root.bzl +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -3,64 +3,64 @@ # be separate to avoid cyclic references. def tf_cuda_tests_tags(): - return ["requires-gpu"] + return ["requires-gpu", "local", "gpu"] def tf_sycl_tests_tags(): - return ["requires-gpu"] + return ["requires-gpu", "local", "gpu"] def tf_additional_plugin_deps(): - return select({ - str(Label("//tensorflow:with_xla_support")): [ - str(Label("//tensorflow/compiler/jit")) - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_xla_support")): [ + str(Label("//tensorflow/compiler/jit")), + ], + "//conditions:default": [], + }) def tf_additional_xla_deps_py(): - return [] + return [] def tf_additional_grpc_deps_py(): - return [] + return [] def tf_additional_license_deps(): - return select({ - str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"], + "//conditions:default": [], + }) def tf_additional_verbs_deps(): - return select({ - str(Label("//tensorflow:with_verbs_support")): [ - str(Label("//tensorflow/contrib/verbs:verbs_server_lib")), - str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")), - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_verbs_support")): [ + str(Label("//tensorflow/contrib/verbs:verbs_server_lib")), + str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")), + ], + "//conditions:default": [], + }) def tf_additional_mpi_deps(): - return select({ - str(Label("//tensorflow:with_mpi_support")): [ - str(Label("//tensorflow/contrib/mpi:mpi_server_lib")), - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_mpi_support")): [ + str(Label("//tensorflow/contrib/mpi:mpi_server_lib")), + ], + "//conditions:default": [], + }) def tf_additional_gdr_deps(): - return select({ - str(Label("//tensorflow:with_gdr_support")): [ - str(Label("//tensorflow/contrib/gdr:gdr_server_lib")), - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_gdr_support")): [ + str(Label("//tensorflow/contrib/gdr:gdr_server_lib")), + ], + "//conditions:default": [], + }) -def if_static(extra_deps, otherwise=[]): - return select({ - str(Label("//tensorflow:framework_shared_object")): otherwise, - "//conditions:default": extra_deps, - }) +def if_static(extra_deps, otherwise = []): + return select({ + str(Label("//tensorflow:framework_shared_object")): otherwise, + "//conditions:default": extra_deps, + }) -def if_dynamic_kernels(extra_deps, otherwise=[]): - return select({ - str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps, - "//conditions:default": otherwise, - }) +def if_dynamic_kernels(extra_deps, otherwise = []): + return select({ + str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps, + "//conditions:default": otherwise, + }) diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h new file mode 100644 index 0000000000..5823374d1a --- /dev/null +++ b/tensorflow/core/platform/default/cord.h @@ -0,0 +1,21 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_ +#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_ + +// TODO(ebrevdo): Fill this in. + +#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_ diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index ccddf1eafc..83c65dbfa9 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -321,6 +321,16 @@ class DeviceTracerImpl : public DeviceTracer, return nullptr; } + bool IsEnabledForAnnotations() const override { + // We are always enabled for 'Annotations'. + return true; + } + + bool IsEnabledForActivities(bool is_expensive) const override { + // We don't do anything with 'Activities' so we are never 'enabled'. + return false; + } + protected: // This callback is used exclusively by CUPTIManager. friend class CUPTIManager; diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 305a9a682f..2e32abdffb 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -345,7 +346,13 @@ TEST_F(DefaultEnvTest, LocalTempFilename) { // Write something to the temporary file. std::unique_ptr<WritableFile> file_to_write; TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write)); +#if defined(PLATFORM_GOOGLE) + TF_CHECK_OK(file_to_write->Append("Nu")); + TF_CHECK_OK(file_to_write->Append(absl::Cord("ll"))); +#else + // TODO(ebrevdo): Remove this version. TF_CHECK_OK(file_to_write->Append("Null")); +#endif TF_CHECK_OK(file_to_write->Close()); TF_CHECK_OK(env->FileExists(filename)); diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index 077b1d79cf..156af6cdea 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/file_statistics.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/platform.h" @@ -252,7 +253,15 @@ class WritableFile { virtual ~WritableFile(); /// \brief Append 'data' to the file. - virtual Status Append(const StringPiece& data) = 0; + virtual Status Append(StringPiece data) = 0; + + // TODO(ebrevdo): Remove this ifdef when absl is updated. +#if defined(PLATFORM_GOOGLE) + // \brief Append 'data' to the file. + virtual Status Append(const absl::Cord& cord) { + return errors::Unimplemented("Append(absl::Cord) is not implemented"); + } +#endif /// \brief Close the file. /// diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 8cdb08f51b..eb35531e9f 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -282,7 +282,7 @@ class HDFSWritableFile : public WritableFile { } } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { if (hdfs_->hdfsWrite(fs_, file_, data.data(), static_cast<tSize>(data.size())) == -1) { return IOError(filename_, errno); diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc index 47bfa020ce..c7afab9583 100644 --- a/tensorflow/core/platform/posix/posix_file_system.cc +++ b/tensorflow/core/platform/posix/posix_file_system.cc @@ -91,7 +91,7 @@ class PosixWritableFile : public WritableFile { } } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { return IOError(filename_, errno); diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index ce0f6cd741..e0b8e37745 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -211,7 +211,7 @@ class S3WritableFile : public WritableFile { std::ios_base::binary | std::ios_base::trunc | std::ios_base::in | std::ios_base::out)) {} - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { if (!outfile_) { return errors::FailedPrecondition( "The internal temporary file is not writable."); diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index e5851f1dfe..aefbe64425 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -155,6 +155,13 @@ class TraceCollector { StringPiece name_part1, StringPiece name_part2, bool is_expensive) const = 0; + // Returns true if this annotation tracing is enabled for any op. + virtual bool IsEnabledForAnnotations() const = 0; + + // Returns true if this activity handle tracking is enabled for an op of the + // given expensiveness. + virtual bool IsEnabledForActivities(bool is_expensive) const = 0; + protected: static string ConcatenateNames(StringPiece first, StringPiece second); diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc index 9079a5ccaa..6cf79634d7 100644 --- a/tensorflow/core/platform/windows/windows_file_system.cc +++ b/tensorflow/core/platform/windows/windows_file_system.cc @@ -150,7 +150,7 @@ class WindowsWritableFile : public WritableFile { } } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { DWORD bytes_written = 0; DWORD data_size = static_cast<DWORD>(data.size()); BOOL write_result = diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index da3a99565e..85cd02350a 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -68,7 +68,7 @@ message GPUOptions { // after the process starts. Users are required to use vendor // specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the // physical to visible device mapping prior to invoking TensorFlow. - // 2. In the code, the ids in this list are also called "CUDA GPU id"s, + // 2. In the code, the ids in this list are also called "platform GPU id"s, // and the 'virtual' ids of GPU devices (i.e. the ids in the device // name "/device:GPU:<id>") are also called "TF GPU id"s. Please // refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h @@ -390,9 +390,12 @@ message ConfigProto { message Experimental { // Task name for group resolution. string collective_group_leader = 1; - // Whether the client will format templated errors. For example, the string: - // "The node was defined on ^^node:Foo:${file}:${line}^^". - bool client_handles_error_formatting = 2; + + // We removed the flag client_handles_error_formatting. Marking the tag + // number as reserved. + // TODO(shikharagarwal): Should we just remove this tag so that it can be + // used in future for other purpose? + reserved 2; // Which executor to use, the default executor will be used // if it is an empty string or "DEFAULT" diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto new file mode 100644 index 0000000000..7644314fc9 --- /dev/null +++ b/tensorflow/core/protobuf/replay_log.proto @@ -0,0 +1,47 @@ +syntax = "proto3"; + +option cc_enable_arenas = true; +package tensorflow; + +import "tensorflow/core/framework/graph.proto"; +import "tensorflow/core/protobuf/cluster.proto"; +import "tensorflow/core/protobuf/master.proto"; + +// Records the creation of a new replay session. We record the device listing +// here to capture the state of the cluster. +message NewReplaySession { + ListDevicesResponse devices = 1; + string session_handle = 2; +} + +message ReplayOp { + double start_time_us = 31; + double end_time_us = 32; + + oneof op { + CreateSessionRequest create_session = 1; + ExtendSessionRequest extend_session = 2; + PartialRunSetupRequest partial_run_setup = 3; + RunStepRequest run_step = 4; + CloseSessionRequest close_session = 5; + ListDevicesRequest list_devices = 6; + ResetRequest reset_request = 7; + MakeCallableRequest make_callable = 8; + RunCallableRequest run_callable = 9; + ReleaseCallableRequest release_callable = 10; + NewReplaySession new_replay_session = 11; + } + + oneof response { + CreateSessionResponse create_session_response = 21; + ExtendSessionResponse extend_session_response = 22; + PartialRunSetupResponse partial_run_setup_response = 23; + RunStepResponse run_step_response = 24; + CloseSessionResponse close_session_response = 25; + ListDevicesResponse list_devices_response = 26; + ResetResponse reset_request_response = 27; + MakeCallableResponse make_callable_response = 28; + RunCallableResponse run_callable_response = 29; + ReleaseCallableResponse release_callable_response = 30; + } +} diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index 07f984ceea..bb8f88336d 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -75,6 +75,8 @@ message RewriterConfig { // Try to allocate some independent Op outputs contiguously in order to // merge or eliminate downstream Ops (off by default). Toggle scoped_allocator_optimization = 15; + // Force small ops onto the CPU (default is ON). + Toggle pin_to_host_optimization = 18; // Controls how many times we run the optimizers in meta optimizer (default // is once). diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 4129c93af5..b043a69431 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -19,12 +19,12 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. #define TF_MAJOR_VERSION 1 -#define TF_MINOR_VERSION 10 +#define TF_MINOR_VERSION 11 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", // "-beta", "-rc", "-rc.1") -#define TF_VERSION_SUFFIX "" +#define TF_VERSION_SUFFIX "-rc1" #define TF_STR_HELPER(x) #x #define TF_STR(x) TF_STR_HELPER(x) diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index 973e315f09..24002e72a0 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h index 1a622babe1..1e45a8abd3 100644 --- a/tensorflow/core/util/ctc/ctc_beam_scorer.h +++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange // Collection of scoring classes that can be extended and provided to the // CTCBeamSearchDecoder to incorporate additional scoring logic (such as a diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 5e2aeb7830..6fbb1ed0da 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 3be36822e5..b55d7d77ac 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h index 36be9e92ef..054412d388 100644 --- a/tensorflow/core/util/ctc/ctc_loss_util.h +++ b/tensorflow/core/util/ctc/ctc_loss_util.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index 540adb58d4..f6f0408ccc 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -93,11 +93,11 @@ __device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync( } namespace cuda_helper { -template <typename IntType> -__device__ IntType upper_bound(IntType* first, IntType count, IntType val) { - IntType* orig = first; - IntType* it = nullptr; - IntType step = 0; +template <typename T, typename OutType = int32> +__device__ OutType upper_bound(const T* first, OutType count, T val) { + const T* orig = first; + const T* it = nullptr; + OutType step = 0; while (count > 0) { it = first; step = count / 2; @@ -112,6 +112,27 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) { return first - orig; } + +template <typename T, typename OutType = int32> +__device__ OutType lower_bound(const T* first, OutType count, T val) { + const T* orig = first; + const T* it = nullptr; + OutType step = 0; + while (count > 0) { + it = first; + step = count / 2; + it += step; + if (*it < val) { + first = ++it; + count -= step + 1; + } else { + count = step; + } + } + + return first - orig; +} + } // namespace cuda_helper } // namespace tensorflow diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 680211edff..cf7ffd8149 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -34,9 +34,8 @@ limitations under the License. #endif #ifdef INTEL_MKL_ML_ONLY -// Using pragma message since #warning doesn't work with all compilers -#pragma message("Compiling for INTEL MKL ML only will be deprecated soon.") -#pragma message("Please use MKL DNN (the default option for --config=mkl)") +#error \ + "Compiling for INTEL MKL ML only is no longer supported.Please use MKL DNN (the default option for --config=mkl)" #endif #ifdef INTEL_MKL_ML_ONLY diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc index c081ceae57..e01058dff6 100644 --- a/tensorflow/core/util/port.cc +++ b/tensorflow/core/util/port.cc @@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() { } bool IsMklEnabled() { -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) return true; #else return false; -#endif +#endif // INTEL_MKL && ENABLE_MKL } } // end namespace tensorflow diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc index 204b933051..546b0a833c 100644 --- a/tensorflow/core/util/sparse/group_iterator.cc +++ b/tensorflow/core/util/sparse/group_iterator.cc @@ -21,8 +21,8 @@ namespace sparse { void GroupIterable::IteratorStep::UpdateEndOfGroup() { ++next_loc_; - int64 N = iter_->ix_.dim_size(0); - auto ix_t = iter_->ix_.template matrix<int64>(); + const auto& ix_t = iter_->ix_matrix_; + const int64 N = ix_t.dimension(0); while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) { ++next_loc_; } @@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++( std::vector<int64> Group::group() const { std::vector<int64> g; - auto ix_t = iter_->ix_.template matrix<int64>(); + const auto& ix_t = iter_->ix_matrix_; for (const int d : iter_->group_dims_) { g.push_back(ix_t(loc_, d)); } @@ -62,8 +62,8 @@ std::vector<int64> Group::group() const { } TTypes<int64>::UnalignedConstMatrix Group::indices() const { - return TTypes<int64>::UnalignedConstMatrix( - &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_); + return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)), + next_loc_ - loc_, iter_->dims_); } } // namespace sparse diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h index 3fa8cb6116..14610c61d9 100644 --- a/tensorflow/core/util/sparse/group_iterator.h +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -79,6 +79,7 @@ class GroupIterable { GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims) : ix_(ix), + ix_matrix_(ix_.matrix<int64>()), vals_(vals), dims_(dims), group_dims_(group_dims.begin(), group_dims.end()) {} @@ -127,7 +128,8 @@ class GroupIterable { private: friend class Group; - Tensor ix_; + const Tensor ix_; + const TTypes<int64>::ConstMatrix ix_matrix_; Tensor vals_; const int dims_; const gtl::InlinedVector<int64, 8> group_dims_; diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h index 0f04b65f60..b9ca8ab395 100644 --- a/tensorflow/core/util/sparse/sparse_tensor.h +++ b/tensorflow/core/util/sparse/sparse_tensor.h @@ -20,6 +20,7 @@ limitations under the License. #include <numeric> #include <vector> +#include "absl/base/macros.h" #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" @@ -95,21 +96,21 @@ class SparseTensor { SparseTensor() : dims_(0) {} - // DEPRECATED: use Create() functions instead of constructors directly. + ABSL_DEPRECATED("Use Create() functions instead of constructors directly.") SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape) : SparseTensor(ix, vals, TensorShapeToVector(shape), UndefinedOrder(TensorShapeToVector(shape))) {} - // DEPRECATED: use Create() functions instead of constructors directly. + ABSL_DEPRECATED("Use Create() functions instead of constructors directly.") SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape) : SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {} - // DEPRECATED: use Create() functions instead of constructors directly. + ABSL_DEPRECATED("use Create() functions instead of constructors directly.") SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape, const VarDimArray order) : SparseTensor(ix, vals, TensorShapeToVector(shape), order) {} - // DEPRECATED: use Create() functions instead of constructors directly. + ABSL_DEPRECATED("Use Create() functions instead of constructors directly.") SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape, const VarDimArray order) : ix_(ix), @@ -237,9 +238,10 @@ class SparseTensor { static Status Split(const SparseTensor& tensor, const int split_dim, const int num_split, std::vector<SparseTensor>* result); - // DEPRECATED: use the form of Split() that takes an output pointer and - // returns a status instead. template <typename T> + ABSL_DEPRECATED( + "Use the form of Split() that takes an output pointer and returns a " + "status instead.") static std::vector<SparseTensor> Split(const SparseTensor& tensor, const int split_dim, const int num_split, diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h index 6539d565e2..7b101971a8 100644 --- a/tensorflow/core/util/tensor_bundle/naming.h +++ b/tensorflow/core/util/tensor_bundle/naming.h @@ -35,6 +35,7 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_ #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc index f4bd2950e9..74f0713a61 100644 --- a/tensorflow/core/util/work_sharder.cc +++ b/tensorflow/core/util/work_sharder.cc @@ -50,6 +50,8 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total, max_parallelism); } +// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you +// to directly specify the shard size. void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work, const Runner& runner, int max_parallelism) { cost_per_unit = std::max(int64{1}, cost_per_unit); diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h index b12c31c1ae..9db85a54c6 100644 --- a/tensorflow/core/util/work_sharder.h +++ b/tensorflow/core/util/work_sharder.h @@ -23,6 +23,9 @@ limitations under the License. namespace tensorflow { +// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you +// to directly specify the shard size. Use this function only if you want to +// manually cap parallelism. // Shards the "total" unit of work assuming each unit of work having // roughly "cost_per_unit". Each unit of work is indexed 0, 1, ..., // total - 1. Each shard contains 1 or more units of work and the |