diff options
Diffstat (limited to 'tensorflow/core')
373 files changed, 14768 insertions, 6982 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 84b11024fd..9bcf5b0865 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -85,11 +85,12 @@ load( "tf_cc_tests", "tf_copts", "tf_cuda_library", + "tf_features_nomodules_if_android", "tf_gen_op_libs", "tf_generate_proto_text_sources", "tf_genrule_cmd_append_to_srcs", "tf_opts_nortti_if_android", - "tf_features_nomodules_if_android", + "transitive_hdrs", ) load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl") load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu") @@ -120,16 +121,16 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", - "tf_additional_proto_hdrs", "tf_additional_proto_compiler_hdrs", + "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", "tf_additional_verbs_lib_defines", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", - "tf_lib_proto_parsing_deps", "tf_lib_proto_compiler_deps", + "tf_lib_proto_parsing_deps", "tf_nano_proto_library", "tf_platform_hdrs", "tf_platform_srcs", @@ -168,6 +169,7 @@ COMMON_PROTO_SRCS = [ "example/example.proto", "example/feature.proto", "framework/allocation_description.proto", + "framework/api_def.proto", "framework/attr_value.proto", "framework/cost_graph.proto", "framework/device_attributes.proto", @@ -179,7 +181,6 @@ COMMON_PROTO_SRCS = [ "framework/log_memory.proto", "framework/node_def.proto", "framework/op_def.proto", - "framework/api_def.proto", "framework/reader_base.proto", "framework/remote_fused_graph_execute_info.proto", "framework/resource_handle.proto", @@ -299,6 +300,7 @@ filegroup( name = "platform_base_hdrs", srcs = [ "platform/byte_order.h", + "platform/cord.h", "platform/env_time.h", "platform/logging.h", "platform/macros.h", @@ -695,7 +697,24 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":lib_internal", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + ], +) + +cc_library( + name = "feature_util", + srcs = ["example/feature_util.cc"], + hdrs = [ + "example/feature_util.h", + "platform/types.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":core_stringpiece", + ":platform_protobuf", + ":protos_all_cc", ], ) @@ -703,6 +722,7 @@ cc_library( name = "abi", srcs = ["platform/abi.cc"], hdrs = ["platform/abi.h"], + deps = [":platform_base"], ) cc_library( @@ -857,7 +877,6 @@ tf_cuda_library( "util/bcast.h", "util/cuda_kernel_helper.h", "util/device_name_utils.h", - "util/env_var.h", "util/events_writer.h", "util/example_proto_fast_parsing.h", "util/example_proto_helper.h", @@ -872,7 +891,6 @@ tf_cuda_library( "util/sparse/sparse_tensor.h", "util/stat_summarizer.h", "util/stat_summarizer_options.h", - "util/status_util.h", "util/stream_executor_util.h", "util/strided_slice_op.h", "util/tensor_format.h", @@ -939,15 +957,6 @@ cc_library( ) cc_library( - name = "status_util", - hdrs = ["util/status_util.h"], - deps = [ - ":graph", - ":lib", - ], -) - -cc_library( name = "reader_base", srcs = ["framework/reader_base.cc"], hdrs = ["framework/reader_base.h"], @@ -1347,6 +1356,7 @@ cc_library( "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_softmax_op", + "//tensorflow/core/kernels:mkl_transpose_op", "//tensorflow/core/kernels:mkl_tfconv_op", "//tensorflow/core/kernels:mkl_aggregate_ops", ]) + if_cuda([ @@ -1418,9 +1428,11 @@ cc_library( ":test", ":testlib_ops", "//tensorflow/cc:scope", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", + "//tensorflow/core/kernels:random_ops", ], ) @@ -1910,6 +1922,13 @@ tf_pyclif_proto_library( ) tf_pyclif_proto_library( + name = "protobuf/config_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/config.proto", + visibility = ["//visibility:public"], +) + +tf_pyclif_proto_library( name = "protobuf/device_properties_pyclif", proto_lib = ":protos_all_cc", proto_srcfile = "protobuf/device_properties.proto", @@ -2048,6 +2067,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/snappy.h", "platform/tensor_coding.h", "platform/tracing.h", + "util/env_var.h", ] # Replicated for lib_internal and lib_internal_impl. @@ -2087,6 +2107,7 @@ cc_library( "platform/*.cc", "platform/profile_utils/**/*.cc", "framework/resource_handle.cc", + "util/env_var.cc", ], exclude = [ "**/*test*", @@ -2442,7 +2463,6 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/unique_tensor_references.h", "framework/variant.h", "util/command_line_flags.h", - "util/env_var.h", "util/equal_graph_def.h", "util/presized_cuckoo_map.h", "util/tensor_slice_set.h", @@ -2518,6 +2538,7 @@ tf_cuda_library( "util/memmapped_file_system_writer.*", "util/stats_calculator.*", "util/version_info.cc", + "util/env_var.cc", ], ) + select({ "//tensorflow:windows": [], @@ -2762,7 +2783,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/step_stats_collector.h", "common_runtime/threadpool_device.h", "common_runtime/tracing_device.h", - "common_runtime/visitable_allocator.h", "common_runtime/process_state.h", "common_runtime/pool_allocator.h", "graph/gradients.h", @@ -3229,12 +3249,10 @@ tf_cc_tests( "lib/gtl/edit_distance_test.cc", "lib/gtl/flatmap_test.cc", "lib/gtl/flatset_test.cc", - "lib/gtl/inlined_vector_test.cc", "lib/gtl/int_type_test.cc", "lib/gtl/iterator_range_test.cc", "lib/gtl/manual_constructor_test.cc", "lib/gtl/map_util_test.cc", - "lib/gtl/optional_test.cc", "lib/gtl/top_n_test.cc", "lib/hash/crc32c_test.cc", "lib/hash/hash_test.cc", @@ -3560,7 +3578,6 @@ tf_cc_tests( "util/semver_test.cc", "util/sparse/sparse_tensor_test.cc", "util/stat_summarizer_test.cc", - "util/status_util_test.cc", "util/tensor_format_test.cc", "util/tensor_slice_reader_test.cc", "util/tensor_slice_set_test.cc", @@ -3585,7 +3602,6 @@ tf_cc_tests( ":ops", ":protos_all_cc", ":protos_test_cc", - ":status_util", ":test", ":test_main", ":testlib", @@ -3724,6 +3740,7 @@ tf_cc_test_mkl( ":core_cpu_internal", ":framework", ":framework_internal", + ":lib", ":test", ":test_main", ":testlib", @@ -4078,6 +4095,7 @@ tf_cuda_cc_test( ":testlib", "//third_party/eigen3", "//tensorflow/cc:cc_ops", + "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:dense_update_ops", @@ -4119,6 +4137,7 @@ tf_cc_test( "//tensorflow/cc:cc_ops", # Link with support for TensorFlow Debugger (tfdbg). "//tensorflow/core/debug", + "//tensorflow/core/kernels:collective_ops", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:dense_update_ops", @@ -4701,6 +4720,18 @@ cc_library( ] + tf_additional_libdevice_deps(), ) +transitive_hdrs( + name = "headers", + visibility = ["//tensorflow:__subpackages__"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:stream_executor", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets go here (must be at the end). diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt new file mode 100644 index 0000000000..cdaeb5091c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt @@ -0,0 +1,34 @@ +op { + graph_op_name: "BoostedTreesBucketize" + visibility: HIDDEN + in_arg { + name: "float_values" + description: <<END +float; List of Rank 2 Tensor each containing float values for a single feature. +END + } + in_arg { + name: "bucket_boundaries" + description: <<END +float; List of Rank 1 Tensors each containing the bucket boundaries for a single +feature. +END + } + out_arg { + name: "buckets" + description: <<END +int; List of Rank 2 Tensors each containing the bucketized values for a single feature. +END + } + attr { + name: "num_features" + description: <<END +inferred int; number of features. +END + } + summary: "Bucketize each feature based on bucket boundaries." + description: <<END +An op that returns a list of float tensors, where each tensor represents the +bucketized values for a single feature. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt new file mode 100644 index 0000000000..20da1295f6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "BoostedTreesCreateQuantileStreamResource" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource; Handle to quantile stream resource. +END + } + in_arg { + name: "epsilon" + description: <<END +float; The required approximation error of the stream resource. +END + } + in_arg { + name: "num_streams" + description: <<END +int; The number of streams managed by the resource that shares the same epsilon. +END + } + attr { + name: "max_elements" + description : <<END +int; The maximum number of data points that can be fed to the stream. +END + } + summary: "Create the Resource for Quantile Streams." +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt new file mode 100644 index 0000000000..ca111af312 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt @@ -0,0 +1,40 @@ +op { + graph_op_name: "BoostedTreesMakeQuantileSummaries" + visibility: HIDDEN + in_arg { + name: "float_values" + description: <<END +float; List of Rank 2 Tensors each containing values for a single feature. +END + } + in_arg { + name: "example_weights" + description: <<END +float; Rank 1 Tensor with weights per instance. +END + } + in_arg { + name: "epsilon" + description: <<END +float; The required maximum approximation error. +END + } + out_arg { + name: "summaries" + description: <<END +float; List of Rank 2 Tensors each containing the quantile summary (value, weight, +min_rank, max_rank) of a single feature. +END + } + attr { + name: "num_features" + description: <<END +int; Inferred from the size of float_values. +The number of float features. +END + } + summary: "Makes the summary of quantiles for the batch." + description: <<END +An op that takes a list of tensors and outputs the quantile summaries for each tensor. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt new file mode 100644 index 0000000000..bbeecbf32b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt @@ -0,0 +1,22 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource handle referring to a QuantileStreamResource. +END + } + in_arg { + name: "summaries" + description: <<END +string; List of Rank 2 Tensor each containing the summaries for a single feature. +END + } + summary: "Add the quantile summaries to each quantile stream resource." + description: <<END +An op that adds a list of quantile summaries to a quantile stream resource. Each +summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank) +for a single feature. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt new file mode 100644 index 0000000000..2fd94efa10 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt @@ -0,0 +1,31 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceFlush" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource handle referring to a QuantileStreamResource. +END + } + in_arg { + name: "num_buckets", + description: <<END +int; approximate number of buckets unless using generate_quantiles. +END + } + attr { + name: "generate_quantiles" + description: <<END +bool; If True, the output will be the num_quantiles for each stream where the ith +entry is the ith quantile of the input with an approximation error of epsilon. +Duplicate values may be present. +If False, the output will be the points in the histogram that we got which roughly +translates to 1/epsilon boundaries and without any duplicates. +Default to False. +END + } + summary: "Flush the summaries for a quantile stream resource." + description: <<END +An op that flushes the summaries for a quantile stream resource. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt new file mode 100644 index 0000000000..206672802f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt @@ -0,0 +1,27 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource handle referring to a QuantileStreamResource. +END + } + out_arg { + name: "bucket_boundaries" + description: <<END +float; List of Rank 1 Tensors each containing the bucket boundaries for a feature. +END + } + attr { + name: "num_features" + description: <<END +inferred int; number of features to get bucket boundaries for. +END + } + summary: "Generate the bucket boundaries for each feature based on accumulated summaries." + description: <<END +An op that returns a list of float tensors for a quantile stream resource. Each +tensor is Rank 1 containing bucket boundaries for a single feature. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt new file mode 100644 index 0000000000..cb7786c051 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt @@ -0,0 +1,5 @@ +op { + graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp" + visibility: HIDDEN + summary: "Creates a handle to a BoostedTreesQuantileStreamResource." +} diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt index e39213cbc7..440800704e 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt @@ -11,7 +11,8 @@ END name: "record_defaults" description: <<END One tensor per column of the input record, with either a -scalar default value for that column or empty if the column is required. +scalar default value for that column or an empty vector if the column is +required. END } out_arg { diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt new file mode 100644 index 0000000000..758eeb96f0 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt @@ -0,0 +1,20 @@ +op { + graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized" + visibility: HIDDEN + in_arg { + name: "quantile_stream_resource_handle" + description: <<END +resource; The reference to quantile stream resource handle. +END + } + out_arg { + name: "is_initialized" + description: <<END +bool; True if the resource is initialized, False otherwise. +END + } + summary: "Checks whether a quantile stream has been initialized." + description: <<END +An Op that checks if quantile stream resource is initialized. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt new file mode 100644 index 0000000000..171add16d4 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt @@ -0,0 +1,14 @@ +op { + graph_op_name: "ModelDataset" + visibility: HIDDEN + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + summary: "Identity transformation that models performance." + description: <<END +Identity transformation that models performance. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt new file mode 100644 index 0000000000..27bc4013c3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "ParallelInterleaveDatasetV2" + visibility: HIDDEN + attr { + name: "f" + description: <<END +A function mapping elements of `input_dataset`, concatenated with +`other_arguments`, to a Dataset variant that contains elements matching +`output_types` and `output_shapes`. +END + } + summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`." +} diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt index 8cef243aee..30fd97a0d7 100644 --- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt @@ -9,7 +9,7 @@ END in_arg { name: "pattern" description: <<END -A 1-D string tensor of the regular expression to match the input. +A scalar string tensor containing the regular expression to match the input. END } out_arg { diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt index 35f55fe106..d33a36ce06 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt index 70a07d9b4c..afdc39da96 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt index b2e3eece38..026b5b3991 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt index 7bac02e23d..a168eed87f 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt index a73306a892..876b860824 100644 --- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt @@ -3,7 +3,7 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s +A 1-D tensor whose size is equal to the size of `data`'s first dimension. Values should be sorted and can be repeated. END } diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt new file mode 100644 index 0000000000..6d9d9908ca --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt @@ -0,0 +1,29 @@ +op { + graph_op_name: "StaticRegexFullMatch" + in_arg { + name: "input" + description: <<END +A string tensor of the text to be processed. +END + } + out_arg { + name: "output" + description: <<END +A bool tensor with the same shape as `input`. +END + } + attr { + name: "pattern" + description: "The regular expression to match the input." + } + summary: "Check if the input matches the regex pattern." + description: <<END +The input is a string tensor of any shape. The pattern is the +regular expression to be matched with every element of the input tensor. +The boolean values (True or False) of the output tensor indicate +if the input matches the regex pattern provided. + +The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax) +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt index 8d6fc04847..9a89a4e8e7 100644 --- a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt @@ -32,7 +32,7 @@ END description: <<END a bitmask where a bit i being 1 means to ignore the begin value and instead use the largest interval possible. At runtime -begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or +begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or `[-1, n-1]` if `stride[i] < 0` END } diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt index 8fc1e5cba3..5246090ab3 100644 --- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt @@ -32,8 +32,10 @@ For each string in the input `Tensor`, creates a substring starting at index If `len` defines a substring that would extend beyond the length of the input string, then as many characters as possible are used. -If `pos` is negative or specifies a character index larger than any of the input -strings, then an `InvalidArgumentError` is thrown. +A negative `pos` indicates distance within the string backwards from the end. + +If `pos` specifies an index which is out of range for any of the input strings, +then an `InvalidArgumentError` is thrown. `pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on Op creation. diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt index 907c6d2022..7a60e4387a 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt @@ -3,15 +3,14 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. -END +A tensor whose shape is a prefix of `data.shape`.END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the maximum along segments of a tensor." @@ -24,13 +23,16 @@ This operator is similar to the unsorted segment sum operator found [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). Instead of computing the sum over segments, it computes the maximum such that: -\\(output_i = \max_j data_j\\) where max is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such +that `segment_ids[j...] == i`. If the maximum is empty for a given segment ID `i`, it outputs the smallest possible value for the specific numeric type, `output[i] = numeric_limits<T>::lowest()`. +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. + <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;"> <img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt> </div> diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt index 37dd973b23..7e139ddf4d 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. +A tensor whose shape is a prefix of `data.shape`. END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the minimum along segments of a tensor." @@ -24,11 +24,14 @@ This operator is similar to the unsorted segment sum operator found [(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum). Instead of computing the sum over segments, it computes the minimum such that: -\\(output_i = \min_j data_j\\) where min is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such +that `segment_ids[j...] == i`. If the minimum is empty for a given segment ID `i`, it outputs the largest possible value for the specific numeric type, `output[i] = numeric_limits<T>::max()`. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt index efbc023705..9c8ea3b620 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt @@ -3,15 +3,15 @@ op { in_arg { name: "segment_ids" description: <<END -A 1-D tensor whose rank is equal to the rank of `data`'s -first dimension. +A tensor whose shape is a prefix of `data.shape`. END } out_arg { name: "output" description: <<END -Has same shape as data, except for dimension 0 which -has size `num_segments`. +Has same shape as data, except for the first `segment_ids.rank` +dimensions, which are replaced with a single dimension which has size +`num_segments`. END } summary: "Computes the product along segments of a tensor." @@ -25,9 +25,12 @@ This operator is similar to the unsorted segment sum operator found Instead of computing the sum over segments, it computes the product of all entries belonging to a segment such that: -\\(output_i = \prod_j data_j\\) where the product is over `j` such -that `segment_ids[j] == i`. +\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples +`j...` such that `segment_ids[j...] == i`. If there is no entry for a given segment ID `i`, it outputs 1. + +If the given segment ID `i` is negative, then the corresponding value is +dropped, and will not be included in the result. END } diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt index a8874950eb..7e5d9265c2 100644 --- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt @@ -21,7 +21,7 @@ Read for an explanation of segments. Computes a tensor such that -\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such +\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids` need not be sorted and need not cover all values in the full range of valid values. diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt index 1bc3660479..01387b7527 100644 --- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt @@ -2,10 +2,31 @@ op { visibility: HIDDEN graph_op_name: "WindowDataset" in_arg { - name: "window_size" + name: "size" description: <<END A scalar representing the number of elements to accumulate in a window. END } + in_arg { + name: "shift" + description: <<END +A scalar representing the steps moving the sliding window forward in one +iteration. It must be positive. +END + } + in_arg { + name: "stride" + description: <<END +A scalar representing the stride of the input elements of the sliding window. +It must be positive. +END + } + in_arg { + name: "drop_remainder" + description: <<END +A scalar representing whether a window should be dropped in case its size is +smaller than desired. +END + } summary: "A dataset that creates window datasets from the input dataset." } diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc index 3bf0532491..3843ea9e60 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.cc +++ b/tensorflow/core/common_runtime/bfc_allocator.cc @@ -31,7 +31,7 @@ namespace tensorflow { BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, bool allow_growth, const string& name) - : suballocator_(sub_allocator), + : sub_allocator_(sub_allocator), name_(name), free_chunks_list_(kInvalidChunkHandle), next_allocation_id_(1) { @@ -72,7 +72,7 @@ BFCAllocator::~BFCAllocator() { VLOG(2) << "Number of regions allocated: " << region_manager_.regions().size(); for (const auto& region : region_manager_.regions()) { - suballocator_->Free(region.ptr(), region.memory_size()); + sub_allocator_->Free(region.ptr(), region.memory_size()); } for (BinNum b = 0; b < kNumBins; b++) { @@ -108,7 +108,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { // Try allocating. size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes); - void* mem_addr = suballocator_->Alloc(alignment, bytes); + void* mem_addr = sub_allocator_->Alloc(alignment, bytes); if (mem_addr == nullptr && !started_backpedal_) { // Only backpedal once. started_backpedal_ = true; @@ -119,7 +119,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { while (mem_addr == nullptr) { bytes = RoundedBytes(bytes * kBackpedalFactor); if (bytes < rounded_bytes) break; - mem_addr = suballocator_->Alloc(alignment, bytes); + mem_addr = sub_allocator_->Alloc(alignment, bytes); } } @@ -158,10 +158,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) { // Insert the chunk into the right bin. InsertFreeChunkIntoBin(h); - // Invoke visitors on newly allocated region. - for (const auto& visitor : region_visitors_) { - visitor(mem_addr, bytes); - } return true; } @@ -490,15 +486,6 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) { InsertFreeChunkIntoBin(coalesced_chunk); } -void BFCAllocator::AddAllocVisitor(Visitor visitor) { - VLOG(1) << "AddVisitor"; - mutex_lock l(lock_); - region_visitors_.push_back(visitor); - for (const auto& region : region_manager_.regions()) { - visitor(region.ptr(), region.memory_size()); - } -} - bool BFCAllocator::TracksAllocationSizes() { return true; } size_t BFCAllocator::RequestedSize(const void* ptr) { @@ -596,7 +583,7 @@ string BFCAllocator::RenderOccupancy() { region_offset += region.memory_size(); } - return std::string(rendered, resolution); + return string(rendered, resolution); } void BFCAllocator::DumpMemoryLog(size_t num_bytes) { diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h index 20e1dab1d5..364071e066 100644 --- a/tensorflow/core/common_runtime/bfc_allocator.h +++ b/tensorflow/core/common_runtime/bfc_allocator.h @@ -23,7 +23,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/allocator_retry.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/gtl/stl_util.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/macros.h" @@ -42,7 +42,7 @@ namespace tensorflow { // coalescing. One assumption we make is that the process using this // allocator owns pretty much all of the memory, and that nearly // all requests to allocate memory go through this interface. -class BFCAllocator : public VisitableAllocator { +class BFCAllocator : public Allocator { public: // Takes ownership of sub_allocator. BFCAllocator(SubAllocator* sub_allocator, size_t total_memory, @@ -55,11 +55,6 @@ class BFCAllocator : public VisitableAllocator { const AllocationAttributes& allocation_attr) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - - // Does nothing, because memory is never freed. - void AddFreeVisitor(Visitor visitor) override {} - bool TracksAllocationSizes() override; size_t RequestedSize(const void* ptr) override; @@ -423,7 +418,7 @@ class BFCAllocator : public VisitableAllocator { // of the available memory. bool started_backpedal_ = false; - std::unique_ptr<SubAllocator> suballocator_; + std::unique_ptr<SubAllocator> sub_allocator_; string name_; // Structures mutable after construction @@ -435,9 +430,6 @@ class BFCAllocator : public VisitableAllocator { // Pointer to head of linked list of free Chunks ChunkHandle free_chunks_list_ GUARDED_BY(lock_); - // Called once on each region, ASAP. - std::vector<Visitor> region_visitors_ GUARDED_BY(lock_); - // Counter containing the next unique identifier to assign to a // newly-created chunk. int64 next_allocation_id_ GUARDED_BY(lock_); diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index f8cb854b52..cf3d1f0b79 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -358,7 +358,7 @@ static Status WrappedTensorDeviceCopy( #define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy) + Tensor, DIRECTION, WrappedTensorDeviceCopy) REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index bf1d78ec65..af5d5b17e7 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -451,8 +451,22 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, RunState run_state(step_id, &devices_); run_state.rendez = new IntraProcessRendezvous(device_mgr_.get()); #ifndef __ANDROID__ - // Set up for collectives if the RunOption declares a key. - if (run_options.experimental().collective_graph_key() > 0) { + // Set up for collectives if ExecutorsAndKeys declares a key. + if (executors_and_keys->collective_graph_key != + BuildGraphOptions::kNoCollectiveGraphKey) { + if (run_options.experimental().collective_graph_key() != + BuildGraphOptions::kNoCollectiveGraphKey) { + // If a collective_graph_key was specified in run_options, ensure that it + // matches what came out of GraphExecutionState::BuildGraph(). + if (run_options.experimental().collective_graph_key() != + executors_and_keys->collective_graph_key) { + return errors::Internal( + "collective_graph_key in RunOptions ", + run_options.experimental().collective_graph_key(), + " should match collective_graph_key from optimized graph ", + executors_and_keys->collective_graph_key); + } + } if (!collective_executor_mgr_) { std::unique_ptr<DeviceResolverInterface> drl( new DeviceResolverLocal(device_mgr_.get())); @@ -678,10 +692,16 @@ Status DirectSession::Run(const RunOptions& run_options, // Check if we already have an executor for these arguments. ExecutorsAndKeys* executors_and_keys; RunStateArgs run_state_args(run_options.debug_options()); + run_state_args.collective_graph_key = + run_options.experimental().collective_graph_key(); TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names, target_nodes, &executors_and_keys, &run_state_args)); + { + mutex_lock l(collective_graph_key_lock_); + collective_graph_key_ = executors_and_keys->collective_graph_key; + } // Configure a call frame for the step, which we use to feed and // fetch values to and from the executors. @@ -1116,6 +1136,8 @@ Status DirectSession::CreateExecutors( BuildGraphOptions options; options.callable_options = callable_options; options.use_function_convention = !run_state_args->is_partial_run; + options.collective_graph_key = + callable_options.run_options().experimental().collective_graph_key(); std::unique_ptr<FunctionInfo> func_info(new FunctionInfo); std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys); @@ -1123,9 +1145,9 @@ Status DirectSession::CreateExecutors( ek->callable_options = callable_options; std::unordered_map<string, std::unique_ptr<Graph>> graphs; - TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def, - run_state_args, &ek->input_types, - &ek->output_types)); + TF_RETURN_IF_ERROR(CreateGraphs( + options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types, + &ek->output_types, &ek->collective_graph_key)); if (run_state_args->is_partial_run) { ek->graph = std::move(run_state_args->graph); @@ -1180,14 +1202,11 @@ Status DirectSession::CreateExecutors( auto opseg = device->op_segment(); params.create_kernel = [this, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { - // We do not share the kernel via the OpSegment if the node is - // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. - if (!lib->IsStateful(ndef.op()) || - lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { + if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { @@ -1200,13 +1219,11 @@ Status DirectSession::CreateExecutors( create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { - // If the node is stateful, opseg owns it. Otherwise, delete it. - if (kernel && !lib->IsStateful(kernel->type_string())) { + if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) delete kernel; - } }; - optimizer.Optimize(lib, options_.env, device, &iter->second, + optimizer.Optimize(lib, options_.env, device, &partition_graph, /*shape_map=*/nullptr); // TensorFlow Debugger (tfdbg) inserts debug nodes in the graph. @@ -1353,6 +1370,9 @@ Status DirectSession::GetOrCreateExecutors( } *callable_options.mutable_run_options()->mutable_debug_options() = run_state_args->debug_options; + callable_options.mutable_run_options() + ->mutable_experimental() + ->set_collective_graph_key(run_state_args->collective_graph_key); std::unique_ptr<ExecutorsAndKeys> ek; std::unique_ptr<FunctionInfo> func_info; TF_RETURN_IF_ERROR( @@ -1379,7 +1399,7 @@ Status DirectSession::CreateGraphs( std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args, DataTypeVector* input_types, - DataTypeVector* output_types) { + DataTypeVector* output_types, int64* collective_graph_key) { mutex_lock l(graph_def_lock_); std::unique_ptr<ClientGraph> client_graph; @@ -1403,6 +1423,7 @@ Status DirectSession::CreateGraphs( TF_RETURN_IF_ERROR( execution_state->BuildGraph(subgraph_options, &client_graph)); } + *collective_graph_key = client_graph->collective_graph_key; if (subgraph_options.callable_options.feed_size() != client_graph->feed_types.size()) { diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index 55a6fbce6d..c2cf3c7fd7 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -117,6 +117,9 @@ class DirectSession : public Session { ::tensorflow::Status ReleaseCallable(CallableHandle handle) override; private: + // For access to collective_graph_key_. + friend class DirectSessionCollectiveTest; + // We create one executor and its dependent library runtime for // every partition. struct PerPartitionExecutorsAndLib { @@ -150,6 +153,8 @@ class DirectSession : public Session { DataTypeVector output_types; CallableOptions callable_options; + + int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; }; // A FunctionInfo object is created for every unique set of feeds/fetches. @@ -203,6 +208,7 @@ class DirectSession : public Session { string handle; std::unique_ptr<Graph> graph; const DebugOptions& debug_options; + int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey; }; // Initializes the base execution state given the 'graph', @@ -234,7 +240,7 @@ class DirectSession : public Session { std::unordered_map<string, std::unique_ptr<Graph>>* outputs, std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args, DataTypeVector* input_types, - DataTypeVector* output_types); + DataTypeVector* output_types, int64* collective_graph_key); ::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options, CallFrameInterface* call_frame, @@ -391,6 +397,10 @@ class DirectSession : public Session { Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr; + // For testing collective graph key generation. + mutex collective_graph_key_lock_; + int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1; + TF_DISALLOW_COPY_AND_ASSIGN(DirectSession); // EXPERIMENTAL: debugger (tfdbg) related diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 4b51b20bb1..65e816c202 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -1255,7 +1255,7 @@ TEST(DirectSessionTest, RunHandleTest) { ASSERT_TRUE(s.ok()); ASSERT_EQ(1, outputs.size()); - ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()(); + const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()(); Tensor string_handle(DT_STRING, {}); string_handle.flat<string>().setConstant(resource_handle.name()); @@ -1308,7 +1308,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) { ASSERT_TRUE(s.ok()); ASSERT_EQ(1, outputs.size()); - ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()(); + const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()(); Tensor string_handle(DT_STRING, {}); string_handle.flat<string>().setConstant(resource_handle.name()); @@ -2218,4 +2218,121 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10); BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10); } // namespace + +class DirectSessionCollectiveTest : public ::testing::Test { + public: + // Creates a graph with CollectiveOps inside functions and runs it. Returns + // the generated collective_graph_key. + Status RunGraphWithCollectiveFunctions(bool add_unused_function, + int64* collective_graph_key) { + GraphDef g = CreateGraph(add_unused_function); + const Tensor t1 = + test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1}); + const Tensor t2 = + test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3}); + auto session = CreateSession(); + TF_RETURN_IF_ERROR(session->Create(g)); + std::vector<Tensor> outputs; + TF_RETURN_IF_ERROR( + session->Run({{"input1:0", t1}, {"input2:0", t2}}, {}, + {"collective_call1:0", "collective_call2:0"}, &outputs)); + DirectSession* direct_session = static_cast<DirectSession*>(session.get()); + { + mutex_lock l(direct_session->collective_graph_key_lock_); + *collective_graph_key = direct_session->collective_graph_key_; + } + return Status::OK(); + } + + private: + // Creates a function with name `function_name` and a single CollectiveReduce + // node with instance key set as `instance_key`. + FunctionDef CollectiveFunction(const string& function_name, + int instance_key) { + return FunctionDefHelper::Define( + // Function name + function_name, + // In def + {"arg:float"}, + // Out def + {"reduce:float"}, + // Attr def + {}, + // Node def + {{ + {"reduce"}, + "CollectiveReduce", + {"arg"}, + {{"group_size", 2}, + {"group_key", 1}, + {"instance_key", instance_key}, + {"subdiv_offsets", gtl::ArraySlice<int32>({0})}, + {"merge_op", "Add"}, + {"final_op", "Div"}, + {"T", DT_FLOAT}}, + }}); + } + + // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and + // CPU1, with instance_key 1, and appropriate placeholder inputs. If + // `add_unused_function` is true, adds another CollectiveFunction with + // instance_key 2 that is not invoked in the graph. + GraphDef CreateGraph(bool add_unused_function) { + GraphDef g; + FunctionDef collective_function = + CollectiveFunction("CollectiveFunction1", 1); + FunctionDefLibrary* lib = g.mutable_library(); + *lib->add_function() = collective_function; + if (add_unused_function) { + FunctionDef unused_function = + CollectiveFunction("CollectiveFunction2", 2); + *lib->add_function() = unused_function; + } + + // Inputs. + AttrValue dtype_attr; + SetAttrValue(DT_FLOAT, &dtype_attr); + NodeDef input1; + input1.set_name("input1"); + input1.set_op("Placeholder"); + input1.mutable_attr()->insert({"dtype", dtype_attr}); + NodeDef input2; + input2.set_name("input2"); + input2.set_op("Placeholder"); + input2.mutable_attr()->insert({"dtype", dtype_attr}); + + // CollectiveReduce on CPU0 with instance_key 1. + NodeDef collective_call1; + collective_call1.set_name("collective_call1"); + collective_call1.set_op("CollectiveFunction1"); + collective_call1.add_input("input1"); + collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0"); + // CollectiveReduce on CPU1 with instance_key 1. + NodeDef collective_call2; + collective_call2.set_name("collective_call2"); + collective_call2.set_op("CollectiveFunction1"); + collective_call2.add_input("input2"); + collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1"); + + *g.add_node() = input1; + *g.add_node() = input2; + *g.add_node() = collective_call1; + *g.add_node() = collective_call2; + + return g; + } +}; + +#ifndef GOOGLE_CUDA +// TODO(ayushd): enable this test for GPU builds. +TEST_F(DirectSessionCollectiveTest, + TestCollectiveGraphKeyUsesOnlyCalledFunctions) { + int64 key1; + TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1)); + int64 key2; + TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2)); + ASSERT_EQ(key1, key2); +} +#endif + } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 39a3b49cd1..18420b60fd 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -32,34 +32,55 @@ bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) { return default_val; } +std::unique_ptr<thread::ThreadPool> EagerThreadPool( + const SessionOptions& opts) { + SessionOptions opts_copy(opts); + if (opts_copy.config.inter_op_parallelism_threads() == 0) { + // Eager defaults to a single thread when no threads are specified. + opts_copy.config.set_inter_op_parallelism_threads(1); + } + + return std::unique_ptr<thread::ThreadPool>( + NewThreadPoolFromSessionOptions(opts_copy)); +} + } // namespace EagerContext::EagerContext(const SessionOptions& opts, ContextDevicePlacementPolicy default_policy, - bool async, std::unique_ptr<DeviceMgr> device_mgr, + bool async, + std::unique_ptr<const DeviceMgr> device_mgr, Rendezvous* rendezvous) + : EagerContext(opts, default_policy, async, device_mgr.release(), + /*device_mgr_owned*/ true, rendezvous) {} + +EagerContext::EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, + bool async, const DeviceMgr* device_mgr, + bool device_mgr_owned, Rendezvous* rendezvous) : policy_(default_policy), - local_device_manager_(std::move(device_mgr)), - local_unowned_device_manager_(nullptr), - devices_(local_device_manager_->ListDevices()), + devices_(device_mgr->ListDevices()), rendezvous_(rendezvous), - thread_pool_(NewThreadPoolFromSessionOptions(opts)), + thread_pool_(EagerThreadPool(opts)), pflr_(new ProcessFunctionLibraryRuntime( - local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION, - &func_lib_def_, {}, thread_pool_.get())), + device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {}, + thread_pool_.get())), log_device_placement_(opts.config.log_device_placement()), num_active_steps_(0), async_default_(async), + log_memory_(LogMemory::IsEnabled()), env_(opts.env), use_send_tensor_rpc_(false) { - InitDeviceMapAndAsync(); - if (opts.config.inter_op_parallelism_threads() > 0) { - runner_ = [this](std::function<void()> closure) { - this->thread_pool_->Schedule(closure); - }; + if (device_mgr_owned) { + local_device_manager_.reset(device_mgr); + local_unowned_device_manager_ = nullptr; } else { - runner_ = [](std::function<void()> closure) { closure(); }; + local_unowned_device_manager_ = device_mgr; } + InitDeviceMapAndAsync(); + runner_ = [this](std::function<void()> closure) { + this->thread_pool_->Schedule(std::move(closure)); + }; } void EagerContext::InitDeviceMapAndAsync() { diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 3c95ac590d..5ed6057ec6 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/distributed_runtime/eager/eager_client.h" #include "tensorflow/core/distributed_runtime/server_lib.h" #endif +#include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/rendezvous.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" @@ -65,10 +66,17 @@ enum ContextDevicePlacementPolicy { class EagerContext { public: - explicit EagerContext(const SessionOptions& opts, - ContextDevicePlacementPolicy default_policy, bool async, - std::unique_ptr<DeviceMgr> device_mgr, - Rendezvous* rendezvous); + // TODO: remove this constructor once we migrate all callers to the next one. + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + std::unique_ptr<const DeviceMgr> device_mgr, + Rendezvous* rendezvous); + + EagerContext(const SessionOptions& opts, + ContextDevicePlacementPolicy default_policy, bool async, + const DeviceMgr* device_mgr, bool device_mgr_owned, + Rendezvous* rendezvous); + ~EagerContext(); // Returns the function library runtime for the given device. @@ -134,6 +142,7 @@ class EagerContext { void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel); bool LogDevicePlacement() { return log_device_placement_; } + bool LogMemory() { return log_memory_; } Rendezvous* GetRendezvous() { return rendezvous_; } @@ -207,8 +216,8 @@ class EagerContext { thread_local_policies_ GUARDED_BY(policy_map_mu_); // Only one of the below is set. - std::unique_ptr<DeviceMgr> local_device_manager_; - DeviceMgr* local_unowned_device_manager_; + std::unique_ptr<const DeviceMgr> local_device_manager_; + const DeviceMgr* local_unowned_device_manager_; std::unique_ptr<DeviceMgr> remote_device_manager_; // Devices owned by device_manager @@ -254,6 +263,8 @@ class EagerContext { std::unordered_map<std::thread::id, bool> thread_local_async_ GUARDED_BY(async_map_mu_); + const bool log_memory_; + Env* const env_; #ifndef __ANDROID__ diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 5b3a64ba98..1da1326a9a 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -296,7 +296,7 @@ Status EagerLocalExecute(EagerOperation* op, LOG(INFO) << "Executing op " << ndef.op() << " in device " << device->name(); } - kernel = new KernelAndDevice(ctx->GetRendezvous()); + kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory()); auto* flr = ctx->func_lib(device); if (flr == nullptr) { diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc index 3d61ff4dc2..83d8425477 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc @@ -32,21 +32,6 @@ limitations under the License. namespace tensorflow { // static -Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out) { - OpKernel* k = nullptr; - Status s = CreateOpKernel(device->device_type().c_str(), device, - device->GetAllocator(AllocatorAttributes()), - nullptr, ndef, TF_GRAPH_DEF_VERSION, &k); - out->device_ = device; - out->kernel_.reset(k); - out->flib_ = nullptr; - out->runner_ = nullptr; - out->default_runner_ = [](std::function<void()> f) { f(); }; - return s; -} - -// static Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, std::function<void(std::function<void()>)>* runner, KernelAndDevice* out) { @@ -95,6 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container, params.slice_reader_cache = &slice_reader_cache_; params.rendezvous = rendez_; params.cancellation_manager = &cm_; + params.log_memory = log_memory_; if (stats != nullptr) { params.track_allocations = true; } diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h index 0ef419cbaa..04151a1171 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device.h +++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h @@ -52,12 +52,12 @@ class KernelAndDevice { static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib, std::function<void(std::function<void()>)>* runner, KernelAndDevice* out); - // TODO(ashankar): Remove this - static Status InitOp(Device* device, const NodeDef& ndef, - KernelAndDevice* out); - KernelAndDevice(tensorflow::Rendezvous* rendez) - : device_(nullptr), flib_(nullptr), rendez_(rendez) {} + KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory) + : device_(nullptr), + flib_(nullptr), + rendez_(rendez), + log_memory_(log_memory) {} // TODO(ashankar): Handle list-valued inputs. Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs, @@ -87,6 +87,7 @@ class KernelAndDevice { DataTypeVector output_dtypes_; std::function<void(std::function<void()>)>* runner_; std::function<void(std::function<void()>)> default_runner_; + const bool log_memory_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc index 6abe98f53c..da280b2317 100644 --- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc +++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc @@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) { .NumInputs(2) .BuildNodeDef()); TestEnv env; - KernelAndDevice k(nullptr); + KernelAndDevice k(nullptr, false); tensorflow::testing::StartTiming(); for (int i = 0; i < iters; ++i) { TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), @@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) { .NumInputs(inputs.size()) .BuildNodeDef()); TestEnv env; - KernelAndDevice kernel(nullptr); + KernelAndDevice kernel(nullptr, false); TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(), nullptr, &kernel)); tensorflow::testing::StartTiming(); diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index b912f7d37b..d58724cbfa 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) { Status TensorHandle::NumDims(int* num_dims) { if (IsRemote()) { TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); - CHECK(remote_shape_ != nullptr); *num_dims = remote_shape_->dims(); } else { TF_RETURN_IF_ERROR(WaitReady()); @@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) { return Status::OK(); } +Status TensorHandle::NumElements(int64* num_elements) { + if (IsRemote()) { + TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false)); + *num_elements = remote_shape_->num_elements(); + } else { + TF_RETURN_IF_ERROR(WaitReady()); + DCHECK(IsReady()); + DCHECK(num_elements != nullptr); + + *num_elements = tensor_.NumElements(); + } + + return Status::OK(); +} + Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) { if (!IsRemote()) { return errors::FailedPrecondition( diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index 1bc9c6531a..e55f1a0338 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted { Status NumDims(int* num_dims); Status Dim(int dim_index, int64* dim); + Status NumElements(int64* num_elements); // Return the op_id and output num if the handle refers to a remote tensor. Status RemoteAddress(int64* op_id, int32* output_num); diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 3ef6d35182..d0a0767d6b 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -76,56 +76,47 @@ bool IsInitializationOp(const Node* node) { namespace nodestats { inline int64 NowInNsec() { return Env::Default()->NowNanos(); } -void SetScheduled(NodeExecStatsWrapper* stats, int64 micros) { +void SetScheduled(NodeExecStatsInterface* stats, int64 micros) { if (!stats) return; stats->SetScheduled(micros * EnvTime::kMicrosToNanos); } -void SetAllStart(NodeExecStatsWrapper* stats) { +void SetAllStart(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordExecutorStarted(); } -void SetOpStart(NodeExecStatsWrapper* stats) { +void SetOpStart(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordComputeStarted(); } -void SetOpEnd(NodeExecStatsWrapper* stats) { +void SetOpEnd(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordComputeEnded(); } -void SetAllEnd(NodeExecStatsWrapper* stats) { +void SetAllEnd(NodeExecStatsInterface* stats) { if (!stats) return; stats->RecordExecutorEnded(); } -void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) { +void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) { if (!stats) return; stats->SetOutput(slot, v); } -void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) { +void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) { if (!stats) return; stats->SetMemory(ctx); } -void SetReferencedTensors(NodeExecStatsWrapper* stats, +void SetReferencedTensors(NodeExecStatsInterface* stats, const TensorReferenceVector& tensors) { if (!stats) return; stats->SetReferencedTensors(tensors); } -// Sets the timeline_label field of *stats, using data from *node. -// Returns true iff the node is a transfer node. -bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) { - if (!stats) { - return false; - } - return stats->SetTimelineLabel(node); -} - } // namespace nodestats class ExecutorImpl; @@ -1301,7 +1292,7 @@ class ExecutorState { // After item->kernel computation is done, processes its outputs. Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, - EntryVector* outputs, NodeExecStatsWrapper* stats); + EntryVector* outputs, NodeExecStatsInterface* stats); // After processing the outputs, propagates the outputs to their dsts. // Contents of *outputs are left in an indeterminate state after @@ -1312,7 +1303,7 @@ class ExecutorState { // "node" just finishes. Takes ownership of "stats". Returns true if // execution has completed. bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStatsWrapper* stats, + NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready); // Schedule all the expensive nodes in 'ready', and put all the inexpensive @@ -1482,6 +1473,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { const Status fill_status = device->FillContextMap(graph, &device_context_map_); if (!fill_status.ok()) { + delete this; done(fill_status); return; } @@ -1492,6 +1484,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { ready.push_back(TaggedNode{n, root_frame_, 0, false}); } if (ready.empty()) { + delete this; done(Status::OK()); } else { num_outstanding_ops_ = ready.size(); @@ -1511,7 +1504,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) { struct ExecutorState::AsyncState { AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node, const NodeItem* _item, Entry* _first_input, - NodeExecStatsWrapper* _stats) + NodeExecStatsInterface* _stats) : saved_inputs(*p.inputs), saved_input_device_contexts(*p.input_device_contexts), saved_input_alloc_attrs(*p.input_alloc_attrs), @@ -1536,7 +1529,7 @@ struct ExecutorState::AsyncState { const NodeItem* item; Entry* first_input; OpKernelContext ctx; - NodeExecStatsWrapper* stats; + NodeExecStatsInterface* stats; private: OpKernelContext::Params* ParamsButClearingEigenGPUDevice( @@ -1581,7 +1574,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { params.stats_collector = stats_collector_; Status s; - NodeExecStatsWrapper* stats = nullptr; + NodeExecStatsInterface* stats = nullptr; EntryVector outputs; bool completed = false; inline_ready.push_back(tagged_node); @@ -1611,7 +1604,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { if (stats_collector_ && !tagged_node.is_dead) { // track allocations if and only if we are collecting statistics params.track_allocations = true; - stats = new NodeExecStatsWrapper(node->name()); + stats = stats_collector_->CreateNodeExecStats(node); nodestats::SetScheduled(stats, scheduled_nsec); nodestats::SetAllStart(stats); } @@ -1669,7 +1662,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { auto done = [this, state]() { Device* device = impl_->params_.device; - NodeExecStatsWrapper* stats = state->stats; // Shorthand + NodeExecStatsInterface* stats = state->stats; // Shorthand Entry* first_input = state->first_input; // Shorthand nodestats::SetOpEnd(stats); @@ -1860,7 +1853,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input, Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx, EntryVector* outputs, - NodeExecStatsWrapper* stats) { + NodeExecStatsInterface* stats) { const Node* node = item.node; DCHECK_EQ(0, outputs->size()); outputs->resize(item.num_outputs); @@ -2078,16 +2071,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node, bool ExecutorState::NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready, - NodeExecStatsWrapper* stats, + NodeExecStatsInterface* stats, TaggedNodeReadyQueue* inline_ready) { nodestats::SetAllEnd(stats); - if (stats_collector_ != nullptr && - !nodestats::SetTimelineLabel(node, stats)) { - // Only record non-transfer nodes. - // Transfers 'stats' ownership to 'stats_collector_'. - stats_collector_->Save(impl_->params_.device->name(), stats); - } else if (stats) { - delete stats; + if (stats) { + if (stats_collector_) { + stats->Done(impl_->params_.device->name()); + } else { + delete stats; + } } bool abort_run = false; diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc index 46bb8d92f8..472865ca43 100644 --- a/tensorflow/core/common_runtime/function.cc +++ b/tensorflow/core/common_runtime/function.cc @@ -414,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel( device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef, &fbody->fdef.signature(), this, fbody->arg_types, input_memory_types, fbody->ret_types, output_memory_types, graph_def_version_, &s); - *kernel = new CallOp(handle, &construction); - if (!s.ok()) { - delete *kernel; + if (s.ok()) { + *kernel = new CallOp(handle, &construction); } return s; } @@ -615,11 +614,14 @@ void PruneFunctionBody(Graph* g) { std::unordered_set<const Node*> nodes; for (auto n : g->nodes()) { // NOTE(mrry): "_Retval" nodes are stateful, and so will be added - // to the seed set of `nodes`. + // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we + // specifically exclude them as seeds, to avoid unconditionally executing + // unused argument nodes (e.g. in a function like `lambda x, y: y`). // TODO(mrry): Investigate whether the `n->IsControlFlow()` test is // still needed. It would be preferable to prune entire loops and/or // conditionals if they are not used in the graph. - if (n->IsControlFlow() || n->op_def().is_stateful()) { + if (n->IsControlFlow() || + (n->op_def().is_stateful() && n->type_string() != kArgOp)) { nodes.insert(n); } } @@ -925,29 +927,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle, } DCHECK(run_opts.runner != nullptr); - Executor::Args* exec_args = new Executor::Args; + Executor::Args exec_args; // Inherit the step_id from the caller. - exec_args->step_id = run_opts.step_id; - exec_args->rendezvous = run_opts.rendezvous; - exec_args->stats_collector = run_opts.stats_collector; - exec_args->cancellation_manager = run_opts.cancellation_manager; - exec_args->collective_executor = run_opts.collective_executor; - exec_args->step_container = run_opts.step_container; - exec_args->runner = *run_opts.runner; - exec_args->call_frame = frame; - - item->exec->RunAsync( - // Executor args - *exec_args, - // Done callback. - std::bind( - [item, frame, exec_args](DoneCallback done, - // Start unbound arguments. - const Status& status) { - delete exec_args; - done(status); - }, - std::move(done), std::placeholders::_1)); + exec_args.step_id = run_opts.step_id; + exec_args.rendezvous = run_opts.rendezvous; + exec_args.stats_collector = run_opts.stats_collector; + exec_args.cancellation_manager = run_opts.cancellation_manager; + exec_args.collective_executor = run_opts.collective_executor; + exec_args.step_container = run_opts.step_container; + exec_args.runner = *run_opts.runner; + exec_args.call_frame = frame; + + item->exec->RunAsync(exec_args, std::move(done)); } bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) { diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc index 120f480198..7bab9be9a6 100644 --- a/tensorflow/core/common_runtime/function_test.cc +++ b/tensorflow/core/common_runtime/function_test.cc @@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { // Name "SquareAndAddOneWithStatefulNodes", // Args - {"x: int32"}, + {"x: int32", "y: float32"}, // Return values - {"y: int32"}, + {"z: int32"}, // Attrs {}, // Nodes @@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { "RandomUniform", {"shape"}, {{"T", T}, {"dtype", DT_FLOAT}}}, - // y = Add<T>(a, o) - {{"y"}, "Add", {"a", "o"}, {{"T", T}}}}); + // z = Add<T>(a, o) + {{"z"}, "Add", {"a", "o"}, {{"T", T}}}}); Init({stateful_func}); auto x = test::AsTensor<int32>({1, 2, 3, 4}); - Tensor y; + auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0}); + Tensor z; FunctionLibraryRuntime::Handle handle; TF_CHECK_OK( @@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) { StepStatsCollector stats_collector(&stats); FunctionLibraryRuntime::Options opts; opts.stats_collector = &stats_collector; - TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y})); + TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z})); TF_CHECK_OK(flr0_->ReleaseHandle(handle)); TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {}, - {x}, {&y})); - test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17})); + {x, y}, {&z})); + test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17})); stats_collector.FinalizeAndSwap(&stats); - // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute. + // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to + // execute. std::set<string> expected_node_names( - {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"}); + {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"}); std::set<string> executed_node_names; for (const auto& node_stats : stats.dev_stats()[0].node_stats()) { executed_node_names.insert(node_stats.node_name()); diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h index 636cd43575..6bd29ef775 100644 --- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h +++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h @@ -26,8 +26,12 @@ namespace tensorflow { class CUDAHostAllocator : public SubAllocator { public: // Note: stream_exec cannot be null. - explicit CUDAHostAllocator(se::StreamExecutor* stream_exec) - : stream_exec_(stream_exec) { + explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node, + const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), + stream_exec_(stream_exec), + numa_node_(numa_node) { CHECK(stream_exec_ != nullptr); } ~CUDAHostAllocator() override {} @@ -39,19 +43,23 @@ class CUDAHostAllocator : public SubAllocator { if (ptr == nullptr) { LOG(WARNING) << "could not allocate pinned host memory of size: " << num_bytes; + return ptr; } + VisitAlloc(ptr, numa_node_, num_bytes); } return ptr; } void Free(void* ptr, size_t num_bytes) override { if (ptr != nullptr) { + VisitFree(ptr, numa_node_, num_bytes); stream_exec_->HostMemoryDeallocate(ptr); } } private: se::StreamExecutor* stream_exec_; // not owned, non-null + const int numa_node_; TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator); }; diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc index c8db384b64..44ffce77a1 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc @@ -22,19 +22,15 @@ limitations under the License. namespace tensorflow { -GPUBFCAllocator::GPUBFCAllocator(PlatformGpuId platform_gpu_id, +GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory, const string& name) - : GPUBFCAllocator(platform_gpu_id, total_memory, GPUOptions(), name) {} + : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {} -GPUBFCAllocator::GPUBFCAllocator(PlatformGpuId platform_gpu_id, +GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory, const GPUOptions& gpu_options, const string& name) - : BFCAllocator( - new GPUMemAllocator( - GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), - gpu_options.per_process_gpu_memory_fraction() > 1.0 || - gpu_options.experimental().use_unified_memory()), - total_memory, gpu_options.allow_growth(), name) {} + : BFCAllocator(sub_allocator, total_memory, gpu_options.allow_growth(), + name) {} } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h index 435ffb4959..3470f7a9f7 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h @@ -31,28 +31,20 @@ limitations under the License. namespace tensorflow { -// A GPU memory allocator that implements a 'best-fit with coalescing' -// algorithm. -class GPUBFCAllocator : public BFCAllocator { - public: - // 'platform_gpu_id' refers to the ID of the GPU device within - // the process and must reference a valid ID in the process. - GPUBFCAllocator(PlatformGpuId platform_gpu_id, size_t total_memory, - const string& name); - GPUBFCAllocator(PlatformGpuId platform_gpu_id, size_t total_memory, - const GPUOptions& gpu_options, const string& name); - virtual ~GPUBFCAllocator() {} - - TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator); -}; - // Suballocator for GPU memory. class GPUMemAllocator : public SubAllocator { public: + // 'platform_gpu_id' refers to the ID of the GPU device within + // the process and must reference a valid ID in the process. // Note: stream_exec cannot be null. explicit GPUMemAllocator(se::StreamExecutor* stream_exec, - bool use_unified_memory) - : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) { + PlatformGpuId gpu_id, bool use_unified_memory, + const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), + stream_exec_(stream_exec), + gpu_id_(gpu_id), + use_unified_memory_(use_unified_memory) { CHECK(stream_exec_ != nullptr); } ~GPUMemAllocator() override {} @@ -65,12 +57,14 @@ class GPUMemAllocator : public SubAllocator { } else { ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque(); } + VisitAlloc(ptr, gpu_id_.value(), num_bytes); } return ptr; } void Free(void* ptr, size_t num_bytes) override { if (ptr != nullptr) { + VisitFree(ptr, gpu_id_.value(), num_bytes); if (use_unified_memory_) { stream_exec_->UnifiedMemoryDeallocate(ptr); } else { @@ -82,11 +76,25 @@ class GPUMemAllocator : public SubAllocator { private: se::StreamExecutor* stream_exec_; // not owned, non-null + const PlatformGpuId gpu_id_; const bool use_unified_memory_ = false; TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator); }; +// A GPU memory allocator that implements a 'best-fit with coalescing' +// algorithm. +class GPUBFCAllocator : public BFCAllocator { + public: + GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory, + const string& name); + GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory, + const GPUOptions& gpu_options, const string& name); + ~GPUBFCAllocator() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator); +}; + } // namespace tensorflow #endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_ diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc index 518ccba580..e313135d8d 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" +#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h" #include "tensorflow/core/common_runtime/gpu/gpu_init.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -46,7 +47,11 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use, } TEST(GPUBFCAllocatorTest, NoDups) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); CheckStats(&a, 0, 0, 0, 0); // Allocate a lot of raw pointers @@ -75,7 +80,11 @@ TEST(GPUBFCAllocatorTest, NoDups) { } TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); // Allocate 256 raw pointers of sizes between 100 bytes and about // a meg random::PhiloxRandom philox(123, 17); @@ -133,7 +142,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) { } TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); CheckStats(&a, 0, 0, 0, 0); float* first_ptr = a.Allocate<float>(1024); @@ -168,18 +181,30 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) { } TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); float* ptr = a.Allocate<float>(0); EXPECT_EQ(nullptr, ptr); } TEST(GPUBFCAllocatorTest, TracksSizes) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); EXPECT_EQ(true, a.TracksAllocationSizes()); } TEST(GPUBFCAllocatorTest, AllocatedVsRequested) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); float* t1 = a.Allocate<float>(1); EXPECT_EQ(4, a.RequestedSize(t1)); EXPECT_EQ(256, a.AllocatedSize(t1)); @@ -187,8 +212,12 @@ TEST(GPUBFCAllocatorTest, AllocatedVsRequested) { } TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) { + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); // Configure a 1MiB byte limit - GPUBFCAllocator a(PlatformGpuId(0), 1 << 20, "GPU_0_bfc"); + GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc"); float* first_ptr = a.Allocate<float>(1 << 6); float* second_ptr = a.Allocate<float>(1 << 20); @@ -203,7 +232,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) { options.set_allow_growth(true); // Max of 2GiB, but starts out small. - GPUBFCAllocator a(PlatformGpuId(0), 1LL << 31, options, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc"); // Allocate 10 raw pointers of sizes between 100 bytes and about // 64 megs. @@ -264,8 +297,15 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) { } TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) { - GPUBFCAllocator a(PlatformGpuId(0), 1UL << 60, "GPU_0_bfc"); - GPUBFCAllocator b(PlatformGpuId(0), 1UL << 60, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc"); + sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc"); void* amem = a.AllocateRaw(1, 1); void* bmem = b.AllocateRaw(1, 1 << 30); a.DeallocateRaw(amem); @@ -273,7 +313,11 @@ TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) { } static void BM_Allocation(int iters) { - GPUBFCAllocator a(PlatformGpuId(0), 1uLL << 33, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc"); // Exercise a few different allocation sizes std::vector<size_t> sizes = {256, 4096, 16384, 524288, 512, 1048576, 10485760, 104857600, @@ -289,7 +333,11 @@ static void BM_Allocation(int iters) { BENCHMARK(BM_Allocation); static void BM_AllocationThreaded(int iters, int num_threads) { - GPUBFCAllocator a(PlatformGpuId(0), 1uLL << 33, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc"); thread::ThreadPool pool(Env::Default(), "test", num_threads); std::atomic_int_fast32_t count(iters); mutex done_lock; @@ -325,7 +373,11 @@ BENCHMARK(BM_AllocationThreaded)->Arg(1)->Arg(4)->Arg(16); // A more complex benchmark that defers deallocation of an object for // "delay" allocations. static void BM_AllocationDelayed(int iters, int delay) { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); // Exercise a few different allocation sizes std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024}; int size_index = 0; @@ -363,7 +415,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test { // only methods inside this class can access private members of BFCAllocator. void TestBinDebugInfo() { - GPUBFCAllocator a(PlatformGpuId(0), 1 << 30, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc"); std::vector<void*> initial_ptrs; std::vector<size_t> initial_ptrs_allocated_sizes; @@ -441,7 +497,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test { } void TestLog2FloorNonZeroSlow() { - GPUBFCAllocator a(PlatformGpuId(0), 1 /* total_memory */, "GPU_0_bfc"); + PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc"); EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0)); EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1)); EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2)); diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc index 553a5628ad..d85ca8892f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc @@ -27,7 +27,7 @@ limitations under the License. namespace tensorflow { -GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator, +GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator, PlatformGpuId platform_gpu_id) : base_allocator_(allocator) { stream_exec_ = @@ -61,14 +61,6 @@ void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) { #endif // GOOGLE_CUDA } -void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) { - return base_allocator_->AddAllocVisitor(visitor); -} - -void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) { - return base_allocator_->AddFreeVisitor(visitor); -} - bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; } } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h index 8f38cc5a18..8df3724bc4 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h @@ -19,7 +19,7 @@ limitations under the License. #include <memory> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -29,20 +29,18 @@ namespace tensorflow { // An allocator that wraps a GPU allocator and adds debugging // functionality that verifies that users do not write outside their // allocated memory. -class GPUcudaMallocAllocator : public VisitableAllocator { +class GPUcudaMallocAllocator : public Allocator { public: - explicit GPUcudaMallocAllocator(VisitableAllocator* allocator, + explicit GPUcudaMallocAllocator(Allocator* allocator, PlatformGpuId platform_gpu_id); ~GPUcudaMallocAllocator() override; string Name() override { return "gpu_debug"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - void AddFreeVisitor(Visitor visitor) override; bool TracksAllocationSizes() override; private: - VisitableAllocator* base_allocator_ = nullptr; // owned + Allocator* base_allocator_ = nullptr; // owned se::StreamExecutor* stream_exec_; // Not owned. diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc index badb021aa5..989ddbe4af 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc @@ -73,7 +73,7 @@ void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) { // ----------------------------------------------------------------------------- // GPUDebugAllocator // ----------------------------------------------------------------------------- -GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator, +GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator, PlatformGpuId platform_gpu_id) : base_allocator_(allocator) { stream_exec_ = @@ -112,14 +112,6 @@ void GPUDebugAllocator::DeallocateRaw(void* ptr) { base_allocator_->DeallocateRaw(ptr); } -void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) { - return base_allocator_->AddAllocVisitor(visitor); -} - -void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) { - return base_allocator_->AddFreeVisitor(visitor); -} - bool GPUDebugAllocator::TracksAllocationSizes() { return true; } size_t GPUDebugAllocator::RequestedSize(const void* ptr) { @@ -159,7 +151,7 @@ bool GPUDebugAllocator::CheckFooter(void* ptr) { // ----------------------------------------------------------------------------- // GPUNanResetAllocator // ----------------------------------------------------------------------------- -GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator, +GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator, PlatformGpuId platform_gpu_id) : base_allocator_(allocator) { stream_exec_ = @@ -202,14 +194,6 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) { base_allocator_->DeallocateRaw(ptr); } -void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) { - return base_allocator_->AddAllocVisitor(visitor); -} - -void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) { - return base_allocator_->AddFreeVisitor(visitor); -} - size_t GPUNanResetAllocator::RequestedSize(const void* ptr) { return base_allocator_->RequestedSize(ptr); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h index 9e007ed8c1..17757a106c 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h @@ -21,7 +21,7 @@ limitations under the License. #include <unordered_map> #include "tensorflow/core/common_runtime/gpu/gpu_id.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/stream_executor.h" #include "tensorflow/core/platform/types.h" @@ -31,16 +31,14 @@ namespace tensorflow { // An allocator that wraps a GPU allocator and adds debugging // functionality that verifies that users do not write outside their // allocated memory. -class GPUDebugAllocator : public VisitableAllocator { +class GPUDebugAllocator : public Allocator { public: - explicit GPUDebugAllocator(VisitableAllocator* allocator, + explicit GPUDebugAllocator(Allocator* allocator, PlatformGpuId platform_gpu_id); ~GPUDebugAllocator() override; string Name() override { return "gpu_debug"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - void AddFreeVisitor(Visitor visitor) override; bool TracksAllocationSizes() override; size_t RequestedSize(const void* ptr) override; size_t AllocatedSize(const void* ptr) override; @@ -53,7 +51,7 @@ class GPUDebugAllocator : public VisitableAllocator { bool CheckFooter(void* ptr); private: - VisitableAllocator* base_allocator_ = nullptr; // owned + Allocator* base_allocator_ = nullptr; // owned se::StreamExecutor* stream_exec_; // Not owned. @@ -63,23 +61,21 @@ class GPUDebugAllocator : public VisitableAllocator { // An allocator that wraps a GPU allocator and resets the memory on // allocation and free to 'NaN', helping to identify cases where the // user forgets to initialize the memory. -class GPUNanResetAllocator : public VisitableAllocator { +class GPUNanResetAllocator : public Allocator { public: - explicit GPUNanResetAllocator(VisitableAllocator* allocator, + explicit GPUNanResetAllocator(Allocator* allocator, PlatformGpuId platform_gpu_id); ~GPUNanResetAllocator() override; string Name() override { return "gpu_nan_reset"; } void* AllocateRaw(size_t alignment, size_t num_bytes) override; void DeallocateRaw(void* ptr) override; - void AddAllocVisitor(Visitor visitor) override; - void AddFreeVisitor(Visitor visitor) override; size_t RequestedSize(const void* ptr) override; size_t AllocatedSize(const void* ptr) override; void GetStats(AllocatorStats* stats) override; void ClearStats() override; private: - VisitableAllocator* base_allocator_ = nullptr; // owned + Allocator* base_allocator_ = nullptr; // owned se::StreamExecutor* stream_exec_; // Not owned. diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc index bc3e3a8c35..aca08a7e33 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc @@ -35,7 +35,10 @@ namespace { TEST(GPUDebugAllocatorTest, OverwriteDetection_None) { const PlatformGpuId platform_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id); auto stream_exec = GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); @@ -60,7 +63,10 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) { EXPECT_DEATH( { const PlatformGpuId platform_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id); auto stream_exec = GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); @@ -93,7 +99,10 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) { EXPECT_DEATH( { const PlatformGpuId platform_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id); auto stream_exec = GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); @@ -123,7 +132,10 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) { TEST(GPUDebugAllocatorTest, ResetToNan) { const PlatformGpuId platform_gpu_id(0); - GPUNanResetAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id); auto stream_exec = GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(); @@ -165,8 +177,11 @@ TEST(GPUDebugAllocatorTest, ResetToNan) { TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) { const PlatformGpuId platform_gpu_id(0); // NaN reset must be the outer-most allocator. + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); GPUNanResetAllocator a( - new GPUDebugAllocator(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id), platform_gpu_id); auto stream_exec = @@ -208,15 +223,21 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) { TEST(GPUDebugAllocatorTest, TracksSizes) { const PlatformGpuId platform_gpu_id(0); - GPUDebugAllocator a(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); + GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id); EXPECT_EQ(true, a.TracksAllocationSizes()); } TEST(GPUDebugAllocatorTest, AllocatedVsRequested) { const PlatformGpuId platform_gpu_id(0); + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, false /*use_unified_memory*/, {}, {}); GPUNanResetAllocator a( - new GPUDebugAllocator(new GPUBFCAllocator(platform_gpu_id, 1 << 30, ""), + new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""), platform_gpu_id), platform_gpu_id); float* t1 = a.Allocate<float>(1); diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc index 4bf23bc017..cf3faf68ff 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc @@ -41,7 +41,6 @@ limitations under the License. #include "tensorflow/core/common_runtime/gpu/gpu_util.h" #include "tensorflow/core/common_runtime/gpu_device_context.h" #include "tensorflow/core/common_runtime/local_device.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_kernel.h" @@ -285,6 +284,38 @@ BaseGPUDevice::~BaseGPUDevice() { for (auto ctx : device_contexts_) ctx->Unref(); } +// This should be idempotent if already initialized. +Status BaseGPUDevice::InitScratchBuffers() { + mutex_lock l(scratch_init_mutex_); + if (scratch_.size() < max_streams_) { + for (int i = 0; i < max_streams_; i++) { + DCHECK(streams_[i]); + if (scratch_.size() > i && scratch_[i]) continue; + size_t scratch_buffer_size = + Eigen::kCudaScratchSize + sizeof(unsigned int); + void* scratch_buffer = gpu_allocator_->AllocateRaw( + Allocator::kAllocatorAlignment, scratch_buffer_size); + if (scratch_buffer == nullptr) { + return errors::FailedPrecondition( + "Failed to allocate scratch buffer for device ", + tf_gpu_id_.value()); + } + se::DeviceMemory<char> mem( + se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); + + bool ok = executor_->SynchronousMemZero( + &mem, Eigen::kCudaScratchSize + sizeof(unsigned int)); + if (!ok) { + return errors::FailedPrecondition( + "Failed to memcopy into scratch buffer for device ", + tf_gpu_id_.value()); + } + scratch_.push_back(static_cast<char*>(scratch_buffer)); + } + } + return Status::OK(); +} + Status BaseGPUDevice::Init(const SessionOptions& options) { auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_); if (!executor_status.status().ok()) { @@ -303,27 +334,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) { for (int i = 0; i < max_streams_; i++) { streams_.push_back(StreamGroupFactory::Global().GetOrCreate( tf_gpu_id_, i, executor_, options.config.gpu_options())); - - size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int); - void* scratch_buffer = gpu_allocator_->AllocateRaw( - Allocator::kAllocatorAlignment, scratch_buffer_size); - if (scratch_buffer == nullptr) { - return errors::FailedPrecondition( - "Failed to allocate scratch buffer for device ", tf_gpu_id_.value()); - } - scratch_.push_back(static_cast<char*>(scratch_buffer)); - - se::DeviceMemory<char> mem( - se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size)); - - bool ok = executor_->SynchronousMemZero( - &mem, Eigen::kCudaScratchSize + sizeof(unsigned int)); - if (!ok) { - return errors::FailedPrecondition( - "Failed to memcopy into scratch buffer for device ", - tf_gpu_id_.value()); - } - device_contexts_.push_back(new GPUDeviceContext( i, streams_.back()->compute, streams_.back()->host_to_device, streams_.back()->device_to_host, streams_.back()->device_to_device)); @@ -870,10 +880,11 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() { return new ConcretePerOpGpuDevice(); } -void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, - PerOpGpuDevice* device, - DeviceContext* dc, - Allocator* allocator) { +Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, + PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) { + TF_RETURN_IF_ERROR(InitScratchBuffers()); if (dc) { const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc); const int stream_id = gpu_dc->stream_id(); @@ -884,6 +895,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context, } else { ReinitializeDevice(context, device, 0, allocator); } + return Status::OK(); } Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr, diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h index 684cc0c1de..b25fe8645f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_device.h +++ b/tensorflow/core/common_runtime/gpu/gpu_device.h @@ -86,8 +86,9 @@ class BaseGPUDevice : public LocalDevice { // The caller owns the returned device. PerOpGpuDevice* MakeGpuDevice() override; - void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, - DeviceContext* dc, Allocator* allocator) override; + Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) override; // Returns the platform GPU id of this device within the native driver system; // e.g., for CUDA this is the ordinal of the GPU within the system. @@ -125,6 +126,7 @@ class BaseGPUDevice : public LocalDevice { class StreamGroupFactory; gtl::InlinedVector<StreamGroup*, 4> streams_; + mutex scratch_init_mutex_; gtl::InlinedVector<char*, 4> scratch_; std::vector<GPUDeviceContext*> device_contexts_; GpuDeviceInfo* gpu_device_info_ = nullptr; @@ -135,6 +137,9 @@ class BaseGPUDevice : public LocalDevice { std::unique_ptr<EventMgr> em_; std::unique_ptr<thread::ThreadPool> thread_pool_; + // Initialize scractch buffers used by Eigen. + Status InitScratchBuffers(); + void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device, int stream_id, Allocator* allocator); diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc index a5b46382f1..3e95374fda 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc @@ -76,12 +76,16 @@ GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) { // This function is defined for debugging problems with the allocators. GPUProcessState::~GPUProcessState() { CHECK_EQ(this, instance_); - for (auto p : gpu_allocators_) { - delete p; - } instance_ = nullptr; } +int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) { + // Return the NUMA node associated with the GPU's StreamExecutor. + se::StreamExecutor* se = + GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie(); + return se->GetDeviceDescription().numa_node(); +} + Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, TfGpuId tf_gpu_id, size_t total_bytes) { @@ -93,13 +97,10 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) { gpu_allocators_.resize(tf_gpu_id.value() + 1); - if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) - gpu_al_.resize(tf_gpu_id.value() + 1); } - if (gpu_allocators_[tf_gpu_id.value()] == nullptr) { - VisitableAllocator* gpu_allocator; - + AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()]; + if (allocator_parts.allocator.get() == nullptr) { // Validate allocator types. if (!allocator_type.empty() && allocator_type != "BFC") { LOG(ERROR) << "Invalid allocator type: " << allocator_type; @@ -108,8 +109,18 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, PlatformGpuId platform_gpu_id; TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id)); - gpu_allocator = - new GPUBFCAllocator(platform_gpu_id, total_bytes, options, + int bus_id = BusIdForGPU(tf_gpu_id); + while (bus_id >= gpu_visitors_.size()) { + gpu_visitors_.push_back({}); + } + GPUMemAllocator* sub_allocator = new GPUMemAllocator( + GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(), + platform_gpu_id, + (options.per_process_gpu_memory_fraction() > 1.0 || + options.experimental().use_unified_memory()), + gpu_visitors_[bus_id], {}); + Allocator* gpu_allocator = + new GPUBFCAllocator(sub_allocator, total_bytes, options, strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc")); // If true, checks for memory overwrites by writing @@ -124,34 +135,25 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options, gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id); } - gpu_allocators_[tf_gpu_id.value()] = gpu_allocator; - - // If there are any pending AllocVisitors for this bus, add - // them now. - se::StreamExecutor* se = - GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie(); - int bus_id = se->GetDeviceDescription().numa_node(); - if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) { - for (const auto& v : gpu_visitors_[bus_id]) { - gpu_allocator->AddAllocVisitor(v); - } - } + + Allocator* recording_allocator = nullptr; if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { ProcessState::MemDesc md; md.loc = ProcessState::MemDesc::GPU; md.dev_index = platform_gpu_id.value(); md.gpu_registered = false; md.nic_registered = true; - if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) { - gpu_al_.resize(tf_gpu_id.value() + 1); - } - gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator( + recording_allocator = new internal::RecordingAllocator( &process_state_->mem_desc_map_, gpu_allocator, md, &mu_); } + allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator, + std::unique_ptr<Allocator>(recording_allocator)}; + } + if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { + return allocator_parts.recording_allocator.get(); + } else { + return allocator_parts.allocator.get(); } - if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) - return gpu_al_[tf_gpu_id.value()]; - return gpu_allocators_[tf_gpu_id.value()]; #else LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda."; return nullptr; @@ -173,11 +175,12 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { tf_shared_lock lock(mu_); if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types && - static_cast<int>(cuda_al_.size()) > 0) { - return cuda_al_[0]; + !cuda_host_allocators_.empty() && + cuda_host_allocators_[0].recording_allocator != nullptr) { + return cuda_host_allocators_[0].recording_allocator.get(); } if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) { - return cuda_host_allocators_[0]; + return cuda_host_allocators_[0].allocator.get(); } } @@ -191,7 +194,7 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { // it knows is valid. se::StreamExecutor* se = nullptr; for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) { - if (gpu_allocators_[i] != nullptr) { + if (gpu_allocators_[i].allocator != nullptr) { se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie(); break; } @@ -200,6 +203,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { CHECK_NE(nullptr, se); while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) { + while (cuda_host_alloc_visitors_.size() <= numa_node) { + cuda_host_alloc_visitors_.push_back({}); + } + while (cuda_host_free_visitors_.size() <= numa_node) { + cuda_host_free_visitors_.push_back({}); + } + SubAllocator* sub_allocator = new CUDAHostAllocator( + se, numa_node, cuda_host_alloc_visitors_[numa_node], + cuda_host_free_visitors_[numa_node]); // TODO(zheng-xq): evaluate whether 64GB by default is the best choice. int64 cuda_host_mem_limit_in_mb = -1; Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB", @@ -209,62 +221,92 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) { LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message(); } int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20); - VisitableAllocator* allocator = - new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit, + Allocator* allocator = + new BFCAllocator(sub_allocator, cuda_host_mem_limit, true /*allow_growth*/, "cuda_host_bfc" /*name*/); - if (LogMemory::IsEnabled()) { + if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) { // Wrap the allocator to track allocation ids for better logging // at the cost of performance. - allocator = new TrackingVisitableAllocator(allocator, true); + allocator = new TrackingAllocator(allocator, true); } - cuda_host_allocators_.push_back(allocator); + cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator), + sub_allocator, + std::unique_ptr<Allocator>(nullptr)}); + AllocatorParts& allocator_parts = cuda_host_allocators_.back(); if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { ProcessState::MemDesc md; md.loc = ProcessState::MemDesc::CPU; md.dev_index = 0; md.gpu_registered = true; md.nic_registered = false; - cuda_al_.push_back(new internal::RecordingAllocator( - &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md, - &mu_)); + allocator_parts.recording_allocator.reset( + new internal::RecordingAllocator(&process_state_->mem_desc_map_, + allocator_parts.allocator.get(), md, + &mu_)); } } - if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) - return cuda_al_[0]; - return cuda_host_allocators_[0]; + if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) { + return cuda_host_allocators_[0].recording_allocator.get(); + } else { + return cuda_host_allocators_[0].allocator.get(); + } } void GPUProcessState::AddGPUAllocVisitor(int bus_id, - const AllocVisitor& visitor) { - CHECK(process_state_); + const SubAllocator::Visitor& visitor) { #if GOOGLE_CUDA mutex_lock lock(mu_); - for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) { - se::StreamExecutor* se = - GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie(); - if (gpu_allocators_[i] && - (se->GetDeviceDescription().numa_node() + 1) == bus_id) { - gpu_allocators_[i]->AddAllocVisitor(visitor); - } - } + CHECK(gpu_allocators_.empty()) // Crash OK + << "AddGPUAllocVisitor must be called before " + "first call to GetGPUAllocator."; while (bus_id >= static_cast<int64>(gpu_visitors_.size())) { - gpu_visitors_.push_back(std::vector<AllocVisitor>()); + gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>()); } gpu_visitors_[bus_id].push_back(visitor); #endif // GOOGLE_CUDA } +void GPUProcessState::AddCUDAHostAllocVisitor( + int numa_node, const SubAllocator::Visitor& visitor) { +#if GOOGLE_CUDA + mutex_lock lock(mu_); + CHECK(cuda_host_allocators_.empty()) // Crash OK + << "AddCUDAHostAllocVisitor must be called before " + "first call to GetCUDAHostAllocator."; + while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) { + cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>()); + } + cuda_host_alloc_visitors_[numa_node].push_back(visitor); +#endif // GOOGLE_CUDA +} + +void GPUProcessState::AddCUDAHostFreeVisitor( + int numa_node, const SubAllocator::Visitor& visitor) { +#if GOOGLE_CUDA + mutex_lock lock(mu_); + CHECK(cuda_host_allocators_.empty()) // Crash OK + << "AddCUDAHostFreeVisitor must be called before " + "first call to GetCUDAHostAllocator."; + while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) { + cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>()); + } + cuda_host_free_visitors_[numa_node].push_back(visitor); +#endif // GOOGLE_CUDA +} + void GPUProcessState::TestOnlyReset() { - process_state_->ProcessState::TestOnlyReset(); + if (process_state_) { + process_state_->ProcessState::TestOnlyReset(); + } { mutex_lock lock(mu_); gpu_device_enabled_ = false; + gpu_allocators_.clear(); gpu_visitors_.clear(); - gtl::STLDeleteElements(&gpu_allocators_); - gtl::STLDeleteElements(&cuda_host_allocators_); - gtl::STLDeleteElements(&gpu_al_); - gtl::STLDeleteElements(&cuda_al_); + cuda_host_allocators_.clear(); + cuda_host_alloc_visitors_.clear(); + cuda_host_free_visitors_.clear(); } } diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h index cb41c3c6bd..43e9a31660 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h +++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h @@ -32,7 +32,6 @@ limitations under the License. namespace tensorflow { class Allocator; -class VisitableAllocator; class PoolAllocator; // Singleton that manages per-process state when GPUs are present. @@ -72,18 +71,30 @@ class GPUProcessState { virtual Allocator* GetCUDAHostAllocator(int numa_node); - // Registers a function to be called once on every new Region - // allocated by every GPURegionAllocator proximate to the specified - // bus. The AllocVisitor is provided with a memory pointer and the - // size of the area it identifies. The pointer is not guaranteed to - // be valid after the call terminates. The intention is for this - // interface to be used for network device memory registration. - // "bus_id" is platform-specific. On many platforms it - // should be 0. On machines with multiple PCIe buses, it should be - // the index of one of the PCIe buses. If the bus_id is invalid, - // results are undefined. - typedef std::function<void(void*, size_t)> AllocVisitor; - virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor); + // Registers a Visitor to be invoked on new chunks of memory allocated by the + // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor + // is provided with a memory pointer, a GPU id, and the size of the area it + // identifies. The pointer is not guaranteed to be valid after the call + // terminates. The intention is for this interface to be used for network + // device memory registration. "bus_id" is platform-specific. On many + // platforms it should be 0. On machines with multiple PCIe buses, it should + // be the index of one of the PCIe buses (maybe the NUMA node at which the + // PCIe is rooted). If the bus_id is invalid, results are undefined. + virtual void AddGPUAllocVisitor(int bus_id, + const SubAllocator::Visitor& visitor); + + // Registers a Visitor to be invoked on new chunks of memory allocated by + // the SubAllocator of the CUDAHostAllocator for the given numa_node. + virtual void AddCUDAHostAllocVisitor(int numa_node, + const SubAllocator::Visitor& visitor); + + // Registers a Visitor to be invoked on each chunk handed back for freeing to + // the SubAllocator of the CUDAHostAllocator for the given numa_node. + virtual void AddCUDAHostFreeVisitor(int numa_node, + const SubAllocator::Visitor& visitor); + + // Returns bus_id for the given GPU id. + virtual int BusIdForGPU(TfGpuId tf_gpu_id); protected: GPUProcessState(); @@ -103,16 +114,21 @@ class GPUProcessState { mutex mu_; - std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_); - std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_); - std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_); + struct AllocatorParts { + std::unique_ptr<Allocator> allocator; + SubAllocator* sub_allocator; // owned by allocator + std::unique_ptr<Allocator> recording_allocator; + }; + std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_); + std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_); - virtual ~GPUProcessState(); + std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_); + std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_ + GUARDED_BY(mu_); + std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_ + GUARDED_BY(mu_); - // Optional RecordingAllocators that wrap the corresponding - // Allocators for runtime attribute use analysis. - std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_); - std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_); + virtual ~GPUProcessState(); friend class GPUDeviceTest; }; diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc index 583bff2c07..6b2f6547b0 100644 --- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc +++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc @@ -31,7 +31,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) { 2 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/)); @@ -49,7 +50,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) { 0 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); EXPECT_EQ(0, pool.get_from_pool_count()); @@ -82,7 +84,8 @@ TEST(PoolAllocatorTest, Alignment) { 0 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); for (int i = 0; i < 16; ++i) { size_t alignment = 1 << i; @@ -97,8 +100,8 @@ TEST(PoolAllocatorTest, Alignment) { TEST(PoolAllocatorTest, AutoResize) { PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/, - new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder, - "pool"); + new BasicCPUAllocator(0 /*numa_node*/, {}, {}), + new NoopRounder, "pool"); // Alloc/dealloc 10 sizes just a few times, confirming pool size // stays at 2. @@ -123,14 +126,32 @@ TEST(PoolAllocatorTest, AutoResize) { } TEST(PoolAllocatorTest, CudaHostAllocator) { + int alloc_count = 0; + int64 alloc_size = 0; + SubAllocator::Visitor alloc_visitor = + [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) { + ++alloc_count; + alloc_size += size; + }; + int free_count = 0; + int64 free_size = 0; + SubAllocator::Visitor free_visitor = + [&free_count, &free_size](void* ptr, int numa_node, int64 size) { + ++free_count; + free_size += size; + }; se::Platform* platform = se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie(); - PoolAllocator pool( - 2 /*pool_size_limit*/, false /*auto_resize*/, - new CUDAHostAllocator( - platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), - new NoopRounder, "pool"); + CUDAHostAllocator* sub_allocator = new CUDAHostAllocator( + platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) + .ValueOrDie(), + 0 /*numa_node*/, {alloc_visitor}, {free_visitor}); + PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/, + sub_allocator, new NoopRounder, "pool"); + EXPECT_EQ(0, alloc_count); + EXPECT_EQ(0, alloc_size); + EXPECT_EQ(0, free_count); + EXPECT_EQ(0, free_size); // Repeatedly Get a 16-byte value, confirming that there's only // one real allocation. @@ -138,6 +159,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { EXPECT_EQ(0, pool.get_from_pool_count()); EXPECT_EQ(1, pool.allocated_count()); EXPECT_NE(nullptr, p1_16); + EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes + // Each suballocation includes a 16B ChunkPrefix. + static const int kChunkPrefixSize = 16; + EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size); pool.DeallocateRaw(p1_16); // Pool contents {16} EXPECT_EQ(1, pool.put_count()); @@ -148,6 +173,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { pool.DeallocateRaw(p2_16); // Put it back. // Pool contents {16} EXPECT_EQ(2, pool.put_count()); + EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes + EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(0, free_count); // Get two more values of different sizes. void* p3_4 = pool.AllocateRaw(4, 4); @@ -160,6 +188,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer. EXPECT_NE(nullptr, p4_2); EXPECT_EQ(0, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(0, free_count); // The pool is full: when we put back p4_2, the 16-byte buffer // should be evicted since it was least recently inserted. @@ -167,6 +198,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { // Pool contents {2, 4} EXPECT_EQ(4, pool.put_count()); EXPECT_EQ(1, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(1, free_count); + EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size); // Re-getting and putting size 2 or 4 should not alter pool size or // num-evicted. @@ -180,12 +215,20 @@ TEST(PoolAllocatorTest, CudaHostAllocator) { EXPECT_EQ(6, pool.put_count()); EXPECT_EQ(3, pool.allocated_count()); EXPECT_EQ(1, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(1, free_count); + EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size); pool.Clear(); EXPECT_EQ(0, pool.get_from_pool_count()); EXPECT_EQ(0, pool.put_count()); EXPECT_EQ(0, pool.allocated_count()); EXPECT_EQ(0, pool.evicted_count()); + EXPECT_EQ(3, alloc_count); + EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size); + EXPECT_EQ(3, free_count); + EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size); } TEST(PoolAllocatorTest, Pow2Rounder) { @@ -206,7 +249,8 @@ TEST(PoolAllocatorTest, Name) { 2 /*pool_size_limit*/, false /*auto_resize*/, new CUDAHostAllocator( platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0)) - .ValueOrDie()), + .ValueOrDie(), + 0 /*numa_node*/, {}, {}), new NoopRounder, "pool"); EXPECT_EQ("pool", pool.Name()); } diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index 346befc255..4475fa979e 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/graph_execution_state.h" #include <memory> +#include <set> #include <string> #include <unordered_set> #include <utility> @@ -560,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph( grappler::GrapplerItem item; item.id = "tf_graph"; graph_->ToGraphDef(&item.graph); + // TODO(b/114748242): Add a unit test to test this bug fix. + if (flib_def_) { + *item.graph.mutable_library() = flib_def_->ToProto(); + } item.fetch.insert(item.fetch.end(), options.callable_options.fetch().begin(), @@ -727,12 +732,50 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options, TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping( OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options)); + int64 collective_graph_key = options.collective_graph_key; + if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) { + // BuildGraphOptions does not specify a collective_graph_key. Check all + // nodes in the Graph and FunctionLibraryDefinition for collective ops and + // if found, initialize a collective_graph_key as a hash of the ordered set + // of instance keys. + std::set<int32> instance_key_set; + for (Node* node : optimized_graph->nodes()) { + if (node->IsCollective()) { + int32 instance_key; + TF_RETURN_IF_ERROR( + GetNodeAttr(node->attrs(), "instance_key", &instance_key)); + instance_key_set.emplace(instance_key); + } else { + const FunctionDef* fdef = optimized_flib->Find(node->def().op()); + if (fdef != nullptr) { + for (const NodeDef& ndef : fdef->node_def()) { + if (ndef.op() == "CollectiveReduce" || + ndef.op() == "CollectiveBcastSend" || + ndef.op() == "CollectiveBcastRecv") { + int32 instance_key; + TF_RETURN_IF_ERROR( + GetNodeAttr(ndef, "instance_key", &instance_key)); + instance_key_set.emplace(instance_key); + } + } + } + } + } + if (!instance_key_set.empty()) { + uint64 hash = 0x8774aa605c729c72ULL; + for (int32 instance_key : instance_key_set) { + hash = Hash64Combine(instance_key, hash); + } + collective_graph_key = hash; + } + } + // Copy the extracted graph in order to make its node ids dense, // since the local CostModel used to record its stats is sized by // the largest node id. std::unique_ptr<ClientGraph> dense_copy( new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types, - rewrite_metadata.fetch_types)); + rewrite_metadata.fetch_types, collective_graph_key)); CopyGraph(*optimized_graph, &dense_copy->graph); // TODO(vrv): We should check invariants of the graph here. diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h index d44a24c87b..9cabe478a6 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.h +++ b/tensorflow/core/common_runtime/graph_execution_state.h @@ -50,17 +50,20 @@ struct GraphExecutionStateOptions { // BuildGraphOptions. struct ClientGraph { explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib, - DataTypeVector feed_types, DataTypeVector fetch_types) + DataTypeVector feed_types, DataTypeVector fetch_types, + int64 collective_graph_key) : flib_def(std::move(flib)), graph(flib_def.get()), feed_types(std::move(feed_types)), - fetch_types(std::move(fetch_types)) {} + fetch_types(std::move(fetch_types)), + collective_graph_key(collective_graph_key) {} // Each client-graph gets its own function library since optimization passes // post rewrite for execution might want to introduce new functions. std::unique_ptr<FunctionLibraryDefinition> flib_def; Graph graph; DataTypeVector feed_types; DataTypeVector fetch_types; + int64 collective_graph_key; }; // GraphExecutionState is responsible for generating an diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc index 0a1797fa19..f9aef3af70 100644 --- a/tensorflow/core/common_runtime/graph_runner.cc +++ b/tensorflow/core/common_runtime/graph_runner.cc @@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous { } mutex_lock l(mu_); - string edge_name = std::string(parsed.edge_name); + string edge_name(parsed.edge_name); if (table_.count(edge_name) > 0) { return errors::Internal("Send of an already sent tensor"); } @@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous { Tensor tensor; Status status = Status::OK(); { - string key = std::string(parsed.edge_name); + string key(parsed.edge_name); mutex_lock l(mu_); if (table_.count(key) <= 0) { status = errors::Internal("Did not find key ", key); diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 6b76e7e0e7..538a70668a 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -23,10 +23,11 @@ limitations under the License. #include <cstdlib> #include "tensorflow/core/common_runtime/bfc_allocator.h" -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/common_runtime/pool_allocator.h" #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/mem.h" +#include "tensorflow/core/platform/numa.h" #ifndef INTEL_MKL_DNN_ONLY #include "i_malloc.h" @@ -38,19 +39,113 @@ typedef unsigned int uint; namespace tensorflow { -class MklSubAllocator : public SubAllocator { +class MklSubAllocator : public BasicCPUAllocator { public: + MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {} ~MklSubAllocator() override {} +}; + +// CPU allocator that handles small-size allocations by calling +// suballocator directly. Mostly, it is just a wrapper around a suballocator +// (that calls malloc and free directly) with support for bookkeeping. +class MklSmallSizeAllocator : public Allocator { + public: + MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory, + const string& name) + : sub_allocator_(sub_allocator), name_(name) { + stats_.bytes_limit = total_memory; + } + ~MklSmallSizeAllocator() override {} + + TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator); + + inline string Name() override { return name_; } + + void* AllocateRaw(size_t alignment, size_t num_bytes) override { + void* ptr = sub_allocator_->Alloc(alignment, num_bytes); + if (ptr != nullptr) { + std::pair<void*, size_t> map_val(ptr, num_bytes); + mutex_lock l(mutex_); + // Check that insertion in the hash map was successful. + CHECK(map_.insert(map_val).second); + // Increment statistics for small-size allocations. + IncrementStats(num_bytes); + } + return ptr; + } + + void DeallocateRaw(void* ptr) override { + if (ptr == nullptr) { + LOG(ERROR) << "tried to deallocate nullptr"; + return; + } + + mutex_lock l(mutex_); + auto map_iter = map_.find(ptr); + if (map_iter != map_.end()) { + // Call free visitors. + size_t dealloc_bytes = map_iter->second; + sub_allocator_->Free(ptr, dealloc_bytes); + DecrementStats(dealloc_bytes); + map_.erase(map_iter); + } else { + LOG(ERROR) << "tried to deallocate invalid pointer"; + return; + } + } + + inline bool IsSmallSizeAllocation(const void* ptr) const { + mutex_lock l(mutex_); + return map_.find(ptr) != map_.end(); + } + + void GetStats(AllocatorStats* stats) override { + mutex_lock l(mutex_); + *stats = stats_; + } - void* Alloc(size_t alignment, size_t num_bytes) override { - return port::AlignedMalloc(num_bytes, alignment); + void ClearStats() override { + mutex_lock l(mutex_); + stats_.Clear(); } - void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); } + + private: + // Increment statistics for the allocator handling small allocations. + inline void IncrementStats(size_t alloc_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + ++stats_.num_allocs; + stats_.bytes_in_use += alloc_size; + stats_.max_bytes_in_use = + std::max(stats_.max_bytes_in_use, stats_.bytes_in_use); + stats_.max_alloc_size = + std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size)); + } + + // Decrement statistics for the allocator handling small allocations. + inline void DecrementStats(size_t dealloc_size) + EXCLUSIVE_LOCKS_REQUIRED(mutex_) { + stats_.bytes_in_use -= dealloc_size; + } + + SubAllocator* sub_allocator_; // Not owned by this class. + + // Mutex for protecting updates to map of allocations. + mutable mutex mutex_; + + // Allocator name + string name_; + + // Hash map to keep track of "small" allocations + // We do not use BFC allocator for small allocations. + std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_); + + // Allocator stats for small allocs + AllocatorStats stats_ GUARDED_BY(mutex_); }; /// CPU allocator for MKL that wraps BFC allocator and intercepts /// and redirects memory allocation calls from MKL. -class MklCPUAllocator : public VisitableAllocator { +class MklCPUAllocator : public Allocator { public: // Constructor and other standard functions @@ -62,7 +157,10 @@ class MklCPUAllocator : public VisitableAllocator { MklCPUAllocator() { TF_CHECK_OK(Initialize()); } - ~MklCPUAllocator() override { delete allocator_; } + ~MklCPUAllocator() override { + delete small_size_allocator_; + delete large_size_allocator_; + } Status Initialize() { VLOG(2) << "MklCPUAllocator: In MklCPUAllocator"; @@ -96,8 +194,15 @@ class MklCPUAllocator : public VisitableAllocator { } VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes; - allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes, - kAllowGrowth, kName); + + sub_allocator_ = new MklSubAllocator(); + + // SubAllocator is owned by BFCAllocator, so we do not need to deallocate + // it in MklSmallSizeAllocator. + small_size_allocator_ = + new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName); + large_size_allocator_ = + new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName); #ifndef INTEL_MKL_DNN_ONLY // For redirecting all allocations from MKL to this allocator // From: http://software.intel.com/en-us/node/528565 @@ -112,23 +217,45 @@ class MklCPUAllocator : public VisitableAllocator { inline string Name() override { return kName; } inline void* AllocateRaw(size_t alignment, size_t num_bytes) override { - return allocator_->AllocateRaw(alignment, num_bytes); + // If the allocation size is less than threshold, call small allocator, + // otherwise call large-size allocator (BFC). We found that BFC allocator + // does not deliver good performance for small allocations when + // inter_op_parallelism_threads is high. + return (num_bytes < kSmallAllocationsThreshold) + ? small_size_allocator_->AllocateRaw(alignment, num_bytes) + : large_size_allocator_->AllocateRaw(alignment, num_bytes); } inline void DeallocateRaw(void* ptr) override { - allocator_->DeallocateRaw(ptr); + // Check if ptr is for "small" allocation. If it is, then call Free + // directly. Otherwise, call BFC to handle free. + if (small_size_allocator_->IsSmallSizeAllocation(ptr)) { + small_size_allocator_->DeallocateRaw(ptr); + } else { + large_size_allocator_->DeallocateRaw(ptr); + } } - void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); } - - void ClearStats() override { allocator_->ClearStats(); } - - void AddAllocVisitor(Visitor visitor) override { - allocator_->AddAllocVisitor(visitor); + void GetStats(AllocatorStats* stats) override { + AllocatorStats l_stats, s_stats; + small_size_allocator_->GetStats(&s_stats); + large_size_allocator_->GetStats(&l_stats); + + // Combine statistics from small-size and large-size allocator. + stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs; + stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use; + stats->max_bytes_in_use = + l_stats.max_bytes_in_use + s_stats.max_bytes_in_use; + + // Since small-size allocations go to MklSmallSizeAllocator, + // max_alloc_size from large_size_allocator would be the maximum + // size allocated by MklCPUAllocator. + stats->max_alloc_size = l_stats.max_alloc_size; } - void AddFreeVisitor(Visitor visitor) override { - allocator_->AddFreeVisitor(visitor); + void ClearStats() override { + small_size_allocator_->ClearStats(); + large_size_allocator_->ClearStats(); } private: @@ -148,26 +275,33 @@ class MklCPUAllocator : public VisitableAllocator { Status s = Status(error::Code::UNIMPLEMENTED, "Unimplemented case for hooking MKL function."); TF_CHECK_OK(s); // way to assert with an error message - return nullptr; // return a value and make static code analyzers happy + return nullptr; // return a value and make static code analyzers happy } static inline void* ReallocHook(void* ptr, size_t size) { Status s = Status(error::Code::UNIMPLEMENTED, "Unimplemented case for hooking MKL function."); TF_CHECK_OK(s); // way to assert with an error message - return nullptr; // return a value and make static code analyzers happy + return nullptr; // return a value and make static code analyzers happy } - /// Do we allow growth in BFC Allocator + // Do we allow growth in BFC Allocator static const bool kAllowGrowth = true; - /// Name + // Name static constexpr const char* kName = "mklcpu"; - /// The alignment that we need for the allocations + // The alignment that we need for the allocations static constexpr const size_t kAlignment = 64; - VisitableAllocator* allocator_; // owned by this class + Allocator* large_size_allocator_; // owned by this class + MklSmallSizeAllocator* small_size_allocator_; // owned by this class. + + SubAllocator* sub_allocator_; // not owned by this class + + // Size in bytes that defines the upper-bound for "small" allocations. + // Any allocation below this threshold is "small" allocation. + static constexpr const size_t kSmallAllocationsThreshold = 4096; // Prevent copying and assignment TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator); diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc index d581f45a90..3b59995433 100644 --- a/tensorflow/core/common_runtime/placer.cc +++ b/tensorflow/core/common_runtime/placer.cc @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/util/status_util.h" namespace tensorflow { @@ -255,9 +254,11 @@ class ColocationGraph { old_root_member.device_name, allow_soft_placement_); if (!s.ok()) { - return errors::InvalidArgument("Cannot colocate nodes '", x.name(), - "' and '", y.name(), ": ", - s.error_message()); + return errors::InvalidArgument( + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), ": ", + s.error_message()); } // Ensure that the common root has at least one supported device @@ -268,8 +269,10 @@ class ColocationGraph { old_root_member.supported_device_types); if (new_root_member.supported_device_types.empty()) { return errors::InvalidArgument( - "Cannot colocate nodes '", x.name(), "' and '", y.name(), - "' because no device type supports both of those nodes and the " + "Cannot colocate nodes ", + errors::FormatColocationNodeForError(x.name()), " and ", + errors::FormatColocationNodeForError(y.name()), + " because no device type supports both of those nodes and the " "other nodes colocated with them.", DebugInfo(x_root), DebugInfo(y_root)); } @@ -377,8 +380,9 @@ class ColocationGraph { // merged set device is different, so print both. return errors::InvalidArgument( "Could not satisfy explicit device specification '", - node->requested_device(), - "' because the node was colocated with a group of nodes that " + node->requested_device(), "' because the node ", + errors::FormatColocationNodeForError(node->name()), + " was colocated with a group of nodes that ", "required incompatible device '", DeviceNameUtils::ParsedNameToString( members_[node_root].device_name), @@ -810,10 +814,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } // Returns the first device in sorted devices list so we will always @@ -857,10 +861,10 @@ Status Placer::Run() { std::vector<Device*>* devices; Status status = colocation_graph.GetDevicesForNode(node, &devices); if (!status.ok()) { - return AttachDef(errors::InvalidArgument( - "Cannot assign a device for operation ", - RichNodeName(node), ": ", status.error_message()), - *node); + return AttachDef( + errors::InvalidArgument("Cannot assign a device for operation ", + node->name(), ": ", status.error_message()), + *node); } int assigned_device = -1; @@ -926,22 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const { } } -bool Placer::ClientHandlesErrorFormatting() const { - return options_ != nullptr && - options_->config.experimental().client_handles_error_formatting(); -} - -// Returns the node name in single quotes. If the client handles formatted -// errors, appends a formatting tag which the client will reformat into, for -// example, " (defined at filename:123)". -string Placer::RichNodeName(const Node* node) const { - string quoted_name = strings::StrCat("'", node->name(), "'"); - if (ClientHandlesErrorFormatting()) { - string file_and_line = error_format_tag(*node, "${defined_at}"); - return strings::StrCat(quoted_name, file_and_line); - } else { - return quoted_name; - } -} - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h index cefcdd25db..f97ffe7372 100644 --- a/tensorflow/core/common_runtime/placer.h +++ b/tensorflow/core/common_runtime/placer.h @@ -87,8 +87,6 @@ class Placer { // placement if the SessionOptions entry in 'options_' requests it. void AssignAndLog(int assigned_device, Node* node) const; void LogDeviceAssignment(const Node* node) const; - bool ClientHandlesErrorFormatting() const; - string RichNodeName(const Node* node) const; Graph* const graph_; // Not owned. const DeviceSet* const devices_; // Not owned. diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc index 87f2f2ceb9..9b8a95e3b6 100644 --- a/tensorflow/core/common_runtime/placer_test.cc +++ b/tensorflow/core/common_runtime/placer_test.cc @@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) { } Status s = Place(&g); - EXPECT_TRUE( - str_util::StrContains(s.error_message(), - "Cannot colocate nodes 'foo' and 'in' because no " - "device type supports both of those nodes and the " - "other nodes colocated with them")); + EXPECT_TRUE(str_util::StrContains( + s.error_message(), + "Cannot colocate nodes {{colocation_node foo}} and " + "{{colocation_node in}} because no device type supports both of those " + "nodes and the other nodes colocated with them")); } TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) { @@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) { Status s = Place(&g); EXPECT_TRUE(str_util::StrContains( s.error_message(), - "Cannot colocate nodes 'var3' and 'assign3' because no " - "device type supports both of those nodes and the other " - "nodes colocated with them.")); + "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node " + "assign3}} because no device type supports both of those nodes and the " + "other nodes colocated with them.")); } TEST_F(PlacerTest, TestColocationAndReferenceConnections) { @@ -1154,36 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) { } SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - true); Status s = Place(&g, &options); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); LOG(WARNING) << s.error_message(); EXPECT_TRUE(str_util::StrContains(s.error_message(), - "Cannot assign a device for operation 'in'" - "^^node:in:${defined_at}^^")); -} - -// Test that the "Cannot assign a device" error message does not contain a -// format tag when not it shouldn't -TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) { - Graph g(OpRegistry::Global()); - { // Scope for temporary variables used to construct g. - GraphDefBuilder b(GraphDefBuilder::kFailImmediately); - ops::SourceOp("TestDevice", - b.opts().WithName("in").WithDevice("/device:fakegpu:11")); - TF_EXPECT_OK(BuildGraph(b, &g)); - } - - SessionOptions options; - options.config.mutable_experimental()->set_client_handles_error_formatting( - false); - Status s = Place(&g, &options); - EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot assign a device for operation 'in'")); - EXPECT_FALSE(str_util::StrContains( - s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)")); + "Cannot assign a device for operation in")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}")); } // Test that placement fails when a node requests an explicit device that is not @@ -1289,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) { Status s = Place(&g); EXPECT_EQ(error::INVALID_ARGUMENT, s.code()); - EXPECT_TRUE(str_util::StrContains( - s.error_message(), "Cannot colocate nodes 'var' and 'assign'")); + EXPECT_TRUE(str_util::StrContains(s.error_message(), + "Cannot colocate nodes {{colocation_node " + "var}} and {{colocation_node assign}}")); } // Test that a generator node follows its consumers (where there are several diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc index 10a24ed14c..66dc8f3322 100644 --- a/tensorflow/core/common_runtime/pool_allocator.cc +++ b/tensorflow/core/common_runtime/pool_allocator.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/numbers.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -39,8 +40,7 @@ PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize, auto_resize_(auto_resize), pool_size_limit_(pool_size_limit), allocator_(allocator), - size_rounder_(size_rounder), - allocation_begun_(false) { + size_rounder_(size_rounder) { if (auto_resize) { CHECK_LT(size_t{0}, pool_size_limit) << "size limit must be > 0 if auto_resize is true."; @@ -92,7 +92,6 @@ ChunkPrefix* FindPrefix(void* user_ptr) { } // namespace void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { - if (!allocation_begun_) allocation_begun_ = true; if (num_bytes == 0) return nullptr; // If alignment is larger than kPoolAlignment, increase num_bytes so that we @@ -128,9 +127,6 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) { return PrepareChunk(r, alignment, num_bytes); } else { void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes); - for (const auto& v : alloc_visitors_) { - v(ptr, num_bytes); - } return PrepareChunk(ptr, alignment, num_bytes); } } @@ -140,9 +136,6 @@ void PoolAllocator::DeallocateRaw(void* ptr) { ChunkPrefix* cp = FindPrefix(ptr); CHECK_LE((void*)cp, (void*)ptr); if (!has_size_limit_ && !auto_resize_) { - for (const auto& v : free_visitors_) { - v(cp, cp->num_bytes); - } allocator_->Free(cp, cp->num_bytes); } else { mutex_lock lock(mutex_); @@ -163,9 +156,6 @@ void PoolAllocator::Clear() { mutex_lock lock(mutex_); for (auto iter : pool_) { PtrRecord* pr = iter.second; - for (const auto& v : free_visitors_) { - v(pr->ptr, pr->num_bytes); - } allocator_->Free(pr->ptr, pr->num_bytes); delete pr; } @@ -220,9 +210,6 @@ void PoolAllocator::EvictOne() { DCHECK(iter != pool_.end()); } pool_.erase(iter); - for (const auto& v : free_visitors_) { - v(prec->ptr, prec->num_bytes); - } allocator_->Free(prec->ptr, prec->num_bytes); delete prec; ++evicted_count_; @@ -268,28 +255,19 @@ void PoolAllocator::EvictOne() { } } -void PoolAllocator::AddAllocVisitor(Visitor visitor) { - mutex_lock lock(mutex_); - CHECK(!allocation_begun_) - << "AddAllocVisitor may not be called after pool allocation " - << "has begun."; - alloc_visitors_.push_back(visitor); -} - -void PoolAllocator::AddFreeVisitor(Visitor visitor) { - mutex_lock lock(mutex_); - CHECK(!allocation_begun_) - << "AddFreeVisitor may not be called after pool allocation " - << "has begun."; - free_visitors_.push_back(visitor); -} - void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) { - return port::AlignedMalloc(num_bytes, static_cast<int>(alignment)); + void* ptr = nullptr; + if (num_bytes > 0) { + ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment)); + VisitAlloc(ptr, numa_node_, num_bytes); + } + return ptr; } void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) { - port::AlignedFree(ptr); + if (num_bytes > 0) { + VisitFree(ptr, numa_node_, num_bytes); + port::AlignedFree(ptr); + } } - } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h index 607734445b..5b4623ba10 100644 --- a/tensorflow/core/common_runtime/pool_allocator.h +++ b/tensorflow/core/common_runtime/pool_allocator.h @@ -16,14 +16,13 @@ limitations under the License. #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ #define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_ -// Simple LRU pool allocators for various flavors of CPU RAM that -// implement the VisitableAllocator interface. +// Simple LRU pool allocators for various flavors of CPU RAM. #include <atomic> #include <map> #include <memory> #include <vector> -#include "tensorflow/core/common_runtime/visitable_allocator.h" +#include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -41,7 +40,7 @@ class RoundUpInterface { // Size-limited pool of memory buffers obtained from a SubAllocator // instance. Pool eviction policy is LRU. -class PoolAllocator : public VisitableAllocator { +class PoolAllocator : public Allocator { public: // "pool_size_limit" is the maximum number of returned, re-usable // memory buffers to keep in the pool. If pool_size_limit == 0, the @@ -64,14 +63,6 @@ class PoolAllocator : public VisitableAllocator { void DeallocateRaw(void* ptr) override; - // REQUIRES: The following functions may only be called prior - // to the first Allocate*() call. Once allocation has begun, it is - // illegal to register another visitor. - - void AddAllocVisitor(Visitor visitor) override; - - void AddFreeVisitor(Visitor visitor) override; - // Allocate an unused memory region of size "num_bytes". Fetch from // the pool if available, otherwise call allocator_. void* Get(size_t num_bytes); @@ -141,12 +132,6 @@ class PoolAllocator : public VisitableAllocator { int64 put_count_ GUARDED_BY(mutex_) = 0; int64 allocated_count_ GUARDED_BY(mutex_) = 0; int64 evicted_count_ GUARDED_BY(mutex_) = 0; - // Write access to these is guarded by mutex_, but not read - // access. They may only be modified prior to the first - // allocation. Later attempts to modify will fail. - std::vector<Visitor> alloc_visitors_; - std::vector<Visitor> free_visitors_; - std::atomic<bool> allocation_begun_; }; // Do-nothing rounder. Passes through sizes unchanged. @@ -166,7 +151,9 @@ class Pow2Rounder : public RoundUpInterface { class BasicCPUAllocator : public SubAllocator { public: // Argument numa_node is currently ignored. - explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {} + BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {} ~BasicCPUAllocator() override {} @@ -176,6 +163,8 @@ class BasicCPUAllocator : public SubAllocator { private: int numa_node_; + + TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc index 447338e7bd..bcaa37fc8a 100644 --- a/tensorflow/core/common_runtime/process_state.cc +++ b/tensorflow/core/common_runtime/process_state.cc @@ -71,20 +71,28 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) { return MemDesc(); } -VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) { +Allocator* ProcessState::GetCPUAllocator(int numa_node) { CHECK_GE(numa_node, 0); if (!numa_enabled_) numa_node = 0; mutex_lock lock(mu_); while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) { + // If visitors have been defined we need an Allocator built from + // a SubAllocator. Prefer BFCAllocator, but fall back to PoolAllocator + // depending on env var setting. + const bool alloc_visitors_defined = + (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty()); bool use_bfc_allocator = false; - // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and - // efficient. - Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false, - &use_bfc_allocator); + Status status = ReadBoolFromEnvVar( + "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator); if (!status.ok()) { LOG(ERROR) << "GetCPUAllocator: " << status.error_message(); } - VisitableAllocator* allocator; + Allocator* allocator = nullptr; + SubAllocator* sub_allocator = + (alloc_visitors_defined || use_bfc_allocator) + ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1, + cpu_alloc_visitors_, cpu_free_visitors_) + : nullptr; if (use_bfc_allocator) { // TODO(reedwm): evaluate whether 64GB by default is the best choice. int64 cpu_mem_limit_in_mb = -1; @@ -95,34 +103,63 @@ VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) { LOG(ERROR) << "GetCPUAllocator: " << status.error_message(); } int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20); - allocator = new BFCAllocator( - new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit, - true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/); + DCHECK(sub_allocator); + allocator = + new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/, + "bfc_cpu_allocator_for_gpu" /*name*/); VLOG(2) << "Using BFCAllocator with memory limit of " << cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator"; - } else { - allocator = new PoolAllocator( - 100 /*pool_size_limit*/, true /*auto_resize*/, - new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), - new NoopRounder, "cpu_pool"); + } else if (alloc_visitors_defined) { + DCHECK(sub_allocator); + allocator = + new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/, + sub_allocator, new NoopRounder, "cpu_pool"); VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator " << "numa_enabled_=" << numa_enabled_ << " numa_node=" << numa_node; + } else { + DCHECK(!sub_allocator); + allocator = cpu_allocator(); } - if (LogMemory::IsEnabled()) { + if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) { // Wrap the allocator to track allocation ids for better logging // at the cost of performance. - allocator = new TrackingVisitableAllocator(allocator, true); + allocator = new TrackingAllocator(allocator, true); } cpu_allocators_.push_back(allocator); + if (!sub_allocator) { + DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty()); + } } return cpu_allocators_[numa_node]; } +void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) { + VLOG(1) << "AddCPUAllocVisitor"; + mutex_lock lock(mu_); + CHECK_EQ(0, cpu_allocators_.size()) // Crash OK + << "AddCPUAllocVisitor must be called prior to first call to " + "ProcessState::GetCPUAllocator"; + cpu_alloc_visitors_.push_back(std::move(visitor)); +} + +void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) { + mutex_lock lock(mu_); + CHECK_EQ(0, cpu_allocators_.size()) // Crash OK + << "AddCPUFreeVisitor must be called prior to first call to " + "ProcessState::GetCPUAllocator"; + cpu_free_visitors_.push_back(std::move(visitor)); +} + void ProcessState::TestOnlyReset() { mutex_lock lock(mu_); + // Don't delete this value because it's static. + Allocator* default_cpu_allocator = cpu_allocator(); mem_desc_map_.clear(); - gtl::STLDeleteElements(&cpu_allocators_); + for (Allocator* a : cpu_allocators_) { + if (a != default_cpu_allocator) delete a; + } + cpu_allocators_.clear(); gtl::STLDeleteElements(&cpu_al_); } diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h index 2892677333..cac312d849 100644 --- a/tensorflow/core/common_runtime/process_state.h +++ b/tensorflow/core/common_runtime/process_state.h @@ -30,7 +30,6 @@ limitations under the License. namespace tensorflow { class Allocator; -class VisitableAllocator; class PoolAllocator; // Singleton that manages per-process state, e.g. allocation of @@ -65,7 +64,15 @@ class ProcessState { // Returns the one CPUAllocator used for the given numa_node. // TEMPORARY: ignores numa_node. - VisitableAllocator* GetCPUAllocator(int numa_node); + Allocator* GetCPUAllocator(int numa_node); + + // Registers alloc visitor for the CPU allocator(s). + // REQUIRES: must be called before GetCPUAllocator. + void AddCPUAllocVisitor(SubAllocator::Visitor v); + + // Registers free visitor for the CPU allocator(s). + // REQUIRES: must be called before GetCPUAllocator. + void AddCPUFreeVisitor(SubAllocator::Visitor v); typedef std::unordered_map<const void*, MemDesc> MDMap; @@ -87,7 +94,9 @@ class ProcessState { mutex mu_; - std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_); + std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_); + std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_); + std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_); virtual ~ProcessState(); diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h index 103eee03b3..9d59264899 100644 --- a/tensorflow/core/common_runtime/renamed_device.h +++ b/tensorflow/core/common_runtime/renamed_device.h @@ -72,9 +72,10 @@ class RenamedDevice : public Device { return underlying_->MakeGpuDevice(); } - void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, - DeviceContext* dc, Allocator* allocator) override { - underlying_->ReinitializeGpuDevice(context, device, dc, allocator); + Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device, + DeviceContext* dc, + Allocator* allocator) override { + return underlying_->ReinitializeGpuDevice(context, device, dc, allocator); } Status MakeTensorFromProto(const TensorProto& tensor_proto, diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc index 1e3fed0d6f..43ca3f1e3e 100644 --- a/tensorflow/core/common_runtime/rendezvous_util.cc +++ b/tensorflow/core/common_runtime/rendezvous_util.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/rendezvous_util.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/reffed_status_callback.h" diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc index 65ff356e73..5b1915755d 100644 --- a/tensorflow/core/common_runtime/session_state.cc +++ b/tensorflow/core/common_runtime/session_state.cc @@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names, // Save only the tensors in output_names in the session. for (const string& name : output_names) { TensorId id(ParseTensorName(name)); - const string& op_name = std::string(id.first); + const string op_name(id.first); auto it = tensors_.find(op_name); if (it != tensors_.end()) { // Save the tensor to the session state. diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h index 04d5af9087..22650b0d83 100644 --- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h +++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/threadpool.h" namespace tensorflow { diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc index 9c2510e6a9..a70ab93d4a 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.cc +++ b/tensorflow/core/common_runtime/step_stats_collector.cc @@ -27,6 +27,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace { @@ -40,46 +41,24 @@ struct AllocStats { }; } // namespace -NodeExecStatsWrapper::NodeExecStatsWrapper(const string& node_name) - : NodeExecStatsWrapper(new NodeExecStats) { - stats_->set_node_name(node_name); -} -NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats) - : stats_(stats) {} - -void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* v) { - DCHECK(v); - NodeOutput* no = stats_->add_output(); - no->set_slot(slot); - v->FillDescription(no->mutable_tensor_description()); -} - -void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { - for (const auto& allocator_pair : ctx->wrapped_allocators()) { - AddAllocation(allocator_pair.first, allocator_pair.second); - } - auto* ms = stats_->mutable_memory_stats(); - ms->set_temp_memory_size(ctx->temp_memory_allocated()); - for (const auto& alloc_id : ctx->persistent_alloc_ids()) { - ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); - } - ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); +NodeExecStatsWrapper::NodeExecStatsWrapper( + const Node* node, StepStatsCollector* step_stats_collector) + : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node, + step_stats_collector) { + stats_->set_node_name(node->name()); } -void NodeExecStatsWrapper::SetReferencedTensors( - const TensorReferenceVector& tensors) { - // be careful not to increment the reference count on any tensor - // while recording the information - for (size_t i = 0; i < tensors.size(); ++i) { - AllocationDescription* description = stats_->add_referenced_tensor(); - tensors.at(i).FillDescription(description); - } -} - -// TODO(tucker): merge with the DetailText function in session.cc -// in a common location. -bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) { - bool is_transfer_node = false; +NodeExecStatsWrapper::NodeExecStatsWrapper( + std::unique_ptr<NodeExecStats> stats, const Node* node, + StepStatsCollector* step_stats_collector) + : stats_(std::move(stats)), + node_(node), + step_stats_collector_(step_stats_collector) {} + +void NodeExecStatsWrapper::Done(const string& device) { + // TODO(tucker): merge with the DetailText function in session.cc in a common + // location. + DCHECK(node_); string memory; for (auto& all : stats_->memory()) { int64 tot = all.total_bytes(); @@ -96,31 +75,96 @@ bool NodeExecStatsWrapper::SetTimelineLabel(const Node* node) { } } } - const AttrSlice attrs = node->attrs(); + const AttrSlice attrs = node_->attrs(); string text; - if (IsSend(node)) { + if (IsSend(node_)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string recv_device; TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device)); - text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", tensor_name, " @", recv_device); - is_transfer_node = true; - } else if (IsRecv(node)) { + } else if (IsRecv(node_)) { string tensor_name; TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name)); string send_device; TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device)); - text = strings::StrCat(memory, node->name(), " = ", node->type_string(), + text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", tensor_name, " @", send_device); - is_transfer_node = true; } else { text = - strings::StrCat(memory, node->name(), " = ", node->type_string(), "(", - str_util::Join(node->requested_inputs(), ", "), ")"); + strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(", + str_util::Join(node_->requested_inputs(), ", "), ")"); } stats_->set_timeline_label(text); - return is_transfer_node; + step_stats_collector_->Save(device, this); +} + +void NodeExecStatsWrapper::RecordExecutorStarted() { + int64 now_nanos = Env::Default()->NowNanos(); + stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); + stats_->set_all_start_nanos(now_nanos); +} + +void NodeExecStatsWrapper::RecordComputeStarted() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::RecordComputeEnded() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::RecordExecutorEnded() { + int64 now_nanos = Env::Default()->NowNanos(); + DCHECK_NE(stats_->all_start_micros(), 0); + DCHECK_NE(stats_->all_start_nanos(), 0); + stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - + stats_->all_start_micros()); + stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); +} + +void NodeExecStatsWrapper::SetScheduled(int64 nanos) { + stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); + stats_->set_scheduled_nanos(nanos); +} + +void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) { + for (const auto& allocator_pair : ctx->wrapped_allocators()) { + AddAllocation(allocator_pair.first, allocator_pair.second); + } + auto* ms = stats_->mutable_memory_stats(); + ms->set_temp_memory_size(ctx->temp_memory_allocated()); + for (const auto& alloc_id : ctx->persistent_alloc_ids()) { + ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id); + } + ms->set_persistent_memory_size(ctx->persistent_memory_allocated()); +} + +void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) { + DCHECK(tensor); + NodeOutput* node_output = stats_->add_output(); + node_output->set_slot(slot); + tensor->FillDescription(node_output->mutable_tensor_description()); +} + +void NodeExecStatsWrapper::SetReferencedTensors( + const TensorReferenceVector& tensors) { + // be careful not to increment the reference count on any tensor + // while recording the information + for (size_t i = 0; i < tensors.size(); ++i) { + AllocationDescription* description = stats_->add_referenced_tensor(); + tensors.at(i).FillDescription(description); + } } void NodeExecStatsWrapper::AddAllocation( @@ -150,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() { allocations_.clear(); } -StepStatsCollector::StepStatsCollector(StepStats* ss) - : finalized_(false), step_stats_(ss) {} +StepStatsCollector::StepStatsCollector(StepStats* step_stats) + : finalized_(false), step_stats_(step_stats) {} static int ExtractGpuWithStreamAll(string device_name) { // Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp, @@ -176,7 +220,7 @@ static int ExtractGpuWithStreamAll(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -205,7 +249,7 @@ static int ExtractGpuWithoutStream(string device_name) { } else { // Convert the captured string into an integer. But first we need to put // the digits back in order - string ordered_capture = std::string(capture); + string ordered_capture(capture); std::reverse(ordered_capture.begin(), ordered_capture.end()); int gpu_id; CHECK(strings::safe_strto32(ordered_capture, &gpu_id)); @@ -252,7 +296,7 @@ void StepStatsCollector::BuildCostModel( for (auto& itr : per_device_stats) { const StringPiece device_name = itr.first; - const int gpu_id = ExtractGpuWithoutStream(std::string(device_name)); + const int gpu_id = ExtractGpuWithoutStream(string(device_name)); if (gpu_id >= 0) { // Reference the gpu hardware stats in addition to the regular stats // for this gpu device if they're available. @@ -338,28 +382,40 @@ void StepStatsCollector::BuildCostModel( } } -void StepStatsCollector::Save(const string& device, NodeExecStats* nt) { - Save(device, new NodeExecStatsWrapper(nt)); +void StepStatsCollector::Save(const string& device, + NodeExecStats* node_stats_pb) { + Save(device, + new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb), + nullptr, this)); } void StepStatsCollector::Save(const string& device, - NodeExecStatsWrapper* stats) { - if (!stats) return; - VLOG(1) << "Save dev " << device << " nt " << stats->stats(); + NodeExecStatsWrapper* node_stats) { + if (!node_stats) return; + VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats(); { mutex_lock l(mu_); if (finalized_) { LOG(WARNING) << "stats saved after finalize will not be collected."; } - if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) { + if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) { VLOG(1) << "step_stats_ nullptr or already collected too many nodes."; - delete stats; + delete node_stats; return; } - auto& dss = dev_stats_[device]; - dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats)); - collectedNodes++; + auto& device_stats = dev_stats_[device]; + device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats)); + collected_nodes_++; + } +} + +NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats( + const Node* node) { + // Only collect statistics for non-transfer nodes. + if (IsSend(node) || IsRecv(node)) { + return nullptr; } + return new NodeExecStatsWrapper(node, this); } string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) { @@ -446,12 +502,12 @@ void StepStatsCollector::Finalize() { FinalizeInternal(); } -void StepStatsCollector::FinalizeAndSwap(StepStats* ss) { +void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) { mutex_lock l(mu_); CHECK(step_stats_); FinalizeInternal(); - ss->Swap(step_stats_); - collectedNodes = 0; + step_stats->Swap(step_stats_); + collected_nodes_ = 0; } void StepStatsCollector::FinalizeInternal() { diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h index 7206fbf427..4365b11b19 100644 --- a/tensorflow/core/common_runtime/step_stats_collector.h +++ b/tensorflow/core/common_runtime/step_stats_collector.h @@ -36,81 +36,78 @@ class Node; class NodeExecStats; class OpKernelContext; class StepStats; +class StepStatsCollector; class Tensor; class TrackingAllocator; -// Wraps NodeExecStats and adds allocation to it. -class NodeExecStatsWrapper { +// Statistics collection interface for individual node execution. +// +// See `NodeExecStatsWrapper` for a concrete implementation of this interface +// that interfaces with the `Session` layer. +class NodeExecStatsInterface { public: - NodeExecStatsWrapper(const string& node_name); - // Owns 'stats'. - NodeExecStatsWrapper(NodeExecStats* stats); + virtual ~NodeExecStatsInterface() {} - // Destructor calls Finalize() to release the TrackingAllocators. - ~NodeExecStatsWrapper() { Finalize(); } - - // Records the absolute time in nanoseconds at which this node became - // runnable (i.e. was scheduled for execution). - void SetScheduled(int64 nanos) { - stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos); - stats_->set_scheduled_nanos(nanos); - } + // Called when the statistics collection for the node has finished. Once this + // method is called, the caller should not make assumptions about the validity + // of this object. + virtual void Done(const string& device) = 0; // Called immediately after this node starts being processed by the executor. - void RecordExecutorStarted() { - int64 now_nanos = Env::Default()->NowNanos(); - stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos); - stats_->set_all_start_nanos(now_nanos); - } + virtual void RecordExecutorStarted() = 0; // Called immediately before this node's `Compute()` or `ComputeAsync()` // method is called. - void RecordComputeStarted() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos()); - } + virtual void RecordComputeStarted() = 0; // Called immediately after this node's `Compute()` method returned (or, for // asynchronous operations, the callback passed to its `ComputeAsync()` method // was called). - void RecordComputeEnded() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos()); - } + virtual void RecordComputeEnded() = 0; // Called immediately after this executor finishes processing this node. - void RecordExecutorEnded() { - int64 now_nanos = Env::Default()->NowNanos(); - DCHECK_NE(stats_->all_start_micros(), 0); - DCHECK_NE(stats_->all_start_nanos(), 0); - stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos - - stats_->all_start_micros()); - stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos()); - } - - // Records information about the tensor produced by this node at the given - // output slot. - void SetOutput(int slot, const Tensor* v); + virtual void RecordExecutorEnded() = 0; // Records information about the memory allocated during the execution of this // node. - void SetMemory(OpKernelContext* ctx); + virtual void SetMemory(OpKernelContext* ctx) = 0; + + // Records information about the tensor produced by this node at the given + // output slot. + virtual void SetOutput(int slot, const Tensor* tensor) = 0; // Records information about the tensors that were accessed during the // execution of this node. - void SetReferencedTensors(const TensorReferenceVector& tensors); + virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0; - // Sets the timeline_label field of the wrapped NodeExecStats, using data - // from *node. Returns true iff the node is a transfer node. - bool SetTimelineLabel(const Node* node); + // Records the absolute time in nanoseconds at which this node became + // runnable (i.e. was scheduled for execution). + virtual void SetScheduled(int64 nanos) = 0; +}; + +// Wraps NodeExecStats and adds allocation to it. +class NodeExecStatsWrapper : public NodeExecStatsInterface { + public: + // Does not take ownership of `node` or `step_stats_collector`. + NodeExecStatsWrapper(const Node* node, + StepStatsCollector* step_stats_collector); + + // Takes ownership of 'stats' but not `node` or `step_stats_collector`. + NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node, + StepStatsCollector* step_stats_collector); + + // Destructor calls Finalize() to release the TrackingAllocators. + ~NodeExecStatsWrapper() { Finalize(); } + + void Done(const string& device) override; + void RecordExecutorStarted() override; + void RecordComputeStarted() override; + void RecordComputeEnded() override; + void RecordExecutorEnded() override; + void SetMemory(OpKernelContext* ctx) override; + void SetOutput(int slot, const Tensor* tensor) override; + void SetReferencedTensors(const TensorReferenceVector& tensors) override; + void SetScheduled(int64 nanos) override; private: friend class StepStatsCollector; @@ -128,9 +125,11 @@ class NodeExecStatsWrapper { gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2> allocations_; std::unique_ptr<NodeExecStats> stats_; + const Node* const node_; // Not owned. + StepStatsCollector* const step_stats_collector_; // Not owned. }; -// Statistics collection interface for individual node execution. +// Statistics collection interface for step execution. // // See `StepStatsCollector` for a concrete implementation of this interface // that interfaces with the `Session` layer. @@ -138,8 +137,9 @@ class StepStatsCollectorInterface { public: virtual ~StepStatsCollectorInterface() {} - // Saves `stats` to the collector. - virtual void Save(const string& device, NodeExecStatsWrapper* stats) = 0; + // Creates an instance of `NodeExecStatsInterface` that should be used for + // collecting statistics about individual node execution. + virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0; // Generates a string reporting the currently used memory based // on ResourceExhausted OOM `err` message. @@ -154,8 +154,8 @@ class StepStatsCollectorInterface { // Each DeviceStats object holds multiple NodeExecStats. class StepStatsCollector : public StepStatsCollectorInterface { public: - // Does not take ownership of `ss`. - explicit StepStatsCollector(StepStats* ss); + // Does not take ownership of `step_stats`. + explicit StepStatsCollector(StepStats* step_stats); // BuildCostModel builds or updates a CostModel managed by cost_model_manager, // using the currently collected DeviceStats associated with the devices in @@ -164,11 +164,12 @@ class StepStatsCollector : public StepStatsCollectorInterface { CostModelManager* cost_model_manager, const std::unordered_map<string, const Graph*>& device_map); - // Save saves nt to the DeviceStats object associated with device. + // Saves node statistics to the DeviceStats object associated with device. // Should be called before Finalize. - void Save(const string& device, NodeExecStats* nt); - void Save(const string& device, NodeExecStatsWrapper* stats) override; + void Save(const string& device, NodeExecStats* node_stats_pb); + void Save(const string& device, NodeExecStatsWrapper* node_stats); + NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override; string ReportAllocsOnResourceExhausted(const string& err) override; // The following 2 Finalize methods populate the StepStats passed @@ -176,20 +177,22 @@ class StepStatsCollector : public StepStatsCollectorInterface { // User shouldn't call Save() methods after Finalize. void Finalize(); // swaps the content of StepStats* from constructor with 'ss'. - void FinalizeAndSwap(StepStats* ss); + void FinalizeAndSwap(StepStats* step_stats); private: + // TODO(suharshs): Make this configurable if its not possible to find a value + // that works for all cases. + static const uint64 kMaxCollectedNodes = 1 << 20; + + typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector; + void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_); - typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec; - // TODO(suharshs): Make this configurable if its not possible to find a value - // that works for all cases. - const uint64 kMaxCollectedNodes = 1 << 20; mutex mu_; bool finalized_ GUARDED_BY(mu_); - std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_); + std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_); StepStats* step_stats_ GUARDED_BY(mu_); - uint64 collectedNodes GUARDED_BY(mu_) = 0; + uint64 collected_nodes_ GUARDED_BY(mu_) = 0; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/tracing_device.h b/tensorflow/core/common_runtime/tracing_device.h index 39215efa35..e1b163074f 100644 --- a/tensorflow/core/common_runtime/tracing_device.h +++ b/tensorflow/core/common_runtime/tracing_device.h @@ -35,8 +35,11 @@ class TracingDevice : public Device { : Device(env, attributes) {} void Compute(OpKernel* op_kernel, OpKernelContext* context) override { + const tracing::TraceCollector* trace_collector = + tracing::GetTraceCollector(); if (TF_PREDICT_FALSE( - tracing::GetTraceCollector() || + (trace_collector && + trace_collector->IsEnabled(op_kernel->IsExpensive())) || tracing::GetEventCollector(tracing::EventCategory::kCompute))) { const string& op_name = op_kernel->name(); tracing::ScopedActivity activity(op_name, op_kernel->type_string(), diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h deleted file mode 100644 index ae0563a96a..0000000000 --- a/tensorflow/core/common_runtime/visitable_allocator.h +++ /dev/null @@ -1,79 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ -#define TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ - -#include <functional> -#include "tensorflow/core/framework/allocator.h" -#include "tensorflow/core/framework/tracking_allocator.h" - -namespace tensorflow { - -// Subclass VisitableAllocator instead of Allocator when a memory -// allocator needs to enable some kind of registration/deregistration -// of memory areas. -class VisitableAllocator : public Allocator { - public: - // Visitor gets called with a pointer to a memory area and its - // size in bytes. - typedef std::function<void(void*, size_t)> Visitor; - - // Register a visitor guaranteed to be called exactly once on each - // chunk of memory newly allocated from the underlying device. - // Typically, chunks will be reused and possibly sub-divided by a - // pool manager, so the calls will happen only once per process - // execution, not once per tensor (re)allocation. - virtual void AddAllocVisitor(Visitor visitor) = 0; - - // Register a visitor guaranteed to be called on each chunk of - // memory returned to the underlying device. - virtual void AddFreeVisitor(Visitor visitor) = 0; -}; - -// Needed for cases when a VisitableAllocator gets wrapped for tracking. -// Multiple-inheritance is considered acceptable in this case because -// VisitableAllocator is a pure virtual interface and only TrackingAllocator -// has default implementation. -class TrackingVisitableAllocator : public TrackingAllocator, - public VisitableAllocator { - public: - TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids) - : TrackingAllocator(allocator, track_ids), allocator_(allocator) {} - ~TrackingVisitableAllocator() override {} - - string Name() override { return TrackingAllocator::Name(); } - - void* AllocateRaw(size_t alignment, size_t num_bytes) override { - return TrackingAllocator::AllocateRaw(alignment, num_bytes); - } - - void DeallocateRaw(void* ptr) override { - TrackingAllocator::DeallocateRaw(ptr); - } - - void AddAllocVisitor(Visitor visitor) override { - allocator_->AddAllocVisitor(visitor); - } - - void AddFreeVisitor(Visitor visitor) override { - allocator_->AddFreeVisitor(visitor); - } - - protected: - VisitableAllocator* allocator_; -}; -} // namespace tensorflow -#endif // TENSORFLOW_CORE_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_ diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc index 38863db1cc..6994dec3b5 100644 --- a/tensorflow/core/debug/debug_io_utils.cc +++ b/tensorflow/core/debug/debug_io_utils.cc @@ -693,6 +693,7 @@ uint64 DebugFileIO::diskBytesUsed = 0; mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED); bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { + mutex_lock l(bytes_mu); if (globalDiskBytesLimit == 0) { const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT"); if (env_tfdbg_disk_bytes_limit == nullptr || @@ -707,7 +708,6 @@ bool DebugFileIO::requestDiskByteUsage(uint64 bytes) { if (bytes == 0) { return true; } - mutex_lock l(bytes_mu); if (diskBytesUsed + bytes < globalDiskBytesLimit) { diskBytesUsed += bytes; return true; diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc index 6c146036ae..f7a2967d00 100644 --- a/tensorflow/core/distributed_runtime/graph_mgr.cc +++ b/tensorflow/core/distributed_runtime/graph_mgr.cc @@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, params.function_library = lib; params.create_kernel = [session, lib, opseg](const NodeDef& ndef, OpKernel** kernel) { - // We do not share the kernel via the OpSegment if the node is - // stateless, or a function. // NOTE(mrry): We must not share function kernels (implemented // using `CallOp`) between subgraphs, because `CallOp::handle_` // is tied to a particular subgraph. Even if the function itself // is stateful, the `CallOp` that invokes it is not. - if (!lib->IsStateful(ndef.op()) || - lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) { + if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) { return lib->CreateKernel(ndef, kernel); } auto create_fn = [lib, &ndef](OpKernel** kernel) { @@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef, return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn); }; params.delete_kernel = [lib](OpKernel* kernel) { - // If the node is stateful, opseg owns it. Otherwise, delete it. - if (kernel && !lib->IsStateful(kernel->type_string())) { + if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) { delete kernel; } }; diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc index abd07e37b7..8e9eec1ed9 100644 --- a/tensorflow/core/distributed_runtime/master_session.cc +++ b/tensorflow/core/distributed_runtime/master_session.cc @@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions( *c->req.mutable_graph_options() = session_opts_.config.graph_options(); *c->req.mutable_debug_options() = callable_opts_.run_options().debug_options(); - c->req.set_collective_graph_key(bg_opts_.collective_graph_key); + c->req.set_collective_graph_key(client_graph()->collective_graph_key); VLOG(2) << "Register " << c->req.graph_def().DebugString(); auto cb = [c, &done](const Status& s) { c->status = s; @@ -1111,10 +1111,6 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) { h = Hash64(watch_summary.c_str(), watch_summary.size(), h); } - if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { - h = Hash64Combine(opts.collective_graph_key, h); - } - return h; } @@ -1788,10 +1784,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg, Status s = run_status; if (s.ok()) { pss->end_micros = Env::Default()->NowMicros(); - if (rcg->build_graph_options().collective_graph_key != + if (rcg->client_graph()->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) { env_->collective_executor_mgr->RetireStepId( - rcg->build_graph_options().collective_graph_key, step_id); + rcg->client_graph()->collective_graph_key, step_id); } // Schedule post-processing and cleanup to be done asynchronously. rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata); @@ -1850,7 +1846,7 @@ Status MasterSession::DoRunWithLocalExecution( // Keeps the highest 8 bits 0x01: we reserve some bits of the // step_id for future use. - uint64 step_id = NewStepId(bgopts.collective_graph_key); + uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key); TRACEPRINTF("stepid %llu", step_id); std::unique_ptr<ProfileHandler> ph; @@ -1914,8 +1910,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg, // Prepare. int64 count = rcg->get_and_increment_execution_count(); - const uint64 step_id = - NewStepId(rcg->build_graph_options().collective_graph_key); + const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key); TRACEPRINTF("stepid %llu", step_id); const RunOptions& run_options = rcg->callable_options().run_options(); diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto index e7142a4ef9..e36e51d8d5 100644 --- a/tensorflow/core/example/example.proto +++ b/tensorflow/core/example/example.proto @@ -199,7 +199,13 @@ message Example { // to determine if all features within the FeatureList must // have the same size. The same holds for this FeatureList across multiple // examples. -// +// - For sequence modeling, e.g.: +// http://colah.github.io/posts/2015-08-Understanding-LSTMs/ +// https://github.com/tensorflow/nmt +// the feature lists represent a sequence of frames. +// In this scenario, all FeatureLists in a SequenceExample have the same +// number of Feature messages, so that the ith element in each FeatureList +// is part of the ith frame (or time step). // Examples of conformant and non-conformant examples' FeatureLists: // // Conformant FeatureLists: diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc index 888ed0c57b..84cee5569c 100644 --- a/tensorflow/core/framework/allocator.cc +++ b/tensorflow/core/framework/allocator.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/allocator_registry.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/tracking_allocator.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" @@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } +void Allocator::RunVariantCtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); +} + +void Allocator::RunVariantDtor(Variant* p, size_t n) { + for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); +} + // If true, cpu allocator collects more stats. static bool cpu_allocator_collect_stats = false; // If true, cpu allocator collects full stats. @@ -187,7 +196,7 @@ class CPUAllocatorFactory : public AllocatorFactory { class CPUSubAllocator : public SubAllocator { public: explicit CPUSubAllocator(CPUAllocator* cpu_allocator) - : cpu_allocator_(cpu_allocator) {} + : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {} void* Alloc(size_t alignment, size_t num_bytes) override { return cpu_allocator_->AllocateRaw(alignment, num_bytes); @@ -213,4 +222,22 @@ Allocator* cpu_allocator() { } return cpu_alloc; } + +SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors) + : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {} + +void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) { + for (const auto& v : alloc_visitors_) { + v(ptr, index, num_bytes); + } +} + +void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) { + // Although we don't guarantee any order of visitor application, strive + // to apply free visitors in reverse order of alloc visitors. + for (int i = free_visitors_.size() - 1; i >= 0; --i) { + free_visitors_[i](ptr, index, num_bytes); + } +} } // namespace tensorflow diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h index 774b1fe137..8c23604625 100644 --- a/tensorflow/core/framework/allocator.h +++ b/tensorflow/core/framework/allocator.h @@ -23,12 +23,14 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/type_traits.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { +class Variant; + // Attributes for a single allocation call. Different calls to the same // allocator could potentially have different allocation attributes. struct AllocationAttributes { @@ -228,13 +230,9 @@ class Allocator { for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle(); } - virtual void RunVariantCtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) new (p) Variant(); - } + virtual void RunVariantCtor(Variant* p, size_t n); - virtual void RunVariantDtor(Variant* p, size_t n) { - for (size_t i = 0; i < n; ++p, ++i) p->~Variant(); - } + virtual void RunVariantDtor(Variant* p, size_t n); // TODO(jeff): Maybe provide some interface to give info about // current allocation state (total number of bytes available for @@ -390,13 +388,36 @@ void EnableCPUAllocatorStats(bool enable); // full statistics. By default, it's disabled. void EnableCPUAllocatorFullStats(bool enable); -// Abstract interface of an object that does the underlying suballoc/free of -// memory for a higher-level allocator. +// An object that does the underlying suballoc/free of memory for a higher-level +// allocator. The expectation is that the higher-level allocator is doing some +// kind of cache or pool management so that it will call SubAllocator::Alloc and +// Free relatively infrequently, compared to the number of times its own +// AllocateRaw and Free methods are called. class SubAllocator { public: + // Visitor gets called with a pointer to a memory area and its + // size in bytes. The index value will be numa_node for a CPU + // allocator and GPU id for a GPU allocator. + typedef std::function<void(void*, int index, size_t)> Visitor; + + SubAllocator(const std::vector<Visitor>& alloc_visitors, + const std::vector<Visitor>& free_visitors); + virtual ~SubAllocator() {} virtual void* Alloc(size_t alignment, size_t num_bytes) = 0; virtual void Free(void* ptr, size_t num_bytes) = 0; + + protected: + // Implementation of Alloc() method must call this on newly allocated + // value. + void VisitAlloc(void* ptr, int index, size_t num_bytes); + + // Implementation of Free() method must call this on value to be + // freed immediately before deallocation. + void VisitFree(void* ptr, int index, size_t num_bytes); + + const std::vector<Visitor> alloc_visitors_; + const std::vector<Visitor> free_visitors_; }; } // namespace tensorflow diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h index 24f282ce84..e907c52ba9 100644 --- a/tensorflow/core/framework/allocator_registry.h +++ b/tensorflow/core/framework/allocator_registry.h @@ -21,6 +21,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/numa.h" namespace tensorflow { diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc index 1a3994736c..4ffd732f8e 100644 --- a/tensorflow/core/framework/attr_value_util_test.cc +++ b/tensorflow/core/framework/attr_value_util_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include <numeric> #include <vector> #include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc index b0b27ce94f..284dafb886 100644 --- a/tensorflow/core/framework/dataset.cc +++ b/tensorflow/core/framework/dataset.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/graph/node_builder.h" namespace tensorflow { - +namespace data { namespace { // A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor. @@ -179,6 +179,13 @@ Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx, return Status::OK(); } +void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val, + Node** output) { + *output = ops::SourceOp( + "Placeholder", + b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape())); +} + void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val, Node** output) { *output = ops::SourceOp( @@ -322,4 +329,5 @@ void BackgroundWorker::WorkerLoop() { } } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index e06ca68bca..91b1e61d3c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset_stateful_op_whitelist.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/model.h" #include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -40,6 +41,15 @@ limitations under the License. namespace tensorflow { +// Forward declarations to avoid introducing a dependency on headers in +// "tensorflow/core/graph/...". +class GraphDefBuilder; +class Node; + +namespace data { +// A constant that can be used to enable auto-tuning. +constexpr int kAutoTune = -1; + class DatasetBase; class SerializationContext; @@ -66,11 +76,6 @@ class IteratorStateWriter { virtual ~IteratorStateWriter() {} }; -// Forward declarations to avoid introducing a dependency on headers in -// "tensorflow/core/graph/...". -class GraphDefBuilder; -class Node; - // Wrapper around GraphDefBuilder. Used to serialize Dataset graph. class GraphDefBuilderWrapper { public: @@ -110,10 +115,11 @@ class GraphDefBuilderWrapper { return Status::OK(); } - // Adds a Const node with Tensor value to the Graph. + // Adds a `Const` node for the given tensor value to the graph. + // // `*output` contains a pointer to the output `Node`. It is guaranteed to be - // non-null if the method returns with an OK status. - // The returned Node pointer is owned by the backing Graph of GraphDefBuilder. + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing graph of `GraphDefBuilder`. Status AddTensor(const Tensor& val, Node** output) { AddTensorInternal(val, output); if (*output == nullptr) { @@ -122,6 +128,20 @@ class GraphDefBuilderWrapper { return Status::OK(); } + // Adds a `Placeholder` node for the given tensor value to the graph. + // + // `*output` contains a pointer to the output `Node`. It is guaranteed to be + // non-null if the method returns with an OK status. The returned `Node` + // pointer is owned by the backing graph of `GraphDefBuilder`. + Status AddPlaceholder(const Tensor& val, Node** output) { + AddPlaceholderInternal(val, output); + if (*output == nullptr) { + return errors::Internal( + "AddPlaceholder: Failed to build Placeholder op."); + } + return Status::OK(); + } + Status AddDataset(const DatasetBase* dataset, const std::vector<Node*>& inputs, Node** output) { return AddDataset(dataset, inputs, {}, output); @@ -168,6 +188,7 @@ class GraphDefBuilderWrapper { } private: + void AddPlaceholderInternal(const Tensor& val, Node** output); void AddTensorInternal(const Tensor& val, Node** output); Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def, @@ -206,8 +227,7 @@ class GraphDefBuilderWrapper { return (str_util::EndsWith(op_def->name(), "Dataset") && op_def->output_arg_size() == 1 && op_def->output_arg(0).type() == DT_VARIANT) || - dataset::WhitelistedStatefulOpRegistry::Global()->Contains( - op_def->name()); + WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name()); } bool HasAttr(const string& op_type_name, const string& attr_name) const; @@ -274,6 +294,9 @@ class IteratorContext { // The Allocator to be used to allocate the output of an iterator. std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr; + + // If non-null, identifies the object used for performance modeling. + std::shared_ptr<model::Model> model = nullptr; }; explicit IteratorContext(Params params) : params_(std::move(params)) {} @@ -325,6 +348,10 @@ class IteratorContext { return params_.stats_aggregator_getter; } + std::shared_ptr<model::Model> model() { return params_.model; } + + Params params() { return params_; } + private: Params params_; }; @@ -334,7 +361,8 @@ class SerializationContext { public: struct Params { bool allow_stateful_functions = false; - const FunctionLibraryDefinition* flib_def; // Not owned. + const FunctionLibraryDefinition* flib_def = nullptr; // Not owned. + std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned. }; explicit SerializationContext(Params params) : params_(std::move(params)) {} @@ -343,6 +371,10 @@ class SerializationContext { const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; } + std::vector<std::pair<string, Tensor>>* input_list() { + return params_.input_list; + } + private: Params params_; @@ -354,7 +386,11 @@ class SerializationContext { // defined below. class IteratorBase { public: - virtual ~IteratorBase() {} + virtual ~IteratorBase() { + for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) { + (*rit)(); + } + } // Gets the next output from the range that this iterator is traversing. // @@ -388,6 +424,10 @@ class IteratorBase { // in the outputs of this iterator. virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; + // Returns a string that identifies the sequence of iterators leading up to + // this iterator. + virtual const string& prefix() const = 0; + // Performs initialization that needs to happen outside of a constructor to // properly propagate errors. virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); } @@ -427,6 +467,18 @@ class IteratorBase { IteratorStateReader* reader) { return errors::Unimplemented("RestoreInternal"); } + + private: + friend class DatasetBase; // for access to `AddCleanupFunction` + + // Registers a cleanup function to be called upon object destruction. + // + // Registered functions are invoked in the reserve order of registration. + void AddCleanupFunction(std::function<void()>&& cleanup_fn) { + cleanup_fns_.push_back(std::move(cleanup_fn)); + } + + std::vector<std::function<void()>> cleanup_fns_; }; // Represents runtime information needed to construct a dataset. @@ -476,6 +528,27 @@ class DatasetBase : public core::RefCounted { Status MakeIterator(IteratorContext* ctx, const string& prefix, std::unique_ptr<IteratorBase>* iterator) const { *iterator = MakeIteratorInternal(prefix); + if (ctx->model()) { + // The prefix might contain an index. We need to strip it to make it + // possible for the model to successfully identify the output node. + string sanitized_prefix = prefix; + if (str_util::EndsWith(prefix, "]")) { + sanitized_prefix = prefix.substr(0, prefix.rfind('[')); + } + std::shared_ptr<model::Node> node = + ctx->model()->AddNode((*iterator)->prefix(), sanitized_prefix); + std::vector<string> tokens = + str_util::Split((*iterator)->prefix(), ':', str_util::SkipEmpty()); + node->set_name(tokens[tokens.size() - 1]); + std::shared_ptr<model::Model> model = ctx->model(); + const string& prefix = (*iterator)->prefix(); + (*iterator)->AddCleanupFunction([model, node, prefix]() { + if (node->output()) { + node->output()->remove_input(node); + } + model->RemoveNode(prefix); + }); + } return (*iterator)->Initialize(ctx); } @@ -502,6 +575,8 @@ class DatasetBase : public core::RefCounted { IteratorStateWriter* writer) const; protected: + friend class DatasetToGraphOp; // For access to graph related members. + class DatasetGraphDefBuilder : public GraphDefBuilderWrapper { public: DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {} @@ -519,8 +594,6 @@ class DatasetBase : public core::RefCounted { virtual std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const = 0; - friend class DatasetToGraphOp; // For access to graph related members. - private: const string name_; }; @@ -543,7 +616,7 @@ class DatasetBaseIterator : public IteratorBase { ~DatasetBaseIterator() override { params_.dataset->Unref(); } // The sequence of iterators leading up to this iterator. - const string& prefix() const { return params_.prefix; } + const string& prefix() const override { return params_.prefix; } const DataTypeVector& output_dtypes() const override { return params_.dataset->output_dtypes(); @@ -556,7 +629,23 @@ class DatasetBaseIterator : public IteratorBase { Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) final { tracing::ScopedActivity activity(params_.prefix); - Status s = GetNextInternal(ctx, out_tensors, end_of_sequence); + Status s; + if (ctx->model()) { + std::shared_ptr<model::Node> node = + ctx->model()->LookupNode(params_.prefix); + if (node->output()) { + node->output()->stop_work(); + } + node->start_work(); + s = GetNextInternal(ctx, out_tensors, end_of_sequence); + node->stop_work(); + node->add_element(); + if (node->output()) { + node->output()->start_work(); + } + } else { + s = GetNextInternal(ctx, out_tensors, end_of_sequence); + } if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) { s = errors::Internal( "Iterator \"", params_.prefix, @@ -583,6 +672,60 @@ class DatasetBaseIterator : public IteratorBase { return strings::StrCat(params_.prefix, ":", name); } + // When performance modeling is enabled, this method adds a constant parameter + // to the model node corresponding to this iterator. + void AddConstantParameter(IteratorContext* ctx, const string& name, + int64 value) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->add_constant_param(name, value); + } + } + } + + // When performance modeling is enabled, this method adds a tunable parameter + // to the model node corresponding to this iterator. + // + // The `set_fn` function should set the tunable parameter to the value of + // its input argument. The function should be thread-safe; in particular, the + // state it updates should be protected by a lock as the function can be + // invoked asynchronously. It is guaranteed that this function will not be + // invoked after the iterator is deleted because the model node that owns + // the function is deleted when the iterator is deleted. + void AddTunableParameter(IteratorContext* ctx, const string& name, + int64 value, int64 min, int64 max, + std::function<void(int64)>&& set_fn) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->add_tunable_param(name, value, min, max, std::move(set_fn)); + } + } + } + + // When performance modeling is enabled, this method records the fact that + // a thread of this iterator has started work. + void StartWork(IteratorContext* ctx) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->start_work(); + } + } + } + + // When performance modeling is enabled, this method records the fact that + // a thread of this iterator has stopped work. + void StopWork(IteratorContext* ctx) { + if (ctx->model()) { + std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix()); + if (node) { + node->stop_work(); + } + } + } + private: BaseParams params_; }; @@ -730,6 +873,21 @@ class BackgroundWorker { std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_); }; +} // namespace data + +// TODO(b/114112161): Remove these aliases when all users have moved over to the +// `tensorflow::data` namespace. +using data::DatasetBase; +using data::DatasetContext; +using data::DatasetIterator; +using data::DatasetOpKernel; +using data::IteratorBase; +using data::IteratorContext; +using data::IteratorStateReader; +using data::IteratorStateWriter; +using data::SerializationContext; +using data::UnaryDatasetOpKernel; + } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_H_ diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h index 3b48999edb..74bd39cb61 100644 --- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -16,38 +16,38 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ #define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#include <unordered_set> #include "tensorflow/core/lib/core/status.h" namespace tensorflow { -namespace dataset { +namespace data { // Registry for stateful ops that need to be used in dataset functions. // See below macro for usage details. class WhitelistedStatefulOpRegistry { public: - Status Add(StringPiece op_name) { - op_names_.insert(op_name); + Status Add(string op_name) { + op_names_.insert(std::move(op_name)); return Status::OK(); } - bool Contains(StringPiece op_name) { - return op_names_.find(op_name) != op_names_.end(); - } + bool Contains(const string& op_name) { return op_names_.count(op_name); } static WhitelistedStatefulOpRegistry* Global() { - static WhitelistedStatefulOpRegistry* reg = - new WhitelistedStatefulOpRegistry; + static auto* reg = new WhitelistedStatefulOpRegistry; return reg; } private: - WhitelistedStatefulOpRegistry() {} - WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy); + WhitelistedStatefulOpRegistry() = default; + WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) = + delete; WhitelistedStatefulOpRegistry operator=( - WhitelistedStatefulOpRegistry const& copy); - std::set<StringPiece> op_names_; + WhitelistedStatefulOpRegistry const& copy) = delete; + + std::unordered_set<string> op_names_; }; -} // namespace dataset +} // namespace data // Use this macro to whitelist an op that is marked stateful but needs to be // used inside a map_fn in an input pipeline. This is only needed if you wish @@ -67,10 +67,9 @@ class WhitelistedStatefulOpRegistry { WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name) #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) -#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ - static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \ - name) +#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ + static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name) } // namespace tensorflow diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 794250a2c1..53ac639b4c 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -214,10 +214,12 @@ class DeviceBase { // This is overridden by GPU devices to reinitialize the derived // type returned by MakeGpuDevice. - virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/, - PerOpGpuDevice* /*device*/, - DeviceContext* /*dc*/, - Allocator* /*allocator*/) {} + virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/, + PerOpGpuDevice* /*device*/, + DeviceContext* /*dc*/, + Allocator* /*allocator*/) { + return Status::OK(); + } // Unimplemented by default virtual const DeviceAttributes& attributes() const; diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index 26f32677af..a17959a448 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1154,6 +1154,17 @@ Status FunctionLibraryDefinition::LookUp( return default_registry_->LookUp(op, op_reg_data); } +string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { + tf_shared_lock l(mu_); + int index = 0; + string name = strings::StrCat(prefix, index); + while (function_defs_.find(name) != function_defs_.end()) { + ++index; + name = strings::StrCat(prefix, index); + } + return name; +} + const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( const NodeDef& ndef) const { if (ndef.op() != kGradientOp) { @@ -1283,6 +1294,18 @@ FunctionDef FunctionDefHelper::Create( for (const auto& r : ret_def) { fdef.mutable_ret()->insert({r.first, r.second}); } + + auto* op_def_registry = OpRegistry::Global(); + // Check if any op is stateful. + for (const auto& n : node_def) { + const OpDef* op_def = nullptr; + auto status = op_def_registry->LookUpOpDef(n.op, &op_def); + // Lookup can fail if e.g. we are calling a function that was not yet + // defined. If it happens, conservatively assume the op is stateful. + if (!status.ok() || op_def->is_stateful()) { + fdef.mutable_signature()->set_is_stateful(true); + } + } return fdef; } @@ -1344,6 +1367,7 @@ FunctionDef FunctionDefHelper::Define(const string& name, strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first); } } + if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true); } // Returns diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index 03296a7761..e01eb7503d 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -358,6 +358,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface { const OpRegistrationData** op_reg_data) const override LOCKS_EXCLUDED(mu_); + // Generates new function name with the specified prefix that is unique + // across this library. + string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_); + // Ops created for function arguments bear the name given by `kArgOp`; those // created for return values bear the name given by `kRetOp`. static constexpr const char* const kArgOp = "_Arg"; diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index 46b169dddc..d5c203d276 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -91,6 +91,40 @@ FunctionDef IsZero() { }); } +FunctionDef RandomUniform() { + const Tensor kZero = test::AsScalar<int64>(0); + const Tensor kTen = test::AsScalar<int64>(10); + + return FDH::Define( + // Name + "RandomUniform", + // Args + {"x: T"}, + // Return values + {"random_uniform: int64"}, + // Attr def + {"T:{float, double, int32, int64, string}"}, + {{{"random_uniform/shape"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform/min"}, + "Const", + {}, + {{"value", kZero}, {"dtype", DT_INT64}}}, + {{"random_uniform/max"}, + "Const", + {}, + {{"value", kTen}, {"dtype", DT_INT64}}}, + {{"random_uniform"}, + "RandomUniformInt", + {}, + {{"T", DT_INT64}, + {"Tout", DT_INT64}, + {"seed", 87654321}, + {"seed2", 42}}}}); +} + FunctionDef XTimesTwo() { const Tensor kTwo = test::AsScalar<int64>(2); return FDH::Define( @@ -110,6 +144,22 @@ FunctionDef XTimesTwo() { }); } +FunctionDef XAddX() { + return FDH::Define( + // Name + "XAddX", + // Args + {"x: T"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + { + {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}}, + }); +} + FunctionDef XTimesTwoInt32() { const Tensor kTwo = test::AsScalar<int64>(2); return FDH::Define( diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h index 6d6476b936..a01743423b 100644 --- a/tensorflow/core/framework/function_testlib.h +++ b/tensorflow/core/framework/function_testlib.h @@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, // x:T -> x * 2. FunctionDef XTimesTwo(); +// x:T -> x + x. +FunctionDef XAddX(); + // x:T -> x * 2, where x is int32. FunctionDef XTimesTwoInt32(); @@ -81,6 +84,9 @@ FunctionDef NonZero(); // x: T -> bool. FunctionDef IsZero(); +// x: T -> int64 +FunctionDef RandomUniform(); + // x:T, y:T -> y:T, x:T FunctionDef Swap(); diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc new file mode 100644 index 0000000000..112298c344 --- /dev/null +++ b/tensorflow/core/framework/model.cc @@ -0,0 +1,365 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/model.h" + +#include <memory> + +#include "tensorflow/core/lib/gtl/map_util.h" + +namespace tensorflow { +namespace data { +namespace model { + +// TODO(jsimsa): Use `Node` subclassing instead of types and node statements. +void Node::CollectTunables( + std::vector<std::shared_ptr<Node::Tunable>>* tunables) { + mutex_lock l(mu_); + for (auto input : inputs_) { + input->CollectTunables(tunables); + } + switch (type_) { + case Type::MAP_AND_BATCH: + case Type::PARALLEL_INTERLEAVE_V2: + case Type::PARALLEL_MAP: { + if (auto* tunable_param = + gtl::FindOrNull(tunable_params_, "parallelism")) { + tunables->push_back(*tunable_param); + } + return; + } + default: + return; + } +} + +int64 Node::GetParameterValue(const string& name) { + if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) { + return (*tunable_param)->value; + } + return constant_params_[name]; +} + +int64 Node::ProcessingTimeLocked() { + switch (type_) { + case Type::BATCH: + case Type::MAP_AND_BATCH: + case Type::PADDED_BATCH: { + int64 batch_size = GetParameterValue("batch_size"); + return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs(); + } + case Type::FILTER: { + std::shared_ptr<Node> input = inputs_.front(); + double ratio = static_cast<double>(input->num_elements()) / + static_cast<double>(num_elements_); + return NanosPerElementLocked() + + static_cast<int64>(ratio * + static_cast<double>(ProcessingTimeForInputs())); + } + case Type::FLAT_MAP: + case Type::INTERLEAVE: + case Type::PARALLEL_INTERLEAVE: + case Type::PARALLEL_INTERLEAVE_V2: { + // TODO(jsimsa): model the first input + // TODO(jsimsa): use processing time history as a prior for future inputs + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 processing_time = + ProcessingTimeForInputs() - inputs_.front()->ProcessingTime(); + return NanosPerElementLocked() + + static_cast<double>(processing_time) / + static_cast<double>(inputs_.size() - 1); + } + case Type::CACHE: + case Type::CONCATENATE: + case Type::MAP: + case Type::PARALLEL_MAP: + case Type::PREFETCH: + // TODO(jsimsa): use processing time history as a prior for future inputs + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + return NanosPerElementLocked() + ProcessingTimeForInputs(); + } + default: + return NanosPerElementLocked(); + } +} + +int64 Node::OutputTimeLocked(std::vector<int64>* input_times) { + switch (type_) { + case Type::BATCH: + case Type::PADDED_BATCH: { + double batch_size = GetParameterValue("batch_size"); + int64 old_value = (*input_times)[input_times->size() - 1]; + (*input_times)[input_times->size() - 1] = static_cast<int64>( + static_cast<double>(old_value + NanosPerElementLocked()) / + batch_size); + auto cleanup = gtl::MakeCleanup([input_times, old_value]() { + (*input_times)[input_times->size() - 1] = old_value; + }); + return NanosPerElementLocked() + + batch_size * OutputTimeForInputs(input_times); + } + case Type::FILTER: { + std::shared_ptr<Node> input = inputs_.front(); + int64 old_value = (*input_times)[input_times->size() - 1]; + double ratio = static_cast<double>(input->num_elements()) / + static_cast<double>(num_elements_); + (*input_times)[input_times->size() - 1] = static_cast<int64>( + static_cast<double>(old_value + NanosPerElementLocked()) / ratio); + auto cleanup = gtl::MakeCleanup([input_times, old_value]() { + (*input_times)[input_times->size() - 1] = old_value; + }); + return NanosPerElementLocked() + + static_cast<int64>( + static_cast<double>(OutputTimeForInputs(input_times)) * ratio); + } + case Type::FLAT_MAP: + case Type::INTERLEAVE: { + // TODO(jsimsa): model the first input + // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1` + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1)); + (*input_times)[input_times->size() - 1] += delta; + auto cleanup = gtl::MakeCleanup([input_times, delta]() { + (*input_times)[input_times->size() - 1] -= delta; + }); + int64 output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + return NanosPerElementLocked() + + static_cast<double>(output_time) / + static_cast<double>(inputs_.size() - 1); + } + case Type::MAP_AND_BATCH: { + double batch_size = GetParameterValue("batch_size"); + double parallelism = GetParameterValue("parallelism"); + int64 delta = + static_cast<int64>(static_cast<double>(NanosPerElementLocked()) / + (batch_size * parallelism)); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = static_cast<int64>( + static_cast<double>(NanosPerElementLocked()) / parallelism + + batch_size * OutputTimeForInputs(input_times)); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_INTERLEAVE: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = GetParameterValue("parallelism"); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_INTERLEAVE_V2: { + // TODO(jsimsa): model the first input + if (inputs_.size() <= 1) { + return NanosPerElementLocked(); + } + int64 delta = static_cast<double>(NanosPerElementLocked()) * + static_cast<double>(inputs_.size() - 1); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 inputs_output_time = OutputTimeForInputs(input_times) - + inputs_.front()->OutputTime(input_times); + double parallelism = + std::min(static_cast<int>(GetParameterValue("cycle_length")), + static_cast<int>(GetParameterValue("parallelism"))); + int64 output_time = + NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) / + static_cast<double>(inputs_.size() - 1)) / + parallelism); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PARALLEL_MAP: { + double parallelism = + std::min(port::NumSchedulableCPUs(), + static_cast<int>(GetParameterValue("parallelism"))); + int64 delta = static_cast<int64>( + static_cast<double>(NanosPerElementLocked()) / parallelism); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + int64 output_time = + static_cast<double>(NanosPerElementLocked()) / parallelism + + OutputTimeForInputs(input_times); + return std::max(0LL, + output_time - input_times->at(input_times->size() - 2)); + } + case Type::PREFETCH: { + int64 delta = NanosPerElementLocked(); + input_times->push_back(delta); + auto cleanup = + gtl::MakeCleanup([input_times]() { input_times->pop_back(); }); + return std::max(0LL, NanosPerElementLocked() + + OutputTimeForInputs(input_times) - + input_times->at(input_times->size() - 2)); + } + case Type::CACHE: + case Type::CONCATENATE: + case Type::MAP: + case Type::REPEAT: + case Type::SHUFFLE: + case Type::SKIP: + case Type::TAKE: + case Type::ZIP: { + int64 delta = NanosPerElementLocked(); + (*input_times)[input_times->size() - 1] += delta; + auto cleanup = gtl::MakeCleanup([input_times, delta]() { + (*input_times)[input_times->size() - 1] -= delta; + }); + return NanosPerElementLocked() + OutputTimeForInputs(input_times); + } + default: + return NanosPerElementLocked(); + } +} + +std::shared_ptr<Node> Model::AddNode(const string& name, + const string& output_name) { + mutex_lock l(mu_); + std::shared_ptr<Node> output; + auto it = lookup_table_.find(output_name); + if (it != lookup_table_.end()) { + output = it->second; + } + std::shared_ptr<Node> node(new Node(id_counter_++, output)); + if (!output_) { + output_ = node; + } + if (output) { + output->add_input(node); + } + lookup_table_.insert(std::make_pair(name, node)); + return node; +} + +std::shared_ptr<Node> Model::LookupNode(const string& name) { + tf_shared_lock l(mu_); + std::shared_ptr<Node> result; + auto it = lookup_table_.find(name); + if (it != lookup_table_.end()) { + result = it->second; + } + return result; +} + +// The optimization algorithm starts by setting all tunable parallelism +// parameters to 1. It then repeatedly identifies the parameter that whose +// increase in parallelism decreases the output time the most. This process is +// repeated until all parameters reach their maximum values or the +// projected output time is less than or equal to the processing time needed to +// produce an element divided by CPU budget. +void Model::Optimize(int64 cpu_budget) { + mutex_lock l(optimization_mu_); + std::vector<std::shared_ptr<Node::Tunable>> tunables; + { + mutex_lock l2(mu_); + const int64 processing_time = ProcessingTime(); + tunables = CollectTunables(); + for (auto tunable : tunables) { + tunable->value = 1; + } + while (true) { + const int64 output_time = OutputTime(); + bool all_tunables = true; + for (auto& tunable : tunables) { + if (tunable->value < tunable->max) { + all_tunables = false; + break; + } + } + if (output_time < processing_time / cpu_budget || all_tunables) { + break; + } + int64 best_delta = -1; + Node::Tunable* best_tunable = nullptr; + for (auto& tunable : tunables) { + if (tunable->value == tunable->max) { + continue; + } + tunable->value++; + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_tunable = tunable.get(); + } + tunable->value--; + } + if (!best_tunable) { + // NOTE: This can happen because we are performing the optimization + // while the model data is changing. If this becomes an issue, we should + // look into performing the optimization using a model snapshot. + break; + } + best_tunable->value++; + } + } + // The `set_fn` functions should be invoked without holding a lock to avoid a + // potential deadlock. + for (auto& tunable : tunables) { + tunable->set_fn(tunable->value); + } +} + +void Model::RemoveNode(const string& prefix) { + // Nodes are not allowed to be removed when optimization is in progress to + // prevent the optimization from trying to access an iterator that was + // concurrently deleted. + mutex_lock l(optimization_mu_); + mutex_lock l2(mu_); + lookup_table_.erase(prefix); +} + +std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() { + std::vector<std::shared_ptr<Node::Tunable>> tunables; + output_->CollectTunables(&tunables); + return tunables; +} + +int64 Model::OutputTime() { + std::vector<int64> input_times(1, 0); + return output_->OutputTime(&input_times); +} + +int64 Model::ProcessingTime() { return output_->ProcessingTime(); } + +} // namespace model +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h new file mode 100644 index 0000000000..f88ec06ef3 --- /dev/null +++ b/tensorflow/core/framework/model.h @@ -0,0 +1,379 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ + +#include <list> +#include <memory> +#include <string> +#include <thread> // (b/114492873): move this include into core/platform +#include <utility> +#include <vector> + +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace data { +namespace model { + +class Model; +class Node; + +// Abstract representation of a TensorFlow input pipeline node. It collects +// information about inputs to this node, processing time spent executing the +// node logic, number of elements produced by the node, various other +// information (e.g. batch size or execution parallelism). +// +// Developers of tf.data transformations are not expected to interact with this +// class directly. Boiler plate code for creating the abstract representation of +// the input pipeline and collecting common information has been added to the +// implementation of `DatasetBase` and `DatasetBaseIterator` respectively. +// +// In addition, `DatasetBaseIterator` provides wrappers that can be used for +// transformation-specific information collection. The `SetMetadata` wrapper can +// be used to pass arbitrary metadata to the modeling framework, while the +// `StartWork` and `StopWork` wrappers should be used to correctly account for +// processing time of multi-threaded transformation that yield the CPU; such +// transformations should invoke `StartWork()` when a transformation thread +// starts executing (e.g. when created or woken up) and `StopWork()` when a +// transformation thread stops executing (e.g. when returning or waiting). +// +// TODO(jsimsa): Create an API to capture the abstract semantics of each +// tf.data transformation and replace switch-case blocks with inheritance. +class Node { + public: + Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {} + + // Adds a constant parameter. + void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + constant_params_[name] = value; + } + + // Records that the node produced an element. + void add_element() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + num_elements_++; + } + + // Adds an input. + void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.push_back(node); + } + + // Increments the aggregate processing time by the given delta. + void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + processing_time_ += delta; + } + + // Adds a tunable parameter. + void add_tunable_param(const string& name, int64 value, int64 min, int64 max, + std::function<void(int64)>&& set_fn) + LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + tunable_params_[name] = + std::make_shared<Tunable>(value, min, max, std::move(set_fn)); + } + + // Returns the unique node ID. + int64 id() LOCKS_EXCLUDED(mu_) { return id_; } + + // Returns the node inputs. + std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return inputs_; + } + + // Returns the node name. + const string& name() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return name_; + } + + // Returns the number of elements produced by the node. + int64 num_elements() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return num_elements_; + } + + // Returns the node output. + std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return output_; + } + + // Removes an input. + void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + inputs_.remove(input); + } + + // Sets the node name. + void set_name(const string& name) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + name_ = name; + type_ = TypeFromName(name); + } + + // Set the node output. + void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + output_ = output; + } + + // Records that a node thread has started work. + void start_work() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos(); + } + + // Records that a node thread has stopped work. + void stop_work() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + auto iter = work_start_.find(std::this_thread::get_id()); + CHECK(work_start_.end() != iter) + << "Encountered a stop event that was not preceded by a start event."; + processing_time_ += Env::Default()->NowNanos() - iter->second; + work_start_.erase(iter); + } + + private: + // Represents a tunable parameter. + struct Tunable { + Tunable(int64 value, int64 min, int64 max, + std::function<void(int64)> set_fn) + : value(value), min(min), max(max), set_fn(std::move(set_fn)) {} + + int64 value; + int64 min; + int64 max; + std::function<void(int64)> set_fn; + }; + + enum class Type { + BATCH = 0, + CACHE, + CONCATENATE, + FILTER, + FLAT_MAP, + INTERLEAVE, + MAP, + MAP_AND_BATCH, + PADDED_BATCH, + PARALLEL_INTERLEAVE, + PARALLEL_INTERLEAVE_V2, + PARALLEL_MAP, + PREFETCH, + REPEAT, + SHUFFLE, + SKIP, + TAKE, + ZIP, + UNKNOWN, + }; + + // Collects tunable parameters in the subtree rooted in this node. + void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables) + LOCKS_EXCLUDED(mu_); + + // Gets a value of the given parameter (tunable or constant). + int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the per-element processing time spent in this node. + int64 NanosPerElement() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return NanosPerElementLocked(); + } + + int64 NanosPerElementLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (num_elements_ == 0) { + return 0; + } + return (int64)((double)processing_time_ / (double)num_elements_); + } + + // Returns the per-element output time for this node. + int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return OutputTimeLocked(input_times); + } + + int64 OutputTimeLocked(std::vector<int64>* input_times) + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + int64 OutputTimeForInputs(std::vector<int64>* input_times) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto input : inputs_) { + sum += input->OutputTime(input_times); + } + return sum; + } + + // Returns the per-element processing time spent in the subtree rooted in this + // node. + int64 ProcessingTime() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return ProcessingTimeLocked(); + } + + int64 ProcessingTimeLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Returns the per-element processing time spent in the inputs of this node. + int64 ProcessingTimeForInputs() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 sum = 0; + for (auto input : inputs_) { + sum += input->ProcessingTimeLocked(); + } + return sum; + } + + Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (name_ == "Batch") { + return Type::BATCH; + } + if (str_util::EndsWith(name_, "Cache")) { + return Type::CACHE; + } + if (name_ == "Concatenate") { + return Type::CONCATENATE; + } + if (name_ == "Filter") { + return Type::FILTER; + } + if (name_ == "FlatMap") { + return Type::FLAT_MAP; + } + if (name_ == "Interleave") { + return Type::INTERLEAVE; + } + if (name_ == "Map") { + return Type::MAP; + } + if (name_ == "MapAndBatch") { + return Type::MAP_AND_BATCH; + } + if (name_ == "PaddedBatch") { + return Type::PADDED_BATCH; + } + if (name_ == "ParallelInterleave") { + return Type::PARALLEL_INTERLEAVE; + } + if (name_ == "ParallelInterleaveV2") { + return Type::PARALLEL_INTERLEAVE_V2; + } + if (name_ == "ParallelMap") { + return Type::PARALLEL_MAP; + } + if (name_ == "Prefetch") { + return Type::PREFETCH; + } + if (str_util::EndsWith(name_, "Repeat")) { + return Type::REPEAT; + } + if (name_ == "Shuffle") { + return Type::SHUFFLE; + } + if (str_util::EndsWith(name_, "Skip")) { + return Type::SKIP; + } + if (str_util::EndsWith(name_, "Take")) { + return Type::TAKE; + } + if (name_ == "Zip") { + return Type::ZIP; + } + return Type::UNKNOWN; + } + + mutex mu_; + const int64 id_; + Type type_ GUARDED_BY(mu_); + string name_ GUARDED_BY(mu_); + int64 processing_time_ GUARDED_BY(mu_) = 0; + int64 num_elements_ GUARDED_BY(mu_) = 0; + std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_); + std::map<string, int64> constant_params_ GUARDED_BY(mu_); + // Tunables are shared with the model during optimization. + std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_); + std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_); + std::shared_ptr<Node> output_ GUARDED_BY(mu_); + + friend class Model; +}; + +// Abstract representation of a TensorFlow input pipeline that can be used +// for collecting runtime information and optimizing performance. It collects +// runtime information about execution of the input pipeline that is used to +// create a performance model, which is in turn used to identify optimal values +// of tunable parameters. +// +// Developers of tf.data transformations are not expected to interact with this +// class directly. Boiler plate code for creating the abstract representation of +// the input pipeline and collecting runtime information has been added to the +// implementation of `DatasetBase` and `DatasetBaseIterator` respectively. +class Model { + public: + Model() = default; + + // Returns the model output node. + std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) { + tf_shared_lock l(mu_); + return output_; + } + + // Adds a node with the given name and given output (identified by name). + std::shared_ptr<Node> AddNode(const string& name, const string& output_name) + LOCKS_EXCLUDED(mu_); + + // Looks up the node using the given name. + std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_); + + // Runs optimization. + void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_); + + // Removes the node identified by the given name. + void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_); + + private: + std::vector<std::shared_ptr<Node::Tunable>> CollectTunables() + EXCLUSIVE_LOCKS_REQUIRED(mu_); + + int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Used for coordination between different input pipeline threads. + mutex mu_; + // Used for preventing iterator deletion when optimization is in progress + // because the optimization may try to update the values of tunable + // parameters. + mutex optimization_mu_ ACQUIRED_BEFORE(mu_); + int64 id_counter_ GUARDED_BY(mu_) = 1; + std::shared_ptr<Node> output_ GUARDED_BY(mu_); + std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_); +}; + +} // namespace model +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_ diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc index bacc1d72c4..42ec315a32 100644 --- a/tensorflow/core/framework/node_def_util.cc +++ b/tensorflow/core/framework/node_def_util.cc @@ -403,6 +403,14 @@ Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, return OutputTypesForNode(node_def, op_def, outputs); } +Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs) { + DataTypeVector outputs; + TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs)); + *num_outputs = outputs.size(); + return Status::OK(); +} + Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) { if (node_def.op() != op_def.name()) { return errors::InvalidArgument("NodeDef op '", node_def.op(), diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 499034cab2..7528d3d306 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -261,6 +261,10 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, // REQUIRES: ValidateOpDef(op_def).ok() Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, DataTypeVector* inputs, DataTypeVector* outputs); +// Computes the number of outputs for a specific node. +// REQUIRES: ValidateOpDef(op_def).ok() +Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, + int* num_outputs); // Validates that the NodeDef: // * Defines all expected attrs from the OpDef. diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc index c694e10193..3e34bf0418 100644 --- a/tensorflow/core/framework/op_kernel.cc +++ b/tensorflow/core/framework/op_kernel.cc @@ -41,6 +41,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs, // OpKernel ------------------------------------------------------------------ -// TODO(mrry): Convert to std::make_unique when available. OpKernel::OpKernel(OpKernelConstruction* context) - : OpKernel(context, - std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {} + : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {} OpKernel::OpKernel(OpKernelConstruction* context, std::unique_ptr<const NodeDef> node_def) @@ -266,9 +265,12 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs) params_->ensure_eigen_gpu_device(); if (params_->eigen_gpu_device != nullptr) { Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes()); - params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device, - params_->op_device_context, - eigen_gpu_allocator); + Status s = params_->device->ReinitializeGpuDevice( + this, params_->eigen_gpu_device, params_->op_device_context, + eigen_gpu_allocator); + if (!s.ok()) { + SetStatus(s); + } } if (params_->record_tensor_accesses) { referenced_tensors_.Init(); @@ -525,10 +527,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input( return nullptr; } } - // TODO(rmlarsen): Use MakeUnique here. There is already a copy in - // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of - // general cleanup of ownership in this code. - std::unique_ptr<Tensor> output_tensor(new Tensor()); + + auto output_tensor = MakeUnique<Tensor>(); CHECK(output_tensor->CopyFrom(*input.tensor, output_shape)); return output_tensor; } diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h index e752599de1..4bbd6c3d7d 100644 --- a/tensorflow/core/framework/op_kernel.h +++ b/tensorflow/core/framework/op_kernel.h @@ -372,18 +372,37 @@ class OpKernelConstruction { template <typename ListType, typename ElementType> class OpArgIterator { public: - typedef OpArgIterator<ListType, ElementType> ME; + using iterator_category = std::forward_iterator_tag; + using value_type = ElementType; + using pointer = ElementType*; + using reference = ElementType&; + using difference_type = ptrdiff_t; + OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {} - bool operator==(const ME& rhs) { + + bool operator==(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ == rhs.i_; } - bool operator!=(const ME& rhs) { + + bool operator!=(const OpArgIterator& rhs) { DCHECK(list_ == rhs.list_); return i_ != rhs.i_; } - void operator++() { ++i_; } - ElementType& operator*() { return (*list_)[i_]; } + + OpArgIterator operator++() { // prefix ++it + ++i_; + return *this; + } + + OpArgIterator operator++(int) { // postfix it++ + OpArgIterator old_value = *this; + ++i_; + return old_value; + } + + reference operator*() { return (*list_)[i_]; } + pointer operator->() { return &(*list_)[i_]; } private: const ListType* const list_; @@ -394,7 +413,7 @@ class OpArgIterator { // that are passed to the op as a single named argument. class OpInputList { public: - typedef OpArgIterator<OpInputList, const Tensor&> Iterator; + typedef OpArgIterator<OpInputList, const Tensor> Iterator; OpInputList() : ctx_(nullptr), start_(0), stop_(0) {} OpInputList(OpKernelContext* ctx, int start, int stop) : ctx_(ctx), start_(start), stop_(stop) {} diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc index dfc5aa7747..75ed4a4eaf 100644 --- a/tensorflow/core/framework/op_segment.cc +++ b/tensorflow/core/framework/op_segment.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/core/framework/op_segment.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) { delete item; } +bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op) { + // OpSegment should not own kernel if the node is stateless, or a function. + return lib->IsStateful(node_op) && + lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr; +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h index 4433a2554f..37d939ea2b 100644 --- a/tensorflow/core/framework/op_segment.h +++ b/tensorflow/core/framework/op_segment.h @@ -60,6 +60,10 @@ class OpSegment { Status FindOrCreate(const string& session_handle, const string& node_name, OpKernel** kernel, CreateKernelFn create_fn); + // Returns true if OpSegment should own the kernel. + static bool ShouldOwnKernel(FunctionLibraryRuntime* lib, + const string& node_op); + private: // op name -> OpKernel typedef std::unordered_map<string, OpKernel*> KernelMap; diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index 0a19861efd..ebdaaec153 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -271,7 +271,7 @@ string ContainerInfo::DebugString() const { "]"); } -ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) { +const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) { return ctx->input(input).flat<ResourceHandle>()(0); } diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index f8a587c9b5..d58deaa3fc 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -79,7 +79,7 @@ class ResourceBase : public core::RefCounted { virtual string DebugString() = 0; // Returns memory used by this resource. - virtual int64 MemoryUsed() const { return 0; }; + virtual int64 MemoryUsed() const { return 0; } }; // Container used for per-step resources. @@ -234,7 +234,7 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx, const string& name); // Returns a resource handle from a numbered op input. -ResourceHandle HandleFromInput(OpKernelContext* ctx, int input); +const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input); Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle); @@ -348,6 +348,8 @@ class ResourceHandleOp : public OpKernel { void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + private: string container_; string name_; diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h index 4a18efc940..af53ed0a3c 100644 --- a/tensorflow/core/framework/stats_aggregator.h +++ b/tensorflow/core/framework/stats_aggregator.h @@ -25,6 +25,8 @@ namespace tensorflow { class Summary; +namespace data { + // A `StatsAggregator` accumulates statistics incrementally. A // `StatsAggregator` can accumulate multiple different statistics, distinguished // by a string name. @@ -87,6 +89,7 @@ class StatsAggregatorResource : public ResourceBase { const std::shared_ptr<StatsAggregator> stats_aggregator_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_ diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 1b19ab5da3..696fd277cd 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -37,11 +37,12 @@ namespace tensorflow { class AllocationDescription; class Allocator; class OpKernelContext; +class Tensor; class TensorBuffer; class TensorCApi; class TensorDescription; class TensorProto; -class VariantTensorData; + namespace batch_util { Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index); diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 84a373c196..9a78cdc91e 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/math/math_util.h" diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h index 4bda8f9eb8..a7cf600bab 100644 --- a/tensorflow/core/framework/tensor_util.h +++ b/tensorflow/core/framework/tensor_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_ #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include <vector> diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h index 15b1add2c1..2e96b05787 100644 --- a/tensorflow/core/framework/types.h +++ b/tensorflow/core/framework/types.h @@ -30,7 +30,6 @@ limitations under the License. #include "tensorflow/core/framework/numeric_types.h" #include "tensorflow/core/framework/resource_handle.h" #include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" @@ -39,6 +38,8 @@ limitations under the License. namespace tensorflow { +class Variant; + // MemoryType is used to describe whether input or output Tensors of // an OpKernel should reside in "Host memory" (e.g., CPU memory) or // "Device" Memory (CPU memory for CPU devices, GPU memory for GPU diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc index 5a507804b0..d43e3c72ec 100644 --- a/tensorflow/core/framework/variant.cc +++ b/tensorflow/core/framework/variant.cc @@ -23,11 +23,11 @@ limitations under the License. namespace tensorflow { -bool Variant::TryDecode(Variant* out) const { - const VariantTensorDataProto* p = get<VariantTensorDataProto>(); - if (p == nullptr) return false; - VariantTensorData data(*p); - return out->Decode(data); +bool Variant::Decode(VariantTensorData data) { + if (!is_empty()) { + return value_->Decode(std::move(data)); + } + return true; } template <> @@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) { template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data) { - data->FromProto(value); + data->FromConstProto(value); } template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value) { - data.ToProto(value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) { + data->ToProto(value); return true; } @@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) { } template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value) { - return value->ParseFromString(buf); +bool DecodeVariant(string* buf, VariantTensorDataProto* value) { + return value->ParseFromString(*buf); } void EncodeVariantList(const Variant* variant_array, int64 n, @@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, if (variant_array[i].is_empty()) { variant_array[i] = VariantTensorDataProto(); } + // TODO(ebrevdo): Replace with StringPiece? Any way to make this a + // zero-copy operation that keeps a reference to the data in d? string str(d->Data(sizes[i]), sizes[i]); - if (!variant_array[i].Decode(str)) return false; + if (!variant_array[i].Decode(std::move(str))) return false; if (!DecodeUnaryVariant(&variant_array[i])) { LOG(ERROR) << "Could not decode variant with type_name: \"" << variant_array[i].TypeName() diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h index 52732801a0..10eabbc85f 100644 --- a/tensorflow/core/framework/variant.h +++ b/tensorflow/core/framework/variant.h @@ -23,7 +23,6 @@ limitations under the License. #include <unordered_map> #include <utility> -#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove #include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/core/status.h" @@ -38,17 +37,19 @@ string TypeNameVariant(const T& value); template <typename T> string DebugStringVariant(const T& value); +// Allows for specializations of Variant Decoding. `data` may be modified in +// the process of decoding to `value`. template <typename T> -void EncodeVariant(const T& value, VariantTensorData* data); +bool DecodeVariant(VariantTensorData* data, T* value); template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value); +bool DecodeVariant(string* buf, T* value); template <typename T> -void EncodeVariant(const T& value, string* buf); +void EncodeVariant(const T& value, VariantTensorData* data); template <typename T> -bool DecodeVariant(const string& buf, T* value); +void EncodeVariant(const T& value, string* buf); // This is an implementation of a type-erased container that can store an // object of any type. The implementation is very similar to std::any, but has @@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value); // // string TypeName() const; // void Encode(VariantTensorData* data) const; -// void Decode(const VariantTensorData& data); +// void Decode(VariantTensorData data); // // Simple POD types can elide the Encode/Decode functions, they are provided by // helper methods. @@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value); // x.Encode(&serialized_f); // // Variant y = Foo(); // default constructed Foo. -// y.Decode(&serialized_f); +// y.Decode(std::move(serialized_f)); // EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>()); // // @@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value); // EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo. // EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(), // y_type_unknown.TypeId()); -// // Decode and get y_type_unknown; compare to value in x. -// Foo f_decoded; -// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded)); -// EXPECT_EQ(f_decoded, f); // class Variant { public: @@ -241,12 +238,7 @@ class Variant { } // Deserialize `data` and update the stored object. - bool Decode(const VariantTensorData& data) { - if (!is_empty()) { - return value_->Decode(data); - } - return true; - } + bool Decode(VariantTensorData data); // Helper methods to directly serialize/deserialize from strings. void Encode(string* buf) const { @@ -254,31 +246,13 @@ class Variant { value_->Encode(buf); } } - bool Decode(const string& buf) { + bool Decode(string buf) { if (!is_empty()) { - return value_->Decode(buf); + return value_->Decode(std::move(buf)); } return true; } - template <typename T> - bool MaybeDecodeAndCopy(T* out) const { - const T* ret = get<T>(); - if (ret != nullptr) { - *out = std::move(*ret); - return true; - }; - Variant decoded = T(); - if (!TryDecode(&decoded)) return false; - T* decoded_ret = decoded.get<T>(); - CHECK_NOTNULL(decoded_ret); - *out = std::move(*decoded_ret); - return true; - } - - private: - bool TryDecode(Variant* out) const; - private: struct in_place_t {}; static constexpr in_place_t in_place{}; @@ -292,9 +266,9 @@ class Variant { virtual string TypeName() const = 0; virtual string DebugString() const = 0; virtual void Encode(VariantTensorData* data) const = 0; - virtual bool Decode(const VariantTensorData& data) = 0; + virtual bool Decode(VariantTensorData data) = 0; virtual void Encode(string* buf) const = 0; - virtual bool Decode(const string& data) = 0; + virtual bool Decode(string data) = 0; }; template <typename T> @@ -325,15 +299,13 @@ class Variant { EncodeVariant(value, data); } - bool Decode(const VariantTensorData& data) override { - return DecodeVariant(data, &value); + bool Decode(VariantTensorData data) override { + return DecodeVariant(&data, &value); } void Encode(string* buf) const override { EncodeVariant(value, buf); } - bool Decode(const string& buf) override { - return DecodeVariant(buf, &value); - } + bool Decode(string buf) override { return DecodeVariant(&buf, &value); } T value; }; diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h index f155aa4892..5e08e5a7a6 100644 --- a/tensorflow/core/framework/variant_encode_decode.h +++ b/tensorflow/core/framework/variant_encode_decode.h @@ -22,6 +22,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/variant_tensor_data.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/abi.h" @@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value, // Specialization for POD type template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, true /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { @@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for tensorflow::Tensor template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, true /* Tensor */, false /* protobuf */>, T* value) { @@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for protobuf template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, true /* protobuf */>, T* value) { @@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data, // Specialization for other types template <typename T> -bool DecodeVariantImpl(const VariantTensorData& data, +bool DecodeVariantImpl(VariantTensorData data, TypeResolver<T, false /* is_pod */, false /* Tensor */, false /* protobuf */>, T* value) { - return value->Decode(data); + return value->Decode(std::move(data)); } template <typename C, typename = void> @@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) { } template <typename T> -bool DecodeVariant(const VariantTensorData& data, T* value) { - return DecodeVariantImpl(data, TypeResolver<T>(), value); +bool DecodeVariant(VariantTensorData* data, T* value) { + return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value); } template <typename T> @@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) { } template <typename T> -bool DecodeVariant(const string& buf, T* value) { +bool DecodeVariant(string* buf, T* value) { VariantTensorData data; - if (!data.ParseFromString(buf)) return false; - if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false; + if (!data.ParseFromString(*buf)) return false; + if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) { + return false; + } return true; } // Specializations for VariantTensorDataProto template <> string TypeNameVariant(const VariantTensorDataProto& value); + template <> void EncodeVariant(const VariantTensorDataProto& value, VariantTensorData* data); + template <> -bool DecodeVariant(const VariantTensorData& data, - VariantTensorDataProto* value); +bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value); + template <> void EncodeVariant(const VariantTensorDataProto& value, string* buf); + template <> -bool DecodeVariant(const string& buf, VariantTensorDataProto* value); +bool DecodeVariant(string* buf, VariantTensorDataProto* value); // Encodes an array of Variant objects in to the given StringListEncoder. // `variant_array` is assumed to point to an array of `n` Variant objects. diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc index 60fa7bd559..daa744e877 100644 --- a/tensorflow/core/framework/variant_op_copy_test.cc +++ b/tensorflow/core/framework/variant_op_copy_test.cc @@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyCPUToGPU); + StoredTensorValue::CopyCPUToGPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST, - "StoredTensorValue", StoredTensorValue::CopyGPUToCPU); + StoredTensorValue::CopyGPUToCPU); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, - "StoredTensorValue", StoredTensorValue::CopyGPUToGPU); + StoredTensorValue::CopyGPUToGPU); REGISTER_OP("CreateTestVariant") .Input("input: T") diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc index ee07db1aee..ef5b240aea 100644 --- a/tensorflow/core/framework/variant_op_registry.cc +++ b/tensorflow/core/framework/variant_op_registry.cc @@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() { } UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn( - StringPiece type_name) { - auto found = shape_fns.find(type_name); + const TypeIndex& type_index) { + auto found = shape_fns.find(type_index); if (found == shape_fns.end()) return nullptr; return &found->second; } -void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name, +void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index, const VariantShapeFn& shape_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape"; - VariantShapeFn* existing = GetShapeFn(type_name); + VariantShapeFn* existing = GetShapeFn(type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantShapeFn for type_name: " << type_name - << " already registered"; - shape_fns.insert(std::pair<StringPiece, VariantShapeFn>( - GetPersistentStringPiece(type_name), shape_fn)); + << "Unary VariantShapeFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered"; + shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn)); } Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { @@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) { CHECK_EQ(variant_tensor.dims(), 0); const Variant& v = variant_tensor.scalar<Variant>()(); UnaryVariantOpRegistry::VariantShapeFn* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName()); + UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId()); if (shape_fn == nullptr) { return errors::Internal( - "No unary variant shape function found for Variant type_name: ", - v.TypeName()); + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(v.TypeId().name())); } return (*shape_fn)(v, shape); } @@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) { } // namespace #define REGISTER_VARIANT_SHAPE_TYPE(T) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>); + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>); // No encode/shape registered for std::complex<> and Eigen::half // objects yet. @@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double); UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* UnaryVariantOpRegistry::GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name) { - auto found = device_copy_fns.find(std::make_pair(direction, type_name)); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { + auto found = device_copy_fns.find(std::make_pair(direction, type_index)); if (found == device_copy_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterDeviceCopyFn( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy"; - AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name); + AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); CHECK_EQ(existing, nullptr) << "UnaryVariantDeviceCopy for direction: " << direction - << " and type_name: " << type_name << " already registered"; + << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) + << " already registered"; device_copy_fns.insert( - std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn>( - std::make_pair(direction, GetPersistentStringPiece(type_name)), - device_copy_fn)); + std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index), + device_copy_fn)); } Status VariantDeviceCopy( @@ -170,35 +167,34 @@ Status VariantDeviceCopy( const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) { UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction, - from.TypeName()); + from.TypeId()); if (device_copy_fn == nullptr) { return errors::Internal( "No unary variant device copy function found for direction: ", - direction, " and Variant type_name: ", from.TypeName()); + direction, " and Variant type_index: ", + port::MaybeAbiDemangle(from.TypeId().name())); } return (*device_copy_fn)(from, to, copy_fn); } // Special casing UnaryOpFn per op and per device. UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn( - VariantUnaryOp op, StringPiece device, StringPiece type_name) { - auto found = unary_op_fns.find({op, device, type_name}); + VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) { + auto found = unary_op_fns.find({op, device, type_index}); if (found == unary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterUnaryOpFn( - VariantUnaryOp op, const string& device, const string& type_name, + VariantUnaryOp op, const string& device, const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp"; - VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name); + VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantUnaryOpFn for type_name: " << type_name + << "Unary VariantUnaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - unary_op_fn)); + {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); } namespace { @@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t, #define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \ - DEVICE_CPU, T, TF_STR(T), \ + DEVICE_CPU, T, \ ZerosLikeVariantPrimitiveType<T>); // No zeros_like registered for std::complex<> or Eigen::half objects yet. @@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool); // Special casing BinaryOpFn per op and per device. UnaryVariantOpRegistry::VariantBinaryOpFn* UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name) { - auto found = binary_op_fns.find({op, device, type_name}); + const TypeIndex& type_index) { + auto found = binary_op_fns.find({op, device, type_index}); if (found == binary_op_fns.end()) return nullptr; return &found->second; } void UnaryVariantOpRegistry::RegisterBinaryOpFn( - VariantBinaryOp op, const string& device, const string& type_name, + VariantBinaryOp op, const string& device, const TypeIndex& type_index, const VariantBinaryOpFn& add_fn) { - CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp"; - VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name); + VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); CHECK_EQ(existing, nullptr) - << "Unary VariantBinaryOpFn for type_name: " << type_name + << "Unary VariantBinaryOpFn for type_index: " + << port::MaybeAbiDemangle(type_index.name()) << " already registered for device type: " << device; binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>( - {op, GetPersistentStringPiece(device), - GetPersistentStringPiece(type_name)}, - add_fn)); + {op, GetPersistentStringPiece(device), type_index}, add_fn)); } namespace { @@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b, #define REGISTER_VARIANT_ADD_TYPE(T) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \ - T, TF_STR(T), \ - AddVariantPrimitiveType<T>); + T, AddVariantPrimitiveType<T>); // No add registered for std::complex<> or Eigen::half objects yet. REGISTER_VARIANT_ADD_TYPE(int); diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h index e6a2665a56..7eb37e859f 100644 --- a/tensorflow/core/framework/variant_op_registry.h +++ b/tensorflow/core/framework/variant_op_registry.h @@ -22,10 +22,14 @@ limitations under the License. #define EIGEN_USE_THREADS +#include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/framework/type_index.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/variant.h" #include "tensorflow/core/framework/variant_encode_decode.h" +#include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/abi.h" namespace tensorflow { @@ -90,10 +94,11 @@ class UnaryVariantOpRegistry { AsyncVariantDeviceCopyFn; // Add a shape lookup function to the registry. - void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn); + void RegisterShapeFn(const TypeIndex& type_index, + const VariantShapeFn& shape_fn); - // Returns nullptr if no shape function was found for the given TypeName. - VariantShapeFn* GetShapeFn(StringPiece type_name); + // Returns nullptr if no shape function was found for the given TypeIndex. + VariantShapeFn* GetShapeFn(const TypeIndex& type_index); // Add a decode function to the registry. void RegisterDecodeFn(const string& type_name, @@ -104,33 +109,33 @@ class UnaryVariantOpRegistry { // Add a copy-to-GPU function to the registry. void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, - const string& type_name, + const TypeIndex& type_index, const AsyncVariantDeviceCopyFn& device_copy_fn); // Returns nullptr if no copy function was found for the given // TypeName and direction. AsyncVariantDeviceCopyFn* GetDeviceCopyFn( - const VariantDeviceCopyDirection direction, StringPiece type_name); + const VariantDeviceCopyDirection direction, const TypeIndex& type_index); // Add a unary op function to the registry. void RegisterUnaryOpFn(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantUnaryOpFn& unary_op_fn); // Returns nullptr if no unary op function was found for the given // op, device, and TypeName. VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Add a binary op function to the registry. void RegisterBinaryOpFn(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const VariantBinaryOpFn& add_fn); // Returns nullptr if no binary op function was found for the given // op, device and TypeName. VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, - StringPiece type_name); + const TypeIndex& type_index); // Get a pointer to a global UnaryVariantOpRegistry object static UnaryVariantOpRegistry* Global(); @@ -145,24 +150,26 @@ class UnaryVariantOpRegistry { static std::unordered_set<string>* PersistentStringStorage(); private: - std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns; - std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher> - decode_fns; + struct TypeIndexHash { + std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } + }; + + gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns; + gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns; // Map std::pair<Direction, type_name> to function. struct PairHash { template <typename Direction> - std::size_t operator()(const std::pair<Direction, StringPiece>& x) const { + std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); - ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); + ret = Hash64Combine(ret, std::get<1>(x).hash_code()); return ret; } - StringPieceHasher sp_hasher_; }; - std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>, - AsyncVariantDeviceCopyFn, PairHash> + gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>, + AsyncVariantDeviceCopyFn, PairHash> device_copy_fns; // Map std::tuple<Op, device, type_name> to function. @@ -172,10 +179,11 @@ class UnaryVariantOpRegistry { // and references therein template <typename Op> struct FuncTuple { - FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname) - : op_type_(op), device_(dev), typename_(tname){}; + FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index) + : op_type_(op), device_(dev), type_index_(type_index) {} Op op_type_; - StringPiece device_, typename_; + StringPiece device_; + TypeIndex type_index_; }; // friend declaration for operator== // needed for clang @@ -184,11 +192,11 @@ class UnaryVariantOpRegistry { struct TupleHash { template <typename Op> std::size_t operator()( - const std::tuple<Op, StringPiece, StringPiece>& x) const { + const std::tuple<Op, StringPiece, TypeIndex>& x) const { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); - ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x))); + ret = Hash64Combine(ret, std::get<2>(x).hash_code()); return ret; } @@ -197,14 +205,14 @@ class UnaryVariantOpRegistry { // The hash of an enum is just its value as a std::size_t. std::size_t ret = static_cast<std::size_t>(x.op_type_); ret = Hash64Combine(ret, sp_hasher_(x.device_)); - ret = Hash64Combine(ret, sp_hasher_(x.typename_)); + ret = Hash64Combine(ret, x.type_index_.hash_code()); return ret; } StringPieceHasher sp_hasher_; }; - std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> + gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> unary_op_fns; - std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> + gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> binary_op_fns; // Find or insert a string into a persistent string storage @@ -225,7 +233,7 @@ template <typename Op> inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs, const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) { return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && - (lhs.typename_ == rhs.typename_); + (lhs.type_index_ == rhs.type_index_); } // Gets a TensorShape from a Tensor containing a scalar Variant. // Returns an Internal error if the Variant does not have a registered shape @@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, Variant* v_out) { const string& device = DeviceName<Device>::value; UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = - UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName()); + UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); if (unary_op_fn == nullptr) { return errors::Internal( "No unary variant unary_op function found for unary variant op enum: ", @@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, template <typename Device> Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, const Variant& a, const Variant& b, Variant* out) { - if (a.TypeName() != b.TypeName()) { + if (a.TypeId() != b.TypeId()) { return errors::Internal( "BianryOpVariants: Variants a and b have different " - "type names: '", + "type ids. Type names: '", a.TypeName(), "' vs. '", b.TypeName(), "'"); } const string& device = DeviceName<Device>::value; UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = - UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName()); + UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); if (binary_op_fn == nullptr) { return errors::Internal( "No unary variant binary_op function found for binary variant op " @@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration { public: typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn; - UnaryVariantShapeRegistration(const string& type_name, + UnaryVariantShapeRegistration(const TypeIndex& type_index, const LocalVariantShapeFn& shape_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterShapeFn( - type_name, - [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status { + type_index, + [type_index_name, shape_fn](const Variant& v, + TensorShape* s) -> Status { const T* t = v.get<T>(); if (t == nullptr) { return errors::Internal( - "VariantShapeFn: Could not access object, type_name: ", - type_name); + "VariantShapeFn: Could not access object, type_index: ", + type_index_name); } return shape_fn(*t, s); }); @@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration { return false; } Variant decoded = T(); - VariantTensorData data(*t); - if (!decoded.Decode(data)) { + VariantTensorData data(std::move(*t)); + if (!decoded.Decode(std::move(data))) { return false; } - *v = std::move(decoded); + std::swap(decoded, *v); return true; }); } @@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration { UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)> LocalVariantDeviceCopyFn; UnaryVariantDeviceCopyRegistration( - const VariantDeviceCopyDirection direction, const string& type_name, + const VariantDeviceCopyDirection direction, const TypeIndex& type_index, const LocalVariantDeviceCopyFn& device_copy_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( - direction, type_name, - [type_name, device_copy_fn]( + direction, type_index, + [type_index_name, device_copy_fn]( const Variant& from, Variant* to, UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn device_copy_tensor_fn) -> Status { @@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration { *to = T(); if (from.get<T>() == nullptr) { return errors::Internal( - "VariantCopyToGPUFn: Could not access object, type_name: ", - type_name); + "VariantCopyToGPUFn: Could not access object, type_index: ", + type_index_name); } const T& t = *from.get<T>(); T* t_out = to->get<T>(); @@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration { public: UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantUnaryOpFn& unary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( - op, device, type_name, - [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, - Variant* v_out) -> Status { + op, device, type_index, + [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, + Variant* v_out) -> Status { DCHECK_NE(v_out, nullptr); *v_out = T(); if (v.get<T>() == nullptr) { return errors::Internal( - "VariantUnaryOpFn: Could not access object, type_name: ", - type_name); + "VariantUnaryOpFn: Could not access object, type_index: ", + type_index_name); } const T& t = *v.get<T>(); T* t_out = v_out->get<T>(); @@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration { public: UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device, - const string& type_name, + const TypeIndex& type_index, const LocalVariantBinaryOpFn& binary_op_fn) { + const string type_index_name = port::MaybeAbiDemangle(type_index.name()); UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( - op, device, type_name, - [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, - const Variant& b, Variant* out) -> Status { + op, device, type_index, + [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, + const Variant& b, + Variant* out) -> Status { DCHECK_NE(out, nullptr); *out = T(); if (a.get<T>() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'a', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'a', type_index: ", + type_index_name); } if (b.get<T>() == nullptr) { return errors::Internal( - "VariantBinaryOpFn: Could not access object 'b', type_name: ", - type_name); + "VariantBinaryOpFn: Could not access object 'b', type_index: ", + type_index_name); } const T& t_a = *a.get<T>(); const T& t_b = *b.get<T>(); @@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration { // Register a unary shape variant function with the signature: // Status ShapeFn(const T& t, TensorShape* s); -// to Variants having TypeName type_name. -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \ - shape_function) +// to Variants having TypeIndex type_index. +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, MakeTypeIndex<T>(), shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \ - shape_function) \ - REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function) +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \ + shape_function) \ + REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function) -#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \ +#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \ shape_function) \ static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \ - register_unary_variant_op_shape_registration_fn_##ctr(type_name, \ + register_unary_variant_op_shape_registration_fn_##ctr(type_index, \ shape_function) // Register a unary decode variant function for the given type. @@ -519,63 +533,63 @@ class UnaryVariantBinaryOpRegistration { // ****** NOTE ****** // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. // ****** NOTE ****** -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - T, direction, type_name, device_copy_fn) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, T, direction, type_name, device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \ + device_copy_fn) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn) #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ - ctr, T, direction, type_name, device_copy_fn) \ + ctr, T, direction, type_index, device_copy_fn) \ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) + ctr, T, direction, type_index, device_copy_fn) -#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ - ctr, T, direction, type_name, device_copy_fn) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantDeviceCopyRegistration<T> \ - register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \ - device_copy_fn) +#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ + ctr, T, direction, type_index, device_copy_fn) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantDeviceCopyRegistration<T> \ + register_unary_variant_op_device_copy_fn_##ctr( \ + direction, type_index, device_copy_fn) // Register a unary unary_op variant function with the signature: // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for UnaryVariantOp enum op. -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \ - unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \ + unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function) -#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, unary_op_function) \ - REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \ - unary_op_function) +#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ + ctr, op, device, T, type_index, unary_op_function) \ + REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \ + type_index, unary_op_function) #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, unary_op_function) \ + ctr, op, device, T, type_index, unary_op_function) \ static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \ T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ unary_op_function) // Register a binary_op variant function with the signature: // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); -// to Variants having TypeName type_name, for device string device, +// to Variants having TypeIndex type_index, for device string device, // for BinaryVariantOp enum OP. -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \ - binary_op_function) \ - REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - __COUNTER__, op, device, T, type_name, binary_op_function) +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \ + binary_op_function) \ + REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ + __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function) #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ - ctr, op, device, T, type_name, binary_op_function) \ + ctr, op, device, T, type_index, binary_op_function) \ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) + ctr, op, device, T, type_index, binary_op_function) -#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ - ctr, op, device, T, type_name, binary_op_function) \ - static variant_op_registry_fn_registration:: \ - UnaryVariantBinaryOpRegistration<T> \ - register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \ +#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ + ctr, op, device, T, type_index, binary_op_function) \ + static variant_op_registry_fn_registration:: \ + UnaryVariantBinaryOpRegistration<T> \ + register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ binary_op_function) } // end namespace tensorflow diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc index 7055e62c0e..b2443e8676 100644 --- a/tensorflow/core/framework/variant_op_registry_test.cc +++ b/tensorflow/core/framework/variant_op_registry_test.cc @@ -89,41 +89,37 @@ struct VariantValue { int value; }; -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue", - VariantValue::ShapeFn); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn); REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue"); INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE, - "TEST VariantValue", VariantValue::CPUToGPUCopyFn); + VariantValue::CPUToGPUCopyFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, VariantValue, - "TEST VariantValue", VariantValue::CPUZerosLikeFn); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, VariantValue, - "TEST VariantValue", VariantValue::GPUZerosLikeFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - VariantValue, "TEST VariantValue", - VariantValue::CPUAddFn); + VariantValue, VariantValue::CPUAddFn); REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - VariantValue, "TEST VariantValue", - VariantValue::GPUAddFn); + VariantValue, VariantValue::GPUAddFn); } // namespace TEST(VariantOpShapeRegistryTest, TestBasic) { - EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"), + class Blah {}; + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()), nullptr); - auto* shape_fn = - UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue"); + auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn( + MakeTypeIndex<VariantValue>()); EXPECT_NE(shape_fn, nullptr); TensorShape shape; @@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) { TEST(VariantOpShapeRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantShapeFn f; - string kTypeName = "fjfjfj"; - registry.RegisterShapeFn(kTypeName, f); - EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f), - "fjfjfj already registered"); + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); + registry.RegisterShapeFn(kTypeIndex, f); + EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpDecodeRegistryTest, TestBasic) { @@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) { TEST(VariantOpCopyToGPURegistryTest, TestBasic) { // No registered copy fn for GPU<->GPU. - EXPECT_EQ( - UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"), - nullptr); + EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( + VariantDeviceCopyDirection::DEVICE_TO_DEVICE, + MakeTypeIndex<VariantValue>()), + nullptr); auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, + MakeTypeIndex<VariantValue>()); EXPECT_NE(copy_to_gpu_fn, nullptr); VariantValue vv{true /* early_exit */}; @@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) { TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE, - kTypeName, f); + kTypeIndex, f); EXPECT_DEATH(registry.RegisterDeviceCopyFn( - VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f), - "fjfjfj already registered"); + VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn( - ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 0 /* value */}; @@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) { TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantUnaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_CPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_CPU, kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName, - f); + registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, + kTypeIndex, f); EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, - DEVICE_GPU, kTypeName, f), - "fjfjfj already registered"); + DEVICE_GPU, kTypeIndex, f), + "FjFjFj already registered"); } TEST(VariantOpAddRegistryTest, TestBasicCPU) { - return; + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) { #if GOOGLE_CUDA TEST(VariantOpAddRegistryTest, TestBasicGPU) { + class Blah {}; EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn( - ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"), + ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()), nullptr); VariantValue vv_early_exit{true /* early_exit */, 3 /* value */}; @@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) { TEST(VariantOpAddRegistryTest, TestDuplicate) { UnaryVariantOpRegistry registry; UnaryVariantOpRegistry::VariantBinaryOpFn f; - string kTypeName = "fjfjfj"; + class FjFjFj {}; + const auto kTypeIndex = MakeTypeIndex<FjFjFj>(); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); - registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f); + registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f); EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - kTypeName, f), - "fjfjfj already registered"); + kTypeIndex, f), + "FjFjFj already registered"); } } // namespace tensorflow diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc index 99712dc114..3e67e4a864 100644 --- a/tensorflow/core/framework/variant_tensor_data.cc +++ b/tensorflow/core/framework/variant_tensor_data.cc @@ -22,8 +22,8 @@ namespace tensorflow { VariantTensorData::VariantTensorData() {} -VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) { - FromProto(proto); +VariantTensorData::VariantTensorData(VariantTensorDataProto proto) { + FromProto(std::move(proto)); } VariantTensorData::~VariantTensorData() {} @@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const { } } -bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) { +bool VariantTensorData::FromProto(VariantTensorDataProto proto) { + // TODO(ebrevdo): Do this lazily. + set_type_name(proto.type_name()); + set_metadata(proto.metadata()); + for (const auto& tensor : proto.tensors()) { + Tensor tmp; + if (!tmp.FromProto(tensor)) return false; + tensors_.push_back(tmp); + } + return true; +} + +bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) { set_type_name(proto.type_name()); set_metadata(proto.metadata()); for (const auto& tensor : proto.tensors()) { @@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) { return proto.SerializeToString(buf); } -bool VariantTensorData::ParseFromString(const string& s) { +bool VariantTensorData::ParseFromString(string s) { VariantTensorDataProto proto; const bool status = proto.ParseFromString(s); - if (status) FromProto(proto); + if (status) FromProto(std::move(proto)); return status; } diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h index 7500e77d43..8a240ee1e3 100644 --- a/tensorflow/core/framework/variant_tensor_data.h +++ b/tensorflow/core/framework/variant_tensor_data.h @@ -19,13 +19,13 @@ limitations under the License. #include <algorithm> #include <vector> +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { class VariantTensorDataProto; -class Tensor; // The serialization format for Variant objects. Objects with references to // other Tensors can simply store those tensors in the `tensors` field, and @@ -38,7 +38,7 @@ class Tensor; class VariantTensorData { public: VariantTensorData(); - VariantTensorData(const VariantTensorDataProto& proto); + VariantTensorData(VariantTensorDataProto proto); ~VariantTensorData(); // Name of the type of objects being serialized. @@ -68,12 +68,14 @@ class VariantTensorData { // Conversion to and from VariantTensorDataProto void ToProto(VariantTensorDataProto* proto) const; - bool FromProto(const VariantTensorDataProto& proto); + // This allows optimizations via std::move. + bool FromProto(VariantTensorDataProto proto); + bool FromConstProto(const VariantTensorDataProto& proto); // Serialization via VariantTensorDataProto string SerializeAsString() const; bool SerializeToString(string* buf); - bool ParseFromString(const string& s); + bool ParseFromString(string s); string DebugString() const; diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc index eef5c47d15..08d09de7b8 100644 --- a/tensorflow/core/framework/variant_test.cc +++ b/tensorflow/core/framework/variant_test.cc @@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) { struct TensorList { void Encode(VariantTensorData* data) const { data->tensors_ = vec; } - bool Decode(const VariantTensorData& data) { - vec = data.tensors_; + bool Decode(VariantTensorData data) { + vec = std::move(data.tensors_); return true; } @@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) { x.Encode(&serialized); Variant y = TensorList(); - y.Decode(serialized); + y.Decode(std::move(serialized)); const TensorList& decoded_vec = *y.get<TensorList>(); for (int i = 0; i < 4; ++i) { @@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) { EXPECT_EQ(y_unknown.DebugString(), strings::StrCat( "Variant<type: TensorList value: ", data.DebugString(), ">")); - - TensorList unknown_decoded_vec; - EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec)); - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i); - } - for (int i = 0; i < 4; ++i) { - EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i); - } } TEST(VariantTest, VariantArray) { diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc index ee10194142..eeb5c14eaa 100644 --- a/tensorflow/core/graph/graph_constructor.cc +++ b/tensorflow/core/graph/graph_constructor.cc @@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() { } if (processed < node_defs_.size()) { - LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed) + LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed) << " NODES IN A CYCLE"; for (int64 i = 0; i < node_defs_.size(); i++) { if (pending_count_[i] != 0) { LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i]) - << "WITH PENDING COUNT = " << pending_count_[i]; + << " WITH PENDING COUNT = " << pending_count_[i]; } } return errors::InvalidArgument(node_defs_.size() - processed, @@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() { const NodeDef* node_def = node_defs_[pair->second.gdef_index]; const OpDef* op_def; TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def)); - if (key.second >= op_def->output_arg_size()) { + int num_outputs; + TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs)); + if (key.second >= num_outputs) { // key's index out of bounds missing_unused_input_map_keys_->push_back(key); } diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc index 73142ebde7..3eef6bd2bd 100644 --- a/tensorflow/core/graph/graph_constructor_test.cc +++ b/tensorflow/core/graph/graph_constructor_test.cc @@ -199,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput") .Output("y: T") .Attr("T: {float, int64}") .SetShapeFn(shape_inference::UnchangedShape); +REGISTER_OP("TestVariadicOutput") + .Output("outputs: N * int32") + .Attr("N: int >= 0") + .SetShapeFn(shape_inference::UnknownShape); REGISTER_OP("TestDefaultAttr") .Attr("default_int: int=31415") .SetShapeFn(shape_inference::NoOutputs); @@ -1463,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) { opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0); // Unused but not missing opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0); + // Unused but not missing + opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0); ExpectOK( R"EOF( node { name: 'W2' op: 'TestParams' } node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] } node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] } - node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] } + node { name: 'variadic' op: 'TestVariadicOutput' + attr { key: "N" value { i: 5 } } } )EOF", opts, &refiner, &results); diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index 2e644fe987..f5b0105862 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/framework/node_def_util.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc index ea7788f654..0a38aa1c91 100644 --- a/tensorflow/core/graph/testlib.cc +++ b/tensorflow/core/graph/testlib.cc @@ -485,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) { return ret; } +Node* CheckNumerics(Graph* g, Node* in, const string& message) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics") + .Input(in) + .Attr("message", message) + .Finalize(g, &ret)); + return ret; +} + +Node* Arg(Graph* g, int64 index, DataType type) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg") + .Attr("T", type) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + +Node* Retval(Graph* g, int64 index, Node* in) { + Node* ret; + TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval") + .Input(in) + .Attr("index", index) + .Finalize(g, &ret)); + return ret; +} + void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); } } // end namespace graph diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h index 8585b35a19..bd0284d43a 100644 --- a/tensorflow/core/graph/testlib.h +++ b/tensorflow/core/graph/testlib.h @@ -209,6 +209,15 @@ Node* Diag(Graph* g, Node* in, DataType type); // Add a DiagPart node in "g". Node* DiagPart(Graph* g, Node* in, DataType type); +// Add a CheckNumerics node in "g". +Node* CheckNumerics(Graph* g, Node* in, const string& message); + +// Add an _Arg node in "g". +Node* Arg(Graph* g, int64 index, DataType type); + +// Add a _Retval node in "g". +Node* Retval(Graph* g, int64 index, Node* in); + } // end namespace graph } // end namespace test } // end namespace tensorflow diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 6710ff9df3..56c8339d57 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -260,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) { } bool IsEnqueue(const NodeDef& n) { - return (n.op().find("Enqueue") != std::string::npos && - n.op().find("EnqueueMany") == std::string::npos); + return (n.op().find("Enqueue") != string::npos && + n.op().find("EnqueueMany") == string::npos); } bool IsDequeue(const NodeDef& n) { - return (n.op().find("Dequeue") != std::string::npos && - n.op().find("DequeueMany") == std::string::npos); + return (n.op().find("Dequeue") != string::npos && + n.op().find("DequeueMany") == string::npos); } bool HasAnyUnknownDimensions(const TensorShapeProto& proto) { @@ -345,6 +345,56 @@ void VerboseLogUnknownDimensionSources( } } +bool IsShapeFullyDefinedIntegerVectorOrScalar( + InferenceContext* ic, const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, const DataType& dtype) { + if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 || + !ic->FullyDefined(tensor_as_shape) || + (dtype != DT_INT32 && dtype != DT_INT64)) { + return false; + } + return true; +} + +// Returned tensor's shape is like `shape`, and its values and dtype are from +// `tensor_as_shape` and `dtype`. +TensorProto MakeTensorProtoFromShape(InferenceContext* ic, + const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, + const DataType& dtype) { + TensorProto tensor_proto; + tensor_proto.set_dtype(dtype); + auto* shape_proto = tensor_proto.mutable_tensor_shape(); + if (ic->Rank(shape) == 1) { + shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape)); + } + // For a scalar tensor, tensor_shape field will be left empty; no dim. + for (int i = 0; i < ic->Rank(tensor_as_shape); i++) { + int64 value = ic->Value(ic->Dim(tensor_as_shape, i)); + if (dtype == DT_INT32) { + tensor_proto.add_int_val(value); + } else { + tensor_proto.add_int64_val(value); + } + } + return tensor_proto; +} + +// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`, +// and dtype = `dtype`. +NodeDef MakeConstNodeDefFromShape(InferenceContext* ic, + const ShapeHandle& shape, + const ShapeHandle& tensor_as_shape, + const DataType& dtype) { + NodeDef const_node; + const_node.set_name("const_from_shape"); + const_node.set_op("Const"); + auto* attr = const_node.mutable_attr(); + (*attr)["dtype"].set_type(dtype); + auto* tensor = (*attr)["value"].mutable_tensor(); + *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype); + return const_node; +} } // namespace // Queue of nodes to process. Nodes can be enqueued in any order, but will be @@ -429,18 +479,22 @@ class SymbolicShapeRefiner { // perform shape inference on the function body. // // Propagate shape information of final function body node - // to function node `node`. + // to function node `function_node`. // - // In the event of an error, UpdateNode will simply set `node`'s + // In the event of an error, UpdateNode will simply set `function_node`'s // output shape to be Unknown. - Status UpdateFunction(const NodeDef* node) { - auto it = fun_to_grappler_function_item_.find(node->op()); + Status UpdateFunction(const NodeDef* function_node) { + auto it = fun_to_grappler_function_item_.find(function_node->op()); if (it == fun_to_grappler_function_item_.end()) { return errors::InvalidArgument( - node->op(), " was not previously added to SymbolicShapeRefiner."); + function_node->op(), + " was not previously added to SymbolicShapeRefiner."); } - GrapplerFunctionItem& grappler_function_item = it->second; + // Copy (not reference) so that changes we make here (e.g., replacing + // Placeholder with Const) don't affect one in + // fun_to_grappler_function_item_. + GrapplerFunctionItem grappler_function_item = it->second; GraphView gv(&grappler_function_item.graph); // Forward shapes from function input nodes to argument nodes. @@ -453,7 +507,7 @@ class SymbolicShapeRefiner { "supported."); } NodeDef* fun_node = gv.GetNode(fun_input.input_name); - const string& input = node->input(i); + const string& input = function_node->input(i); const string& node_name = NodeName(input); if (IsControlInput(input)) { @@ -478,17 +532,48 @@ class SymbolicShapeRefiner { TensorShapeProto proto; const auto& handle = input_inference_context->output(output_port_num); input_inference_context->ShapeHandleToProto(handle, &proto); + // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1. + for (int i = 0; i < proto.dim_size(); i++) { + if (proto.dim(i).size() < -1) { + proto.mutable_dim(i)->set_size(-1); + } + } *attr_output_shape.mutable_shape() = proto; (*fun_node->mutable_attr())["shape"] = attr_output_shape; } + // Replace input Placeholders with Consts, if values are known. Note that + // we don't check exceptions here as it's done in the above loop. + auto* ctx = GetNodeContext(function_node); + auto* ic = ctx->inference_context.get(); + for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) { + const string& input = function_node->input(i); + const string& node_name = NodeName(input); + NodeDef* input_node = graph_.GetNode(node_name); + if (IsConstant(*input_node)) { + TF_CHECK_OK( + ReplaceInputWithConst(*input_node, i, &grappler_function_item)); + } else if (ic->input_tensors_as_shapes().size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i])) { + // We have fully defined input_tensors_as_shapes for this input; use it + // as a const input to the function node. + NodeDef const_input_node = MakeConstNodeDefFromShape( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i]); + TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i, + &grappler_function_item)); + } + } + // Perform inference on function body. GraphProperties gp(grappler_function_item); TF_RETURN_IF_ERROR(gp.InferStatically(true)); // Add return nodes for output shapes. - auto ic = GetContext(node); int output = 0; + ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size()); for (auto const& out_arg : grappler_function_item.outputs()) { if (out_arg.output_tensors.size() > 1) { // TODO(jmdecker): Handle case of multiple output tensors @@ -505,8 +590,9 @@ class SymbolicShapeRefiner { const NodeDef* retnode = gv.GetNode(node_name); if (retnode == nullptr) { - return errors::FailedPrecondition("Unable to find return node ", - node_name, " for ", node->name()); + return errors::FailedPrecondition( + "Unable to find return function_node ", node_name, " for ", + function_node->name()); } auto output_properties = gp.GetOutputProperties(retnode->name()); @@ -520,6 +606,14 @@ class SymbolicShapeRefiner { ShapeHandle out; TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out)); ic->set_output(output, out); + if (outprop.has_value()) { + // Forward tensor value to output_tensors_as_shape. + Tensor tensor; + if (tensor.FromProto(outprop.value())) { + MaybeSetTensorValueToShape(ic, tensor, + &ctx->output_tensors_as_shapes[output]); + } + } output++; } @@ -562,21 +656,9 @@ class SymbolicShapeRefiner { if (const_values[dst_input].FromProto( input->attr().at("value").tensor())) { input_tensors[dst_input] = &const_values[dst_input]; - // Integer tensors of rank one can also be interpreted as a shape - // provided all their values are >= -1. - if (const_values[dst_input].dims() == 1 && - (const_values[dst_input].dtype() == DT_INT32 || - const_values[dst_input].dtype() == DT_INT64)) { - ShapeHandle tensor_shape = inference_context->Vector( - const_values[dst_input].NumElements()); - ShapeHandle shp; - if (inference_context - ->MakeShapeFromTensor(input_tensors[dst_input], - tensor_shape, &shp) - .ok()) { - input_tensors_as_shapes[dst_input] = shp; - } - } + MaybeSetTensorValueToShape(inference_context, + const_values[dst_input], + &input_tensors_as_shapes[dst_input]); } } else if (IsRank(*input)) { if (c->inference_context->RankKnown(c->inference_context->input(0))) { @@ -671,11 +753,13 @@ class SymbolicShapeRefiner { // true, as the updates to the call node will have changed, even if it's // the same function being called twice with the same input shapes. // Example: simple_function.pbtxt - if (UpdateFunction(node).ok()) { + auto s = UpdateFunction(node); + if (s.ok()) { return Status::OK(); } else { VLOG(1) << "UpdateFunction failed for " << node->op() - << ". Defaulting to ShapeUnknown."; + << ". Defaulting to ShapeUnknown.\n" + << s.ToString(); } } @@ -942,13 +1026,25 @@ class SymbolicShapeRefiner { : t->scalar<int64>()(); dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size)); } else { - dims.push_back(ic->UnknownDim()); + // Don't have tensor value, but use input_tensors_as_shapes, if + // possible. + const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i]; + if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 && + ic->ValueKnown(ic->Dim(shape_handle, 0))) { + dims.push_back(ic->Dim(shape_handle, 0)); + } else { + dims.push_back(ic->UnknownDim()); + } } } if (valid) { c->output_tensors_as_shapes.resize(1); c->output_tensors_as_shapes[0] = ic->MakeShape(dims); } + } else if (IsIdentity(node)) { + // Pass input_tensors_as_shapes to output_tensors_as_shapes. + c->output_tensors_as_shapes.resize(1); + c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0]; } else if (IsSlice(node)) { ShapeHandle input = ic->input_tensors_as_shapes()[0]; bool valid = ic->RankKnown(input); @@ -1053,6 +1149,46 @@ class SymbolicShapeRefiner { } private: + bool IsIntegerVector(const Tensor& tensor) { + if (tensor.dims() == 1 && + (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) { + return true; + } + return false; + } + + bool IsIntegerScalar(const Tensor& tensor) { + if (tensor.dims() == 0 && + (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) && + tensor.NumElements() == 1) { + return true; + } + return false; + } + + void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor, + ShapeHandle* tensors_as_shapes) { + // Integer tensors of rank one can also be interpreted as a shape + // provided all their values are >= -1. + if (IsIntegerVector(tensor)) { + ShapeHandle tensor_shape = ic->Vector(tensor.NumElements()); + ShapeHandle shp; + // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1). + if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) { + *tensors_as_shapes = shp; + } + } else if (IsIntegerScalar(tensor)) { + // Scalar constant. + int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0) + : tensor.flat<int64>()(0); + // Ideally, values can be < -1, but MakeDim() fails with a value < -1. + // It's a limitation as we use ShapeHandle as a means to pass values. + if (value >= -1) { + *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)}); + } + } + } + const GraphView& graph_; int graph_def_version_; std::unordered_map<const NodeDef*, NodeContext> node_to_context_; @@ -1528,6 +1664,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { continue; } + auto* ic = ctx->inference_context.get(); + // Fill input properties. { auto& input_properties = input_properties_[node.name()]; @@ -1535,19 +1673,26 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { // Should always be empty, node names in graph are supposed to be unique. CHECK_EQ(input_properties.size(), 0); - input_properties.resize(ctx->inference_context->num_inputs()); + input_properties.resize(ic->num_inputs()); GraphView::InputPort input(&node, -1); - for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) { - shape_manager.AsTensorProperties(ctx->inference_context->input(i), - ctx->input_types[i], + for (int i = 0; i < ic->num_inputs(); ++i) { + shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i], &input_properties[i]); input.port_id = i; GraphView::OutputPort fanin = graph_view.GetRegularFanin(input); - if (!IsConstant(*fanin.node)) { - continue; + // Export tensor value (either const tensor or input_tensors_as_shapes) + // to input_properties.value. + if (IsConstant(*fanin.node)) { + const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); + *input_properties[i].mutable_value() = raw_val; + } else if (ic->input_tensors_as_shapes().size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i])) { + *input_properties[i].mutable_value() = MakeTensorProtoFromShape( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i]); } - const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); - *input_properties[i].mutable_value() = raw_val; } } @@ -1558,11 +1703,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) { // Should always be empty, node names in graph are supposed to be unique. CHECK_EQ(output_properties.size(), 0); - output_properties.resize(ctx->inference_context->num_outputs()); - for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) { - shape_manager.AsTensorProperties(ctx->inference_context->output(i), - ctx->output_types[i], + output_properties.resize(ic->num_outputs()); + for (int i = 0; i < ic->num_outputs(); ++i) { + shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i], &output_properties[i]); + // Export tensor value (either const tensor or input_tensors_as_shapes) + // to output_properties.value. + if (IsConstant(node)) { + const TensorProto& raw_val = node.attr().at("value").tensor(); + *output_properties[i].mutable_value() = raw_val; + } else if (ctx->output_tensors_as_shapes.size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ctx->output_types[i])) { + *output_properties[i].mutable_value() = MakeTensorProtoFromShape( + ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ctx->output_types[i]); + } } } } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 8938b7c32e..362092a6cf 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -44,6 +44,30 @@ class GraphPropertiesTest : public ::testing::Test { // Provision a single machine with 3 cpu cores cluster_.reset(new SingleMachine(5 * 60, 3, 0)); TF_CHECK_OK(cluster_->Provision()); + + // This function is simply + // out = Fill(shape, value), but + // Fill requires values in the shape input, not just shape of it, to infer + // output shape. + auto f = FunctionDefHelper::Create( + // Name + "MyFillFunc", + // Inputs + {"shape: int32", "value: float"}, + // Outputs + {"out: float"}, + // Attrs + {}, + // Nodes + { + {{"a"}, + "Fill", + {"shape", "value"}, + {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}}, + }, + // Returns + {{"out", "a:output:0"}}); + function_lib_.add_function()->Swap(&f); } void TearDown() override { @@ -69,7 +93,29 @@ class GraphPropertiesTest : public ::testing::Test { return s; } + // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected + // ones. + void ExpectTensorValues(const std::vector<int64>& expected, + const TensorProto& tensor_proto_to_compare) { + Tensor tensor; + EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare)); + EXPECT_EQ(expected.size(), tensor.NumElements()); + // We're interested in only integer tensors as only shapes are exported as + // graph properties values. + CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64); + if (tensor.dtype() == DT_INT32) { + for (int i = 0; i < tensor.NumElements(); i++) { + EXPECT_EQ(expected[i], tensor.flat<int32>()(i)); + } + } else { + for (int i = 0; i < tensor.NumElements(); i++) { + EXPECT_EQ(expected[i], tensor.flat<int64>()(i)); + } + } + } + std::unique_ptr<SingleMachine> cluster_; + FunctionDefLibrary function_lib_; }; TEST_F(GraphPropertiesTest, StaticProperties) { @@ -785,7 +831,220 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) { EXPECT_EQ("float: [128,256]", PropToString(prop)); } -TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { +TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2}); + Output a1 = ops::Identity(s.WithOpName("a1"), a); + Output b = ops::Const(s.WithOpName("b"), 99, {}); + Output b1 = ops::Identity(s.WithOpName("b1"), b); + Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4}); + Output c1 = ops::Identity(s.WithOpName("c1"), c); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + + // Check output shapes. + EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0])); + EXPECT_EQ("int32: [2]", + PropToString(properties.GetOutputProperties("a1")[0])); + EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0])); + EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0])); + EXPECT_EQ("int32: [4,4,4]", + PropToString(properties.GetOutputProperties("c")[0])); + EXPECT_EQ("int32: [4,4,4]", + PropToString(properties.GetOutputProperties("c1")[0])); + + // Check has_value. + EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value()); + EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value()); + EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value()); + EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value()); + EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value()); + // Note that we propagate tensro value of only 1D vector and scalar. + EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value()); + + // Check values. + ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value()); + ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value()); + ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value()); + ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value()); + ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value()); + ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value()); + std::vector<int64> c_values; + for (int i = 0; i < 4 * 4 * 4; i++) { + c_values.push_back(1); + } + ExpectTensorValues({c_values}, + properties.GetOutputProperties("c")[0].value()); + ExpectTensorValues({c_values}, + properties.GetInputProperties("c1")[0].value()); + // No output value for c1, as it's neither 1D vector nor scalar. +} + +TEST_F(GraphPropertiesTest, IdentityPassingShape) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 5, {2}); + Output b = ops::Identity(s.WithOpName("b"), a); + Output c = ops::Const(s.WithOpName("const"), 0.1f, {}); + // Fill needs not only e's shape but also the value of e to figure out output + // shape; hence, Identity op (b) should pass a's value as + // output_tensors_as_shape. + Output d = ops::Fill(s.WithOpName("fill"), b, c); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("fill"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [5,5]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, PackWithConstInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output a = ops::Const(s.WithOpName("a"), 1, {}); + Output b = ops::Const(s.WithOpName("b"), 2, {}); + Output c = ops::Const(s.WithOpName("c"), 3, {}); + Output d = ops::Const(s.WithOpName("d"), 4, {}); + // Note ops::Stack instantiates Pack op. + Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d}); + // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4} + Output f = ops::Const(s.WithOpName("const"), 0.1f, {}); + // Fill needs not only e's shape but also its value to figure out output + // shape. + Output g = ops::Fill(s.WithOpName("fill"), e, f); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("fill"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, PackWithIdentityInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops + // from Const. + // If output_tensors_as_shape is not not set for those Shape ops or Pack op + // doesn't take input_tensors_as_shape, Fill op's input doesn't have value; + // hence, its output shape becomes unknown. + Output a0 = ops::Const(s.WithOpName("a0"), 1, {}); + Output b0 = ops::Const(s.WithOpName("b0"), 2, {}); + Output c0 = ops::Const(s.WithOpName("c0"), 3, {}); + Output d0 = ops::Const(s.WithOpName("d0"), 4, {}); + Output a = ops::Identity(s.WithOpName("a"), a0); + Output b = ops::Identity(s.WithOpName("b"), b0); + Output c = ops::Identity(s.WithOpName("c"), c0); + Output d = ops::Identity(s.WithOpName("d"), d0); + // Note ops::Stack instantiates Pack op. + Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d}); + // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4} + Output f = ops::Const(s.WithOpName("const"), 0.1f, {}); + // Fill needs not only e's shape but also its value to figure out output + // shape. + Output g = ops::Fill(s.WithOpName("fill"), e, f); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("fill"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, FunctionWithConstInput) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_)); + Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4}); + Output value = ops::Const(s.WithOpName("value"), 0.1f, {}); + auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc", + s.graph()->op_registry()); + tensorflow::Node* func_op; + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto _value = tensorflow::ops::AsNodeOut(s, value); + TF_CHECK_OK( + builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFillFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) { + // Same to FunctionWithConstInput, but function inputs are Identity of Const, + // so tensor shapes, not tensor value, should be used as Const input to + // function. + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_)); + Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4}); + Output shape = ops::Identity(s.WithOpName("shape"), shape_); + Output value = ops::Const(s.WithOpName("value"), 0.1f, {}); + auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc", + s.graph()->op_registry()); + tensorflow::Node* func_op; + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto _value = tensorflow::ops::AsNodeOut(s, value); + TF_CHECK_OK( + builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op)); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(false)); + const auto out_props = properties.GetOutputProperties("MyFillFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0)); +} + +TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) { + FunctionDefLibrary library; + *library.add_function() = FunctionDefHelper::Create( + "MyFunc", // Name + {"x: int32"}, // Inputs + {"out: int32"}, // Outputs + {}, // Attrs + {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes + {{"out", "a:output:0"}}); // Returns + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + TF_CHECK_OK(s.graph()->AddFunctionLibrary(library)); + + // MyFunc takes Const (shape) and passes it with Identity. Expect function + // output has the same shape as well as value (output_tensors_as_shape) as + // input Const tensor. + Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2}); + auto _shape = tensorflow::ops::AsNodeOut(s, shape); + auto builder = + tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry()); + tensorflow::Node* func_op; + TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op)); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphProperties properties(item); + TF_CHECK_OK(properties.InferStatically(true)); + const auto out_props = properties.GetOutputProperties("MyFunc"); + const OpInfo::TensorProperties out_prop0 = out_props[0]; + EXPECT_EQ("int32: [2]", PropToString(out_prop0)); + EXPECT_TRUE(out_prop0.has_value()); + ExpectTensorValues({5, 7}, out_prop0.value()); + ExpectTensorValues({5, 7}, + properties.GetInputProperties("MyFunc")[0].value()); +} + +TEST_F(GraphPropertiesTest, FunctionWithScalarInput) { // Create graph with a function that takes a scalar value so that we use // Placeholder with scalar as for input to the function shape inference. // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of @@ -818,7 +1077,7 @@ TEST_F(GraphPropertiesTest, FunctionWithScalarInputTest) { // MyFunc output shouldn't be unknown rank. GraphProperties properties(item); - TF_CHECK_OK(properties.InferStatically(false)); + TF_CHECK_OK(properties.InferStatically(true)); const auto out_props = properties.GetOutputProperties("MyFunc"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); @@ -856,18 +1115,10 @@ TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) { EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) { @@ -882,51 +1133,25 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) { EXPECT_EQ(2, out_props.size()); const OpInfo::TensorProperties& out_prop0 = out_props[0]; - EXPECT_EQ(DT_FLOAT, out_prop0.dtype()); - EXPECT_EQ(4, out_prop0.shape().dim_size()); - EXPECT_EQ(128, out_prop0.shape().dim(0).size()); - EXPECT_EQ(112, out_prop0.shape().dim(1).size()); - EXPECT_EQ(112, out_prop0.shape().dim(2).size()); - EXPECT_EQ(64, out_prop0.shape().dim(3).size()); + EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0)); const OpInfo::TensorProperties& out_prop1 = out_props[1]; - EXPECT_EQ(DT_FLOAT, out_prop1.dtype()); - EXPECT_EQ(128, out_prop1.shape().dim(0).size()); - EXPECT_EQ(112, out_prop1.shape().dim(1).size()); - EXPECT_EQ(112, out_prop1.shape().dim(2).size()); - EXPECT_EQ(24, out_prop1.shape().dim(3).size()); + EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1)); const auto in_props = properties.GetInputProperties("y0"); EXPECT_EQ(4, in_props.size()); const OpInfo::TensorProperties& in_prop0 = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop0.dtype()); - EXPECT_EQ(1, in_prop0.shape().dim_size()); - EXPECT_EQ(64, in_prop0.shape().dim(0).size()); + EXPECT_EQ("float: [64]", PropToString(in_prop0)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_EQ(4, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(1, in_prop1.shape().dim(1).size()); - EXPECT_EQ(24, in_prop1.shape().dim(2).size()); - EXPECT_EQ(64, in_prop1.shape().dim(3).size()); + EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1)); const OpInfo::TensorProperties& in_prop2 = in_props[2]; - EXPECT_EQ(DT_FLOAT, in_prop2.dtype()); - EXPECT_EQ(4, in_prop2.shape().dim_size()); - EXPECT_EQ(128, in_prop2.shape().dim(0).size()); - EXPECT_EQ(224, in_prop2.shape().dim(1).size()); - EXPECT_EQ(224, in_prop2.shape().dim(2).size()); - EXPECT_EQ(3, in_prop2.shape().dim(3).size()); + EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2)); const OpInfo::TensorProperties& in_prop3 = in_props[3]; - EXPECT_EQ(DT_FLOAT, in_prop3.dtype()); - EXPECT_EQ(4, in_prop3.shape().dim_size()); - EXPECT_EQ(7, in_prop3.shape().dim(0).size()); - EXPECT_EQ(7, in_prop3.shape().dim(1).size()); - EXPECT_EQ(3, in_prop3.shape().dim(2).size()); - EXPECT_EQ(8, in_prop3.shape().dim(3).size()); + EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3)); } TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) { @@ -986,18 +1211,10 @@ TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) { EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) { @@ -1022,27 +1239,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) { const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o"); const OpInfo::TensorProperties& out_prop = out_props[0]; EXPECT_EQ(DT_FLOAT, out_prop.dtype()); - EXPECT_FALSE(out_prop.shape().unknown_rank()); - EXPECT_EQ(2, out_prop.shape().dim_size()); - EXPECT_EQ(1, out_prop.shape().dim(0).size()); - EXPECT_EQ(2, out_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(out_prop)); const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o"); EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) { @@ -1066,28 +1272,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) { TF_CHECK_OK(properties.InferStatically(false)); const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o"); const OpInfo::TensorProperties& out_prop = out_props[0]; - EXPECT_EQ(DT_FLOAT, out_prop.dtype()); - EXPECT_FALSE(out_prop.shape().unknown_rank()); - EXPECT_EQ(2, out_prop.shape().dim_size()); - EXPECT_EQ(1, out_prop.shape().dim(0).size()); - EXPECT_EQ(2, out_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(out_prop)); const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o"); EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(2, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) { @@ -1115,28 +1309,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) { TF_CHECK_OK(properties.InferStatically(false)); const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I"); const OpInfo::TensorProperties& out_prop = out_props[0]; - EXPECT_EQ(DT_FLOAT, out_prop.dtype()); - EXPECT_FALSE(out_prop.shape().unknown_rank()); - EXPECT_EQ(2, out_prop.shape().dim_size()); - EXPECT_EQ(1, out_prop.shape().dim(0).size()); - EXPECT_EQ(2, out_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(out_prop)); const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I"); EXPECT_EQ(2, in_props.size()); const OpInfo::TensorProperties& in_prop = in_props[0]; - EXPECT_EQ(DT_FLOAT, in_prop.dtype()); - EXPECT_FALSE(in_prop.shape().unknown_rank()); - EXPECT_EQ(2, in_prop.shape().dim_size()); - EXPECT_EQ(1, in_prop.shape().dim(0).size()); - EXPECT_EQ(2, in_prop.shape().dim(1).size()); + EXPECT_EQ("float: [1,2]", PropToString(in_prop)); const OpInfo::TensorProperties& in_prop1 = in_props[1]; - EXPECT_EQ(DT_FLOAT, in_prop1.dtype()); - EXPECT_FALSE(in_prop1.shape().unknown_rank()); - EXPECT_EQ(2, in_prop1.shape().dim_size()); - EXPECT_EQ(1, in_prop1.shape().dim(0).size()); - EXPECT_EQ(3, in_prop1.shape().dim(1).size()); + EXPECT_EQ("float: [1,3]", PropToString(in_prop1)); } TEST_F(GraphPropertiesTest, SymbolicShapes) { diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc index 7691f25327..5415324b48 100644 --- a/tensorflow/core/grappler/costs/utils.cc +++ b/tensorflow/core/grappler/costs/utils.cc @@ -127,7 +127,7 @@ static void ExtractExtraProperties( // For filename input, the file size can also be useful. if (op_def && i < op_def->input_arg_size() && - op_def->input_arg(i).name().find("filename") != std::string::npos) { + op_def->input_arg(i).name().find("filename") != string::npos) { Tensor tensor; if (!tensor.FromProto(t)) { continue; @@ -153,7 +153,7 @@ static void ExtractExtraProperties( // When the input is a handle (e.g. look up table handle), the information // in the op itself is not sufficient to predict the op memory. if (op_def && i < op_def->input_arg_size() && - op_def->input_arg(i).name().find("handle") != std::string::npos) { + op_def->input_arg(i).name().find("handle") != string::npos) { string new_key = strings::StrCat("parent_", i, "_op"); AttrValue attr; attr.set_s(input_node->op()); @@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) { buckets_.begin(), std::plus<uint64>()); } -std::string TensorSizeHistogram::ToString() const { - std::string r; +string TensorSizeHistogram::ToString() const { + string r; char buf[200]; snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_); r.append(buf); diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h index d2c7c67666..5fd6717712 100644 --- a/tensorflow/core/grappler/costs/utils.h +++ b/tensorflow/core/grappler/costs/utils.h @@ -80,7 +80,7 @@ class TensorSizeHistogram { uint64 Max() const { return max_; } uint64 NumElem() const { return num_elem_; } uint64 SumElem() const { return sum_elem_; } - std::string ToString() const; + string ToString() const; protected: const int Index(const uint64 value) const; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc index 02a379fca8..80889afc86 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc @@ -1999,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) { // Helper lambda to extract port num from _Send and _Recv op name. auto get_port_num = [](const string& name) -> int { - if (name.find("bn_0") != std::string::npos) { + if (name.find("bn_0") != string::npos) { return 0; - } else if (name.find("bn_1") != std::string::npos) { + } else if (name.find("bn_1") != string::npos) { return 1; - } else if (name.find("bn_2") != std::string::npos) { + } else if (name.find("bn_2") != string::npos) { return 2; - } else if (name.find("bn_minus1") != std::string::npos) { + } else if (name.find("bn_minus1") != string::npos) { return -1; } return -999; diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h index 26d38a4931..97626346c7 100644 --- a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h +++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h @@ -138,7 +138,7 @@ class GraphAnalyzer { // The entries are owned by collation_map_, so must be removed from // ordered_collation_ before removing them from collation_map_. struct ReverseLessByCount { - bool operator()(CollationEntry* left, CollationEntry* right) { + bool operator()(CollationEntry* left, CollationEntry* right) const { return left->count > right->count; // Reverse order. } }; diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc index 5029dff877..def9198a69 100644 --- a/tensorflow/core/grappler/inputs/utils.cc +++ b/tensorflow/core/grappler/inputs/utils.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/inputs/utils.h" -#include "tensorflow/core/platform/env.h" #include <vector> +#include "tensorflow/core/platform/env.h" + namespace tensorflow { namespace grappler { @@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) { return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr); } -bool FileExists(const std::string& file, Status* status) { +bool FileExists(const string& file, Status* status) { *status = Env::Default()->FileExists(file); return status->ok(); } -Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path, +Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path, GraphDef* result) { Status status; if (FileExists(graph_def_pbtxt_path, &status)) { diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h index 627dd5359f..4b9cb0a9ad 100644 --- a/tensorflow/core/grappler/inputs/utils.h +++ b/tensorflow/core/grappler/inputs/utils.h @@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files, std::vector<Status>* status = nullptr); bool FilesExist(const std::set<string>& files); -bool FileExists(const std::string& file, Status* status); +bool FileExists(const string& file, Status* status); -Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path, +Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path, GraphDef* result); } // end namespace grappler diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 653b088b1d..3521669b63 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -135,16 +135,37 @@ bool IsDequeueOp(const NodeDef& node) { bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } -bool IsElementWiseMonotonic(const NodeDef& node) { - static const std::unordered_set<string>* element_wise_monotonic_ops = +// Returns true if node represents a unary elementwise function that is +// monotonic. If *is_non_decreasing is true, the function is non-decreasing, +// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, +// e.g. inv. +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { + static const std::unordered_set<string>* monotonic_non_decreasing_ops = CHECK_NOTNULL((new std::unordered_set<string>{ - "Relu", - "Relu6", - "Sigmoid", - "Sqrt", - "Tanh", + "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", + "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", + "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", + })); + static const std::unordered_set<string>* monotonic_non_increasing_ops = + CHECK_NOTNULL((new std::unordered_set<string>{ + "Inv", + "Reciprocal", + "Erfc", + "Rsqrt", + "Neg", })); - return element_wise_monotonic_ops->count(node.op()) > 0; + if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = true; + } + return true; + } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + if (is_non_decreasing) { + *is_non_decreasing = false; + } + return true; + } + return false; } bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; } @@ -470,7 +491,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) { } } // Queue ops modify the queue which is a side effect. - if (node.op().find("Queue") != std::string::npos) { + if (node.op().find("Queue") != string::npos) { return false; } return !ModifiesInputsInPlace(node); diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 94439265c9..25ab6b65ac 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -55,7 +55,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node); bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node); bool IsDequeueOp(const NodeDef& node); bool IsDiv(const NodeDef& node); -bool IsElementWiseMonotonic(const NodeDef& node); +bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing); bool IsEluGrad(const NodeDef& node); bool IsEnter(const NodeDef& node); bool IsEqual(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 70ad9f9a9b..029205248b 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -110,12 +110,13 @@ cc_library( ], ) -tf_cuda_cc_test( +tf_cc_test( name = "constant_folding_test", srcs = ["constant_folding_test.cc"], - tags = ["requires-gpu-sm35"], + shard_count = 5, deps = [ ":constant_folding", + ":dependency_optimizer", "//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops_internal", "//tensorflow/core:all_kernels", @@ -514,6 +515,7 @@ cc_library( ":custom_graph_optimizer_registry", ":debug_stripper", ":dependency_optimizer", + ":experimental_implementation_selector", ":function_optimizer", ":graph_optimizer", ":layout_optimizer", @@ -845,3 +847,68 @@ tf_cc_test( "//third_party/eigen3", ], ) + +cc_library( + name = "function_api_info", + srcs = ["function_api_info.cc"], + hdrs = ["function_api_info.h"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "function_api_info_test", + size = "small", + srcs = ["function_api_info_test.cc"], + deps = [ + ":function_api_info", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +cc_library( + name = "experimental_implementation_selector", + srcs = ["experimental_implementation_selector.cc"], + hdrs = ["experimental_implementation_selector.h"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":function_api_info", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/costs:graph_properties", + ], +) + +tf_cc_test( + name = "experimental_implementation_selector_test", + size = "small", + srcs = ["experimental_implementation_selector_test.cc"], + deps = [ + ":custom_graph_optimizer", + ":custom_graph_optimizer_registry", + ":experimental_implementation_selector", + ":function_api_info", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder", + "//tensorflow/core/grappler/utils:grappler_test", + ], +) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 4fed88d536..992e85d2c6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1121,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); NodeDef* tail = node; - // TODO(rmlarsen): Enable after debugging breakage in Bayesflow. - if (ctx().opt_level == RewriterConfig::AGGRESSIVE) { - tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, - *ctx().nodes_to_preserve); - } + tail = GetTailOfIdempotentChain(*tail, *ctx().node_map, + *ctx().nodes_to_preserve); NodeDef* first_transpose; TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose)); @@ -1328,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage { } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const string node_name = node->name(); NodeDef* x; NodeDef* y; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); bool updated = false; - if (IsAdd(*node)) { - if (IsNeg(*x)) { - // (-a) + b = b - a - node->set_op("Sub"); - node->mutable_input()->SwapElements(0, 1); - node->set_input(1, x->input(0)); - node->add_input(AsControlDependency(x->name())); - ctx().node_map->AddOutput(NodeName(x->input(0)), node_name); - updated = true; - } else if (IsNeg(*y)) { - // a + (-b) = a - b - node->set_op("Sub"); - node->set_input(1, y->input(0)); - node->add_input(AsControlDependency(y->name())); - ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); - updated = true; - } - } else if (IsSub(*node)) { - if (IsNeg(*y)) { - // a - (-b) = a + b - node->set_op("Add"); - node->set_input(1, y->input(0)); - node->add_input(AsControlDependency(y->name())); - ctx().node_map->AddOutput(NodeName(y->input(0)), node_name); - updated = true; - } + if (IsNeg(*y)) { + // a - (-b) = a + b or a + (-b) = a - b + ForwardControlDependencies(node, {y}); + ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0)); + node->set_op(IsAdd(*node) ? "Sub" : "Add"); + node->set_input(1, y->input(0)); + updated = true; + } else if (IsAdd(*node) && IsNeg(*x)) { + // (-a) + b = b - a + ForwardControlDependencies(node, {x}); + ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0)); + node->set_op("Sub"); + node->mutable_input()->SwapElements(0, 1); + node->set_input(1, x->input(0)); + updated = true; } if (updated) { AddToOptimizationQueue(node); @@ -2706,8 +2691,9 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { // 0. inner_function is not in the preserve set, // 1. inner_function's Op is element-wise monotonic // 2. inner_function's output is not being consumed elsewhere. + bool is_non_decreasing = false; if (!IsInPreserveSet(*inner_function) && - IsElementWiseMonotonic(*inner_function) && + IsElementWiseMonotonic(*inner_function, &is_non_decreasing) && ctx().node_map->GetOutputs(inner_function->name()).size() == 1) { // Swap the first inputs of the inner function Op & the reduction Op. NodeDef* inner_input; @@ -2719,7 +2705,12 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage { UpdateConsumers(reduction_node, inner_function->name()); ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(), reduction_node->name()); - + if (!is_non_decreasing) { + // Flip Min<->Max if the function is non-increasing, e.g. + // Max(Neg(x)) = Neg(Min(x)). + const string opposite = IsMax(*reduction_node) ? "Min" : "Max"; + reduction_node->set_op(opposite); + } AddToOptimizationQueue(reduction_node); AddToOptimizationQueue(inner_function); AddToOptimizationQueue(inner_input); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 551c3652bf..d457eb6d21 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -61,7 +61,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool fold_multiply_into_conv = true; bool fold_transpose_into_matmul = true; bool hoist_common_factor_out_of_aggregation = true; - bool hoist_cwise_unary_chains = false; + bool hoist_cwise_unary_chains = true; bool minimize_broadcasts = true; bool optimize_max_or_min_of_monotonic = true; bool remove_idempotent = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index bfccc0affd..88839d944c 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) { const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); - EXPECT_EQ(std::string("\0\0\0@", 4), + EXPECT_EQ(string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); @@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) { const NodeDef* new_const = node_map.GetNode(optimized_const_name); ASSERT_NE(new_const, nullptr); EXPECT_EQ("^x", new_const->input(0)); - EXPECT_EQ(std::string("\0\0\0@", 4), + EXPECT_EQ(string("\0\0\0@", 4), new_const->attr().at("value").tensor().tensor_content()); const NodeDef* new_mul = node_map.GetNode(optimized_mul_name); @@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y); Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y); Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y); - auto add_all = ops::AddN(s.WithOpName("add_all"), - {add_x_y, add_negx_y, add_x_negy, add_negx_negy, - sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy}); + Output neg_x_with_dep = ops::Neg( + s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x); + Output add_negx_with_dep_y = + ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y); + auto add_all = + ops::AddN(s.WithOpName("add_all"), + {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y, + sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y}); GrapplerItem item; item.fetch = {"add_all"}; @@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { GraphDef output; ArithmeticOptimizer optimizer; EnableOnlyRemoveNegation(&optimizer); - OptimizeAndPrune(&optimizer, &item, &output); + OptimizeTwice(&optimizer, &item, &output); EXPECT_EQ(item.graph.node_size(), output.node_size()); int found = 0; @@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) { if (node.name() == "Add_negx_y") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_x", node.input(2)); } else if (node.name() == "Add_x_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); } else if (node.name() == "Add_negx_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ("Neg_y", node.input(0)); - EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_x", node.input(2)); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("Neg_x", node.input(0)); + EXPECT_EQ("y", node.input(1)); } else if (node.name() == "Sub_x_negy") { ++found; EXPECT_EQ("Add", node.op()); - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("y", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); } else if (node.name() == "Sub_negx_negy") { ++found; EXPECT_EQ("Sub", node.op()); - EXPECT_EQ(4, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ("x", node.input(1)); - EXPECT_EQ("^Neg_y", node.input(2)); - EXPECT_EQ("^Neg_x", node.input(3)); + } else if (node.name() == "Add_negx_with_dep_y") { + ++found; + EXPECT_EQ("Sub", node.op()); + EXPECT_EQ(3, node.input_size()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("x", node.input(1)); + EXPECT_EQ("^Add_x_y", node.input(2)); } } - EXPECT_EQ(5, found); + EXPECT_EQ(6, found); auto tensors = EvaluateNodes(output, item.fetch, feed); EXPECT_EQ(1, tensors.size()); @@ -3248,6 +3254,48 @@ TEST_F(ArithmeticOptimizerTest, VerifyGraphsMatch(item.graph, output, __LINE__); } +TEST_F(ArithmeticOptimizerTest, + OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output neg = ops::Neg(s.WithOpName("neg"), x); + Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0}); + Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max); + + GrapplerItem item; + item.fetch = {"final_out"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + auto tensors_expected = EvaluateNodes(item.graph, item.fetch); + EXPECT_EQ(1, tensors_expected.size()); + + GraphDef output; + ArithmeticOptimizer optimizer; + EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer); + OptimizeAndPrune(&optimizer, &item, &output); + auto tensors = EvaluateNodes(output, item.fetch); + EXPECT_EQ(1, tensors.size()); + + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6); + EXPECT_EQ(item.graph.node_size(), output.node_size()); + // Check if the inputs are switched + int required_node_count = 0; + for (int i = 0; i < output.node_size(); ++i) { + const NodeDef& node = output.node(i); + if (node.name() == "neg") { + EXPECT_EQ("Neg", node.op()); + EXPECT_EQ(1, node.input_size()); + EXPECT_EQ("reduce_max", node.input(0)); + ++required_node_count; + } else if (node.name() == "reduce_max") { + EXPECT_EQ("Min", node.op()); + EXPECT_EQ(2, node.input_size()); + EXPECT_EQ("x", node.input(0)); + ++required_node_count; + } + } + EXPECT_EQ(2, required_node_count); +} + TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 815bd23307..99737a71eb 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -136,6 +136,27 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node, return removed_input; } +bool GetConcatAxis(const GraphProperties& properties, NodeDef* node, + int* axis) { + if (node->op() != "ConcatV2" || + properties.GetInputProperties(node->name()).empty()) { + return false; + } + const auto& axis_input = properties.GetInputProperties(node->name()).back(); + if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) { + return false; + } + + Tensor axis_tensor(axis_input.dtype(), axis_input.shape()); + if (!axis_tensor.FromProto(axis_input.value())) { + return false; + } + *axis = axis_input.dtype() == DT_INT64 + ? static_cast<int>(axis_tensor.scalar<int64>()()) + : axis_tensor.scalar<int32>()(); + return true; +} + } // namespace ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level, @@ -852,19 +873,7 @@ DataType GetDataTypeFromNodeOrProps(const NodeDef& node, } return dtype; } -bool IsValidConstShapeForNCHW(const TensorShapeProto& shape) { - if (shape.dim_size() != 4) { - return false; - } - int num_dim_larger_than_one = 0; - for (const auto& dim : shape.dim()) { - if (dim.size() > 1) ++num_dim_larger_than_one; - } - return num_dim_larger_than_one <= 1; -} -const string& GetShape(const NodeDef& node) { - return node.attr().at("data_format").s(); -} + } // namespace // static @@ -1711,7 +1720,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } - if (MulConvPushDown(*properties, optimized_graph, node)) { + if (MulConvPushDown(node, *properties)) { graph_modified_ = true; return Status::OK(); } @@ -1731,6 +1740,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } + if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) { + graph_modified_ = true; + return Status::OK(); + } + return Status::OK(); } @@ -2553,9 +2567,8 @@ bool ConstantFolding::ConstantPushDown(NodeDef* node) { return false; } -bool ConstantFolding::MulConvPushDown(const GraphProperties& properties, - GraphDef* optimized_graph, - NodeDef* node) { +bool ConstantFolding::MulConvPushDown(NodeDef* node, + const GraphProperties& properties) { // Push down multiplication on ConvND. // * ConvND // / \ / \ @@ -2631,14 +2644,12 @@ bool ConstantFolding::MulConvPushDown(const GraphProperties& properties, } const auto& const_shape = const_props[0].shape(); - if (GetShape(*conv_node) == "NHWC") { - TensorShapeProto new_filter_shape; - if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) { - return false; - } - if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) { - return false; - } + TensorShapeProto new_filter_shape; + if (!ShapeAfterBroadcast(filter_shape, const_shape, &new_filter_shape)) { + return false; + } + if (!ShapesSymbolicallyEqual(filter_shape, new_filter_shape)) { + return false; } string mul_new_name = @@ -2672,69 +2683,6 @@ bool ConstantFolding::MulConvPushDown(const GraphProperties& properties, } node_map_->AddNode(mul_new_name, node); - if (GetShape(*conv_node) == "NCHW") { - if (const_node->attr().at("value").tensor().tensor_shape().dim_size() <= - 1) { - // Broadcast should work for scalar or 1D. No need to reshape. - return true; - } - if (!IsValidConstShapeForNCHW( - const_node->attr().at("value").tensor().tensor_shape())) { - return false; - } - // Adds Const node for Reshape. - auto* shape_const_node = optimized_graph->add_node(); - const string shape_const_node_name = - OptimizedNodeName(*const_node, "_new_shape"); - shape_const_node->set_name(shape_const_node_name); - shape_const_node->set_op("Const"); - shape_const_node->set_device(const_node->device()); - (*shape_const_node->mutable_attr())["dtype"].set_type(DT_INT32); - Tensor t(DT_INT32, {4}); - t.flat<int32>()(0) = 1; - t.flat<int32>()(1) = 1; - t.flat<int32>()(2) = 1; - t.flat<int32>()(3) = const_node->attr() - .at("value") - .tensor() - .tensor_shape() - .dim(1) // IsValidConstShapeForNCHW guarantees - // dim 1 is the dim to reshape - .size(); - t.AsProtoTensorContent( - (*shape_const_node->mutable_attr())["value"].mutable_tensor()); - node_map_->AddNode(shape_const_node_name, shape_const_node); - - // Adds Reshape node. - auto* reshape_node = optimized_graph->add_node(); - const string reshape_node_name = - OptimizedNodeName(*const_node, "_reshape"); - reshape_node->set_op("Reshape"); - reshape_node->set_name(reshape_node_name); - reshape_node->set_device(const_node->device()); - (*reshape_node->mutable_attr())["T"].set_type( - const_node->attr().at("dtype").type()); - (*reshape_node->mutable_attr())["Tshape"].set_type(DT_INT32); - node_map_->AddNode(reshape_node_name, reshape_node); - - // const_node -> reshape_node - node_map_->RemoveOutput(const_node->name(), node->name()); - *reshape_node->add_input() = const_node->name(); - node_map_->AddOutput(const_node->name(), reshape_node_name); - - // shape_const_node -> reshape_node - *reshape_node->add_input() = shape_const_node_name; - node_map_->AddOutput(shape_const_node_name, reshape_node_name); - - // reshape_node -> node (Mul) - node_map_->AddOutput(reshape_node_name, node->name()); - if (left_child_is_constant) { - node->set_input(0, reshape_node_name); - } else { - node->set_input(1, reshape_node_name); - } - } - return true; } return false; @@ -2988,6 +2936,55 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph, return false; } +bool ConstantFolding::MergeConcat(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node) { + // We only optimize for ConcatV2. + int axis; + if (!use_shape_info || !GetConcatAxis(properties, node, &axis) || + nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() || + node_map_->GetOutputs(node->name()).size() != 1) { + return false; + } + + NodeDef* parent = *node_map_->GetOutputs(node->name()).begin(); + int parent_axis; + if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) { + return false; + } + + const int index = NumNonControlInputs(*node) - 1; + auto inputs = parent->input(); + parent->clear_input(); + for (int i = 0; i < inputs.size(); ++i) { + if (IsSameInput(inputs.Get(i), node->name())) { + for (int j = 0; j < node->input_size(); ++j) { + if (j < index) { + // Input tensors (non axis), add to input list of parent. + parent->add_input(node->input(j)); + node_map_->RemoveOutput(node->input(j), node->name()); + node_map_->AddOutput(node->input(j), parent->name()); + } + // Skip j == index, which means axis tensor. + if (j > index) { + // Control Dependencies, push back to inputs so they can be forwarded + // to parent. + *inputs.Add() = node->input(j); + } + } + } else { + parent->add_input(inputs.Get(i)); + } + } + node->clear_input(); + node->set_op("NoOp"); + node->clear_attr(); + node_map_->RemoveNode(node->name()); + (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1); + + return true; +} + Status ConstantFolding::RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 051dfb681e..8593b3e0b8 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -125,8 +125,7 @@ class ConstantFolding : public GraphOptimizer { // Aggregate constants present around a conv operator. Returns true if the // transformation was applied successfully. - bool MulConvPushDown(const GraphProperties& properties, - GraphDef* optimized_graph, NodeDef* node); + bool MulConvPushDown(NodeDef* node, const GraphProperties& properties); // Strength reduces floating point division by a constant Div(x, const) to // multiplication by the reciprocal Mul(x, Reciprocal(const)). @@ -210,6 +209,10 @@ class ConstantFolding : public GraphOptimizer { // Removes Split or SplitV node if possible. bool RemoveSplitOrSplitV(const GraphProperties& properties, GraphDef* optimized_graph, NodeDef* node); + + bool MergeConcat(const GraphProperties& properties, bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 0683572dcc..2a19b3f95a 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -240,7 +240,7 @@ TEST_F(ConstantFoldingTest, AddTree) { } } -TEST_F(ConstantFoldingTest, ConvPushDownTestNHWC) { +TEST_F(ConstantFoldingTest, ConvPushDownTest) { // Tests if the following rewrite is performed: // // * Conv2D @@ -2030,6 +2030,130 @@ TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) { CompareGraphs(want, got); } +TEST_F(ConstantFoldingTest, MergeConcat) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT); + Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT); + Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + + ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis); + ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis); + + GrapplerItem item; + item.fetch = {"c2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("in3", "VariableV2", {}, {}, &want); + AddNode("axis", "Const", {}, {}, &want); + AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want); + + CompareGraphs(want, got); +} + +TEST_F(ConstantFoldingTest, MergeConcat_SameInput) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT); + Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT); + Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + + ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis); + ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis); + + GrapplerItem item; + item.fetch = {"c2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("in3", "VariableV2", {}, {}, &want); + AddNode("axis", "Const", {}, {}, &want); + AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {}, + &want); + + CompareGraphs(want, got); +} + +TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT); + Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT); + Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT); + Output axis = ops::Const(scope.WithOpName("axis"), 0, {}); + + ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis); + ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis); + + GrapplerItem item; + item.fetch = {"c2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("in3", "VariableV2", {}, {}, &want); + AddNode("axis", "Const", {}, {}, &want); + AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want); + + CompareGraphs(want, got); +} + +TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) { + tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); + + Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT); + Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT); + Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT); + Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {}); + Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {}); + + ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2); + ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1); + + GrapplerItem item; + item.fetch = {"c2"}; + TF_CHECK_OK(scope.ToGraphDef(&item.graph)); + + ConstantFolding optimizer(nullptr /* cpu_device */); + GraphDef got; + Status status = optimizer.Optimize(nullptr, item, &got); + TF_EXPECT_OK(status); + + GraphDef want; + AddNode("in1", "VariableV2", {}, {}, &want); + AddNode("in2", "VariableV2", {}, {}, &want); + AddNode("in3", "VariableV2", {}, {}, &want); + AddNode("axis1", "Const", {}, {}, &want); + AddNode("axis2", "Const", {}, {}, &want); + AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want); + AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want); + + CompareGraphs(want, got); +} + TEST_F(ConstantFoldingTest, PaddingWithZeroSize) { tensorflow::Scope scope = tensorflow::Scope::NewRootScope(); @@ -3080,110 +3204,6 @@ TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) { test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]); } -#if GOOGLE_CUDA -TEST_F(ConstantFoldingTest, ConvPushDownTestNCHW) { - // Tests if the following rewrite is performed: - // - // * Conv2D - // / \ / \ - // c Conv2D --> x (c * filter) - // / \ - // x filter - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - - int input_channel = 1; - int output_channel = 2; - int filter_size = 1; - - TensorShape filter_shape( - {filter_size, filter_size, input_channel, output_channel}); - - // Filter shape: [1, 1, 1, 2] - // Filter for output channel 0 = {2.f} - // Filter for output channel 1 = {-2.f} - // clang-format off - Output filter = - ops::Const(s.WithOpName("filter"), { - { - {{2.f, -2.f}} - } - }); - // clang-format on - - int batch_size = 1; - int matrix_size = 3; - // input shape: [1,1,3,3] - TensorShape input_shape( - {batch_size, input_channel, matrix_size, matrix_size}); - Output input = ops::Placeholder(s.WithOpName("x"), DT_FLOAT, - ops::Placeholder::Shape(input_shape)); - - Output conv = ops::Conv2D(s.WithOpName("conv"), input, filter, {1, 1, 1, 1}, - "VALID", ops::Conv2D::DataFormat("NCHW")); - Output c = ops::Const(s.WithOpName("c"), 2.0f, /* shape */ {1, 2, 1, 1}); - Output mul = ops::Mul(s.WithOpName("mul"), c, conv); - - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - - ConstantFolding fold(nullptr); - GraphDef output; - Status status = fold.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - - // Here only op/IO are checked. The values are verified by EvaluateNodes - // below. - int found = 0; - for (const auto& node : output.node()) { - if (node.name() == "mul") { - ++found; - EXPECT_EQ("Conv2D", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("x", node.input(0)); - EXPECT_EQ("conv/merged_input", node.input(1)); - } else if (node.name() == "conv/merged_input") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ(0, node.input_size()); - } - } - EXPECT_EQ(2, found); - - // Check that const folded multiplication node has the expected value. - std::vector<string> fetch = {"mul"}; - // Input shape (NCHW) is [1,1,3,3], filter is [1,1,1,2] output shape should be - // (NCHW) [1,2,3,3] - ::tensorflow::Input::Initializer x{ - { - { - {1.f, 2.f, 3.f}, // H = 0 - {4.f, 5.f, 6.f}, // H = 1 - {7.f, 8.f, 9.f} // H = 2 - } // C = 0 - } // N = 0 - }; - - // |1,2,3| - // conv( |4,5,6|, // input - // |7,8,9| - // [[[2,-2]]]) // filter - // * [1,2,1,1] // mul by const - // = - // [ - // |4, 8, 12| - // |16,20,24| ==> output channel 0 - // |28,32,36| - // - // | -4, -8,-12| - // |-16,-20,-24| ==> output channel 1 - // |-28,-32,-36| - // ] - auto actual = EvaluateNodes(output, fetch, {{"x", x.tensor}}); - auto expected = EvaluateNodes(item.graph, fetch, {{"x", x.tensor}}); - test::ExpectTensorEqual<float>(expected[0], actual[0]); -} -#endif - } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index 530c957068..e84df10778 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -19,7 +19,6 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", @@ -56,8 +55,8 @@ cc_library( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:functional_ops", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", "//tensorflow/core:lib_internal", ] + tf_protos_all(), @@ -107,7 +106,6 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core/kernels:cast_op", ], ) @@ -164,7 +162,6 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", - "//tensorflow/core/kernels:cast_op", # Must be linked for the testlib functions to work. ], ) @@ -256,7 +253,6 @@ cc_library( "//tensorflow/core/grappler:op_types", "//tensorflow/core/grappler:utils", "//tensorflow/core/grappler/clusters:cluster", - "//tensorflow/core/kernels:cast_op", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", @@ -275,6 +271,43 @@ tf_cc_test( "//tensorflow/core:test_main", "//tensorflow/core:testlib", "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/kernels:control_flow_ops", + ], +) + +cc_library( + name = "map_parallelization", + srcs = ["map_parallelization.cc"], + hdrs = [ + "map_parallelization.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/utils:topological_sort", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "map_parallelization_test", + srcs = ["map_parallelization_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_utils", + ":map_parallelization", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", ], ) @@ -355,6 +388,7 @@ cc_library( ":map_and_batch_fusion", ":map_and_filter_fusion", ":map_fusion", + ":map_parallelization", ":map_vectorization", ":noop_elimination", ":shuffle_and_repeat_fusion", diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 5a7fe19265..d4ab444036 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -273,7 +273,7 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, string name = string(prefix); int id = graph->node_size(); while (ContainsGraphNodeWithName(name, *graph)) { - if (name.rfind("_generated") != std::string::npos && + if (name.rfind("_generated") != string::npos && (name.rfind("_generated") == (name.size() - strlen("_generated")))) { name.insert(name.rfind("_generated"), strings::StrCat("/_", id)); } else { diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc new file mode 100644 index 0000000000..305325e434 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -0,0 +1,106 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h" + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" + +namespace tensorflow { +namespace grappler { +namespace { + +bool CanParallelize(const FunctionDef& function, + const FunctionLibraryDefinition& library) { + if (!function.signature().is_stateful()) return true; + + for (const auto& node : function.node_def()) { + const OpDef* op_def; + TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def)); + // Assert is marked as stateful, but it does not have any state (except + // changing io). Similarly to CUDA, we do not give guarantee that the + // assert operation that would fail would be the first one, so that we can + // parallelize it. + if (op_def->is_stateful() && op_def->name() != "Assert") return false; + } + + return true; +} + +NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) { + NodeDef parallel_map = map_node; + graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(), + ¶llel_map); + parallel_map.set_op("ParallelMapDataset"); + // TODO(b/114475558): We want to set `num_parallel_calls` to a special value, + // so that dynamic tunning will pick the optimal value at runtime. Because + // this feature is not yet implemented, we set it to 2, which is the smallest + // value that introduces parallelism. + auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph); + parallel_map.add_input(num_parallel_calls->name()); + + return parallel_map; +} + +} // namespace + +Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + auto get_map_node = [](const NodeDef& node) -> const NodeDef* { + if (node.op() == "MapDataset") return &node; + return nullptr; + }; + + for (const NodeDef& node : item.graph.node()) { + const NodeDef* map_node = get_map_node(node); + if (!map_node) continue; + + auto* function = + function_library.Find(map_node->attr().at("f").func().name()); + if (!CanParallelize(*function, function_library)) continue; + + auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph)); + graph.ReplaceInput(*map_node, *parallel_map); + + // TODO(prazek): we could also remove map functions from library if they + // are not used anymore. + nodes_to_delete.insert(map_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h new file mode 100644 index 0000000000..ac9cf7e12a --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h @@ -0,0 +1,47 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimization parallelizes MapDataset when function is stateless. +class MapParallelization : public CustomGraphOptimizer { + public: + MapParallelization() = default; + ~MapParallelization() override = default; + + string name() const override { return "map_parallelization"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_ diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc new file mode 100644 index 0000000000..b2a5d9b6af --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc @@ -0,0 +1,94 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h" + +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapDataset", {string(input_node_name)}, + {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", {}}, + {"output_types", {}}}); +} + +const char stateless_fun_name[] = "XTimesTwo"; +const char stateful_fun_name[] = "RandomUniform"; + +TEST(MapParallelizationTest, ParallelizeSimpleMap) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeMapNode("map1", "range", stateless_fun_name)}, + // FunctionLib + { + test::function::XTimesTwo(), + }); + + MapParallelization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + +TEST(MapParallelization, ParallelizeAssert) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, {}), + MakeMapNode("map1", "range", stateful_fun_name), + MakeMapNode("map2", "map1", stateless_fun_name), + NDef("cache", "CacheDataset", {"map2", "filename"}, {})}, + // FunctionLib + { + test::function::XTimesTwo(), + test::function::RandomUniform(), + }); + + MapParallelization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output)); + EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output)); + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc new file mode 100644 index 0000000000..2c36c9b7b3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc @@ -0,0 +1,111 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include <string> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +Status ExperimentalImplementationSelector::LoadFunctions( + const GraphDef& graph) { + lib_info_.reset(new FunctionLibraryApiInfo); + TF_RETURN_IF_ERROR(lib_info_->Init(graph.library())); + return Status::OK(); +} + +Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall( + NodeDef* node_def) const { + // There are two ways of calling functions: + // 1. By specifying an op name as a function name, or + // 2. Via the @defun functional interface, where the real function name + // appear as the attribute with type func. + std::vector<string> function_attribute_names; + for (const auto& attr : node_def->attr()) { + if (attr.second.has_func() && + lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) { + function_attribute_names.emplace_back(attr.first); + } + } + + if (function_attribute_names.empty() && + lib_info_->GetApiInfo(node_def->op()) == nullptr) { + // A regular op, or a function which has no interface. + return Status::OK(); + } + + string task, device; + if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) { + return errors::Internal("Could not split device name:", node_def->device()); + } + VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device() + << " = (" << task << ", " << device << ")"; + DeviceNameUtils::ParsedName parsed_name; + DeviceNameUtils::ParseLocalName(device, &parsed_name); + + for (const auto& attr_name : function_attribute_names) { + string function_name = node_def->attr().at(attr_name).func().name(); + string best_function_name; + lib_info_->GetBestImplementation(function_name, parsed_name.type, + &best_function_name); + if (function_name != best_function_name) { + node_def->mutable_attr() + ->find(attr_name) + ->second.mutable_func() + ->set_name(best_function_name); + } + } + if (lib_info_->GetApiInfo(node_def->op()) != nullptr) { + string best_function_name; + lib_info_->GetBestImplementation(node_def->op(), parsed_name.type, + &best_function_name); + if (node_def->op() != best_function_name) { + node_def->set_op(best_function_name); + } + } + return Status::OK(); +} + +Status ExperimentalImplementationSelector::SelectImplementation( + GraphDef* graph) const { + for (int k = 0; k < graph->node_size(); ++k) + TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k))); + + return Status::OK(); +} + +Status ExperimentalImplementationSelector::Optimize(Cluster* cluster, + const GrapplerItem& item, + GraphDef* optimized_graph) { + *optimized_graph = item.graph; + TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph)); + return SelectImplementation(optimized_graph); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h new file mode 100644 index 0000000000..82f7473a14 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h @@ -0,0 +1,115 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ + +#include <string> + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/function_api_info.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace grappler { + +// -- EXPERIMENTAL -- +// This transformation replaces function calls by the appropriate function +// definition based on properties of the runtime system. For instance, +// we may choose one implementation over another if we have a GPU with +// enough memory available. +// +// It is a way for the programmer to specify alternative implementations +// of the same functionality in the graph, and let TensorFlow pick the +// most appropriate one at runtime. +// +// For instance, the python code might specify: +// @Defun(tf.float32, +// experimental_api_implements='plus_one', +// experimental_api_preferred_device='GPU') +// def plus_one_gpu(x): return x + 1.0 +// +// @Defun(tf.float32, +// experimental_api_implements='plus_one') +// def plus_one_reference_implementation(x): return x + 1.0 +// input = tf.constant(2.0, dtype=tf.float32) +// +// z = plus_one_reference_implementation(input) +// z = plus_one_gpu(input) +// print(sess.run(z)) +// +// At runtime, we will trim either `plus_one_gpu` or +// `plus_one_reference_implementation` based on the availability of the GPU. +// +// Available annotations: +// - experimental_api_implements(string): all functions mapping to the same +// string can be interchanged. For now, all functions must have the same +// signature and overloads are not allowed. Defuns within defuns are +// allowed. +// - experimental_api_preferred_device(string): sets which device is preferred. +class ExperimentalImplementationSelector : public CustomGraphOptimizer { + public: + ExperimentalImplementationSelector() = default; + ~ExperimentalImplementationSelector() override = default; + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + string name() const override { + return "experimental_implementation_selector"; + } + + // This call is not thread-safe. + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override; + + // Does not take any feedback. + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + Status LoadFunctions(const GraphDef& graph); + Status MaybeOptimizeFunctionCall(NodeDef* node_def) const; + + // Finds all call sites for functions, then replace with the appropriate + // implementation. + // There are two ways of calling functions: + // 1. By specifying an op name as a function name, and + // 2. Via the functional interface, where the function name appears as an + // Attr. + // + // There may be multiple call sites for a given function. The function body + // may call into another function, so a function might have to be duplicated. + // For simplicity, we do not change function bodies. Also, we do not change + // gradients. + Status SelectImplementation(GraphDef* graph) const; + + std::unique_ptr<FunctionLibraryApiInfo> lib_info_; + + TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector); +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_ diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc new file mode 100644 index 0000000000..3f1ebefac6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc @@ -0,0 +1,138 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" + +#include <algorithm> +#include <memory> +#include <string> +#include <vector> + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/utils/grappler_test.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +constexpr char CpuDevice[] = "/device:CPU:0"; +constexpr char GpuDevice[] = "/device:GPU:0"; + +class ExperimentalImplementationSelectorTest : public GrapplerTest {}; + +TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + std::unique_ptr<CustomGraphOptimizer> optimizer( + new ExperimentalImplementationSelector); + ASSERT_NE(nullptr, optimizer); + TF_ASSERT_OK(optimizer->Init()); + + GraphDef output; + const Status status = optimizer->Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + // This is a trivial graph so there is nothing to update. + EXPECT_EQ(item.graph.node_size(), output.node_size()); +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("times_two"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XAddX(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("times_two"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice), + NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice), + NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(output.node_size(), 5); + for (const NodeDef& node : output.node()) { + if (node.name() == "y1") { + // Make sure the implementation has been swapped to use the GPU version. + EXPECT_EQ("XAddX", node.op()); + } else if (node.name() == "y2") { + // Make sure the implementation is not changed. + EXPECT_EQ("XTimesTwo", node.op()); + } + } +} + +TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) { + using test::function::NDef; + auto cpu_def = test::function::XTimesTwo(); + auto* func_attr = cpu_def.mutable_attr(); + (*func_attr)["experimental_api_implements"].set_s("random_boost"); + (*func_attr)["experimental_api_preferred_device"].set_s("CPU"); + + auto gpu_def = test::function::XTimesFour(); + auto* func2_attr = gpu_def.mutable_attr(); + (*func2_attr)["experimental_api_implements"].set_s("random_boost"); + (*func2_attr)["experimental_api_preferred_device"].set_s("GPU"); + + ExperimentalImplementationSelector optimizer; + GraphDef output; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice), + NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice), + NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)}, + // FunctionLib + {cpu_def, gpu_def}); + + const Tensor input = test::AsScalar<float>(1.0f); + item.fetch = {"z"}; + item.feed.emplace_back("x", input); + + const auto four_times_boosted_tensor = EvaluateFetchNodes(item); + test::ExpectTensorEqual<float>(four_times_boosted_tensor[0], + test::AsScalar<float>(4.0f)); + + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + GrapplerItem optimized(item, std::move(output)); + const auto twice_boosted_tensor = EvaluateFetchNodes(optimized); + test::ExpectTensorEqual<float>(twice_boosted_tensor[0], + test::AsScalar<float>(2.0f)); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc new file mode 100644 index 0000000000..798e0f6fd5 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info.cc @@ -0,0 +1,167 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_api_info.h" + +#include <string> +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +FunctionApiInfo::FunctionApiInfo() {} +FunctionApiInfo::~FunctionApiInfo() {} + +Status FunctionApiInfo::Init(const FunctionDef& function_def) { + for (const auto& attr : function_def.attr()) { + if (attr.first == "experimental_api_preferred_device") { + preferred_device_ = attr.second.s(); + } + if (attr.first == "experimental_api_implements") { + interface_name_ = attr.second.s(); + } + } + if (interface_name_.empty() && !preferred_device_.empty()) { + return errors::InvalidArgument( + "Function '", function_def.signature().name(), + "' has a preferred device, but does not implement an interface"); + } + return Status::OK(); +} + +const string& FunctionApiInfo::preferred_device() const { + return preferred_device_; +} + +const string& FunctionApiInfo::interface_name() const { + return interface_name_; +} + +FunctionLibraryApiInfo::FunctionLibraryApiInfo() {} +FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {} + +namespace { +bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) { + if (f1.ret().size() != f2.ret().size()) return false; + const auto& sig1 = f1.signature(); + const auto& sig2 = f2.signature(); + // Functions have positional semantics, so we don't check for names. + if (sig1.input_arg_size() != sig2.input_arg_size()) return false; + for (int k = 0; k < sig1.input_arg_size(); ++k) { + const OpDef::ArgDef& arg1 = sig1.input_arg(k); + const OpDef::ArgDef& arg2 = sig2.input_arg(k); + if (arg1.type() != arg2.type()) return false; + if (arg1.type_attr() != arg2.type_attr()) return false; + if (arg1.number_attr() != arg2.number_attr()) return false; + if (arg1.type_list_attr() != arg2.type_list_attr()) return false; + if (arg1.is_ref() != arg2.is_ref()) return false; + } + return true; +} + +Status ValidateSignature(const string& interface_name, + const std::vector<const FunctionDef*>& equiv_funcs) { + if (equiv_funcs.size() < 2) return Status::OK(); + for (size_t k = 1; k < equiv_funcs.size(); ++k) { + if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k])) + return errors::InvalidArgument( + "Functions '", equiv_funcs[0]->signature().name(), "' and '", + equiv_funcs[k]->signature().name(), "' both implement '", + interface_name, "' but their signatures do not match."); + } + return Status::OK(); +} + +Status ValidateSignatures( + const std::unordered_map<string, std::vector<const FunctionDef*>>& + intf_to_func) { + for (const auto& item : intf_to_func) + TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second)); + return Status::OK(); +} +} // namespace + +Status FunctionLibraryApiInfo::Init( + const FunctionDefLibrary& function_library) { + std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func; + for (const auto& function : function_library.function()) { + std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo); + TF_RETURN_IF_ERROR(func_info->Init(function)); + // Ignore the function if it does not implement any interface. + if (func_info->interface_name().empty()) continue; + + const string& function_name = function.signature().name(); + const string& interface_name = func_info->interface_name(); + func_to_intf_[function_name] = interface_name; + intf_to_funcs_[interface_name].emplace_back(function_name); + intf_to_func[interface_name].emplace_back(&function); + func_info_[function_name] = std::move(func_info); + } + TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func)); + return Status::OK(); +} + +void FunctionLibraryApiInfo::GetEquivalentImplementations( + const string& function_name, std::vector<string>* other_names) const { + const auto intf_it = func_to_intf_.find(function_name); + // The function does not implement any interface. + if (intf_it == func_to_intf_.end()) return; + CHECK(!intf_it->second.empty()) << "Function " << function_name + << "should at least implement 1 interface."; + const auto it = intf_to_funcs_.find(intf_it->second); + CHECK(it != intf_to_funcs_.end()) + << "Function " << function_name << " maps to " << intf_it->second + << " but no reverse mapping was found"; + CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty"; + other_names->reserve(it->second.size() - 1); + for (const auto& other_name : it->second) { + if (other_name == function_name) continue; + other_names->emplace_back(other_name); + } +} + +void FunctionLibraryApiInfo::GetBestImplementation( + const string& function_name, const string& device, + string* best_func_name) const { + CHECK(best_func_name != nullptr); + const auto func_it = func_to_intf_.find(function_name); + if (func_it == func_to_intf_.end()) return; + + const auto it = intf_to_funcs_.find(func_it->second); + // No function found for the given interface. + if (it == intf_to_funcs_.end()) return; + for (const auto& func_name : it->second) { + const auto func_api_info = func_info_.find(func_name)->second.get(); + if (func_api_info->preferred_device() == device) { + best_func_name->assign(func_name); + return; + } + } + // Didn't find a function with the match device name, choose the first one + // among all the available functions. + best_func_name->assign(it->second.front()); +} + +const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo( + const string& function_name) const { + const auto it = func_info_.find(function_name); + if (it == func_info_.end()) return nullptr; + return it->second.get(); +} + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h new file mode 100644 index 0000000000..412687c58c --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info.h @@ -0,0 +1,80 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ + +#include <memory> +#include <string> +#include <unordered_map> +#include <vector> + +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace grappler { +class FunctionApiInfo { + public: + FunctionApiInfo(); + virtual ~FunctionApiInfo(); + + Status Init(const FunctionDef& function_def); + + const string& interface_name() const; + const string& preferred_device() const; + + private: + string interface_name_; + string preferred_device_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo); +}; + +// A collection of information for function and the interface it implements. +// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple +// functions could implement the same interface with different behavior based on +// different hardware condition and limits, +// eg F1 = math_ops.add(math_ops.add(x, x), y), or +// F2 = math_ops.add(math_ops.matmul(x, 2), y). +class FunctionLibraryApiInfo { + public: + FunctionLibraryApiInfo(); + virtual ~FunctionLibraryApiInfo(); + // Populate the internal field for the functions within the function_library. + Status Init(const FunctionDefLibrary& function_library); + + void GetEquivalentImplementations(const string& function_name, + std::vector<string>* other_names) const; + + void GetBestImplementation(const string& function_name, const string& device, + string* best_func_name) const; + + const FunctionApiInfo* GetApiInfo(const string& function_name) const; + + private: + // Map between function name to function details. + std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_; + // Map between function name to interface name. + std::unordered_map<string, string> func_to_intf_; + // Map between interface name to function names. + std::unordered_map<string, std::vector<string>> intf_to_funcs_; + + TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo); +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_ diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc new file mode 100644 index 0000000000..582890d3e3 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc @@ -0,0 +1,160 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/function_api_info.h" + +#include <string> +#include <unordered_set> +#include <vector> + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { +void SetArg(const string& name, const string& type_name, + OpDef::ArgDef* arg_def) { + arg_def->set_name(name); + arg_def->set_type_attr(type_name); +} + +typedef std::pair<string, string> ArgSpec; // name, type. + +void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) { + for (const auto& arg_spec : args_spec) + SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg()); + SetArg("output", "float32", sig->add_output_arg()); +} + +void PopulateFunction(const string& name, const string& api_interface_name, + const string& preferred_device, + const std::vector<ArgSpec>& input_args, + FunctionDef* func_def) { + OpDef* sig = func_def->mutable_signature(); + sig->set_name(name); + + SetArgs(input_args, sig); + + if (!api_interface_name.empty() || !preferred_device.empty()) { + auto* func_attr = func_def->mutable_attr(); + if (!api_interface_name.empty()) + (*func_attr)["experimental_api_implements"].set_s(api_interface_name); + if (!preferred_device.empty()) + (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device); + } +} + +void PopulateSampleLibrary(const bool mismatch_args, + FunctionDefLibrary* func_lib) { + const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}}; + const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"}, + {"in2", "int32"}}; + PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args, + func_lib->add_function()); + PopulateFunction("DoStuffGpu", "DoStuff", "GPU", + mismatch_args ? func_wrong_args : func_args, + func_lib->add_function()); + PopulateFunction("DoThings", "DoThings", "", func_args, + func_lib->add_function()); + PopulateFunction("OneOff", "", "", func_args, func_lib->add_function()); + PopulateFunction("AnotherOneOff", "", "", func_args, + func_lib->add_function()); +} + +bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name, + const std::vector<string>& expected_other) { + std::vector<string> other_impl; + lib_api_info.GetEquivalentImplementations(func_name, &other_impl); + const std::unordered_set<string> actual(other_impl.begin(), other_impl.end()); + const std::unordered_set<string> expected(expected_other.begin(), + expected_other.end()); + return actual == expected; +} + +bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info, + const string& function_name, const string& device, + const string& expected_function_name) { + string best_function_name; + lib_api_info.GetBestImplementation(function_name, device, + &best_function_name); + + return best_function_name == expected_function_name; +} + +string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name) { + auto* info = lib_api_info.GetApiInfo(func_name); + CHECK_NOTNULL(info); + return info->interface_name(); +} + +string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info, + const string& func_name) { + auto* info = lib_api_info.GetApiInfo(func_name); + CHECK_NOTNULL(info); + return info->preferred_device(); +} + +TEST(FunctionApiInfoTest, ParseTags) { + FunctionDefLibrary func_lib; + PopulateSampleLibrary(/* mismatch_args */ false, &func_lib); + FunctionLibraryApiInfo lib_api_info; + TF_ASSERT_OK(lib_api_info.Init(func_lib)); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {})); + EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {})); + + EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu")); + EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu")); + EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings")); + + EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu")); + EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu")); + EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings")); + + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu")); + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu")); + + EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings")); + // TPU impl is not available, choose the first one available which is the CPU. + EXPECT_TRUE( + CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu")); +} + +TEST(FunctionApiInfoTest, MismatchedArguments) { + FunctionDefLibrary func_lib; + PopulateSampleLibrary(/* mismatch_args */ true, &func_lib); + FunctionLibraryApiInfo lib_api_info; + const Status ret = lib_api_info.Init(func_lib); + EXPECT_FALSE(ret.ok()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc index 91794cefe5..c775a26914 100644 --- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc @@ -1071,11 +1071,13 @@ static bool IdentifySwappingCandidates( // ensure that swapping the tensor back in won't recreate the memory // bottleneck. Last but not least, we want the tensor to have as few // remaining uses as possible. + // + // Note that we must perform the arithmetic inexactly as "double", since + // the values do not fit into any integral type. mem_info.fitness = - MathUtil::IPow((earliest_use - peak_time).count(), 2); - mem_info.fitness /= MathUtil::IPow(mem_info.uses_left.size(), 2); - mem_info.fitness += - MathUtil::IPow((allocation_time - peak_time).count(), 2); + MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) / + MathUtil::IPow<double>(mem_info.uses_left.size(), 2) + + MathUtil::IPow<double>((allocation_time - peak_time).count(), 2); mem_info.fitness = -mem_info.fitness; mem_state.push_back(mem_info); } diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index 5fd34efeb1..1ed1b22931 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" #include "tensorflow/core/grappler/optimizers/debug_stripper.h" #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h" +#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h" #include "tensorflow/core/grappler/optimizers/function_optimizer.h" #include "tensorflow/core/grappler/optimizers/layout_optimizer.h" #include "tensorflow/core/grappler/optimizers/loop_optimizer.h" @@ -72,6 +73,16 @@ bool IsRunOnceOptimizer(const string& name) { name == "loop_optimizer"; } +// Check if the graphdef contains nodes that indicate TPU execution. +bool IsTPUGraphDef(const GraphDef& def) { + for (auto node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} + } // namespace #define MK_OPT(NAME, VALUE) \ @@ -156,7 +167,7 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } - return Status::OK(); + return InitializeCustomGraphOptimizers(optimizers); } Status MetaOptimizer::InitializeOptimizersByName( @@ -180,9 +191,24 @@ Status MetaOptimizer::InitializeOptimizersByName( VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } + return InitializeCustomGraphOptimizers(optimizers); +} + +Status MetaOptimizer::InitializeCustomGraphOptimizers( + std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { for (const auto& optimizer_config : cfg_.custom_optimizers()) { - auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull( - optimizer_config.name()); + // Initialize the ExperimentalImplementationSelector here instead of + // CustomizeOptimizer registry, due the static link issue in TensorRT for + // double registry. + // TODO(laigd): Remove this hack and change it back to use the registry once + // the duplicate static import issue is fixed. + std::unique_ptr<CustomGraphOptimizer> custom_optimizer; + if (optimizer_config.name() == "ExperimentalImplementationSelector") { + custom_optimizer.reset(new ExperimentalImplementationSelector()); + } else { + custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull( + optimizer_config.name()); + } if (custom_optimizer) { VLOG(2) << "Registered custom configurable graph optimizer: " << optimizer_config.name(); @@ -208,7 +234,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, } std::vector<std::unique_ptr<GraphOptimizer>> optimizers; - if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) { + if (cfg_.optimizers().empty()) { TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers)); } else { TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers)); @@ -326,10 +352,25 @@ Status MetaOptimizer::RunOptimizer( Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { + LOG(INFO) << "Starting optimization for grappler item: " << item.id; optimization_results_.clear(); // 1. Optimize main graph TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph)); + VLOG(1) << "Optimized main graph."; + + // Skip optimizing functions if this is a TPU graph. Currently, Grappler + // passes do not handle TPU functions correctly in a variety of ways (Note + // that due to the pre-placement TPU graph rewriting passes, the TPU-related + // ops are encapsulated away into functions). For example, TPU graphs contain + // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler + // passes could prune that away. Grappler passes could also cause issues + // around shape inference. Since the desired and existing behavior is to not + // optimize TPU functions with Grappler, this check preserves that. + if (IsTPUGraphDef(*optimized_graph)) { + VLOG(2) << "Skipping optimizing funcs for TPU graphs"; + return Status::OK(); + } // 2. Optimize function library FunctionLibraryDefinition flib(OpRegistry::Global(), @@ -393,7 +434,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } } - VLOG(3) << "Optimized " << optimized_funcs.size() + VLOG(1) << "Optimized " << optimized_funcs.size() << " functions: " << str_util::Join(optimized_funcs, ", "); return Status::OK(); diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 151a54cbdf..831c5e37c0 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -52,6 +52,9 @@ class MetaOptimizer : public GraphOptimizer { // Initialize active optimizers from RewriterConfig optimizer names. Status InitializeOptimizersByName( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; + // Initialize active optimizers from RewriterConfig.custom_optimizers. + Status InitializeCustomGraphOptimizers( + std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index 9a03c7dfef..e74e0f7501 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -64,6 +64,13 @@ bool TestOptimizer::optimized_; REGISTER_GRAPH_OPTIMIZER(TestOptimizer); +class TestGraphOptimizer : public TestOptimizer { + public: + string name() const override { return "test_graph_optimizer"; } +}; + +REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -83,6 +90,27 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + TestGraphOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizer"); + auto customGraphOptimizer = rewriter_config.add_custom_optimizers(); + customGraphOptimizer->set_name("TestGraphOptimizer"); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); + EXPECT_TRUE(TestGraphOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; @@ -98,6 +126,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) { TF_EXPECT_OK(status); } +TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + RewriterConfig rewriter_config; + auto customGraphOptimizer = rewriter_config.add_custom_optimizers(); + customGraphOptimizer->set_name("TestGraphOptimizer"); + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestGraphOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { using test::function::NDef; diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index a2c363ea6e..a428aea7f5 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -304,21 +304,21 @@ Status GrapplerFunctionItemInstantiation::GetArgType( } GrapplerFunctionItem::GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector<InputArgExpansion>& input_arg_expansions, - const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, const int graph_def_version, - bool is_stateful, GraphDef&& function_body) - : description_(description), - func_attr_(func_attr), - input_arg_expansions_(input_arg_expansions), - output_arg_expansions_(output_arg_expansions), + string func_name, string description, AttrValueMap func_attr, + std::vector<InputArgExpansion> input_arg_expansions, + std::vector<OutputArgExpansion> output_arg_expansions, + std::vector<string> keep_nodes, const int graph_def_version, + const bool is_stateful, GraphDef&& function_body) + : description_(std::move(description)), + func_attr_(std::move(func_attr)), + input_arg_expansions_(std::move(input_arg_expansions)), + output_arg_expansions_(std::move(output_arg_expansions)), is_stateful_(is_stateful) { - id = func_name; - keep_ops = keep_nodes; - // Swap the graph body. - graph.Swap(&function_body); + // Move assign GrapplerItem members. + keep_ops = std::move(keep_nodes); + id = std::move(func_name); + graph = std::move(function_body); + graph.mutable_versions()->set_producer(graph_def_version); // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { @@ -598,8 +598,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, *item = GrapplerFunctionItem( /*func_name=*/signature.name(), /*description=*/signature.description(), /*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()), - inputs, outputs, keep_nodes, graph_def_version, is_stateful, - std::move(function_body)); + std::move(inputs), std::move(outputs), std::move(keep_nodes), + graph_def_version, is_stateful, std::move(function_body)); return Status::OK(); } diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 61588ceb83..733caf325f 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -136,13 +136,12 @@ class GrapplerFunctionItemInstantiation { class GrapplerFunctionItem : public GrapplerItem { public: GrapplerFunctionItem() = default; - GrapplerFunctionItem( - const string& func_name, const string& description, - const AttrValueMap& func_attr, - const std::vector<InputArgExpansion>& input_arg_expansions, - const std::vector<OutputArgExpansion>& output_arg_expansions, - const std::vector<string>& keep_nodes, const int versions, - bool is_stateful, GraphDef&& function_body); + GrapplerFunctionItem(string func_name, string description, + AttrValueMap func_attr, + std::vector<InputArgExpansion> input_arg_expansions, + std::vector<OutputArgExpansion> output_arg_expansions, + std::vector<string> keep_nodes, int graph_def_version, + bool is_stateful, GraphDef&& function_body); const string& description() const; diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h index 4fb7aab647..ceb9f5dbf2 100644 --- a/tensorflow/core/grappler/utils/scc.h +++ b/tensorflow/core/grappler/utils/scc.h @@ -24,15 +24,16 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Compute modified strongly connected components: +// Computes modified strongly connected components: // All nodes that are not part of a loop are assigned the special -1 id // All nodes that are part of at least one loop are assigned a positive // component id: if 2 nodes v and w are reachable from one another (i.e. if they // belong to the same scc), they'll be assigned the same id, otherwise they'll -// be assigned distinct ids. Returns the number of distinct ids. +// be assigned distinct ids. *num_components is set to the number of distinct +// ids. void StronglyConnectedComponents( const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components, - int* num_ids); + int* num_components); // Returns the number of individual loops present in the graph, and populate the // 'loops' argument with the collection of loops (denoted by their loop ids) a diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 633fe9ab77..7aa1169061 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -30,6 +30,7 @@ load( "//tensorflow:tensorflow.bzl", "if_android", "tf_cc_test", + "tf_cc_test_mkl", "tf_cc_tests", "tf_cc_binary", "tf_copts", @@ -50,6 +51,10 @@ load( "tf_kernel_tests_linkstatic", ) load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) +load( "//third_party/mkl:build_defs.bzl", "if_mkl", "if_mkl_ml", @@ -643,14 +648,7 @@ cc_library( ":split_v_op", ":strided_slice_op", ":tile_ops", - ] + if_mkl( - [ - ":mkl_transpose_op", - ], - [ - ":transpose_op", - ], - ) + [ + ":transpose_op", ":unique_op", ":unpack_op", ":unravel_index_op", @@ -893,24 +891,13 @@ tf_kernel_library( deps = ARRAY_DEPS, ) -if_mkl( - [tf_mkl_kernel_library( - name = "mkl_transpose_op", - srcs = [ - "mkl_transpose_op.cc", - "transpose_op.cc", - ], - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS + mkl_deps(), - )], - [tf_kernel_library( - name = "transpose_op", - srcs = [ - "transpose_op.cc", - ], - hdrs = ["transpose_op.h"], - deps = ARRAY_DEPS, - )], +tf_kernel_library( + name = "transpose_op", + srcs = [ + "transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]), ) tf_kernel_library( @@ -1123,7 +1110,7 @@ tf_cuda_cc_test( name = "depthwise_conv_ops_test", size = "small", srcs = ["depthwise_conv_ops_test.cc"], - tags = ["requires-gpu-sm35"], + tags = tf_cuda_tests_tags(), deps = [ ":conv_ops", ":image", @@ -2296,6 +2283,31 @@ tf_cc_tests( ], ) +cc_library( + name = "eigen_benchmark", + testonly = 1, + hdrs = [ + "eigen_benchmark.h", + ":eigen_helpers", + ], + deps = [ + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_cc_test( + name = "eigen_benchmark_cpu_test", + srcs = ["eigen_benchmark_cpu_test.cc"], + deps = [ + ":eigen_benchmark", + ":eigen_helpers", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//third_party/eigen3", + ], +) + tf_cc_tests( name = "basic_ops_benchmark_test", size = "small", @@ -4196,6 +4208,7 @@ cc_library( "hinge-loss.h", "logistic-loss.h", "loss.h", + "poisson-loss.h", "smooth-hinge-loss.h", "squared-loss.h", ], @@ -4496,6 +4509,25 @@ tf_kernel_library( deps = STRING_DEPS, ) +tf_cc_test( + name = "substr_op_test", + size = "small", + srcs = ["substr_op_test.cc"], + deps = [ + ":substr_op", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", + "//tensorflow/core/kernels:ops_util", + ], +) + tf_kernel_library( name = "as_string_op", prefix = "as_string_op", @@ -5176,6 +5208,7 @@ filegroup( "fifo_queue.cc", "fifo_queue_op.cc", "fused_batch_norm_op.cc", + "listdiff_op.cc", "population_count_op.cc", "population_count_op.h", "winograd_transform.h", @@ -6200,6 +6233,26 @@ tf_mkl_kernel_library( ] + mkl_deps(), ) +tf_cc_test_mkl( + name = "mkl_conv_ops_test", + size = "small", + srcs = ["mkl_conv_ops_test.cc"], + deps = [ + ":ops_testutil", + ":ops_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + tf_mkl_kernel_library( name = "mkl_tfconv_op", prefix = "mkl_tfconv", @@ -6325,6 +6378,15 @@ tf_mkl_kernel_library( deps = NN_DEPS + mkl_deps() + [":cwise_op"], ) +tf_mkl_kernel_library( + name = "mkl_transpose_op", + srcs = [ + "mkl_transpose_op.cc", + ], + hdrs = ["transpose_op.h"], + deps = ARRAY_DEPS + mkl_deps(), +) + # NOTE(lespeholt): This rule is deprecated, please use: # tensorflow/core/util/batch_util.h cc_library( diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index 7b28c8e91f..e15ea82e7d 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -134,8 +134,8 @@ class BiasOp : public BinaryOp<T> { if (data_format_ == FORMAT_NCHW) { int32 batch, height, width, channel; GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel); - Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1); - Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width); + Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1); + Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width); const Device& d = context->eigen_device<Device>(); output->tensor<T, 4>().device(d) = input.tensor<T, 4>() + @@ -247,14 +247,14 @@ class BiasGradOp : public OpKernel { OP_REQUIRES(context, output_backprop.dims() == 4, errors::InvalidArgument( "NCHW format supports only 4D input/output tensor.")); - Eigen::DSizes<int, 4> four_dims(batch, channel, height, width); + Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width); #ifdef EIGEN_HAS_INDEX_LIST using idx0 = Eigen::type2index<0>; using idx2 = Eigen::type2index<2>; using idx3 = Eigen::type2index<3>; Eigen::IndexList<idx0, idx2, idx3> reduction_axes; #else - Eigen::array<int, 3> reduction_axes = {0, 2, 3}; + Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3}; #endif output->template flat<T>().device(context->eigen_device<Device>()) = output_backprop.flat<T>() @@ -263,11 +263,12 @@ class BiasGradOp : public OpKernel { .sum(reduction_axes) .template cast<T>(); // End of code by intel_tf. } else { - Eigen::DSizes<int, 2> two_dims(batch * height * width, channel); + Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width, + channel); #ifdef EIGEN_HAS_INDEX_LIST Eigen::IndexList<Eigen::type2index<0> > reduction_axis; #else - Eigen::array<int, 1> reduction_axis = {0}; + Eigen::array<Eigen::Index, 1> reduction_axis = {0}; #endif output->template flat<T>().device(context->eigen_device<Device>()) = output_backprop.flat<T>() diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD index 4910021c63..4e8bfa02fc 100644 --- a/tensorflow/core/kernels/boosted_trees/BUILD +++ b/tensorflow/core/kernels/boosted_trees/BUILD @@ -15,7 +15,9 @@ load( tf_proto_library( name = "boosted_trees_proto", - srcs = ["boosted_trees.proto"], + srcs = [ + "boosted_trees.proto", + ], cc_api_version = 2, visibility = ["//visibility:public"], ) @@ -87,9 +89,21 @@ tf_kernel_library( ) tf_kernel_library( + name = "quantile_ops", + srcs = ["quantile_ops.cc"], + deps = [ + "//tensorflow/core:boosted_trees_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles", + ], +) + +tf_kernel_library( name = "boosted_trees_ops", deps = [ ":prediction_ops", + ":quantile_ops", ":resource_ops", ":stats_ops", ":training_ops", diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc new file mode 100644 index 0000000000..d1840941c1 --- /dev/null +++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc @@ -0,0 +1,453 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include <algorithm> +#include <iterator> +#include <string> +#include <vector> + +#include "tensorflow/core/framework/device_base.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { + +const char* const kExampleWeightsName = "example_weights"; +const char* const kMaxElementsName = "max_elements"; +const char* const kGenerateQuantiles = "generate_quantiles"; +const char* const kNumBucketsName = "num_buckets"; +const char* const kEpsilonName = "epsilon"; +const char* const kBucketBoundariesName = "bucket_boundaries"; +const char* const kBucketsName = "buckets"; +const char* const kSummariesName = "summaries"; +const char* const kNumStreamsName = "num_streams"; +const char* const kNumFeaturesName = "num_features"; +const char* const kFloatFeaturesName = "float_values"; +const char* const kResourceHandleName = "quantile_stream_resource_handle"; + +using QuantileStreamResource = BoostedTreesQuantileStreamResource; +using QuantileStream = + boosted_trees::quantiles::WeightedQuantilesStream<float, float>; +using QuantileSummary = + boosted_trees::quantiles::WeightedQuantilesSummary<float, float>; +using QuantileSummaryEntry = + boosted_trees::quantiles::WeightedQuantilesSummary<float, + float>::SummaryEntry; + +// Generates quantiles on a finalized QuantileStream. +std::vector<float> GenerateBoundaries(const QuantileStream& stream, + const int64 num_boundaries) { + std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries); + + // Uniquify elements as we may get dupes. + auto end_it = std::unique(boundaries.begin(), boundaries.end()); + boundaries.resize(std::distance(boundaries.begin(), end_it)); + return boundaries; +} + +// Generates quantiles on a finalized QuantileStream. +std::vector<float> GenerateQuantiles(const QuantileStream& stream, + const int64 num_quantiles) { + // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values + // will be returned. + std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1); + CHECK_EQ(boundaries.size(), num_quantiles); + return boundaries; +} + +std::vector<float> GetBuckets(const int32 feature, + const OpInputList& buckets_list) { + const auto& buckets = buckets_list[feature].flat<float>(); + std::vector<float> buckets_vector(buckets.data(), + buckets.data() + buckets.size()); + return buckets_vector; +} + +REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource); + +REGISTER_KERNEL_BUILDER( + Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU), + IsResourceInitialized<BoostedTreesQuantileStreamResource>); + +class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel { + public: + explicit BoostedTreesCreateQuantileStreamResourceOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_)); + } + + void Compute(OpKernelContext* context) override { + // Only create one, if one does not exist already. Report status for all + // other exceptions. If one already exists, it unrefs the new one. + // An epsilon value of zero could cause perfoamance issues and is therefore, + // disallowed. + const Tensor* epsilon_t; + OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t)); + float epsilon = epsilon_t->scalar<float>()(); + OP_REQUIRES( + context, epsilon > 0, + errors::InvalidArgument("An epsilon value of zero is not allowed.")); + + const Tensor* num_streams_t; + OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t)); + int64 num_streams = num_streams_t->scalar<int64>()(); + + auto result = + new QuantileStreamResource(epsilon, max_elements_, num_streams); + auto status = CreateResource(context, HandleFromInput(context, 0), result); + if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) { + OP_REQUIRES(context, false, status); + } + } + + private: + // An upper bound on the number of entries that the summaries might have + // for a feature. + int64 max_elements_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU), + BoostedTreesCreateQuantileStreamResourceOp); + +class BoostedTreesMakeQuantileSummariesOp : public OpKernel { + public: + explicit BoostedTreesMakeQuantileSummariesOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + // Read float features list; + OpInputList float_features_list; + OP_REQUIRES_OK( + context, context->input_list(kFloatFeaturesName, &float_features_list)); + + // Parse example weights and get batch size. + const Tensor* example_weights_t; + OP_REQUIRES_OK(context, + context->input(kExampleWeightsName, &example_weights_t)); + auto example_weights = example_weights_t->flat<float>(); + const int64 batch_size = example_weights.size(); + const Tensor* epsilon_t; + OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t)); + float epsilon = epsilon_t->scalar<float>()(); + + OpOutputList summaries_output_list; + OP_REQUIRES_OK( + context, context->output_list(kSummariesName, &summaries_output_list)); + + auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) { + // Iterating features. + for (int64 index = begin; index < end; index++) { + const auto feature_values = float_features_list[index].flat<float>(); + QuantileStream stream(epsilon, batch_size + 1); + // Run quantile summary generation. + for (int64 j = 0; j < batch_size; j++) { + stream.PushEntry(feature_values(j), example_weights(j)); + } + stream.Finalize(); + const auto summary_entry_list = stream.GetFinalSummary().GetEntryList(); + Tensor* output_t; + OP_REQUIRES_OK( + context, + summaries_output_list.allocate( + index, + TensorShape({static_cast<int64>(summary_entry_list.size()), 4}), + &output_t)); + auto output = output_t->matrix<float>(); + for (auto row = 0; row < summary_entry_list.size(); row++) { + const auto& entry = summary_entry_list[row]; + output(row, 0) = entry.value; + output(row, 1) = entry.weight; + output(row, 2) = entry.min_rank; + output(row, 3) = entry.max_rank; + } + } + }; + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * batch_size; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_features_, + kCostPerUnit, do_quantile_summary_gen); + } + + private: + int64 num_features_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU), + BoostedTreesMakeQuantileSummariesOp); + +class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel { + public: + explicit BoostedTreesQuantileStreamResourceAddSummariesOp( + OpKernelConstruction* const context) + : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + ResourceHandle handle; + OP_REQUIRES_OK(context, + HandleFromInput(context, kResourceHandleName, &handle)); + QuantileStreamResource* stream_resource; + // Create a reference to the underlying resource using the handle. + OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); + // Remove the reference at the end of this scope. + mutex_lock l(*stream_resource->mutex()); + core::ScopedUnref unref_me(stream_resource); + + OpInputList summaries_list; + OP_REQUIRES_OK(context, + context->input_list(kSummariesName, &summaries_list)); + int32 num_streams = stream_resource->num_streams(); + CHECK_EQ(static_cast<int>(num_streams), summaries_list.size()); + + auto do_quantile_add_summary = [&](const int64 begin, const int64 end) { + // Iterating all features. + for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) { + const Tensor& summaries = summaries_list[feature_idx]; + const auto summary_values = summaries.matrix<float>(); + const auto& tensor_shape = summaries.shape(); + const int64 entries_size = tensor_shape.dim_size(0); + CHECK_EQ(tensor_shape.dim_size(1), 4); + std::vector<QuantileSummaryEntry> summary_entries; + summary_entries.reserve(entries_size); + for (int64 i = 0; i < entries_size; i++) { + float value = summary_values(i, 0); + float weight = summary_values(i, 1); + float min_rank = summary_values(i, 2); + float max_rank = summary_values(i, 3); + QuantileSummaryEntry entry(value, weight, min_rank, max_rank); + summary_entries.push_back(entry); + } + stream_resource->stream(feature_idx)->PushSummary(summary_entries); + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_streams; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_streams, + kCostPerUnit, do_quantile_add_summary); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU), + BoostedTreesQuantileStreamResourceAddSummariesOp); + +class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel { + public: + explicit BoostedTreesQuantileStreamResourceFlushOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->GetAttr(kGenerateQuantiles, &generate_quantiles_)); + } + + void Compute(OpKernelContext* context) override { + ResourceHandle handle; + OP_REQUIRES_OK(context, + HandleFromInput(context, kResourceHandleName, &handle)); + QuantileStreamResource* stream_resource; + // Create a reference to the underlying resource using the handle. + OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); + // Remove the reference at the end of this scope. + mutex_lock l(*stream_resource->mutex()); + core::ScopedUnref unref_me(stream_resource); + + const Tensor* num_buckets_t; + OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t)); + const int64 num_buckets = num_buckets_t->scalar<int64>()(); + const int64 num_streams = stream_resource->num_streams(); + + auto do_quantile_flush = [&](const int64 begin, const int64 end) { + // Iterating over all streams. + for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) { + QuantileStream* stream = stream_resource->stream(stream_idx); + stream->Finalize(); + stream_resource->set_boundaries( + generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets) + : GenerateBoundaries(*stream, num_buckets), + stream_idx); + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_streams; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_streams, + kCostPerUnit, do_quantile_flush); + + stream_resource->set_buckets_ready(true); + } + + private: + bool generate_quantiles_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU), + BoostedTreesQuantileStreamResourceFlushOp); + +class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp + : public OpKernel { + public: + explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp( + OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + ResourceHandle handle; + OP_REQUIRES_OK(context, + HandleFromInput(context, kResourceHandleName, &handle)); + QuantileStreamResource* stream_resource; + // Create a reference to the underlying resource using the handle. + OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource)); + // Remove the reference at the end of this scope. + mutex_lock l(*stream_resource->mutex()); + core::ScopedUnref unref_me(stream_resource); + + const int64 num_streams = stream_resource->num_streams(); + CHECK_EQ(num_features_, num_streams); + OpOutputList bucket_boundaries_list; + OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName, + &bucket_boundaries_list)); + + auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) { + // Iterating over all streams. + for (int64 stream_idx = begin; stream_idx < end; stream_idx++) { + const auto& boundaries = stream_resource->boundaries(stream_idx); + Tensor* bucket_boundaries_t = nullptr; + OP_REQUIRES_OK(context, + bucket_boundaries_list.allocate( + stream_idx, {static_cast<int64>(boundaries.size())}, + &bucket_boundaries_t)); + auto* quantiles_flat = bucket_boundaries_t->flat<float>().data(); + memcpy(quantiles_flat, boundaries.data(), + sizeof(float) * boundaries.size()); + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_streams; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_streams, + kCostPerUnit, do_quantile_get_buckets); + } + + private: + int64 num_features_; +}; + +REGISTER_KERNEL_BUILDER( + Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries") + .Device(DEVICE_CPU), + BoostedTreesQuantileStreamResourceGetBucketBoundariesOp); + +// Given the calculated quantiles thresholds and input data, this operation +// converts the input features into the buckets (categorical values), depending +// on which quantile they fall into. +class BoostedTreesBucketizeOp : public OpKernel { + public: + explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context) + : OpKernel(context) { + OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_)); + } + + void Compute(OpKernelContext* const context) override { + // Read float features list; + OpInputList float_features_list; + OP_REQUIRES_OK( + context, context->input_list(kFloatFeaturesName, &float_features_list)); + OpInputList bucket_boundaries_list; + OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName, + &bucket_boundaries_list)); + OP_REQUIRES(context, + tensorflow::TensorShapeUtils::IsVector( + bucket_boundaries_list[0].shape()), + errors::InvalidArgument( + strings::Printf("Buckets should be flat vectors."))); + OpOutputList buckets_list; + OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list)); + + auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) { + // Iterating over all resources + for (int64 feature_idx = begin; feature_idx < end; feature_idx++) { + const Tensor& values_tensor = float_features_list[feature_idx]; + const int64 num_values = values_tensor.dim_size(0); + + Tensor* output_t = nullptr; + OP_REQUIRES_OK( + context, buckets_list.allocate( + feature_idx, TensorShape({num_values, 1}), &output_t)); + auto output = output_t->matrix<int32>(); + + const std::vector<float>& bucket_boundaries_vector = + GetBuckets(feature_idx, bucket_boundaries_list); + CHECK(!bucket_boundaries_vector.empty()) + << "Got empty buckets for feature " << feature_idx; + auto flat_values = values_tensor.flat<float>(); + for (int64 instance = 0; instance < num_values; instance++) { + const float value = flat_values(instance); + auto bucket_iter = + std::lower_bound(bucket_boundaries_vector.begin(), + bucket_boundaries_vector.end(), value); + if (bucket_iter == bucket_boundaries_vector.end()) { + --bucket_iter; + } + const int32 bucket = static_cast<int32>( + bucket_iter - bucket_boundaries_vector.begin()); + // Bucket id. + output(instance, 0) = bucket; + } + } + }; + + // TODO(tanzheny): comment on the magic number. + const int64 kCostPerUnit = 500 * num_features_; + const DeviceBase::CpuWorkerThreads& worker_threads = + *context->device()->tensorflow_cpu_worker_threads(); + Shard(worker_threads.num_threads, worker_threads.workers, num_features_, + kCostPerUnit, do_quantile_get_quantiles); + } + + private: + int64 num_features_; +}; + +REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU), + BoostedTreesBucketizeOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD index 3163c63949..12d9473776 100644 --- a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD +++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD @@ -1,5 +1,5 @@ # Description: -# This directory contains common utilities used in boosted_trees. +# This directory contains common quantile utilities used in boosted_trees. package( default_visibility = ["//tensorflow:internal"], ) @@ -16,6 +16,7 @@ cc_library( name = "weighted_quantiles", srcs = [], hdrs = [ + "quantile_stream_resource.h", "weighted_quantiles_buffer.h", "weighted_quantiles_stream.h", "weighted_quantiles_summary.h", @@ -23,6 +24,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", ], ) diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h new file mode 100644 index 0000000000..1c31724272 --- /dev/null +++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h @@ -0,0 +1,96 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ +#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ + +#include <vector> +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/mutex.h" + +namespace tensorflow { + +using QuantileStream = + boosted_trees::quantiles::WeightedQuantilesStream<float, float>; + +// Quantile Stream Resource for a list of streams sharing the same number of +// quantiles, maximum elements, and epsilon. +class BoostedTreesQuantileStreamResource : public ResourceBase { + public: + BoostedTreesQuantileStreamResource(const float epsilon, + const int64 max_elements, + const int64 num_streams) + : are_buckets_ready_(false), + epsilon_(epsilon), + num_streams_(num_streams), + max_elements_(max_elements) { + streams_.reserve(num_streams_); + boundaries_.reserve(num_streams_); + for (int64 idx = 0; idx < num_streams; ++idx) { + streams_.push_back(QuantileStream(epsilon, max_elements)); + boundaries_.push_back(std::vector<float>()); + } + } + + string DebugString() override { return "QuantileStreamResource"; } + + tensorflow::mutex* mutex() { return &mu_; } + + QuantileStream* stream(const int64 index) { return &streams_[index]; } + + const std::vector<float>& boundaries(const int64 index) { + return boundaries_[index]; + } + + void set_boundaries(const std::vector<float>& boundaries, const int64 index) { + boundaries_[index] = boundaries; + } + + float epsilon() const { return epsilon_; } + int64 num_streams() const { return num_streams_; } + + bool are_buckets_ready() const { return are_buckets_ready_; } + void set_buckets_ready(const bool are_buckets_ready) { + are_buckets_ready_ = are_buckets_ready; + } + + private: + ~BoostedTreesQuantileStreamResource() override {} + + // Mutex for the whole resource. + tensorflow::mutex mu_; + + // Quantile streams. + std::vector<QuantileStream> streams_; + + // Stores the boundaries. Same size as streams_. + std::vector<std::vector<float>> boundaries_; + + // Whether boundaries are created. Initially boundaries are empty until + // set_boundaries are called. + bool are_buckets_ready_; + + const float epsilon_; + const int64 num_streams_; + // An upper-bound for the number of elements. + int64 max_elements_; + + TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_ diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h index a7836896c7..390db8fe5a 100644 --- a/tensorflow/core/kernels/conditional_accumulator.h +++ b/tensorflow/core/kernels/conditional_accumulator.h @@ -51,9 +51,11 @@ class ConditionalAccumulator // dtype: The datatype of the gradients to be accumulated. // shape: The shape of the accumulated gradients. // name: A name to use for the ConditionalAccumulator. + // reduction_type: The reduction type, i.e., MEAN or SUM ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {} + const string& name, const string& reduction_type) + : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name, + reduction_type) {} ~ConditionalAccumulator() override{}; protected: diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc index 90593c56b8..292cf0cd64 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.cc +++ b/tensorflow/core/kernels/conditional_accumulator_base.cc @@ -14,12 +14,17 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/kernels/conditional_accumulator_base.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { ConditionalAccumulatorBase::ConditionalAccumulatorBase( - const DataType& dtype, const PartialTensorShape& shape, const string& name) - : dtype_(dtype), shape_(shape), name_(name) { + const DataType& dtype, const PartialTensorShape& shape, const string& name, + const string& reduction_type) + : dtype_(dtype), + shape_(shape), + name_(name), + reduction_type_(reduction_type) { counter_ = 0; current_global_step_ = 0; } @@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx, current_global_step_++; // Average the accumulated gradient - DivideAccumGradByCounter(ctx); + if (reduction_type_ == "MEAN") { + DivideAccumGradByCounter(ctx); + } // Set output for accumulated gradient tensor bool successful_set_output = SetOutput(ctx); diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h index b7b7482a00..4a5ec6f0fb 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base.h +++ b/tensorflow/core/kernels/conditional_accumulator_base.h @@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase { // name: A name to use for the ConditionalAccumulator. ConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name); + const string& name, const string& reduction_type); typedef AsyncOpKernel::DoneCallback DoneCallback; @@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase { const DataType dtype_; const PartialTensorShape shape_; const string name_; + const string reduction_type_; mutex mu_; int counter_ GUARDED_BY(mu_); int64 current_global_step_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h index 012a0dcc12..ca24d690f8 100644 --- a/tensorflow/core/kernels/conditional_accumulator_base_op.h +++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h @@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel { &accumulator_handle_, nullptr)); OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_)); OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(context, + context->GetAttr("reduction_type", &reduction_type_)); } void Compute(OpKernelContext* ctx) override { @@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel { DataType dtype_; PartialTensorShape shape_; ContainerInfo cinfo_; + string reduction_type_; private: Status SetAccumulatorHandle(OpKernelContext* ctx) diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc index e13bf8a4c6..52ac51a9b6 100644 --- a/tensorflow/core/kernels/conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/conditional_accumulator_op.cc @@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { ConditionalAccumulator<Device, T>* accumulator = - new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name()); + new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(), + reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h index de9b69828e..639c3062cc 100644 --- a/tensorflow/core/kernels/conv_2d.h +++ b/tensorflow/core/kernels/conv_2d.h @@ -137,17 +137,16 @@ struct MatMulConvFunctor { } }; -// Shuffles a filter tensor from: -// [<spatial_dims>, in, out] -// to: -// [out, in, <spatial_dims>] +// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format. +// +// Note: Currently OIHW is the only supported destination format. Support for +// OHWI format will be added in a follow-up change. template <typename Device, typename T, typename IndexType, int NDIMS> struct TransformFilter { - void operator()(const Device& d, + void operator()(const Device& d, FilterTensorFormat dst_filter_format, typename TTypes<T, NDIMS, IndexType>::ConstTensor in, typename TTypes<T, NDIMS, IndexType>::Tensor out) { - // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together - // to speed up the shuffle operation. + // Merge the spatial dimensions together to speed up the shuffle operation. Eigen::DSizes<IndexType, 3> merged_dims; merged_dims[0] = in.dimension(0); // spatial dimensions for (int i = 1; i < NDIMS - 2; ++i) { @@ -156,16 +155,30 @@ struct TransformFilter { merged_dims[1] = in.dimension(NDIMS - 2); // input filters merged_dims[2] = in.dimension(NDIMS - 1); // output filters + CHECK(dst_filter_format == FORMAT_OIHW) + << "Unsupported destination filter format: " + << ToString(dst_filter_format); + // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged + // in the beginning. + Eigen::DSizes<IndexType, 3> shuffling_perm = + Eigen::DSizes<IndexType, 3>(2, 1, 0); + Eigen::DSizes<IndexType, NDIMS> expanded_dims; - expanded_dims[0] = in.dimension(NDIMS - 1); // output filters - expanded_dims[1] = in.dimension(NDIMS - 2); // input filters - for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions - expanded_dims[i + 2] = in.dimension(i); + int out_index = 0; + for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) { + if (shuffling_perm[merged_dim] == 0) { + for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) { + expanded_dims[out_index++] = in.dimension(spatial_dim); + } + } else { + constexpr int kLastSpatialDim = NDIMS - 3; + expanded_dims[out_index++] = + in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]); + } } - out.device(d) = in.reshape(merged_dims) - .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0)) - .reshape(expanded_dims); + out.device(d) = + in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims); } }; @@ -282,7 +295,9 @@ struct SwapDimension0And2InTensor3 { const gtl::ArraySlice<int64>& input_dims, T* out); }; -// Reverses the effect of TransformFilter above. +// Transforms back filter from OIHW to HWOI format to reverse effect of +// TransformFilter above. +// TODO(hinsu): Support reverse transformation from filter format OHWI as well. template <typename Device, typename T, int NDIMS> struct ReverseTransformFilter { void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in, diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h index 02e3655ad1..b819c6f910 100644 --- a/tensorflow/core/kernels/conv_3d.h +++ b/tensorflow/core/kernels/conv_3d.h @@ -19,6 +19,7 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_CONV_3D_H_ #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" #include "tensorflow/core/kernels/eigen_cuboid_convolution.h" namespace tensorflow { @@ -28,6 +29,14 @@ namespace functor { template <typename Device, typename T> struct CuboidConvolution; +// Backward input pass for the cuboid convolution. +template <typename Device, typename T> +struct CuboidConvolutionBackwardInput; + +// Backward filter pass for the cuboid convolution. +template <typename Device, typename T> +struct CuboidConvolutionBackwardFilter; + typedef Eigen::ThreadPoolDevice CPUDevice; template <typename T> @@ -42,6 +51,40 @@ struct CuboidConvolution<CPUDevice, T> { } }; +template <typename T> +struct CuboidConvolutionBackwardInput<CPUDevice, T> { + void operator()(const CPUDevice& d, + typename TTypes<T, 5>::Tensor input_backward, + typename TTypes<T, 5>::ConstTensor filter, + typename TTypes<T, 5>::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, + input_backward.dimension(3), // input_planes + input_backward.dimension(2), // input_rows + input_backward.dimension(1), // input_cols + stride_cols, stride_rows, stride_planes); + } +}; + +template <typename T> +struct CuboidConvolutionBackwardFilter<CPUDevice, T> { + void operator()(const CPUDevice& d, + typename TTypes<T, 5>::Tensor filter_backward, + typename TTypes<T, 5>::ConstTensor input, + typename TTypes<T, 5>::ConstTensor output_backward, + int stride_planes, int stride_rows, int stride_cols) { + // Need to swap the order of plane/row/col strides when calling Eigen. + filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel( + input, output_backward, + filter_backward.dimension(2), // kernel_planes + filter_backward.dimension(1), // kernel_rows + filter_backward.dimension(0), // kernel_cols + stride_cols, stride_rows, stride_planes); + } +}; + } // namespace functor } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc index 63b1bcda43..9e86a16b66 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc @@ -1018,7 +1018,8 @@ namespace functor { extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \ template <> \ void TransformFilter<GPUDevice, T, int, 4>::operator()( \ - const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 4, int>::ConstTensor in, \ typename TTypes<T, 4, int>::Tensor out); \ extern template struct TransformFilter<GPUDevice, T, int, 4>; \ template <> \ diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index d664a11e73..43bb5ea56c 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -901,7 +901,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()( &transformed_filter)); functor::TransformFilter<GPUDevice, T, int, 4>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), + ctx->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 4>()), To32Bit(transformed_filter.tensor<T, 4>())); Tensor transformed_out_backprop; @@ -1090,7 +1091,8 @@ namespace functor { extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \ template <> \ void TransformFilter<GPUDevice, T, int, 4>::operator()( \ - const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 4, int>::ConstTensor in, \ typename TTypes<T, 4, int>::Tensor out); \ extern template struct TransformFilter<GPUDevice, T, int, 4>; \ template <> \ diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc index fc0a2f123f..507720c998 100644 --- a/tensorflow/core/kernels/conv_grad_ops.cc +++ b/tensorflow/core/kernels/conv_grad_ops.cc @@ -41,6 +41,17 @@ limitations under the License. namespace tensorflow { +// Compute padding for the given spatial dimension. +int ConvBackpropDimensions::SpatialPadding(const Padding& padding, + int dim) const { + return (padding == VALID) + ? 0 + : std::max<int>( + 0, static_cast<int>((output_size(dim) - 1) * stride(dim) + + (filter_size(dim) - 1) * dilation(dim) + + 1 - input_size(dim))); +} + // The V2 version computes windowed output size with arbitrary dilation_rate, // while the original version only handles the cases where dilation_rates equal // to 1. diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h index 535586d53a..9551959463 100644 --- a/tensorflow/core/kernels/conv_grad_ops.h +++ b/tensorflow/core/kernels/conv_grad_ops.h @@ -234,6 +234,16 @@ struct ConvBackpropDimensions { // Input and output feature depth. int64 in_depth, out_depth; + + // Convenience access methods for spatial dimensions properties. + int64 input_size(int dim) const { return spatial_dims[dim].input_size; } + int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; } + int64 output_size(int dim) const { return spatial_dims[dim].output_size; } + int64 stride(int dim) const { return spatial_dims[dim].stride; } + int64 dilation(int dim) const { return spatial_dims[dim].dilation; } + + // Compute padding for the given spatial dimension. + int SpatialPadding(const Padding& padding, int dim) const; }; // Common code between implementations of Conv?DBackpropInput and diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc index 15f1bf9aba..bab91f5e86 100644 --- a/tensorflow/core/kernels/conv_grad_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/kernels/conv_2d.h" +#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/conv_ops_gpu.h" #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/core/errors.h" @@ -32,111 +33,130 @@ limitations under the License. #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/use_cudnn.h" +#include "tensorflow/core/util/work_sharder.h" #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" using stream_executor::dnn::DimIndex; #endif +namespace { + +// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and +// conv_grad_input_ops_3d.cc. + +// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels. + +// "Depth" is already used for the channel dimension, so for the third spatial +// dimension in this file we use "plane", although in NDHWC layout it's +// indicated with a "D". + +// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage +// order (planes, height, width, depth), constructed from patches in 'col_data', +// which is required to be in storage order (out_planes * out_height * +// out_width, filter_planes, filter_height, filter_width, in_depth). +// +// Based on 2-dimensional implementation written by Yangqing Jia (jiayq). +template <typename T> +void Col2im(const T* col_data, const int depth, const int planes, + const int height, const int width, const int filter_p, + const int filter_h, const int filter_w, const int pad_pt, + const int pad_t, const int pad_l, const int pad_pb, const int pad_b, + const int pad_r, const int stride_p, const int stride_h, + const int stride_w, T* im_data) { + const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; + const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + int p_pad = -pad_pt; + for (int p = 0; p < planes_col; ++p) { + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + T* im_patch_data = + im_data + (p_pad * height * width + h_pad * width + w_pad) * depth; + for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && + iw < width) { + for (int i = 0; i < depth; ++i) { + im_patch_data[i] += col_data[i]; + } + } + im_patch_data += depth; + col_data += depth; + } + // Jump over remaining number of depth. + im_patch_data += depth * (width - filter_w); + } + // Jump over remaining number of (depth * width). + im_patch_data += (depth * width) * (height - filter_h); + } + w_pad += stride_w; + } + h_pad += stride_h; + } + p_pad += stride_p; + } +} + +// Returns in 'col_data', image patches in storage order (planes, height, width, +// depth) extracted from image at 'input_data', which is required to be in +// storage order (batch, planes, height, width, depth). +// +// Based on 2-dimensional implementation written by Yangqing Jia (jiayq). +template <typename T> +void Im2col(const T* input_data, const int depth, const int planes, + const int height, const int width, const int filter_p, + const int filter_h, const int filter_w, const int pad_pt, + const int pad_t, const int pad_l, const int pad_pb, const int pad_b, + const int pad_r, const int stride_p, const int stride_h, + const int stride_w, T* col_data) { + const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1; + const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1; + const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1; + + int p_pad = -pad_pt; + for (int p = 0; p < planes_col; ++p) { + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + for (int ip = p_pad; ip < p_pad + filter_p; ++ip) { + for (int ih = h_pad; ih < h_pad + filter_h; ++ih) { + for (int iw = w_pad; iw < w_pad + filter_w; ++iw) { + if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 && + iw < width) { + memcpy(col_data, + input_data + + (ip * height * width + ih * width + iw) * depth, + sizeof(T) * depth); + } else { + // This should be simply padded with zero. + memset(col_data, 0, sizeof(T) * depth); + } + col_data += depth; + } + } + } + w_pad += stride_w; + } + h_pad += stride_h; + } + p_pad += stride_p; + } +} + +} // namespace + namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -// TODO(mjanusz): Get rid of the macro and return shapes directly. -#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \ - const Tensor& out_backprop = context->input(2); \ - OP_REQUIRES( \ - context, input_shape.dims() == 5, \ - errors::InvalidArgument(label, ": input must be 5-dimensional")); \ - OP_REQUIRES( \ - context, filter_shape.dims() == 5, \ - errors::InvalidArgument(label, ": filter must be 5-dimensional")); \ - OP_REQUIRES( \ - context, out_backprop.dims() == 5, \ - errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \ - const int64 batch = input_shape.dim_size(0); \ - OP_REQUIRES( \ - context, batch == out_backprop.dim_size(0), \ - errors::InvalidArgument( \ - label, ": input and out_backprop must have the same batch size")); \ - const std::array<int64, 3> input_size = { \ - {GetTensorDim(input_shape, data_format_, '0'), \ - GetTensorDim(input_shape, data_format_, '1'), \ - GetTensorDim(input_shape, data_format_, '2')}}; \ - const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \ - const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \ - filter_shape.dim_size(1), \ - filter_shape.dim_size(2)}}; \ - const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \ - const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \ - const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \ - OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \ - errors::InvalidArgument( \ - label, ": input and filter must have the same depth")); \ - const int64 out_depth = filter_shape.dim_size(4); \ - OP_REQUIRES( \ - context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \ - errors::InvalidArgument( \ - label, ": filter and out_backprop must have the same out_depth")); \ - const std::array<int64, 3> dilations = { \ - {GetTensorDim(dilation_, data_format_, '0'), \ - GetTensorDim(dilation_, data_format_, '1'), \ - GetTensorDim(dilation_, data_format_, '2')}}; \ - const std::array<int64, 3> strides = { \ - {GetTensorDim(stride_, data_format_, '0'), \ - GetTensorDim(stride_, data_format_, '1'), \ - GetTensorDim(stride_, data_format_, '2')}}; \ - std::array<int64, 3> out, padding; \ - OP_REQUIRES_OK( \ - context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \ - padding_, &out, &padding)); \ - OP_REQUIRES(context, output_planes == out[0], \ - errors::InvalidArgument( \ - label, \ - ": Number of planes of out_backprop doesn't match " \ - "computed: actual = ", \ - output_planes, ", computed = ", out[0])); \ - OP_REQUIRES( \ - context, output_rows == out[1], \ - errors::InvalidArgument( \ - label, ": Number of rows of out_backprop doesn't match computed: ", \ - "actual = ", output_rows, ", computed = ", out[1])); \ - OP_REQUIRES( \ - context, output_cols == out[2], \ - errors::InvalidArgument( \ - label, ": Number of cols of out_backprop doesn't match computed: ", \ - "actual = ", output_cols, ", computed = ", out[2])); \ - const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \ - const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \ - const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \ - const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \ - const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \ - const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \ - const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \ - const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \ - const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \ - const auto bottom_pad_planes = \ - padded_out_planes - expanded_out_planes - top_pad_planes; \ - const auto bottom_pad_rows = \ - padded_out_rows - expanded_out_rows - top_pad_rows; \ - const auto right_pad_cols = \ - padded_out_cols - expanded_out_cols - left_pad_cols; \ - VLOG(2) << "Conv3d: " << label \ - << ": expanded_out_planes = " << expanded_out_planes \ - << ": expanded_out_rows = " << expanded_out_rows \ - << ", expanded_out_cols = " << expanded_out_cols \ - << ", padded_out_planes = " << padded_out_planes \ - << ", padded_out_rows = " << padded_out_rows \ - << ", padded_out_cols = " << padded_out_cols \ - << ", top_pad_planes = " << top_pad_planes \ - << ", top_pad_rows = " << top_pad_rows \ - << ", left_pad_cols = " << left_pad_cols \ - << ", bottom_pad_planes = " << bottom_pad_planes \ - << ", bottom_pad_rows = " << bottom_pad_rows \ - << ", right_pad_cols = " << right_pad_cols - -// Backprop for input. +// Backprop for input that offloads computation to +// Eigen::CuboidConvolutionBackwardInput. template <typename Device, class T> class Conv3DBackpropInputOp : public OpKernel { public: @@ -192,6 +212,116 @@ class Conv3DBackpropInputOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& filter = context->input(1); const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape input_shape; + if (takes_shape_) { + const Tensor& input_sizes = context->input(0); + // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes. + OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape)); + } else { + input_shape = context->input(0).shape(); + } + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, + stride_, padding_, data_format_, &dims)); + + Tensor* in_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, input_shape, &in_backprop)); + + functor::CuboidConvolutionBackwardInput<Device, T>()( + context->eigen_device<Device>(), + in_backprop->tensor<T, 5>(), // input_backward + filter.tensor<T, 5>(), // filter + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + } + + private: + std::vector<int32> dilation_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp); +}; + +// Custom backprop for input that explicitly does the work sharding and calls +// Eigen only to multiply matrices. +template <typename Device, class T> +class Conv3DCustomBackpropInputOp : public OpKernel { + // Limit the maximum size of allocated temporary buffer to + // kMaxTempAllocationOverhead times the size of the input tensors (input, + // filter, out_backprop). If the size of the temporary buffer exceeds this + // limit, fallback on Eigen implementation. + static constexpr int kMaxTempAllocationOverhead = 25; + + public: + explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU.")); + } + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + + // TODO(yangzihao): Add CPU version of dilated conv 3D. + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, '0') == 1 && + GetTensorDim(dilation_, data_format_, '1') == 1 && + GetTensorDim(dilation_, data_format_, '2') == 1), + errors::InvalidArgument( + "Current CPU implementation does not yet support " + "dilation rates larger than 1.")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& filter = context->input(1); + const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); @@ -200,51 +330,239 @@ class Conv3DBackpropInputOp : public OpKernel { } else { input_shape = context->input(0).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, ConvBackpropComputeDimensions( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, + stride_, padding_, data_format_, &dims)); + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); - // Fill out a padded out_backprop. - TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows, - padded_out_cols, out_depth}); - Tensor padded_output; + int64 top_pad_planes, bottom_pad_planes; + int64 top_pad_rows, bottom_pad_rows; + int64 left_pad_cols, right_pad_cols; + + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, + &top_pad_planes, &bottom_pad_planes)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, + dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, + &top_pad_rows, &bottom_pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, + dims.spatial_dims[2].filter_size, + dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, + &left_pad_cols, &right_pad_cols)); + + // TODO(ezhulenev): Extract work size and shard estimation to shared + // functions in conv_grad_ops, and update 2d convolution backprop. + + // The total dimension size of each kernel. + const int64 filter_total_size = + dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * + dims.spatial_dims[2].filter_size * dims.in_depth; + + // The output image size is the spatial size of the output. + const int64 output_image_size = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size; + + const auto cache_sizes = Eigen::internal::CacheSizes(); + const ptrdiff_t l3_cache_size = cache_sizes.m_l3; + + // Use L3 cache size as target working set size. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + // Calculate size of matrices involved in MatMul: C = A x B. + const int64 size_A = output_image_size * dims.out_depth; + + const int64 size_B = filter_total_size * dims.out_depth; + + const int64 size_C = output_image_size * filter_total_size; + + const int64 work_unit_size = size_A + size_B + size_C; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + // Use parallel tensor contractions if there is no batching. + // + // Compared to Conv2D code, this version is missing work size estimation. In + // benchmarks I didn't find a case when it's beneficial to run parallel + // contraction compared to sharding and matmuls. + const bool use_parallel_contraction = dims.batch_size == 1; + + const size_t shard_size = + use_parallel_contraction + ? 1 + : (target_working_set_size + work_unit_size - 1) / work_unit_size; + + // Total number of elements in all the tensors used by this kernel. + int64 total_tensor_elements = input_shape.num_elements() + + filter_shape.num_elements() + + out_backprop_shape.num_elements(); + + // Shape of the temporary workspace buffer. + TensorShape col_buffer_shape = {static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}; + int64 col_buffer_elements = col_buffer_shape.num_elements(); + + // If the temporary allocation overhead is too large, fallback on Eigen + // implementation which requires much less memory. + int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; + if (col_buffer_overhead > kMaxTempAllocationOverhead) { + VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: " + "col_buffer_overhead=" + << col_buffer_overhead; + + functor::CuboidConvolutionBackwardInput<Device, T>()( + context->eigen_device<Device>(), + in_backprop->tensor<T, 5>(), // input_backward + filter.tensor<T, 5>(), // filter + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + + return; + } + + Tensor col_buffer; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4}; - Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), - eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>()); - const Tensor& padded_output_cref = padded_output; - - // Fill a new "reverted" filter. We need to transpose the in_depth and - // out_depth for the filter and reverse the planes, rows and cols. - TensorShape r_filter_shape( - {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth}); - Tensor r_filter; - OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(), - r_filter_shape, &r_filter)); - Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3}; - Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order, - filter_rev_dims, r_filter.tensor<T, 5>()); - const Tensor& r_filter_cref = r_filter; - - // Now we can call conv_3d directly. - functor::CuboidConvolution<Device, T>()( - context->eigen_device<Device>(), in_backprop->tensor<T, 5>(), - padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); + context->allocate_temp(DataTypeToEnum<T>::value, + col_buffer_shape, &col_buffer)); + + // The input offset corresponding to a single input image. + const int64 input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * + dims.spatial_dims[2].input_size * dims.in_depth; + + // The output offset corresponding to a single output image. + const int64 output_offset = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size * dims.out_depth; + + const T* filter_data = filter.template flat<T>().data(); + T* col_buffer_data = col_buffer.template flat<T>().data(); + const T* out_backprop_data = out_backprop.template flat<T>().data(); + + auto in_backprop_flat = in_backprop->template flat<T>(); + T* input_backprop_data = in_backprop_flat.data(); + in_backprop_flat.device(context->eigen_device<Device>()) = + in_backprop_flat.constant(T(0)); + + if (use_parallel_contraction) { + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + ConstTensorMap; + + // Initialize contraction dims (we need to transpose 'B' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 1; + contract_dims[0].second = 1; + + for (int image_id = 0; image_id < dims.batch_size; ++image_id) { + // Compute gradient into col_buffer. + TensorMap C(col_buffer_data, output_image_size, filter_total_size); + + ConstTensorMap A(out_backprop_data + output_offset * image_id, + output_image_size, dims.out_depth); + ConstTensorMap B(filter_data, filter_total_size, dims.out_depth); + + C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims); + + Col2im<T>(col_buffer_data, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + input_backprop_data); + + input_backprop_data += input_offset; + } + } else { + typedef Eigen::Map< + Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>> + MatrixMap; + typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, + Eigen::RowMajor>> + ConstMatrixMap; + + for (int image_id = 0; image_id < dims.batch_size; + image_id += shard_size) { + const int shard_limit = + std::min(static_cast<int>(shard_size), + static_cast<int>(dims.batch_size) - image_id); + + auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols, + &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols, + &output_image_size, &filter_total_size, + &input_backprop_data, &col_buffer_data, + &out_backprop_data, &filter_data, &input_offset, + &output_offset, &size_C](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + T* im2col_buf = col_buffer_data + shard_id * size_C; + T* input_data = input_backprop_data + shard_id * input_offset; + const T* out_data = out_backprop_data + shard_id * output_offset; + + // Compute gradient into 'im2col_buf'. + MatrixMap C(im2col_buf, output_image_size, filter_total_size); + + ConstMatrixMap A(out_data, output_image_size, dims.out_depth); + ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth); + + C.noalias() = A * B.transpose(); + + Col2im<T>(im2col_buf, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + input_data); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + work_unit_size, shard); + + input_backprop_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } + } } private: @@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel { Padding padding_; TensorFormat data_format_; bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp); }; +// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than +// default Eigen implementation (at the cost of ~2x-8x peak memory usage). + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropInputOp<CPUDevice, T>); \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropInputOp<CPUDevice, T>); + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropInputOp<CPUDevice, T>); + TF_CALL_half(REGISTER_CPU_KERNEL); TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL -// Backprop for filter. +// Backprop for filter that offloads computation to +// Eigen::CuboidConvolutionBackwardFilter. template <typename Device, class T> class Conv3DBackpropFilterOp : public OpKernel { public: @@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); - TensorShape filter_shape; + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( @@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel { filter_shape = context->input(1).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); - Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{ - {0, 0}, - {top_pad_planes, bottom_pad_planes}, - {top_pad_rows, bottom_pad_rows}, - {left_pad_cols, right_pad_cols}, - {0, 0}}; + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensions( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); + Tensor* filter_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, filter_shape, &filter_backprop)); @@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel { return; } - // For the backprop of the filter, we need to also transpose the - // out_backprop. - // The shape of backprop is - // [batch, out_z, out_y, out_x, out_depth] - // And we need to change it to - // [out_depth, out_x, out_y, out_z, batch] - Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0}; - TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows, - padded_out_cols, batch}); - Tensor padded_output; + functor::CuboidConvolutionBackwardFilter<Device, T>()( + context->eigen_device<Device>(), + filter_backprop->tensor<T, 5>(), // filter_backward + input.tensor<T, 5>(), // input + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + } + + private: + std::vector<int32> dilation_; + std::vector<int32> stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp); +}; + +// Custom backprop for filter that explicitly does the work sharding and calls +// Eigen only to multiply matrices. +template <typename Device, class T> +class Conv3DCustomBackpropFilterOp : public OpKernel { + // Limit the maximum size of allocated temporary buffer to + // kMaxTempAllocationOverhead times the size of the input tensors (input, + // filter, out_backprop). If the size of the temporary buffer exceeds this + // limit, fallback on Eigen implementation. + static constexpr int kMaxTempAllocationOverhead = 25; + + public: + explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { + string data_format; + OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); + OP_REQUIRES(context, FormatFromString(data_format, &data_format_), + errors::InvalidArgument("Invalid data format")); + OP_REQUIRES( + context, data_format_ == FORMAT_NHWC, + errors::InvalidArgument( + "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU.")); + } + + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + + // TODO(yangzihao): Add CPU version of dilated conv 3D. + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, '0') == 1 && + GetTensorDim(dilation_, data_format_, '1') == 1 && + GetTensorDim(dilation_, data_format_, '2') == 1), + errors::InvalidArgument( + "Current CPU implementation does not yet support " + "dilation rates larger than 1.")); + + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const TensorShape& input_shape = input.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + + TensorShape filter_shape; + if (takes_shape_) { + const Tensor& filter_sizes = context->input(1); + OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape( + filter_sizes.vec<int32>(), &filter_shape)); + } else { + filter_shape = context->input(1).shape(); + } + + ConvBackpropDimensions dims; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - padded_out_shape, &padded_output)); - Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1], - strides[2], 1}; - functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), out_backprop.tensor<T, 5>(), - eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>()); - const Tensor& padded_output_cref = padded_output; - - // For the backprop of the filter, we need to transpose the input. - // The shape of input is - // [batch, in_z, in_y, in_x, in_depth] - // And we need to change it to - // [in_z, in_y, in_x, batch, in_depth] - Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4}; - TensorShape in_shuffle_shape( - {input_size[0], input_size[1], input_size[2], batch, in_depth}); - Tensor in_shuffle; + ConvBackpropComputeDimensions( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, stride_, + padding_, data_format_, &dims)); + + Tensor* filter_backprop; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::v(), - in_shuffle_shape, &in_shuffle)); - // No need for reversing this time. - Eigen::array<bool, 5> no_reverse{false, false, false, false, false}; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), input.tensor<T, 5>(), in_order, - no_reverse, in_shuffle.tensor<T, 5>()); - const Tensor& in_shuffle_cref = in_shuffle; - - // The output of the conv_3d would be - // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth] - // and we need to shuffle it back to - // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth]; - // And we need to reverse the filter backprops. - // So we need to allocate (sigh) yet another piece of memory to hold the - // output. - TensorShape filter_shuffle_shape( - {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth}); - Tensor filter_shuffle; - OP_REQUIRES_OK( - context, context->allocate_temp(DataTypeToEnum<T>::v(), - filter_shuffle_shape, &filter_shuffle)); - functor::CuboidConvolution<Device, T>()( - context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(), - padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1, - 1, BrainPadding2EigenPadding(VALID)); - - // Now copy the filter_backprop back to the destination. - Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0}; - Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false}; - const Tensor& filter_shuffle_cref = filter_shuffle; - functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()( - context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(), - filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>()); + context->allocate_output(0, filter_shape, &filter_backprop)); + + if (input_shape.num_elements() == 0) { + filter_backprop->template flat<T>().setZero(); + return; + } + + int64 top_pad_planes, bottom_pad_planes; + int64 top_pad_rows, bottom_pad_rows; + int64 left_pad_cols, right_pad_cols; + + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[0].input_size, + dims.spatial_dims[0].filter_size, + dims.spatial_dims[0].stride, padding_, + &dims.spatial_dims[0].output_size, + &top_pad_planes, &bottom_pad_planes)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[1].input_size, + dims.spatial_dims[1].filter_size, + dims.spatial_dims[1].stride, padding_, + &dims.spatial_dims[1].output_size, + &top_pad_rows, &bottom_pad_rows)); + OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose( + dims.spatial_dims[2].input_size, + dims.spatial_dims[2].filter_size, + dims.spatial_dims[2].stride, padding_, + &dims.spatial_dims[2].output_size, + &left_pad_cols, &right_pad_cols)); + + // TODO(ezhulenev): Extract work size and shard estimation to shared + // functions in conv_grad_ops, and update 2d convolution backprop. + + // The total dimension size of each kernel. + const int64 filter_total_size = + dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size * + dims.spatial_dims[2].filter_size * dims.in_depth; + // The output image size is the spatial size of the output. + const int64 output_image_size = dims.spatial_dims[0].output_size * + dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size; + + // Shard 'batch' images (volumes) into 'shard_size' groups of images + // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by + // dividing the L3 cache size ('target_working_set_size') by the matmul size + // of an individual image ('work_unit_size'). + + const auto cache_sizes = Eigen::internal::CacheSizes(); + const ptrdiff_t l3_cache_size = cache_sizes.m_l3; + + // TODO(andydavis) + // *) Consider reducing 'target_working_set_size' if L3 is shared by + // other concurrently running tensorflow ops. + const size_t target_working_set_size = l3_cache_size / sizeof(T); + + const int64 size_A = output_image_size * filter_total_size; + + const int64 size_B = output_image_size * dims.out_depth; + + const int64 size_C = filter_total_size * dims.out_depth; + + const int64 work_unit_size = size_A + size_B + size_C; + + const size_t shard_size = + (target_working_set_size + work_unit_size - 1) / work_unit_size; + + // Total number of elements in all the tensors used by this kernel. + int64 total_tensor_elements = input_shape.num_elements() + + filter_shape.num_elements() + + out_backprop_shape.num_elements(); + + // Shape of the temporary workspace buffer. + TensorShape col_buffer_shape = {static_cast<int64>(shard_size), + static_cast<int64>(output_image_size), + static_cast<int64>(filter_total_size)}; + int64 col_buffer_elements = col_buffer_shape.num_elements(); + + // If the temporary allocation overhead is too large, fallback on Eigen + // implementation which requires much less memory. + int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements; + if (col_buffer_overhead > kMaxTempAllocationOverhead) { + VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: " + "col_buffer_overhead=" + << col_buffer_overhead; + + functor::CuboidConvolutionBackwardFilter<Device, T>()( + context->eigen_device<Device>(), + filter_backprop->tensor<T, 5>(), // filter_backward + input.tensor<T, 5>(), // input + out_backprop.tensor<T, 5>(), // output_backward + static_cast<int>(dims.spatial_dims[0].stride), // stride_planes + static_cast<int>(dims.spatial_dims[1].stride), // stride_rows + static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols + + return; + } + + Tensor col_buffer; + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum<T>::value, + col_buffer_shape, &col_buffer)); + + // The input offset corresponding to a single input image. + const int64 input_offset = dims.spatial_dims[0].input_size * + dims.spatial_dims[1].input_size * + dims.spatial_dims[2].input_size * dims.in_depth; + // The output offset corresponding to a single output image. + const int64 output_offset = + dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size * + dims.spatial_dims[2].output_size * dims.out_depth; + + const T* input_data = input.template flat<T>().data(); + T* col_buffer_data = col_buffer.template flat<T>().data(); + const T* out_backprop_data = out_backprop.template flat<T>().data(); + T* filter_backprop_data = filter_backprop->template flat<T>().data(); + + typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + TensorMap; + typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>, + Eigen::Unaligned> + ConstTensorMap; + + TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth); + C.setZero(); + + // Initialize contraction dims (we need to transpose 'A' below). + Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims; + contract_dims[0].first = 0; + contract_dims[0].second = 0; + + auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); + + for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) { + const int shard_limit = + std::min(static_cast<int>(shard_size), + static_cast<int>(dims.batch_size) - image_id); + + auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes, + &top_pad_rows, &left_pad_cols, &bottom_pad_planes, + &bottom_pad_rows, &right_pad_cols, &input_offset, + &size_A](int64 start, int64 limit) { + for (int shard_id = start; shard_id < limit; ++shard_id) { + const T* input_data_shard = input_data + shard_id * input_offset; + T* col_data_shard = col_buffer_data + shard_id * size_A; + + // When we compute the gradient with respect to the filters, we need + // to do im2col to allow gemm-type computation. + Im2col<T>(input_data_shard, dims.in_depth, + // Input spatial dimensions. + dims.spatial_dims[0].input_size, // input planes + dims.spatial_dims[1].input_size, // input rows + dims.spatial_dims[2].input_size, // input cols + // Filter spatial dimensions. + dims.spatial_dims[0].filter_size, // filter planes + dims.spatial_dims[1].filter_size, // filter rows + dims.spatial_dims[2].filter_size, // filter cols + // Spatial padding. + top_pad_planes, top_pad_rows, left_pad_cols, + bottom_pad_planes, bottom_pad_rows, right_pad_cols, + // Spatial striding. + dims.spatial_dims[0].stride, // stride planes + dims.spatial_dims[1].stride, // stride rows + dims.spatial_dims[2].stride, // stride cols + col_data_shard); + } + }; + Shard(worker_threads.num_threads, worker_threads.workers, shard_limit, + size_A, shard); + + ConstTensorMap A(col_buffer_data, output_image_size * shard_limit, + filter_total_size); + ConstTensorMap B(out_backprop_data, output_image_size * shard_limit, + dims.out_depth); + + // Gradient with respect to filter. + C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims); + + input_data += input_offset * shard_limit; + out_backprop_data += output_offset * shard_limit; + } } private: @@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel { Padding padding_; TensorFormat data_format_; bool takes_shape_; + + TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp); }; +// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than +// default Eigen implementation (at the cost of ~2x-8x peak memory usage). + #define REGISTER_CPU_KERNEL(T) \ REGISTER_KERNEL_BUILDER( \ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ - Conv3DBackpropFilterOp<CPUDevice, T>); \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .Label("custom") \ + .TypeConstraint<T>("T"), \ + Conv3DCustomBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \ + .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); \ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ .Device(DEVICE_CPU) \ + .Label("eigen_tensor") \ .TypeConstraint<T>("T"), \ Conv3DBackpropFilterOp<CPUDevice, T>); -TF_CALL_half(REGISTER_CPU_KERNEL); + TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); #undef REGISTER_CPU_KERNEL +// WARNING: Eigen::half is not trivially copyable and can't be used in +// custom backprop filter kernel because of memcpy and memset in Im2col. +#define REGISTER_CPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); \ + REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<T>("T"), \ + Conv3DBackpropFilterOp<CPUDevice, T>); + +TF_CALL_half(REGISTER_CPU_KERNEL); +#undef REGISTER_CPU_KERNEL + // GPU definitions of both ops. #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. @@ -445,7 +1054,8 @@ namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void TransformFilter<GPUDevice, T, int, 5>::operator()( \ - const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 5, int>::ConstTensor in, \ typename TTypes<T, 5, int>::Tensor out); \ template <> \ void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ @@ -523,6 +1133,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& filter = context->input(1); const TensorShape& filter_shape = filter.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape input_shape; if (takes_shape_) { const Tensor& input_sizes = context->input(0); @@ -531,7 +1145,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { } else { input_shape = context->input(0).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput"); + + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv3DBackpropInputOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, dilation_, + stride_, padding_, data_format_, &dims)); + Tensor* in_backprop; OP_REQUIRES_OK(context, context->allocate_output(0, input_shape, &in_backprop)); @@ -539,13 +1160,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 && - dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 && - stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 && + if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 && + dims.filter_size(2) == 1 && dims.dilation(0) == 1 && + dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 && + dims.stride(1) == 1 && dims.stride(2) == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = batch * input_size[0] * input_size[1] * input_size[2]; - const uint64 k = out_depth; - const uint64 n = in_depth; + const uint64 m = dims.batch_size * dims.input_size(0) * + dims.input_size(1) * dims.input_size(2); + const uint64 k = dims.out_depth; + const uint64 n = dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), out_backprop.template flat<T>().size()); @@ -567,13 +1190,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (filter_size[0] == input_size[0] && - filter_size[1] == input_size[1] && - filter_size[2] == input_size[2] && padding_ == Padding::VALID && - data_format_ == FORMAT_NHWC) { - const uint64 m = batch; - const uint64 k = out_depth; - const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth; + } else if (dims.filter_size(0) == dims.input_size(0) && + dims.filter_size(1) == dims.input_size(1) && + dims.filter_size(2) == dims.input_size(2) && + padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { + const uint64 m = dims.batch_size; + const uint64 k = dims.out_depth; + const uint64 n = dims.input_size(0) * dims.input_size(1) * + dims.input_size(2) * dims.in_depth; auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(), out_backprop.template flat<T>().size()); @@ -597,65 +1221,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { return; } - int padding_rows = 0, padding_cols = 0, padding_planes = 0; - - if (padding_ == Padding::SAME) { - padding_planes = std::max<int>( - 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); - padding_cols = std::max<int>( - 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); - padding_rows = std::max<int>( - 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); - } + int padding_planes = dims.SpatialPadding(padding_, 0); + int padding_rows = dims.SpatialPadding(padding_, 1); + int padding_cols = dims.SpatialPadding(padding_, 2); + const bool planes_odd = (padding_planes % 2 != 0); const bool rows_odd = (padding_rows % 2 != 0); const bool cols_odd = (padding_cols % 2 != 0); - const bool planes_odd = (padding_planes % 2 != 0); TensorShape compatible_input_shape; if (rows_odd || cols_odd || planes_odd) { // cuDNN only supports the same amount of padding on both sides. compatible_input_shape = { - batch, - in_depth, - input_size[0] + planes_odd, - input_size[1] + rows_odd, - input_size[2] + cols_odd, + dims.batch_size, + dims.in_depth, + dims.input_size(0) + planes_odd, + dims.input_size(1) + rows_odd, + dims.input_size(2) + cols_odd, }; } else { - compatible_input_shape = {batch, in_depth, input_size[0], input_size[1], - input_size[2]}; + compatible_input_shape = {dims.batch_size, dims.in_depth, + dims.input_size(0), dims.input_size(1), + dims.input_size(2)}; } CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0) << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4)) .set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3)) .set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2)) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(batch) - .set_spatial_dim(DimIndex::X, output_cols) - .set_spatial_dim(DimIndex::Y, output_rows) - .set_spatial_dim(DimIndex::Z, output_planes) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) - .set_spatial_dim(DimIndex::Y, filter_size[1]) - .set_spatial_dim(DimIndex::Z, filter_size[0]) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) - .set_dilation_rate(DimIndex::Y, dilations[1]) - .set_dilation_rate(DimIndex::Z, dilations[0]) - .set_filter_stride(DimIndex::X, strides[2]) - .set_filter_stride(DimIndex::Y, strides[1]) - .set_filter_stride(DimIndex::Z, strides[0]) + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2); @@ -664,20 +1282,23 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { Tensor transformed_filter; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({out_depth, in_depth, filter_size[0], - filter_size[1], filter_size[2]}), - &transformed_filter)); + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &transformed_filter)); functor::TransformFilter<GPUDevice, T, int, 5>()( - context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), + context->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 5>()), To32Bit(transformed_filter.tensor<T, 5>())); // Shape: batch, filters, z, y, x. Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, - output_cols}; - if (out_depth > 1) { + TensorShape nchw_shape = {dims.batch_size, dims.out_depth, + dims.output_size(0), dims.output_size(1), + dims.output_size(2)}; + if (dims.out_depth > 1) { OP_REQUIRES_OK(context, context->allocate_temp( DataTypeToEnum<T>::value, nchw_shape, &transformed_out_backprop)); @@ -713,14 +1334,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { const int device_id = stream->parent()->device_ordinal(); DataType dtype = context->input(0).dtype(); const ConvParameters conv_parameters = { - batch, - in_depth, - {{input_size[0], input_size[1], input_size[2]}}, + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, FORMAT_NCHW, - out_depth, - {{filter_size[0], filter_size[1], filter_size[2]}}, - {{dilations[0], dilations[1], dilations[2]}}, - {{strides[0], strides[1], strides[2]}}, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, {{padding_planes, padding_rows, padding_cols}}, dtype, device_id, @@ -799,10 +1420,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel { if (rows_odd || cols_odd || planes_odd) { Tensor in_backprop_remove_padding; OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum<T>::value, - {batch, in_depth, input_size[0], - input_size[1], input_size[2]}, - &in_backprop_remove_padding)); + context->allocate_temp( + DataTypeToEnum<T>::value, + {dims.batch_size, dims.in_depth, dims.input_size(0), + dims.input_size(1), dims.input_size(2)}, + &in_backprop_remove_padding)); // Remove the padding for odd spatial dimensions. functor::PadInput<GPUDevice, T, int, 5>()( @@ -896,6 +1518,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { void Compute(OpKernelContext* context) override { const Tensor& input = context->input(0); const TensorShape& input_shape = input.shape(); + + const Tensor& out_backprop = context->input(2); + const TensorShape& out_backprop_shape = out_backprop.shape(); + TensorShape filter_shape; if (takes_shape_) { const Tensor& filter_sizes = context->input(1); @@ -905,7 +1531,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { filter_shape = context->input(1).shape(); } - EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter"); + ConvBackpropDimensions dims; + OP_REQUIRES_OK(context, + ConvBackpropComputeDimensionsV2( + "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3, + input_shape, filter_shape, out_backprop_shape, dilation_, + stride_, padding_, data_format_, &dims)); Tensor* filter_backprop; OP_REQUIRES_OK(context, @@ -914,13 +1545,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { auto* stream = context->op_device_context()->stream(); OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); - if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 && - dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 && - strides[2] == 1 && strides[1] == 1 && strides[0] == 1 && + if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 && + dims.filter_size(0) == 1 && dims.dilation(2) == 1 && + dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 && + dims.stride(1) == 1 && dims.stride(0) == 1 && data_format_ == FORMAT_NHWC) { - const uint64 m = in_depth; - const uint64 k = batch * input_size[1] * input_size[2] * input_size[0]; - const uint64 n = out_depth; + const uint64 m = dims.in_depth; + const uint64 k = dims.batch_size * dims.input_size(1) * + dims.input_size(2) * dims.input_size(0); + const uint64 n = dims.out_depth; // The shape of output backprop is // [batch, out_z, out_y, out_x, out_depth] @@ -951,13 +1584,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { ", n=", n, ", k=", k)); } return; - } else if (filter_size[0] == input_size[0] && - filter_size[1] == input_size[1] && - filter_size[2] == input_size[2] && padding_ == Padding::VALID && - data_format_ == FORMAT_NHWC) { - const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth; - const uint64 k = batch; - const uint64 n = out_depth; + } else if (dims.filter_size(0) == dims.input_size(0) && + dims.filter_size(1) == dims.input_size(1) && + dims.filter_size(2) == dims.input_size(2) && + padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) { + const uint64 m = dims.input_size(0) * dims.input_size(1) * + dims.input_size(2) * dims.in_depth; + const uint64 k = dims.batch_size; + const uint64 n = dims.out_depth; auto a_ptr = AsDeviceMemory(input.template flat<T>().data(), input.template flat<T>().size()); @@ -979,30 +1613,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { return; } - int padding_rows = 0, padding_cols = 0, padding_planes = 0; - - if (padding_ == Padding::SAME) { - padding_planes = std::max<int>( - 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]); - padding_cols = std::max<int>( - 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]); - padding_rows = std::max<int>( - 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]); - } - bool rows_odd = (padding_rows % 2 != 0); - bool cols_odd = (padding_cols % 2 != 0); - bool planes_odd = (padding_planes % 2 != 0); + int padding_planes = dims.SpatialPadding(padding_, 0); + int padding_rows = dims.SpatialPadding(padding_, 1); + int padding_cols = dims.SpatialPadding(padding_, 2); + const bool planes_odd = (padding_planes % 2 != 0); + const bool rows_odd = (padding_rows % 2 != 0); + const bool cols_odd = (padding_cols % 2 != 0); Tensor compatible_input; if (rows_odd || cols_odd || planes_odd) { - OP_REQUIRES_OK(context, context->allocate_temp( - DataTypeToEnum<T>::value, - ShapeFromFormat(data_format_, batch, - {{input_size[0] + planes_odd, - input_size[1] + rows_odd, - input_size[2] + cols_odd}}, - in_depth), - &compatible_input)); + OP_REQUIRES_OK(context, + context->allocate_temp( + DataTypeToEnum<T>::value, + ShapeFromFormat(data_format_, dims.batch_size, + {{dims.input_size(0) + planes_odd, + dims.input_size(1) + rows_odd, + dims.input_size(2) + cols_odd}}, + dims.in_depth), + &compatible_input)); functor::PadInput<GPUDevice, T, int, 5>()( context->template eigen_device<GPUDevice>(), To32Bit(input.tensor<T, 5>()), {{0, 0, 0}}, @@ -1016,35 +1644,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { << "Negative paddings: (" << padding_rows << ", " << padding_cols << ", " << padding_planes << ")"; se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(batch) + input_desc.set_count(dims.batch_size) .set_spatial_dim(DimIndex::X, GetTensorDim(compatible_input, data_format_, '2')) .set_spatial_dim(DimIndex::Y, GetTensorDim(compatible_input, data_format_, '1')) .set_spatial_dim(DimIndex::Z, GetTensorDim(compatible_input, data_format_, '0')) - .set_feature_map_count(in_depth) + .set_feature_map_count(dims.in_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(batch) - .set_spatial_dim(DimIndex::X, output_cols) - .set_spatial_dim(DimIndex::Y, output_rows) - .set_spatial_dim(DimIndex::Z, output_planes) - .set_feature_map_count(out_depth) + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) .set_layout(se::dnn::DataLayout::kBatchDepthYX); se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, filter_size[2]) - .set_spatial_dim(DimIndex::Y, filter_size[1]) - .set_spatial_dim(DimIndex::Z, filter_size[0]) - .set_input_feature_map_count(in_depth) - .set_output_feature_map_count(out_depth); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(dims.in_depth) + .set_output_feature_map_count(dims.out_depth); se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dilations[2]) - .set_dilation_rate(DimIndex::Y, dilations[1]) - .set_dilation_rate(DimIndex::Z, dilations[0]) - .set_filter_stride(DimIndex::X, strides[2]) - .set_filter_stride(DimIndex::Y, strides[1]) - .set_filter_stride(DimIndex::Z, strides[0]) + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) .set_zero_padding(DimIndex::X, padding_cols / 2) .set_zero_padding(DimIndex::Y, padding_rows / 2) .set_zero_padding(DimIndex::Z, padding_planes / 2); @@ -1052,19 +1680,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { Tensor pre_transformed_filter_backprop; OP_REQUIRES_OK( context, - context->allocate_temp(DataTypeToEnum<T>::value, - TensorShape({out_depth, in_depth, filter_size[0], - filter_size[1], filter_size[2]}), - &pre_transformed_filter_backprop)); + context->allocate_temp( + DataTypeToEnum<T>::value, + TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2)}), + &pre_transformed_filter_backprop)); Tensor transformed_out_backprop; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows, - output_cols}; + TensorShape nchw_shape = {dims.batch_size, dims.out_depth, + dims.output_size(0), dims.output_size(1), + dims.output_size(2)}; OP_REQUIRES_OK( context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, &transformed_out_backprop)); - if (out_depth > 1) { + if (dims.out_depth > 1) { functor::NHWCToNCHW<GPUDevice, T, 5>()( context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(), transformed_out_backprop.tensor<T, 5>()); @@ -1076,10 +1706,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { } Tensor transformed_input; if (data_format_ == FORMAT_NHWC) { - TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1), - compatible_input.dim_size(2), - compatible_input.dim_size(3)}; - if (in_depth > 1) { + TensorShape nchw_shape = { + dims.batch_size, dims.in_depth, compatible_input.dim_size(1), + compatible_input.dim_size(2), compatible_input.dim_size(3)}; + if (dims.in_depth > 1) { OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape, &transformed_input)); @@ -1110,14 +1740,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel { const int device_id = stream->parent()->device_ordinal(); DataType dtype = input.dtype(); const ConvParameters conv_parameters = { - batch, - in_depth, - {{input_size[0], input_size[1], input_size[2]}}, + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, FORMAT_NCHW, - out_depth, - {{filter_size[0], filter_size[1], filter_size[2]}}, - {{dilations[0], dilations[1], dilations[2]}}, - {{strides[0], strides[1], strides[2]}}, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, {{padding_planes, padding_rows, padding_cols}}, dtype, device_id, diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index ef692418d6..717a9f40a9 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -680,9 +680,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( TensorShape({filter.dim_size(3), filter.dim_size(2), filter.dim_size(0), filter.dim_size(1)}), &transformed_filter)); - functor::TransformFilter<GPUDevice, T, int, 4>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()), + ctx->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 4>()), To32Bit(transformed_filter.tensor<T, 4>())); Tensor transformed_output; @@ -731,9 +731,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()( if (cudnn_use_autotune && !AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) { std::vector<AlgorithmDesc> algorithms; - CHECK(stream->parent()->GetConvolveAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()), - &algorithms)); + OP_REQUIRES( + ctx, + stream->parent()->GetConvolveAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( + stream->parent()), + &algorithms), + errors::Unknown("Failed to get convolution algorithm. This is probably " + "because cuDNN failed to initialize, so try looking to " + "see if a warning log message was printed above.")); ProfileResult best_result; ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { @@ -823,7 +829,8 @@ namespace functor { extern template struct MatMulConvFunctor<GPUDevice, T>; \ template <> \ void TransformFilter<GPUDevice, T, int, 4>::operator()( \ - const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 4, int>::ConstTensor in, \ typename TTypes<T, 4, int>::Tensor out); \ extern template struct TransformFilter<GPUDevice, T, int, 4>; \ template <> \ diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index a1eed4e68c..83df4dce38 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -386,7 +386,8 @@ struct LaunchConvOp<GPUDevice, T> { // filter: [x, y, z, in, out] // t_filter: [out, in, x, y, z] functor::TransformFilter<GPUDevice, T, int, 5>()( - ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()), + ctx->eigen_device<GPUDevice>(), FORMAT_OIHW, + To32Bit(filter.tensor<T, 5>()), To32Bit(transformed_filter.tensor<T, 5>())); Tensor transformed_output; @@ -434,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> { if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find( conv_parameters, &algorithm_config)) { std::vector<AlgorithmDesc> algorithms; - CHECK(stream->parent()->GetConvolveAlgorithms( - conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( - stream->parent()), - &algorithms)); + OP_REQUIRES(ctx, + stream->parent()->GetConvolveAlgorithms( + conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>( + stream->parent()), + &algorithms), + errors::Unknown( + "Failed to get convolution algorithm. This is probably " + "because cuDNN failed to initialize, so try looking to " + "see if a warning log message was printed above.")); + ProfileResult best_result; ProfileResult best_result_no_scratch; for (auto profile_algorithm : algorithms) { @@ -514,7 +521,8 @@ namespace functor { #define DECLARE_GPU_SPEC(T) \ template <> \ void TransformFilter<GPUDevice, T, int, 5>::operator()( \ - const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \ + const GPUDevice& d, FilterTensorFormat dst_filter_format, \ + typename TTypes<T, 5, int>::ConstTensor in, \ typename TTypes<T, 5, int>::Tensor out); \ template <> \ void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \ diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h index afc611f277..21d135decd 100644 --- a/tensorflow/core/kernels/conv_ops_gpu.h +++ b/tensorflow/core/kernels/conv_ops_gpu.h @@ -142,8 +142,12 @@ class ConvParameters { template <typename T> bool ShouldIncludeWinogradNonfusedAlgo( se::StreamExecutor* stream_exec) const { + auto* dnn_support = stream_exec->AsDnn(); + if (!dnn_support) { + return false; + } // Skip this check for cuDNN 7 and newer. - auto version = stream_exec->AsDnn()->GetVersion(); + auto version = dnn_support->GetVersion(); if (version.ok() && version.ValueOrDie().major_version() >= 7) { return true; } diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc index a5fa48f85e..46167db3a2 100644 --- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc +++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc @@ -170,51 +170,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex( return tensor_index; } -// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor. -template <typename T, bool conjugate = false> -__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input, - Dimension<3> input_dims, - T* output) { - Dimension<3> output_dims; - output_dims[0] = input_dims[2]; - output_dims[1] = input_dims[1]; - output_dims[2] = input_dims[0]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int output_index = index; - - Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); - - Index<3> input_tensor_index; - input_tensor_index[0] = output_tensor_index[2]; - input_tensor_index[1] = output_tensor_index[1]; - input_tensor_index[2] = output_tensor_index[0]; - - int input_index = TensorIndexToFlat(input_tensor_index, input_dims); - - output[output_index] = - maybe_conj<T, conjugate>::run(ldg(input + input_index)); - } -} - -// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor. -template <typename T, bool conjugate = false> -__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input, - Dimension<3> input_dims, - T* output) { +// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to +// the given shuffle permutation in template parameters. Shuffle permutation +// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0, +// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1> +// will populate output so that input[x][y][z] is equal to (*output)[y][z][x]. +// +// Requires that nthreads is equal to the total number of elements in the input +// tensor. +template <typename T, int sp0, int sp1, int sp2, bool conjugate = false> +__global__ void ShuffleInTensor3Simple(int nthreads, const T* input, + Dimension<3> input_dims, T* output) { Dimension<3> output_dims; - output_dims[0] = input_dims[0]; - output_dims[1] = input_dims[2]; - output_dims[2] = input_dims[1]; - - CUDA_1D_KERNEL_LOOP(index, nthreads) { - int output_index = index; + output_dims[sp0] = input_dims[0]; + output_dims[sp1] = input_dims[1]; + output_dims[sp2] = input_dims[2]; + + // Iterate over output as opposed to iterating over input for better + // performance. Iterating over output will generate sequential writes and + // random reads that performs better compared to sequential reads and random + // writes. + CUDA_1D_KERNEL_LOOP(output_index, nthreads) { Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims); Index<3> input_tensor_index; - input_tensor_index[0] = output_tensor_index[0]; - input_tensor_index[1] = output_tensor_index[2]; - input_tensor_index[2] = output_tensor_index[1]; + input_tensor_index[0] = output_tensor_index[sp0]; + input_tensor_index[1] = output_tensor_index[sp1]; + input_tensor_index[2] = output_tensor_index[sp2]; int input_index = TensorIndexToFlat(input_tensor_index, input_dims); @@ -439,7 +421,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input, template <typename T, int NDIMS> struct TransformFilter<GPUDevice, T, int, NDIMS> { typedef GPUDevice Device; - void operator()(const Device& d, + void operator()(const Device& d, FilterTensorFormat dst_filter_format, typename TTypes<T, NDIMS, int>::ConstTensor in, typename TTypes<T, NDIMS, int>::Tensor out) { Dimension<3> combined_dims; @@ -450,13 +432,18 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> { combined_dims[1] = in.dimension(NDIMS - 2); // input filters combined_dims[2] = in.dimension(NDIMS - 1); // output filters CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension0And2InTensor3Simple<T> + + CHECK(dst_filter_format == FORMAT_OIHW) + << "Unsupported output layout: " << ToString(dst_filter_format); + + ShuffleInTensor3Simple<T, 2, 1, 0> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in.data(), combined_dims, out.data()); } }; -// Converts Cudnn filter format back to TensorFlow filter format. +// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO. +// TODO(hinsu): Support reverse transformation from filter format OHWI as well. template <typename T, int NDIMS> struct ReverseTransformFilter<GPUDevice, T, NDIMS> { typedef GPUDevice Device; @@ -470,7 +457,7 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> { combined_dims[2] *= in.dimension(i); } CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d); - SwapDimension0And2InTensor3Simple<T> + ShuffleInTensor3Simple<T, 2, 1, 0> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in.data(), combined_dims, out.data()); } @@ -937,7 +924,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input, } else { int total_element_count = input_dims[0] * input_dims[1] * input_dims[2]; CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d); - SwapDimension1And2InTensor3Simple<T, conjugate> + ShuffleInTensor3Simple<T, 0, 2, 1, conjugate> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, input, input_dims, output); } @@ -969,7 +956,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> { static_cast<int>(combined_dims[2])}; size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2]; CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d); - SwapDimension0And2InTensor3Simple<T, conjugate> + ShuffleInTensor3Simple<T, 2, 1, 0, conjugate> <<<config.block_count, config.thread_per_block, 0, d.stream()>>>( config.virtual_thread_count, in, input_dims, out); } diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 8d867455e7..b3c359010d 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -51,6 +51,7 @@ cc_library( hdrs = ["captured_function.h"], deps = [ ":dataset", + ":single_threaded_executor", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -61,6 +62,42 @@ cc_library( ) cc_library( + name = "single_threaded_executor", + srcs = ["single_threaded_executor.cc"], + hdrs = ["single_threaded_executor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + ], + alwayslink = 1, +) + +tf_cc_test( + name = "single_threaded_executor_test", + srcs = ["single_threaded_executor_test.cc"], + deps = [ + ":single_threaded_executor", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/kernels:array", + "//tensorflow/core/kernels:control_flow_ops", + "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:random_ops", + "//tensorflow/core/kernels:state", + ], +) + +cc_library( name = "window_dataset", srcs = ["window_dataset.cc"], hdrs = ["window_dataset.h"], @@ -481,8 +518,7 @@ tf_kernel_library( ":dataset", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + "//tensorflow/core:graph", ], ) @@ -505,8 +541,7 @@ tf_kernel_library( ":dataset", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:lib_internal", + "//tensorflow/core:graph", ], ) @@ -640,6 +675,19 @@ tf_kernel_library( ) tf_kernel_library( + name = "model_dataset_op", + srcs = ["model_dataset_op.cc"], + deps = [ + ":dataset", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( name = "dataset_ops", srcs = ["dataset_ops.cc"], deps = [ @@ -673,6 +721,7 @@ tf_kernel_library( ":map_and_batch_dataset_op", ":map_dataset_op", ":map_defun_op", + ":model_dataset_op", ":optimize_dataset_op", ":optional_ops", ":padded_batch_dataset_op", diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc index f9b5353724..d1db1d7bec 100644 --- a/tensorflow/core/kernels/data/batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/batch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -241,5 +242,5 @@ REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU), BatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc index 6ca0bcd37d..34c6c86538 100644 --- a/tensorflow/core/kernels/data/cache_dataset_ops.cc +++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level description of @@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr<IteratorBase>( - new FileIterator({this, strings::StrCat(prefix, "::FileIterator")})); + new FileIterator({this, strings::StrCat(prefix, "::FileCache")})); } const DataTypeVector& output_dtypes() const override { @@ -553,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { return std::unique_ptr<IteratorBase>(new MemoryIterator( - {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_)); + {this, strings::StrCat(prefix, "::MemoryCache")}, cache_)); } const DataTypeVector& output_dtypes() const override { @@ -891,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU), CacheDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index abdf6ee4e8..8a5d30a27c 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -17,33 +17,101 @@ limitations under the License. #include <utility> #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/step_stats_collector.h" #include "tensorflow/core/framework/cancellation.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { +namespace data { + +namespace { + +// Simplistic implementation of the `StepStatsCollectorInterface` that only +// cares about collecting the CPU time needed to execute a captured function. +class SimpleStepStatsCollector : public StepStatsCollectorInterface { + public: + void IncrementProcessingTime(int64 delta) { + mutex_lock l(mu_); + processing_time_ += delta; + } + + NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override { + return new SimpleNodeExecStats(this); + } + + string ReportAllocsOnResourceExhausted(const string& err) override { + return ""; + } + + int64 processing_time() { + tf_shared_lock l(mu_); + return processing_time_; + } + + private: + class SimpleNodeExecStats : public NodeExecStatsInterface { + public: + explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector) + : step_stats_collector_(step_stats_collector) {} + + void Done(const string& device) override { + step_stats_collector_->IncrementProcessingTime(end_time_ns_ - + start_time_ns_); + delete this; + } + + void RecordExecutorStarted() override { + start_time_ns_ = Env::Default()->NowNanos(); + } + + void RecordComputeStarted() override {} + + void RecordComputeEnded() override {} + + void RecordExecutorEnded() override { + end_time_ns_ = Env::Default()->NowNanos(); + } + + void SetMemory(OpKernelContext* ctx) override {} + + void SetOutput(int slot, const Tensor* tensor) override {} + + void SetReferencedTensors(const TensorReferenceVector& tensors) override {} + + void SetScheduled(int64 nanos) override {} + + private: + int64 start_time_ns_ = 0; + int64 end_time_ns_ = 0; + SimpleStepStatsCollector* step_stats_collector_; // Not owned. + }; + + mutex mu_; + int64 processing_time_ GUARDED_BY(mu_) = 0; +}; + +} // namespace /* static */ Status CapturedFunction::Create( - const NameAttrList& func, std::vector<Tensor> captured_inputs, + const NameAttrList& func, OpKernelContext* ctx, const string& argument, std::unique_ptr<CapturedFunction>* out_function) { - out_function->reset(new CapturedFunction(func, std::move(captured_inputs))); - return Status::OK(); + return CapturedFunction::Create(func, ctx, argument, true, out_function); } -/* static */ Status CapturedFunction::Create( const NameAttrList& func, OpKernelContext* ctx, const string& argument, + bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction>* out_function) { - OpInputList argument_inputs; - TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs)); - std::vector<Tensor> arguments_t; - arguments_t.reserve(argument_inputs.size()); - for (const Tensor& t : argument_inputs) { - arguments_t.push_back(t); - } - return CapturedFunction::Create(func, std::move(arguments_t), out_function); + OpInputList inputs; + TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs)); + std::vector<Tensor> arguments(inputs.begin(), inputs.end()); + *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments), + use_inter_op_parallelism)); + return Status::OK(); } CapturedFunction::~CapturedFunction() { @@ -272,6 +340,9 @@ Status CapturedFunction::Instantiate(IteratorContext* ctx) { inst_opts.overlay_lib = ctx->function_library().get(); inst_opts.state_handle = std::to_string(random::New64()); inst_opts.create_kernels_eagerly = true; + if (!use_inter_op_parallelism_) { + inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR"; + } Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()), inst_opts, &f_handle_)); TF_RETURN_IF_ERROR(s); @@ -345,7 +416,8 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args, void CapturedFunction::RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, - FunctionLibraryRuntime::DoneCallback done) { + FunctionLibraryRuntime::DoneCallback done, + const string& prefix) { // NOTE(mrry): This method does not transfer ownership of `ctx`, and it may // be deleted before `done` is called. Take care not to capture `ctx` in any // code that may execute asynchronously in this function. @@ -355,17 +427,17 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, done(s); return; } - auto frame = - new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_); + std::shared_ptr<OwnedArgsCallFrame> frame( + new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_)); FunctionLibraryRuntime::Options f_opts; f_opts.step_id = CapturedFunction::generate_step_id(); ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager(); - auto step_container = new ScopedStepContainer( + std::shared_ptr<ScopedStepContainer> step_container(new ScopedStepContainer( f_opts.step_id, [resource_mgr](const string& name) { resource_mgr->Cleanup(name).IgnoreError(); - }); - f_opts.step_container = step_container; + })); + f_opts.step_container = step_container.get(); f_opts.runner = ctx->runner(); if (ctx->lib()->device()->device_type() != DEVICE_CPU) { f_opts.create_rendezvous = true; @@ -376,32 +448,55 @@ void CapturedFunction::RunAsync(IteratorContext* ctx, // (such as queue kernels) that depend on the non-nullness of // `OpKernelContext::cancellation_manager()`, but additional effort // will be required to plumb it through the `IteratorContext`. - auto c_mgr = new CancellationManager; - f_opts.cancellation_manager = c_mgr; - - tf_shared_lock l(mu_); - ctx->lib()->Run(f_opts, handle, frame, - std::bind( - [rets, step_container, c_mgr, frame]( - FunctionLibraryRuntime::DoneCallback done, - // Begin unbound arguments. - Status s) { - delete step_container; - delete c_mgr; - if (s.ok()) { - s = frame->ConsumeRetvals(rets); - } - delete frame; - done(s); - }, - std::move(done), std::placeholders::_1)); + std::shared_ptr<CancellationManager> c_mgr(new CancellationManager); + f_opts.cancellation_manager = c_mgr.get(); + std::shared_ptr<SimpleStepStatsCollector> stats_collector; + std::shared_ptr<model::Node> node; + if (ctx->model()) { + node = ctx->model()->LookupNode(prefix); + if (node) { + stats_collector = MakeUnique<SimpleStepStatsCollector>(); + } + } + f_opts.stats_collector = stats_collector.get(); + + OwnedArgsCallFrame* raw_frame = frame.get(); + auto callback = std::bind( + [rets](const std::shared_ptr<CancellationManager>& c_mgr, + const FunctionLibraryRuntime::DoneCallback& done, + const std::shared_ptr<OwnedArgsCallFrame>& frame, + const std::shared_ptr<model::Node>& node, + const std::shared_ptr<SimpleStepStatsCollector>& stats_collector, + const std::shared_ptr<ScopedStepContainer>& step_container, + // Begin unbound arguments. + Status s) { + if (s.ok()) { + s = frame->ConsumeRetvals(rets); + } + if (node) { + node->add_processing_time(stats_collector->processing_time()); + node->start_work(); + } + done(s); + if (node) { + node->stop_work(); + } + }, + std::move(c_mgr), std::move(done), std::move(frame), std::move(node), + std::move(stats_collector), std::move(step_container), + std::placeholders::_1); + + ctx->lib()->Run(f_opts, handle, raw_frame, std::move(callback)); } CapturedFunction::CapturedFunction(const NameAttrList& func, - std::vector<Tensor> captured_inputs) + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism) : func_(func), lib_(nullptr), f_handle_(kInvalidHandle), - captured_inputs_(std::move(captured_inputs)) {} + captured_inputs_(std::move(captured_inputs)), + use_inter_op_parallelism_(use_inter_op_parallelism) {} +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h index c95f2b1c01..a10376bf97 100644 --- a/tensorflow/core/kernels/data/captured_function.h +++ b/tensorflow/core/kernels/data/captured_function.h @@ -32,6 +32,8 @@ class Device; class OpKernelContext; class ResourceMgr; +namespace data { + // A `CapturedFunction` encapsulates a TensorFlow function and all of // the runtime support required to execute it. // @@ -40,18 +42,19 @@ class ResourceMgr; // context. class CapturedFunction { public: - // Creates a new instance from a list of named attributes and captured inputs. - // - // NOTE(mrry): The `captured_inputs` are passed by value. For - // efficiency, you are recommended to move this argument into the call. - static Status Create(const NameAttrList& func, - std::vector<Tensor> captured_inputs, + // Creates a new instance using a list of named attributes, fetching captured + // inputs from a context argument. + static Status Create(const NameAttrList& func, OpKernelContext* ctx, + const string& argument, std::unique_ptr<CapturedFunction>* out_function); // Creates a new instance using a list of named attributes, fetching captured // inputs from a context argument. + // + // If `use_inter_op_parallelism` is false, the runtime may use an executor + // that is optimized for small functions. static Status Create(const NameAttrList& func, OpKernelContext* ctx, - const string& argument, + const string& argument, bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction>* out_function); ~CapturedFunction(); @@ -93,7 +96,8 @@ class CapturedFunction { // in order to be able to deallocate them as early as possible. void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets, - FunctionLibraryRuntime::DoneCallback done); + FunctionLibraryRuntime::DoneCallback done, + const string& prefix); // Returns the named list of function arguments. const NameAttrList& func() { return func_; } @@ -114,7 +118,8 @@ class CapturedFunction { private: CapturedFunction(const NameAttrList& func, - std::vector<Tensor> captured_inputs); + std::vector<Tensor> captured_inputs, + bool use_inter_op_parallelism); Status GetHandle(IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle); @@ -126,10 +131,17 @@ class CapturedFunction { const std::vector<Tensor> captured_inputs_; DataTypeSlice ret_types_; std::function<void(std::function<void()>)> captured_runner_ = nullptr; + const bool use_inter_op_parallelism_; TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction); }; +} // namespace data + +// TODO(b/114112161): Remove these aliases when all users have moved over to the +// `tensorflow::data` namespace. +using data::CapturedFunction; + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_ diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index c361a9adcb..a04f150e71 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU), ConcatenateDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc index c71d027f23..bd1ccd5b5d 100644 --- a/tensorflow/core/kernels/data/dataset_ops.cc +++ b/tensorflow/core/kernels/data/dataset_ops.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -48,4 +49,5 @@ class DatasetToGraphOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU), DatasetToGraphOp); +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index d85ef1cbab..e7ac368ae3 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -17,8 +17,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device.h" namespace tensorflow { - -namespace dataset { +namespace data { Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, @@ -45,6 +44,5 @@ Status MakeIteratorFromInputElement( ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator); } -} // namespace dataset - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6c4191c2be..234856ea39 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -20,16 +20,14 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - -namespace dataset { +namespace data { Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator); -} // namespace dataset - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_ diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc index 9770bc025d..237511a07d 100644 --- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -301,5 +301,5 @@ REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU), DenseToSparseBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc index ce577397c5..a7e3a56727 100644 --- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -166,5 +166,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU), FilterByLastComponentDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index bbce001eaf..19c35f94a6 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -37,14 +37,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - FunctionLibraryRuntime::Handle pred_handle; OP_REQUIRES_OK(ctx, ctx->function_library()->Instantiate( @@ -61,9 +53,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { Node* ret_node = pred_body->ret_nodes[0]; Node* ret_input_node; OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); + std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); if (ret_input_node->def().op() == "_Arg") { int32 index = -1; @@ -280,5 +273,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU), FilterDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc index b1eb2fd849..2fada22a21 100644 --- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -39,18 +39,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); - + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); } @@ -245,7 +236,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel { private: Status BuildCurrentElementIteratorLocked(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return dataset::MakeIteratorFromInputElement( + return MakeIteratorFromInputElement( ctx, captured_func_inputs_, element_index_++, dataset()->captured_func_.get(), prefix(), ¤t_element_iterator_); @@ -285,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU), FlatMapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index ccee690d7e..71a36314a0 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -144,54 +145,31 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx) void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase** output) { - OpInputList init_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args", - &init_func_other_args_input)); - std::vector<Tensor> init_func_other_args; - init_func_other_args.reserve(init_func_other_args_input.size()); - for (const Tensor& t : init_func_other_args_input) { - init_func_other_args.push_back(t); - } std::unique_ptr<CapturedFunction> init_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args), - &init_func)); - - OpInputList next_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args", - &next_func_other_args_input)); - std::vector<Tensor> next_func_other_args; - next_func_other_args.reserve(next_func_other_args_input.size()); - for (const Tensor& t : next_func_other_args_input) { - next_func_other_args.push_back(t); - } + OP_REQUIRES_OK(ctx, CapturedFunction::Create( + init_func_, ctx, "init_func_other_args", &init_func)); + std::unique_ptr<CapturedFunction> next_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args), - &next_func)); - - OpInputList finalize_func_other_args_input; - OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args", - &finalize_func_other_args_input)); - std::vector<Tensor> finalize_func_other_args; - finalize_func_other_args.reserve(finalize_func_other_args_input.size()); - for (const Tensor& t : finalize_func_other_args_input) { - finalize_func_other_args.push_back(t); - } - std::unique_ptr<CapturedFunction> finalize_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create( - finalize_func_, std::move(finalize_func_other_args), - &finalize_func)); + next_func_, ctx, "next_func_other_args", &next_func)); + + std::unique_ptr<CapturedFunction> finalize_func; + OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx, + "finalize_func_other_args", + &finalize_func)); *output = new Dataset(ctx, std::move(init_func), std::move(next_func), std::move(finalize_func), output_types_, output_shapes_); } +namespace { REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU), GeneratorDatasetOp); REGISTER_KERNEL_BUILDER( Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"), GeneratorDatasetOp); +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h index 8407543136..d23ed97ec3 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.h +++ b/tensorflow/core/kernels/data/generator_dataset_op.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" namespace tensorflow { +namespace data { class GeneratorDatasetOp : public DatasetOpKernel { public: @@ -36,5 +37,6 @@ class GeneratorDatasetOp : public DatasetOpKernel { NameAttrList finalize_func_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index 130f04da3e..d6ee42a7c6 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -433,4 +434,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU), GroupByReducerDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 46a3185b49..8b417bb1c2 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -41,50 +42,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - // Get captured inputs for the key, reduce, and window_size functions. - OpInputList key_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments", - &key_func_other_argument_inputs)); - std::vector<Tensor> key_func_other_arguments; - key_func_other_arguments.reserve(key_func_other_argument_inputs.size()); - for (const Tensor& t : key_func_other_argument_inputs) { - key_func_other_arguments.push_back(t); - } - OpInputList reduce_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments", - &reduce_func_other_argument_inputs)); - std::vector<Tensor> reduce_func_other_arguments; - reduce_func_other_arguments.reserve( - reduce_func_other_argument_inputs.size()); - for (const Tensor& t : reduce_func_other_argument_inputs) { - reduce_func_other_arguments.push_back(t); - } - OpInputList window_size_func_other_argument_inputs; - OP_REQUIRES_OK(ctx, - ctx->input_list("window_size_func_other_arguments", - &window_size_func_other_argument_inputs)); - std::vector<Tensor> window_size_func_other_arguments; - window_size_func_other_arguments.reserve( - window_size_func_other_argument_inputs.size()); - for (const Tensor& t : window_size_func_other_argument_inputs) { - window_size_func_other_arguments.push_back(t); - } - // TODO(mrry): Refactor CapturedFunction to share the runtime - // state between multiple functions? std::unique_ptr<CapturedFunction> captured_key_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - key_func_, std::move(key_func_other_arguments), - &captured_key_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx, + "key_func_other_arguments", + &captured_key_func)); std::unique_ptr<CapturedFunction> captured_reduce_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create(reduce_func_, - std::move(reduce_func_other_arguments), - &captured_reduce_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx, + "reduce_func_other_arguments", + &captured_reduce_func)); std::unique_ptr<CapturedFunction> captured_window_size_func; - OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - window_size_func_, std::move(window_size_func_other_arguments), - &captured_window_size_func)); + OP_REQUIRES_OK(ctx, + CapturedFunction::Create(window_size_func_, ctx, + "window_size_func_other_arguments", + &captured_window_size_func)); *output = new Dataset( ctx, input, key_func_, reduce_func_, window_size_func_, @@ -549,4 +519,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU), GroupByWindowDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc index 716e040277..0aa802b874 100644 --- a/tensorflow/core/kernels/data/interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -39,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - const Tensor* cycle_length_t; OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t)); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()), @@ -66,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { errors::InvalidArgument("block_length must be greater than zero.")); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), cycle_length, @@ -201,7 +193,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(input_impl_->GetNext( ctx, &args_list_[cycle_index_], &end_of_input_)); if (!end_of_input_) { - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( ctx, args_list_[cycle_index_], cycle_index_, dataset()->captured_func_.get(), prefix(), ¤t_elements_[cycle_index_])); @@ -288,7 +280,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel { full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), &args_list_[idx][i])); } - TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement( + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( ctx, args_list_[idx], idx, dataset()->captured_func_.get(), prefix(), ¤t_elements_[idx])); TF_RETURN_IF_ERROR( @@ -330,5 +322,5 @@ REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU), InterleaveDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index 4e9b280968..30c6585ba2 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -36,7 +36,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -236,6 +236,8 @@ class IteratorResource : public ResourceBase { const std::vector<PartialTensorShape> output_shapes_; }; +namespace { + // Helper class for reading data from a VariantTensorData object. class VariantTensorDataReader : public IteratorStateReader { public: @@ -401,12 +403,12 @@ class IteratorStateVariant { } string TypeName() const { return kIteratorVariantTypeName; } void Encode(VariantTensorData* data) const { *data = *data_; } - bool Decode(const VariantTensorData& data) { + bool Decode(VariantTensorData data) { if (data.type_name() != TypeName()) { return false; } std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData); - *tensor_data = data; + std::swap(*tensor_data, data); std::unique_ptr<VariantTensorDataReader> reader( new VariantTensorDataReader(tensor_data.get())); status_ = reader->status(); @@ -443,6 +445,8 @@ class IteratorStateVariant { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant, kIteratorVariantTypeName); +} // namespace + // Note that IteratorHandleOp holds a reference to the resource it creates. If // cleaning up resources with DestroyResourceOp is important, consider creating // resource containers with AnonymousIteratorHandleOp instead. @@ -622,6 +626,8 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) { OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator))); } +namespace { + class ToSingleElementOp : public AsyncOpKernel { public: explicit ToSingleElementOp(OpKernelConstruction* ctx) @@ -887,6 +893,8 @@ class OneShotIteratorOp : public AsyncOpKernel { const int graph_def_version_; }; +} // namespace + void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) { IteratorResource* iterator; OP_REQUIRES_OK_ASYNC( @@ -957,6 +965,8 @@ void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) { } } +namespace { + class IteratorGetNextAsOptionalOp : public AsyncOpKernel { public: explicit IteratorGetNextAsOptionalOp(OpKernelConstruction* ctx) @@ -1037,6 +1047,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel { std::vector<PartialTensorShape> output_shapes_; }; +} // namespace + void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) { const Tensor& resource_handle_t = ctx->input(0); OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()), @@ -1108,6 +1120,8 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) { resource_handle_t->scalar<ResourceHandle>()() = resource_handle; } +namespace { + class SerializeIteratorOp : public OpKernel { public: explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} @@ -1202,4 +1216,7 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU), REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU), DeserializeIteratorOp); +} // namespace + +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h index 723564286c..8a2b2639a7 100644 --- a/tensorflow/core/kernels/data/iterator_ops.h +++ b/tensorflow/core/kernels/data/iterator_ops.h @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" namespace tensorflow { +namespace data { class IteratorResource; @@ -142,6 +143,7 @@ class IteratorFromStringHandleOp : public OpKernel { std::vector<PartialTensorShape> output_shapes_; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_ diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 8b0c9ad6b2..83896219a3 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -26,10 +26,11 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -39,7 +40,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -49,14 +49,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 batch_size; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size)); OP_REQUIRES( @@ -77,7 +69,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { case 2: OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, + num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); break; @@ -92,8 +85,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ParseScalarArgument(ctx, "drop_remainder", &drop_remainder)); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, batch_size, num_parallel_calls, drop_remainder, output_types_, output_shapes_, func_, @@ -190,7 +183,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { class Iterator : public DatasetIterator<Dataset> { public: explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + : DatasetIterator<Dataset>(params), + num_parallel_calls_(params.dataset->num_parallel_calls_) {} ~Iterator() override { mutex_lock l(mu_); @@ -204,6 +198,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + std::function<void(int64)> set_fn = [this](int64 value) { + { + mutex_lock l(mu_); + num_parallel_calls_ = value; + } + VLOG(2) << "setting parallelism knob to " << value; + cond_var_.notify_all(); + }; + AddTunableParameter( + ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */, + port::NumSchedulableCPUs() /* max */, std::move(set_fn)); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -218,7 +230,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { EnsureRunnerThreadStarted(ctx); while (batch_results_.empty() || batch_results_.front()->num_calls > 0) { + StopWork(ctx); cond_var_.wait(l); + StartWork(ctx); } std::swap(result, batch_results_.front()); batch_results_.pop_front(); @@ -365,7 +379,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { ctx.get(), std::move(input_element), return_values.get(), [this, ctx, result, return_values, offset](Status status) { Callback(ctx, result, return_values, offset, status); - }); + }, + prefix()); }, ctx, std::move(input_element))); } @@ -423,7 +438,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) / + return (num_parallel_calls_ + dataset()->batch_size_ - 1) / dataset()->batch_size_; } @@ -475,23 +490,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) LOCKS_EXCLUDED(mu_) { std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; - new_calls.reserve(dataset()->num_parallel_calls_); + StartWork(ctx.get()); + auto stop_cleanup = + gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); }); + { + tf_shared_lock l(mu_); + new_calls.reserve(num_parallel_calls_); + } while (true) { { mutex_lock l(mu_); while (!cancelled_ && - (num_calls_ >= dataset()->num_parallel_calls_ || + (num_calls_ >= num_parallel_calls_ || batch_results_.size() > MaxBatchResults() || (batch_results_.size() == MaxBatchResults() && call_counter_ % dataset()->batch_size_ == 0))) { + StopWork(ctx.get()); cond_var_.wait(l); + StartWork(ctx.get()); } if (cancelled_) { return; } - while (num_calls_ < dataset()->num_parallel_calls_ && + while (num_calls_ < num_parallel_calls_ && (batch_results_.size() < MaxBatchResults() || (batch_results_.size() == MaxBatchResults() && call_counter_ % dataset()->batch_size_ != 0))) { @@ -638,6 +661,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // user specified level of parallelism and there are slots available in // the `batch_results_` buffer. condition_variable cond_var_; + // Identifies the maximum number of parallel calls. + int64 num_parallel_calls_ GUARDED_BY(mu_) = 0; // Counts the number of outstanding calls for this batch. int64 num_calls_ GUARDED_BY(mu_) = 0; // Counts the total number of calls. @@ -661,7 +686,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const Eigen::ThreadPoolDevice* device_; // not owned }; - const int graph_def_version_; const int op_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; @@ -675,5 +699,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU), MapAndBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index 7f8182d917..f112e1dc43 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -28,27 +28,20 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: - explicit MapDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + use_inter_op_parallelism_, + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(captured_func), output_types_, output_shapes_); @@ -183,14 +176,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList func_; + bool use_inter_op_parallelism_; }; REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 607d0ca028..6657f2b2b3 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -18,18 +18,20 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_util.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/util/batch_util.h" #include "tensorflow/core/util/reffed_status_callback.h" namespace tensorflow { +namespace data { namespace { void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts, bool always_collect_stats) { opts->step_id = ctx->step_id(); opts->rendezvous = ctx->rendezvous(); - opts->cancellation_manager = ctx->cancellation_manager(); if (always_collect_stats) { opts->stats_collector = ctx->stats_collector(); } @@ -60,103 +62,186 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { - int64 batch_size = ctx->input(0).dim_size(0); - // Inputs - auto* args = new std::vector<Tensor>; - auto* arg_shapes = new std::vector<TensorShape>; - arg_shapes->reserve(ctx->num_inputs()); - args->reserve(ctx->num_inputs()); - + Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { + // Validates inputs and gets the size of their leading dimension. + *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args->push_back(ctx->input(i)); - arg_shapes->push_back(ctx->input(i).shape()); - arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension - OP_REQUIRES_ASYNC( - ctx, batch_size == ctx->input(i).dim_size(0), - errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size, - "."), - done); + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != *batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } } + return Status::OK(); + } - // Outputs - auto* output = new OpOutputList; - OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done); + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ComputeOptions* compute_opts = nullptr; - for (size_t i = 0; i < output_types().size(); ++i) { - Tensor* out = nullptr; - TensorShape output_shape = output_shapes_.at(i); - output_shape.InsertDim(0, batch_size); - OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done); - } + OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done); - SetRunOptions(ctx, &opts_, false); + Status s = SetupOutputs(ctx, compute_opts); + if (!s.ok()) delete compute_opts; + OP_REQUIRES_OK_ASYNC(ctx, s, done); + + FunctionLibraryRuntime::Options opts; + SetRunOptions(ctx, &opts, false); // Run loop StatusCallback callback = std::bind( - [](OpKernelContext* ctx, std::vector<Tensor>* args, - std::vector<TensorShape>* arg_shapes, OpOutputList* output, + [](OpKernelContext* ctx, ComputeOptions* compute_opts, DoneCallback& done, const Status& status) { - delete args; - delete arg_shapes; - delete output; + delete compute_opts; ctx->SetStatus(status); done(); }, - ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1); + ctx, compute_opts, std::move(done), std::placeholders::_1); auto* refcounted = new ReffedStatusCallback(std::move(callback)); - for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) { - // Start from i = 1 because refcounted is initialized with refcount = 1 + CancellationManager* parent_mgr = ctx->cancellation_manager(); + + for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) { + // We use a different cancellation manager each time the function is run + // to avoid the race condition between a function run error and other + // functions being cancelled as a result. + CancellationManager* c_mgr = new CancellationManager; + CancellationToken token = parent_mgr->get_cancellation_token(); + const bool success = parent_mgr->RegisterCallback( + token, [c_mgr]() { c_mgr->StartCancel(); }); + + opts.cancellation_manager = c_mgr; + if (!success) { + delete c_mgr; + refcounted->UpdateStatus(errors::Cancelled( + "MapDefunOp functions cancelled because parent graph cancelled")); + break; + } + + auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i); + refcounted->Ref(); + ctx->function_library()->Run(opts, func_handle_, call_frame, + [call_frame, refcounted, c_mgr, parent_mgr, + token](const Status& func_status) { + parent_mgr->DeregisterCallback(token); + delete c_mgr; + delete call_frame; + refcounted->UpdateStatus(func_status); + refcounted->Unref(); + }); } - for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) { - auto* call_frame = - new MapFunctionCallFrame(*args, *arg_shapes, output, this, i); - ctx->function_library()->Run( - opts_, func_handle_, call_frame, - [call_frame, refcounted](const Status& func_status) { - delete call_frame; - refcounted->UpdateStatus(func_status); - refcounted->Unref(); - }); - } + + // Unref 1 because refcounted is initialized with refcount = 1 + refcounted->Unref(); } private: FunctionLibraryRuntime::Handle func_handle_; - FunctionLibraryRuntime::Options opts_; - std::vector<TensorShape> output_shapes_; + std::vector<PartialTensorShape> output_shapes_; + + struct ComputeOptions { + // These vary per MapDefunOp::ComputeAsync call, but must persist until + // all calls to the function are complete. This struct also encapsulates + // all the components that need to be passed to each MapFunctionCallFrame. + + const std::vector<Tensor> args; + const std::vector<TensorShape> arg_shapes; + const int64 batch_size; + + // Output of a compute call + std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu); + OpOutputList output GUARDED_BY(mu); + mutex mu; + + // Create a copy of output_shapes because every `Compute` may expect a + // different output shape. + ComputeOptions(std::vector<Tensor> args, + std::vector<TensorShape> arg_shapes, int64 batch_size, + const std::vector<PartialTensorShape>& output_shapes_attr) + : args(std::move(args)), + arg_shapes(std::move(arg_shapes)), + batch_size(batch_size), + output_shapes(output_shapes_attr) {} + }; + + // Get inputs to Compute and check that they are valid. + Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { + int64 batch_size = + ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + if (ctx->input(i).dims() == 0) { + return errors::InvalidArgument( + "All inputs must have rank at least 1. Input ", i, + " has a rank of 0."); + } else if (ctx->input(i).dim_size(0) != batch_size) { + return errors::InvalidArgument( + "All inputs must have the same dimension 0. Input ", i, + " has leading dimension ", ctx->input(i).dim_size(0), + ", while all previous inputs have leading dimension ", batch_size); + } + } + + std::vector<Tensor> args; + std::vector<TensorShape> arg_shapes; + args.reserve(ctx->num_inputs()); + arg_shapes.reserve(ctx->num_inputs()); + + for (size_t i = 0; i < ctx->num_inputs(); ++i) { + args.push_back(ctx->input(i)); + arg_shapes.push_back(ctx->input(i).shape()); + arg_shapes.at(i).RemoveDim(0); + } + + *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), + batch_size, output_shapes_); + return Status::OK(); + } + + Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) { + mutex_lock l(opts->mu); + TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output)); + + for (size_t i = 0; i < output_types().size(); ++i) { + if (output_shapes_.at(i).IsFullyDefined()) { + Tensor* out = nullptr; + TensorShape output_shape; + output_shapes_.at(i).AsTensorShape(&output_shape); + output_shape.InsertDim(0, opts->batch_size); + TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out)); + } + } + return Status::OK(); + } class MapFunctionCallFrame : public CallFrameInterface { public: - MapFunctionCallFrame(const std::vector<Tensor>& args, - const std::vector<TensorShape>& arg_shapes, - OpOutputList* output, OpKernel* kernel, size_t iter) - : args_(args), - arg_shapes_(arg_shapes), - output_(output), - kernel_(kernel), - iter_(iter) {} + MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel, + size_t iter) + : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {} ~MapFunctionCallFrame() override {} - size_t num_args() const override { return args_.size(); } + size_t num_args() const override { return compute_opts_->args.size(); } + size_t num_retvals() const override { return static_cast<size_t>(kernel_->num_outputs()); } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= args_.size()) { + if (index < 0 || index >= compute_opts_->args.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } - bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1), - arg_shapes_.at(index)); + bool result = + val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); } else if (!val->IsAligned()) { @@ -179,18 +264,39 @@ class MapDefunOp : public AsyncOpKernel { "output: ", index); } - return batch_util::CopyElementToSlice(val, (*output_)[index], iter_); + { // Locking scope + mutex_lock l(compute_opts_->mu); + if (!compute_opts_->output_shapes.at(index).IsCompatibleWith( + val.shape())) { + return errors::InvalidArgument( + "Mismatch in function retval shape, ", val.shape(), + ", and expected output shape, ", + compute_opts_->output_shapes.at(index).DebugString(), "."); + } + if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) { + // Given val, we have new information about the output shape at + // this index. Store the shape and allocate the output accordingly. + compute_opts_->output_shapes.at(index) = val.shape(); + + Tensor* out = nullptr; + TensorShape actual_shape = val.shape(); + actual_shape.InsertDim(0, compute_opts_->batch_size); + TF_RETURN_IF_ERROR( + compute_opts_->output.allocate(index, actual_shape, &out)); + } + return batch_util::CopyElementToSlice( + val, (compute_opts_->output)[index], iter_); + } } private: - const std::vector<Tensor>& args_; - const std::vector<TensorShape>& arg_shapes_; - OpOutputList* output_; + ComputeOptions* const compute_opts_; // Not owned const OpKernel* kernel_; const size_t iter_; }; -}; // namespace +}; REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc new file mode 100644 index 0000000000..63025d3371 --- /dev/null +++ b/tensorflow/core/kernels/data/model_dataset_op.cc @@ -0,0 +1,146 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/platform/cpu_info.h" + +namespace tensorflow { +namespace data { +namespace { + +const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros; + +class ModelDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ModelDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Model")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { return "ModelDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + model_(std::make_shared<model::Model>()) {} + + Status Initialize(IteratorContext* ctx) override { + IteratorContext ctx_with_model(CreateParams(ctx)); + return dataset()->input_->MakeIterator(&ctx_with_model, prefix(), + &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + int64 now = ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + if (last_optimization_ms_ + optimization_period_ms_ < now) { + model_->Optimize(port::NumSchedulableCPUs()); + // Exponentially increase the period of running the optimization until + // a threshold is reached. + if (optimization_period_ms_ < kOptimizationPeriodThresholdMs) { + if (optimization_period_ms_ << 1 < kOptimizationPeriodThresholdMs) { + optimization_period_ms_ <<= 1; + } else { + optimization_period_ms_ = kOptimizationPeriodThresholdMs; + } + } + last_optimization_ms_ = + ctx->env()->NowMicros() / EnvTime::kMillisToMicros; + } + IteratorContext ctx_with_model(CreateParams(ctx)); + return input_impl_->GetNext(&ctx_with_model, out_tensors, + end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + IteratorContext::Params CreateParams(IteratorContext* ctx) { + IteratorContext::Params params = ctx->params(); + params.model = model_; + return params; + } + + private: + mutex mu_; + std::shared_ptr<model::Model> model_; + int64 last_optimization_ms_ GUARDED_BY(mu_) = 0; + int64 optimization_period_ms_ GUARDED_BY(mu_) = 10; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* input_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU), + ModelDatasetOp); +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index 831e7252da..d5b725eac9 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/core/protobuf/rewriter_config.pb.h" namespace tensorflow { +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -92,8 +93,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { DatasetGraphDefBuilder db(&b); Node* input_node = nullptr; SerializationContext::Params params; + std::vector<std::pair<string, Tensor>> input_list; params.allow_stateful_functions = true; params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition(); + params.input_list = &input_list; SerializationContext serialization_ctx(params); TF_RETURN_IF_ERROR( db.AddInputDataset(&serialization_ctx, input_, &input_node)); @@ -118,7 +121,7 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { GraphRunner graph_runner(ctx->function_library()->device()); TF_RETURN_IF_ERROR( - graph_runner.Run(&graph, lib_, {}, {output_node}, &outputs)); + graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs)); TF_RETURN_IF_ERROR( GetDatasetFromVariantTensor(outputs[0], &optimized_input_)); optimized_input_->Ref(); @@ -268,4 +271,5 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU), OptimizeDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc index cfac45dbc7..346e4ceebd 100644 --- a/tensorflow/core/kernels/data/optional_ops.cc +++ b/tensorflow/core/kernels/data/optional_ops.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_op_registry.h" namespace tensorflow { +namespace data { namespace { const char kOptionalVariantTypeName[] = "tensorflow::data::Optional"; @@ -107,11 +108,8 @@ class OptionalFromValueOp : public OpKernel { void Compute(OpKernelContext* ctx) override { OpInputList components_input; OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input)); - std::vector<Tensor> components; - components.reserve(components_input.size()); - for (const Tensor& component_t : components_input) { - components.push_back(component_t); - } + std::vector<Tensor> components(components_input.begin(), + components_input.end()); OP_REQUIRES_OK( ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components))); } @@ -230,10 +228,9 @@ static Status OptionalDeviceCopy( return Status::OK(); } -#define REGISTER_OPTIONAL_COPY(DIRECTION) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - OptionalVariant, DIRECTION, kOptionalVariantTypeName, \ - OptionalDeviceCopy) +#define REGISTER_OPTIONAL_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ + OptionalVariant, DIRECTION, OptionalDeviceCopy) REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); @@ -267,4 +264,5 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) { return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h index 6f25567678..2cbf2933f5 100644 --- a/tensorflow/core/kernels/data/optional_ops.h +++ b/tensorflow/core/kernels/data/optional_ops.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/variant_tensor_data.h" namespace tensorflow { +namespace data { // Stores a DT_VARIANT value representing an Optional with the given value // in the `output_index`^th output of the given kernel execution context. @@ -31,6 +32,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index, // in the `output_index`^th output of the given kernel execution context. Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_ diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc index be45eac46e..7b01c3b4e0 100644 --- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { + AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); } @@ -382,5 +383,5 @@ REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU), PaddedBatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index f6b3fd97e3..9cd46bf5dd 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include <deque> +#include <utility> #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/partial_tensor_shape.h" @@ -21,11 +22,12 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -34,8 +36,7 @@ namespace { class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); @@ -43,14 +44,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int64 cycle_length = 0; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "cycle_length", &cycle_length)); @@ -82,8 +75,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK( - ctx, CapturedFunction::Create( - interleave_func_, std::move(other_arguments), &captured_func)); + ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, interleave_func_, std::move(captured_func), @@ -125,6 +118,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_dtypes() const override { return output_types_; } + const std::vector<PartialTensorShape>& output_shapes() const override { return output_shapes_; } @@ -250,6 +244,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } Status Initialize(IteratorContext* ctx) override { + AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_); TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); return dataset()->captured_func_->Instantiate(ctx); @@ -349,11 +344,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (must_wait_for_input) { // Wait for elements to become available. + StopWork(ctx); if (dataset()->sloppy_) { sloppy_cond_var_.wait(l); } else { workers_[interleave_indices_[next_index_]].cond_var.wait(l); } + StartWork(ctx); } } return errors::Cancelled( @@ -482,10 +479,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (reader->Contains(full_name("worker_threads_running"))) { worker_threads_.reserve(dataset()->num_threads()); for (size_t i = 0; i < dataset()->num_threads(); ++i) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", - std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i))); + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); } } return Status::OK(); @@ -581,10 +578,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } workers_[i].SetInputs(s, std::move(args)); + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); worker_threads_.emplace_back(ctx->env()->StartThread( {}, "worker_thread", - std::bind(&Iterator::WorkerThread, this, - new IteratorContext(*ctx), i))); + [this, new_ctx, i]() { WorkerThread(new_ctx, i); })); if (i < dataset()->cycle_length_) { interleave_indices_.push_back(i); } else { @@ -599,7 +596,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { } // Produces elements into the worker's output buffers. - void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) { + void WorkerThread(const std::shared_ptr<IteratorContext>& ctx, + const int64 thread_index) { // Notes on checkpointing thread local state, i.e., `WorkerThreadState`: // // 1. Any local state that may need to be checkpointed should be kept @@ -620,10 +618,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // std::function arguments are copy-constructable, so we pass raw // pointers, and then immediately wrap them to ensure correct ownership. - std::unique_ptr<IteratorContext> ctx(ctx_ptr); - auto cleanup = gtl::MakeCleanup([this, thread_index] { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] { mutex_lock l(mu_); workers_[thread_index].cond_var.notify_all(); + StopWork(ctx.get()); }); bool make_new_iterator; { @@ -649,9 +648,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // 1. Build a new iterator or use the existing one. if (make_new_iterator) { // 1a. Get new input tensors or use the exiting ones. - bool read_new_input; - { tf_shared_lock l(ckpt_mu_); // worker_thread_states_[thread_index].input will be non-empty @@ -663,7 +660,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { if (read_new_input) { mutex_lock l(mu_); while (!cancelled_ && !workers_[thread_index].is_producing) { + StopWork(ctx.get()); workers_[thread_index].cond_var.wait(l); + StartWork(ctx.get()); } if (cancelled_) return; // Copy the input tensors so that we do not need to block on `mu_` @@ -684,7 +683,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { { tf_shared_lock l(ckpt_mu_); worker_thread_states_[thread_index].iterator_creation_status = - dataset::MakeIteratorFromInputElement( + MakeIteratorFromInputElement( ctx.get(), worker_thread_states_[thread_index].input, thread_index, dataset()->captured_func_.get(), prefix(), &worker_thread_states_[thread_index].iterator); @@ -713,7 +712,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Wait for space in the prefetch queue. while (!cancelled_ && workers_[thread_index].outputs.size() == dataset()->buffer_output_elements_) { + StopWork(ctx.get()); workers_[thread_index].cond_var.wait(l); + StartWork(ctx.get()); } if (cancelled_) return; tf_shared_lock ckpt_l(ckpt_mu_); @@ -762,7 +763,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { // Wait for space in the prefetch queue. while (!cancelled_ && workers_[thread_index].outputs.size() == dataset()->buffer_output_elements_) { + StopWork(ctx.get()); workers_[thread_index].cond_var.wait(l); + StartWork(ctx.get()); } if (cancelled_) return; @@ -914,7 +917,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { worker_thread_states_[index].iterator.reset(); } else { std::unique_ptr<IteratorBase> iterator; - Status s = dataset::MakeIteratorFromInputElement( + Status s = MakeIteratorFromInputElement( ctx, worker_thread_states_[index].input, index, dataset()->captured_func_.get(), prefix(), &iterator); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator)); @@ -1058,7 +1061,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList interleave_func_; @@ -1067,6 +1069,616 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), ParallelInterleaveDatasetOp); -} // namespace +// The motivation for creating an alternative implementation of parallel +// interleave is to decouple the degree of parallelism from the cycle length. +// This makes it possible to change the degree of parallelism (e.g. through +// auto-tuning) without changing the cycle length (which would change the order +// in which elements are produced). +// +// Furthermore, this class favors modularity over extended functionality. In +// particular, it refrains from implementing configurable buffering of output +// elements and prefetching of input iterators, relying on other parts of +// tf.data to provide this functionality if necessary. +// +// The above design choices were made with automated optimizations in mind, +// isolating the degree of parallelism as the single tunable knob of this +// implementation. +class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { + public: + explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + int64 cycle_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "cycle_length", &cycle_length)); + OP_REQUIRES(ctx, cycle_length > 0, + errors::InvalidArgument("`cycle_length` must be > 0")); + + int64 block_length = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument(ctx, "block_length", &block_length)); + OP_REQUIRES(ctx, block_length > 0, + errors::InvalidArgument("`block_length` must be > 0")); + + int64 num_parallel_calls; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", + &num_parallel_calls)); + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, + errors::InvalidArgument( + "num_parallel_calls must be greater than zero.")); + OP_REQUIRES( + ctx, num_parallel_calls <= cycle_length, + errors::InvalidArgument( + "num_parallel_calls must less than or equal to cycle_length.")); + + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK( + ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments", + &captured_func)); + + *output = new Dataset(ctx, input, interleave_func_, + std::move(captured_func), cycle_length, block_length, + num_parallel_calls, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length, + int64 block_length, int64 num_parallel_calls, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + interleave_func_(func), + captured_func_(std::move(captured_func)), + cycle_length_(cycle_length), + block_length_(block_length), + num_parallel_calls_(num_parallel_calls), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::ParallelInterleaveV2")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ParallelInterleaveDatasetV2Op::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name())); + Node* input_node; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node)); + Node* cycle_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node)); + Node* block_length_node; + TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node)); + Node* num_parallel_calls_node; + TF_RETURN_IF_ERROR( + b->AddScalar(num_parallel_calls_, &num_parallel_calls_node)); + DataTypeVector other_arguments_types; + other_arguments_types.reserve(captured_func_->captured_inputs().size()); + std::vector<Node*> other_arguments; + other_arguments.reserve(captured_func_->captured_inputs().size()); + for (const Tensor& t : captured_func_->captured_inputs()) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + other_arguments.emplace_back(node); + other_arguments_types.emplace_back(t.dtype()); + } + AttrValue f; + b->BuildAttrValue(interleave_func_, &f); + AttrValue other_arguments_types_attr; + b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {{0, input_node}, + {2, cycle_length_node}, + {3, block_length_node}, + {4, num_parallel_calls_node}}, + {{1, other_arguments}}, + {{"f", f}, {"Targuments", other_arguments_types_attr}}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + args_list_(params.dataset->cycle_length_), + current_elements_(params.dataset->cycle_length_), + element_in_use_(params.dataset->cycle_length_, false), + num_parallel_calls_(params.dataset->num_parallel_calls_), + thread_pool_(new thread::ThreadPool( + Env::Default(), ThreadOptions(), "parallel_interleave", + dataset()->cycle_length_ /* num_threads */, + false /* low_latency_hint */)) {} + + ~Iterator() override { + mutex_lock l(mu_); + // Cancel the runner thread. + cancelled_ = true; + cond_var_.notify_all(); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + } + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + auto set_fn = [this](int64 value) { + { + mutex_lock l(mu_); + num_parallel_calls_ = value; + } + VLOG(2) << "setting parallelism knob to " << value; + cond_var_.notify_all(); + }; + AddTunableParameter( + ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */, + dataset()->cycle_length_ /* max */, std::move(set_fn)); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } + AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); + TF_RETURN_IF_ERROR( + dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); + return dataset()->captured_func_->Instantiate(ctx); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + std::shared_ptr<InvocationResult> result; + do { + { + mutex_lock l(mu_); + EnsureRunnerThreadStarted(ctx); + while (invocation_results_.empty() && + (!end_of_input_ || num_open_ > 0)) { + StopWork(ctx); + cond_var_.wait(l); + StartWork(ctx); + } + if (!invocation_results_.empty()) { + std::swap(result, invocation_results_.front()); + invocation_results_.pop_front(); + } else { + *end_of_sequence = true; + return Status::OK(); + } + } + cond_var_.notify_all(); + StopWork(ctx); + result->notification.WaitForNotification(); + StartWork(ctx); + } while (result->skip); + + if (result->status.ok()) { + *out_tensors = std::move(result->return_values); + } + *end_of_sequence = false; + return result->status; + } + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + // Wait for all in-flight calls to complete. + while (num_calls_ > 0) { + cond_var_.wait(l); + } + CHECK_EQ(num_calls_, 0); + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("invocation_results.size"), invocation_results_.size())); + for (size_t i = 0; i < invocation_results_.size(); i++) { + std::shared_ptr<InvocationResult> result = invocation_results_[i]; + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + result->return_values.size())); + for (size_t j = 0; j < result->return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name( + strings::StrCat("invocation_results[", i, "][", j, "]")), + result->return_values[j])); + } + if (result->skip) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("invocation_results[", i, "].skip")), + "")); + } + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("cycle_index"), cycle_index_)); + if (end_of_input_) { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("end_of_input"), "")); + } + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("num_open"), num_open_)); + TF_RETURN_IF_ERROR(WriteCurrentElements(writer)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + int64 invocation_results_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name("invocation_results.size"), &invocation_results_size)); + for (size_t i = 0; i < invocation_results_size; i++) { + std::shared_ptr<InvocationResult> result(new InvocationResult()); + invocation_results_.push_back(result); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + size_t num_return_values; + { + int64 size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("invocation_results[", i, "].size")), + &size)); + num_return_values = static_cast<size_t>(size); + if (num_return_values != size) { + return errors::InvalidArgument(strings::StrCat( + full_name( + strings::StrCat("invocation_results[", i, "].size")), + ": ", size, " is not a valid value of type size_t.")); + } + } + result->return_values.reserve(num_return_values); + for (size_t j = 0; j < num_return_values; j++) { + result->return_values.emplace_back(); + TF_RETURN_IF_ERROR( + reader->ReadTensor(full_name(strings::StrCat( + "invocation_results[", i, "][", j, "]")), + &result->return_values.back())); + } + result->skip = reader->Contains( + full_name(strings::StrCat("invocation_results[", i, "].skip"))); + result->notification.Notify(); + } + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("cycle_index"), &cycle_index_)); + if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true; + TF_RETURN_IF_ERROR( + reader->ReadScalar(full_name("num_open"), &num_open_)); + TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader)); + return Status::OK(); + } + + private: + struct InvocationResult { + Notification notification; // used for coordination with the consumer + Status status; // the invocation status + std::vector<Tensor> return_values; // the invocation result values + bool skip; // if set the result should be skipped + }; + + void EnsureRunnerThreadStarted(IteratorContext* ctx) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!runner_thread_) { + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + runner_thread_.reset(ctx->env()->StartThread( + {}, "runner_thread", + [this, new_ctx]() { RunnerThread(new_ctx); })); + } + } + + // Fetches up to `results.size()` outputs from the cycle element at + // position `cycle_index`. + // + // If end of input is encountered, the `skip` field of the invocation + // result is used to identify results that should be skipped. + void FetchOutputs( + const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, + const std::vector<std::shared_ptr<InvocationResult>>& results) + LOCKS_EXCLUDED(mu_) { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); + bool end_of_input = false; + for (auto& result : results) { + if (!end_of_input) { + result->status = current_elements_[cycle_index]->GetNext( + ctx.get(), &result->return_values, &end_of_input); + } + if (end_of_input) { + result->skip = true; + } + result->notification.Notify(); + if (!result->status.ok()) { + break; + } + } + + // Release the ownership of the cycle element iterator, closing the + // iterator if end of input was encountered. + { + if (end_of_input) { + current_elements_[cycle_index].reset(); + } + mutex_lock l(mu_); + element_in_use_[cycle_index] = false; + num_calls_--; + if (end_of_input) { + args_list_[cycle_index].clear(); + num_open_--; + } + } + cond_var_.notify_all(); + } + + int64 MaxInvocationResults() { + return dataset()->cycle_length_ * dataset()->block_length_; + } + + // Method responsible for 1) creating iterators out of input elements, 2) + // determining the order in which elements are fetched from the iterators, + // and 3) scheduling the fetching of the elements to a threadpool. + // + // This method runs in the `runner_thread` background thread. + void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); + while (true) { + { + mutex_lock l(mu_); + // Wait until this thread is cancelled, the end of input has been + // reached, or the cycle element at the `cycle_index_` position is + // not in use and there is space in the `invocation_results_` queue. + while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && + (element_in_use_[cycle_index_] || + num_calls_ >= num_parallel_calls_ || + invocation_results_.size() >= MaxInvocationResults())) { + StopWork(ctx.get()); + cond_var_.wait(l); + StartWork(ctx.get()); + } + + if (cancelled_ || (end_of_input_ && num_open_ == 0)) { + return; + } + + while (!element_in_use_[cycle_index_] && + (!end_of_input_ || num_open_ > 0) && + num_calls_ < num_parallel_calls_ && + invocation_results_.size() < MaxInvocationResults()) { + if (!current_elements_[cycle_index_]) { + // Try to create a new iterator from the next input element. + Status status = input_impl_->GetNext( + ctx.get(), &args_list_[cycle_index_], &end_of_input_); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + if (!end_of_input_) { + Status status = MakeIteratorFromInputElement( + ctx.get(), args_list_[cycle_index_], cycle_index_, + dataset()->captured_func_.get(), prefix(), + ¤t_elements_[cycle_index_]); + if (!status.ok()) { + invocation_results_.emplace_back(new InvocationResult()); + std::shared_ptr<InvocationResult>& result = + invocation_results_.back(); + result->status.Update(status); + result->notification.Notify(); + break; + } + ++num_open_; + } + } + if (current_elements_[cycle_index_]) { + // Pre-allocate invocation results for outputs to be fetched + // and then fetch the outputs asynchronously. + std::vector<std::shared_ptr<InvocationResult>> results; + results.reserve(dataset()->block_length_); + for (int i = 0; i < dataset()->block_length_; ++i) { + invocation_results_.emplace_back(new InvocationResult()); + results.push_back(invocation_results_.back()); + } + num_calls_++; + element_in_use_[cycle_index_] = true; + thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this, + ctx, cycle_index_, + std::move(results))); + } + cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; + } + } + cond_var_.notify_all(); + } + } + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name( + strings::StrCat("invocation_results[", index, "].error_message")); + } + + Status WriteCurrentElements(IteratorStateWriter* writer) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (current_elements_[idx]) { + TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + args_list_[idx].size())); + for (int i = 0; i < args_list_[idx].size(); i++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + args_list_[idx][i])); + } + } + } + return Status::OK(); + } + + Status ReadCurrentElements(IteratorContext* ctx, + IteratorStateReader* reader) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + for (int idx = 0; idx < current_elements_.size(); idx++) { + if (reader->Contains( + full_name(strings::StrCat("args_size[", idx, "]")))) { + int64 args_size; + TF_RETURN_IF_ERROR(reader->ReadScalar( + full_name(strings::StrCat("args_size[", idx, "]")), + &args_size)); + args_list_[idx].resize(args_size); + for (int i = 0; i < args_size; i++) { + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("args_list_[", idx, "][", i, "]")), + &args_list_[idx][i])); + } + TF_RETURN_IF_ERROR(MakeIteratorFromInputElement( + ctx, args_list_[idx], idx, dataset()->captured_func_.get(), + prefix(), ¤t_elements_[idx])); + TF_RETURN_IF_ERROR( + RestoreInput(ctx, reader, current_elements_[idx])); + } else { + current_elements_[idx].reset(); + } + } + return Status::OK(); + } + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. + mutex mu_; + + // Used for coordination between the main thread, the runner thread, and + // the worker threads. In particular, the runner thread should only + // schedule new calls when the number of in-flight calls is less than the + // user specified level of parallelism, there are slots available in the + // `invocation_results_` buffer, the current cycle element is not in use, + // and there are elements left to be fetched. + condition_variable cond_var_; + + // Iterator for input elements. + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + + // Identifies current cycle element. + int64 cycle_index_ = 0; + + // Arguments for creating an iterator for cycle elements. + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + + // Iterators for the current cycle elements. Concurrent access is + // protected by `element_in_use_`. + std::vector<std::unique_ptr<IteratorBase>> current_elements_; + + // Identifies cycle elements that are in use by worker threads. + std::vector<bool> element_in_use_ GUARDED_BY(mu_); + + // Buffer for storing the invocation results. + std::deque<std::shared_ptr<InvocationResult>> invocation_results_ + GUARDED_BY(mu_); + + // Identifies whether end of input has been reached. + bool end_of_input_ GUARDED_BY(mu_) = false; + + // Identifies the number of open iterators. + int64 num_open_ GUARDED_BY(mu_) = 0; + + // Identifies the maximum number of parallel calls. + int64 num_parallel_calls_ GUARDED_BY(mu_) = 0; + + // Identifies the number of outstanding calls. + int64 num_calls_ GUARDED_BY(mu_) = 0; + + std::unique_ptr<thread::ThreadPool> thread_pool_; + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + + // Identifies whether background activity should be cancelled. + bool cancelled_ GUARDED_BY(mu_) = false; + }; + + const DatasetBase* const input_; + const NameAttrList interleave_func_; + const std::unique_ptr<CapturedFunction> captured_func_; + const int64 cycle_length_; + const int64 block_length_; + const int64 num_parallel_calls_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + NameAttrList interleave_func_; +}; + +REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU), + ParallelInterleaveDatasetV2Op); + +} // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index bff54813d6..6abe6c8338 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -33,37 +33,32 @@ namespace { class ParallelMapDatasetOp : public UnaryDatasetOpKernel { public: explicit ParallelMapDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); } protected: void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } - int32 num_parallel_calls; OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls", &num_parallel_calls)); - OP_REQUIRES(ctx, num_parallel_calls > 0, + OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune, errors::InvalidArgument( "num_parallel_calls must be greater than zero.")); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + use_inter_op_parallelism_, + &captured_func)); *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, - output_shapes_, std::move(captured_func)); + output_shapes_, use_inter_op_parallelism_, + std::move(captured_func)); } private: @@ -73,6 +68,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, int32 num_parallel_calls, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, + bool use_inter_op_parallelism, std::unique_ptr<CapturedFunction> captured_func) : DatasetBase(DatasetContext(ctx)), input_(input), @@ -80,6 +76,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { num_parallel_calls_(num_parallel_calls), output_types_(output_types), output_shapes_(output_shapes), + use_inter_op_parallelism_(use_inter_op_parallelism), captured_func_(std::move(captured_func)) { input_->Ref(); } @@ -92,16 +89,26 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - auto map_func = [this](IteratorContext* ctx, + const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); + ParallelMapIteratorFunction map_func = + [this, new_prefix](IteratorContext* ctx, std::vector<Tensor> input_element, std::vector<Tensor>* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done)); - }; + captured_func_->RunAsync(ctx, std::move(input_element), result, + std::move(done), new_prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func]( + IteratorContext* ctx, std::vector<Tensor> input_element, + std::vector<Tensor>* result, StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), + result, std::move(done))); + }; + } - return NewParallelMapIterator( - {this, strings::StrCat(prefix, "::ParallelMap")}, input_, - std::move(init_func), std::move(map_func), num_parallel_calls_); + return NewParallelMapIterator({this, new_prefix}, input_, + std::move(init_func), std::move(map_func), + num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -167,12 +174,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const int32 num_parallel_calls_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const bool use_inter_op_parallelism_; const std::unique_ptr<CapturedFunction> captured_func_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; + bool use_inter_op_parallelism_; NameAttrList func_; }; @@ -180,5 +188,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU), ParallelMapDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index 61f8139b9e..5f6052ce83 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -19,7 +19,11 @@ limitations under the License. #include <utility> #include <vector> +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/platform/cpu_info.h" + namespace tensorflow { +namespace data { namespace { class ParallelMapIterator : public DatasetBaseIterator { @@ -52,6 +56,25 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + if (num_parallel_calls_ == kAutoTune) { + num_parallel_calls_ = 1; + auto set_fn = [this](int64 value) { + { + mutex_lock l(mu_); + num_parallel_calls_ = value; + } + VLOG(2) << "setting parallelism knob to " << value; + cond_var_.notify_all(); + }; + // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and + // use it here for the maximum. + AddTunableParameter(ctx, "parallelism", num_parallel_calls_ /* value */, + 1 /* min */, port::NumSchedulableCPUs() /* max */, + std::move(set_fn)); + } else { + AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + } TF_RETURN_IF_ERROR( input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); if (init_func_) { @@ -67,13 +90,17 @@ class ParallelMapIterator : public DatasetBaseIterator { mutex_lock l(mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty()) { + StopWork(ctx); cond_var_.wait(l); + StartWork(ctx); } std::swap(result, invocation_results_.front()); invocation_results_.pop_front(); } cond_var_.notify_all(); + StopWork(ctx); result->notification.WaitForNotification(); + StartWork(ctx); return ProcessResult(result, out_tensors, end_of_sequence); } @@ -86,9 +113,8 @@ class ParallelMapIterator : public DatasetBaseIterator { } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - TF_RETURN_IF_ERROR( - writer->WriteScalar(full_name("invocation_results.size"), - invocation_results_.size())); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"), + invocation_results_.size())); for (size_t i = 0; i < invocation_results_.size(); i++) { std::shared_ptr<InvocationResult> result = invocation_results_[i]; TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); @@ -204,8 +230,6 @@ class ParallelMapIterator : public DatasetBaseIterator { std::move(done)); } - int64 MaxInvocationResults() { return num_parallel_calls_; } - Status ProcessResult(const std::shared_ptr<InvocationResult>& result, std::vector<Tensor>* out_tensors, bool* end_of_sequence) { @@ -225,21 +249,28 @@ class ParallelMapIterator : public DatasetBaseIterator { } void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); std::vector<std::shared_ptr<InvocationResult>> new_calls; - new_calls.reserve(num_parallel_calls_); + { + tf_shared_lock l(mu_); + new_calls.reserve(num_parallel_calls_); + } while (true) { { mutex_lock l(mu_); while (!cancelled_ && (num_calls_ >= num_parallel_calls_ || - invocation_results_.size() >= MaxInvocationResults())) { + invocation_results_.size() >= num_parallel_calls_)) { + StopWork(ctx.get()); cond_var_.wait(l); + StartWork(ctx.get()); } if (cancelled_) { return; } while (num_calls_ < num_parallel_calls_ && - invocation_results_.size() < MaxInvocationResults()) { + invocation_results_.size() < num_parallel_calls_) { invocation_results_.emplace_back(new InvocationResult()); new_calls.push_back(invocation_results_.back()); num_calls_++; @@ -294,7 +325,6 @@ class ParallelMapIterator : public DatasetBaseIterator { const DatasetBase* const input_dataset_; // Not owned. const std::function<Status(IteratorContext*)> init_func_; const ParallelMapIteratorFunction map_func_; - const int32 num_parallel_calls_; // Used for coordination between the main thread and the runner thread. mutex mu_; // Used for coordination between the main thread and the runner thread. In @@ -303,6 +333,8 @@ class ParallelMapIterator : public DatasetBaseIterator { // parallelism and there are slots available in the `invocation_results_` // buffer. condition_variable cond_var_; + // Identifies the maximum number of parallel calls. + int64 num_parallel_calls_ GUARDED_BY(mu_) = 0; // Counts the number of outstanding calls. int64 num_calls_ GUARDED_BY(mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; @@ -333,4 +365,5 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( std::move(map_func), num_parallel_calls)); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index 7e6cc586f3..dc26c5cf25 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/dataset.h" namespace tensorflow { +namespace data { // A function that transforms elements of one dataset into another // asynchronously. The arguments are: @@ -47,6 +48,7 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func, int32 num_parallel_calls); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_ diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index 9057800d94..c28c06da62 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/core/util/example_proto_fast_parsing.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -87,11 +87,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "Expected len(dense_defaults) == len(dense_keys) but got: ", dense_default_tensors.size(), " vs. ", dense_keys_.size())); - std::vector<Tensor> dense_defaults; - dense_defaults.reserve(dense_default_tensors.size()); - for (const Tensor& dense_default_t : dense_default_tensors) { - dense_defaults.push_back(dense_default_t); - } + std::vector<Tensor> dense_defaults(dense_default_tensors.begin(), + dense_default_tensors.end()); for (int d = 0; d < dense_keys_.size(); ++d) { const Tensor& def_value = dense_defaults[d]; @@ -368,5 +365,5 @@ REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU), ParseExampleDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc index b3272f6bcd..da357339c9 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_autotuner.h" namespace tensorflow { +namespace data { PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) : buffer_limit_(initial_buffer_size) { @@ -25,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size) } } +namespace { +// Determines what strategy to use for increasing the buffer size limit. For +// limits less than the threshold, an exponential increase is used, while for +// limits greater than or equal to the threshold, a linear increase is used. +size_t kBufferLimitThreshold = 2048; +} // namespace + void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { switch (mode_) { case Mode::kDisabled: @@ -36,11 +44,16 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) { return; case Mode::kDownswing: if (current_buffer_size == 0) { - buffer_limit_ *= 2; // Increase the buffer size. + if (buffer_limit_ >= kBufferLimitThreshold) { + buffer_limit_ += kBufferLimitThreshold; + } else { + buffer_limit_ *= 2; + } mode_ = Mode::kUpswing; } return; } } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h index fa8a184072..8693205512 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner.h +++ b/tensorflow/core/kernels/data/prefetch_autotuner.h @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { +namespace data { // PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator. // @@ -66,6 +67,7 @@ class PrefetchAutotuner { Mode mode_ = Mode::kDisabled; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_ diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc index 29a8cc50cd..cfc324fc7e 100644 --- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc +++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/platform/test.h" namespace tensorflow { +namespace data { namespace { TEST(PrefetchAutotuner, Disabled) { @@ -79,4 +80,5 @@ TEST(PrefetchAutotuner, EnabledSteady) { } } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc index 50efbcbe2a..52c421caee 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc @@ -12,15 +12,19 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <deque> - #include "tensorflow/core/kernels/data/prefetch_dataset_op.h" +#include <deque> + #include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { +namespace data { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. @@ -70,7 +74,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), - auto_tuner_(params.dataset->buffer_size_) {} + auto_tuner_(params.dataset->buffer_size_) { + std::vector<string> components = + str_util::Split(params.prefix, "::", str_util::SkipEmpty()); + prefix_end_ = components.back(); + } ~Iterator() override { // Signal the prefetch thread to terminate it. We will then @@ -97,13 +105,16 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { bool* end_of_sequence) override { { mutex_lock l(mu_); + auto stats_aggregator = ctx->stats_aggregator(); TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx)); // Wait until the next element in the buffer has been // produced, or we are shutting down. while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ && auto_tuner_.buffer_limit() != 0) { auto_tuner_.RecordEmpty(); + StopWork(ctx); cond_var_.wait(l); + StartWork(ctx); } if (cancelled_) { @@ -112,7 +123,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { } if (!buffer_.empty()) { - return Consume(out_tensors, end_of_sequence); + return Consume(out_tensors, end_of_sequence, stats_aggregator); } if (prefetch_thread_finished_) { @@ -200,14 +211,22 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { std::vector<Tensor> value; }; - Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence) + Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence, + const std::shared_ptr<StatsAggregator>& stats_aggregator) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (stats_aggregator) { + stats_aggregator->AddToHistogram( + strings::StrCat(prefix_end_, "::buffer_utilization"), + {static_cast<float>(buffer_.size()) / + static_cast<float>(auto_tuner_.buffer_limit())}); + } // A new element is available. Forward the status from computing it, and // (if we successfully got an element) the output values. Status s = buffer_.front().status; if (s.ok()) { *out_tensors = std::move(buffer_.front().value); } + auto_tuner_.RecordConsumption(buffer_.size()); buffer_.pop_front(); *end_of_sequence = false; @@ -223,10 +242,10 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { Status EnsurePrefetchThreadStarted(IteratorContext* ctx) EXCLUSIVE_LOCKS_REQUIRED(mu_) { if (!prefetch_thread_) { - prefetch_thread_.reset( - ctx->env()->StartThread({}, "prefetch_thread", - std::bind(&Iterator::PrefetchThread, this, - new IteratorContext(*ctx)))); + std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); + prefetch_thread_.reset(ctx->env()->StartThread( + {}, "prefetch_thread", + [this, new_ctx]() { PrefetchThread(new_ctx); })); } return Status::OK(); } @@ -235,8 +254,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { // buffer. // // It owns the iterator context passed to it. - void PrefetchThread(IteratorContext* ctx) { - std::unique_ptr<IteratorContext> cleanup(ctx); + void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) { + StartWork(ctx.get()); + auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); }); while (true) { std::vector<Tensor> value; @@ -244,7 +264,9 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { { mutex_lock l(mu_); while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) { + StopWork(ctx.get()); cond_var_.wait(l); + StartWork(ctx.get()); } if (cancelled_) { @@ -261,8 +283,8 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex_lock parent_l(parent_mu_); bool end_of_sequence; BufferElement buffer_element; - buffer_element.status = - input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence); + buffer_element.status = input_impl_->GetNext( + ctx.get(), &buffer_element.value, &end_of_sequence); if (buffer_element.status.ok() && end_of_sequence) { mutex_lock l(mu_); prefetch_thread_finished_ = true; @@ -324,6 +346,7 @@ class PrefetchDatasetOp::Dataset : public DatasetBase { mutex parent_mu_ ACQUIRED_BEFORE(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_); condition_variable cond_var_; + string prefix_end_; PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_); std::deque<BufferElement> buffer_ GUARDED_BY(mu_); std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_); @@ -346,6 +369,7 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input, *output = new Dataset(ctx, input, buffer_size); } +namespace { REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU), PrefetchDatasetOp); REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") @@ -354,4 +378,7 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset") .HostMemory("input_dataset") .HostMemory("handle"), PrefetchDatasetOp); +} // namespace + +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h index c40c4b00da..588fb25a06 100644 --- a/tensorflow/core/kernels/data/prefetch_dataset_op.h +++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/prefetch_autotuner.h" namespace tensorflow { +namespace data { class PrefetchDatasetOp : public UnaryDatasetOpKernel { public: @@ -34,6 +35,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel { class Dataset; }; +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_ diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc index 7817170e73..044a791a3f 100644 --- a/tensorflow/core/kernels/data/random_dataset_op.cc +++ b/tensorflow/core/kernels/data/random_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random_distributions.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -151,5 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU), RandomDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc index aa38775125..89fbaae369 100644 --- a/tensorflow/core/kernels/data/range_dataset_op.cc +++ b/tensorflow/core/kernels/data/range_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -142,5 +142,5 @@ REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU), RangeDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc index 086b552936..c474cb4773 100644 --- a/tensorflow/core/kernels/data/reader_dataset_ops.cc +++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/io/zlib_inputstream.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -691,5 +691,5 @@ REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU), TFRecordDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc index 299949b99f..94e96635ab 100644 --- a/tensorflow/core/kernels/data/repeat_dataset_op.cc +++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -250,5 +250,5 @@ REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU), RepeatDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index fccad933d0..dbe31f37b8 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -45,23 +45,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { OpInputList initial_state_inputs; OP_REQUIRES_OK(ctx, ctx->input_list("initial_state", &initial_state_inputs)); - std::vector<Tensor> initial_state; - initial_state.reserve(initial_state_inputs.size()); - for (const Tensor& t : initial_state_inputs) { - initial_state.push_back(t); - } - - OpInputList inputs; - OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs)); - std::vector<Tensor> other_arguments; - other_arguments.reserve(inputs.size()); - for (const Tensor& t : inputs) { - other_arguments.push_back(t); - } + std::vector<Tensor> initial_state(initial_state_inputs.begin(), + initial_state_inputs.end()); std::unique_ptr<CapturedFunction> captured_func; - OP_REQUIRES_OK(ctx, CapturedFunction::Create( - func_, std::move(other_arguments), &captured_func)); + OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", + &captured_func)); *output = new Dataset(ctx, input, func_, std::move(initial_state), std::move(captured_func), state_types_, output_types_, @@ -279,5 +268,5 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc index 93a4376836..66466d6a36 100644 --- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc +++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { - +namespace data { namespace { const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds. @@ -620,5 +620,5 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU), ShuffleAndRepeatDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc new file mode 100644 index 0000000000..5b084a16f0 --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.cc @@ -0,0 +1,380 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/executor_factory.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { +namespace data { +namespace { + +typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec; +typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec; +typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec; + +class SingleThreadedExecutorImpl : public Executor { + public: + explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params) + : params_(params) {} + + ~SingleThreadedExecutorImpl() override { + for (const KernelState& kernel_state : kernels_) { + params_.delete_kernel(kernel_state.kernel); + } + } + + Status Initialize(const Graph& graph) { + // Topologicially sort `graph` to get a sequence of OpKernels. + std::vector<Node*> ordered_nodes; + ordered_nodes.reserve(graph.num_nodes()); + GetReversePostOrder(graph, &ordered_nodes); + + if (ordered_nodes.size() != graph.num_nodes()) { + return errors::InvalidArgument("Graph had ", graph.num_nodes(), + " but reverse post-order had ", + ordered_nodes.size()); + } + + kernels_.resize(ordered_nodes.size()); + + std::unordered_map<Node*, size_t> node_to_index_map; + + // Create the kernel and input-related structures for each node in `graph`. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + node_to_index_map[n] = i; + + for (DataType dt : n->output_types()) { + if (IsRefType(dt)) { + return errors::Unimplemented( + "Single-threaded executor does not support reference-typed " + "edges."); + } + } + + if (n->IsControlFlow()) { + return errors::Unimplemented( + "Single-threaded executor does not support control flow."); + } + if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) { + return errors::Unimplemented( + "Single-threaded executor does not support partitioned graphs."); + } + if (n->IsCollective()) { + return errors::Unimplemented( + "Single-threaded executor does not support collective ops."); + } + + KernelState& kernel_state = kernels_[i]; + TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel)); + kernel_state.num_inputs = n->num_inputs(); + kernel_state.num_outputs = n->num_outputs(); + + if (i == 0) { + kernel_state.input_start_index = 0; + } else { + const KernelState& previous_kernel_state = kernels_[i - 1]; + kernel_state.input_start_index = + previous_kernel_state.input_start_index + + previous_kernel_state.num_inputs; + } + } + + // Build the mapping from each node output to the input slot for the + // corresponding destination node. + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + Node* n = ordered_nodes[i]; + KernelState& kernel_state = kernels_[i]; + kernel_state.output_locations.resize(kernel_state.num_outputs); + for (const Edge* e : n->out_edges()) { + if (!e->IsControlEdge()) { + kernel_state.output_locations[e->src_output()].push_back( + kernels_[node_to_index_map[e->dst()]].input_start_index + + e->dst_input()); + } + } + + // Compute allocator attributes for each node output, and corresponding + // node input. + kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs); + AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data(); + + OpKernel* op_kernel = kernel_state.kernel; + for (int out = 0; out < n->num_outputs(); out++) { + DCHECK_LT(out, op_kernel->output_memory_types().size()); + bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY; + if (on_host) { + AllocatorAttributes h; + h.set_on_host(on_host); + attrs[out].Merge(h); + } + } + } + + if (!kernels_.empty()) { + const KernelState& last_kernel_state = kernels_.back(); + total_num_inputs_ = + last_kernel_state.input_start_index + last_kernel_state.num_inputs; + input_alloc_attrs_.resize(total_num_inputs_); + for (size_t i = 0; i < ordered_nodes.size(); ++i) { + for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) { + for (size_t output_location : kernels_[i].output_locations[j]) { + input_alloc_attrs_[output_location] = + kernels_[i].output_alloc_attrs[j]; + } + } + } + } else { + total_num_inputs_ = 0; + } + return Status::OK(); + } + + // TODO(mrry): Consider specializing the implementation of Executor::Run() + // instead, to avoid unnecessary atomic operations in the callback when + // running synchronously. + void RunAsync(const Args& args, DoneCallback done) override { + // The inputs to each kernel are stored contiguously in `inputs`. + // + // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to + // determine the range of elements in this vector that correspond to + // the inputs of `kernels_[i]`. + // + // This vector has the following layout: + // + // * Kernel 0, input 0. + // * Kernel 0, input 1. + // * ... + // * Kernel 0, input `kernels_[0].num_inputs - 1`. + // * Kernel 1, input 0. + // * ... + // * Kernel 1, input `kernels_[1].num_inputs - 1`. + // * ... + // * Kernel `kernels_.size() - 1`, input 0. + // * ... + // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`. + // + // Note that kernels with zero inputs do not correspond to any elements in + // this vector. + // + // We use `ManualConstructor<Tensor>` to avoid the overhead of + // default-constructing an invalid `Tensor` for each slot at the beginning + // of execution: + // * Elements are initialized when the outputs of a kernel execution are + // propagated to the inputs of kernels that depend on them. + // * The elements corresponding to the inputs for kernel `i` are destroyed + // after kernel `i` executes. + // * In an error case (see below), we use the connectivity information in + // `KernelState::output_locations` to determine which locations have been + // initialized, and manually destroy them. + std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_); + + // TODO(mrry): Can we avoid copying into these vectors? Consider modifying + // OpKernelContext to take the TensorValueVec as a pointer into `inputs`. + TensorValueVec node_inputs; + DeviceContextVec input_device_contexts; + AllocatorAttributeVec input_alloc_attrs; + + // Prepare the parameters that will be the same for all kernels. + OpKernelContext::Params params; + params.step_id = args.step_id; + Device* device = params_.device; + params.device = device; + params.log_memory = false; // TODO(mrry): Too severe? + params.record_tensor_accesses = false; // TODO(mrry): Too severe? + params.rendezvous = args.rendezvous; + params.session_state = args.session_state; + params.tensor_store = args.tensor_store; + params.cancellation_manager = args.cancellation_manager; + // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor + // allocations that it performs. Consider specializing its handling in the + // executor. + params.call_frame = args.call_frame; + params.function_library = params_.function_library; + params.resource_manager = device->resource_manager(); + params.step_container = args.step_container; + params.slice_reader_cache = nullptr; // TODO(mrry): Too severe? + params.inputs = &node_inputs; + params.input_device_contexts = &input_device_contexts; + params.input_alloc_attrs = &input_alloc_attrs; + + Args::Runner runner_copy = args.runner; + params.runner = &runner_copy; + params.stats_collector = args.stats_collector; + + // NOTE(mrry): We are assuming that the graph is loopless and condless. + params.frame_iter = FrameAndIter(0, 0); + params.is_input_dead = false; + + // TODO(mrry): Add non-default device context inference. + params.op_device_context = nullptr; + // TODO(mrry): Consider implementing forwarding. + params.forward_from_array = nullptr; + + // Execute the kernels one-at-a-time in topological order. + for (size_t i = 0; i < kernels_.size(); ++i) { + const KernelState& kernel_state = kernels_[i]; + + // Prepare the per-kernel parameters. + const size_t input_start_index = kernel_state.input_start_index; + const size_t num_inputs = kernel_state.num_inputs; + const size_t num_outputs = kernel_state.num_outputs; + + node_inputs.clear(); + node_inputs.resize(num_inputs); + input_alloc_attrs.clear(); + input_alloc_attrs.resize(num_inputs); + for (size_t j = 0; j < num_inputs; ++j) { + auto t = inputs[input_start_index + j].get(); + node_inputs[j].tensor = t; + input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j]; + } + params.op_kernel = kernel_state.kernel; + input_device_contexts.clear(); + input_device_contexts.resize(num_inputs); + params.output_attr_array = kernel_state.output_alloc_attrs.data(); + OpKernelContext ctx(¶ms, num_outputs); + + // Actually execute the kernel. + device->Compute(kernel_state.kernel, &ctx); + + if (!ctx.status().ok()) { + // On failure, we must manually free all intermediate tensors. We have + // already freed all the inputs for kernels up to (but not including) + // the `i`th kernel. We scan through the previously executed kernels and + // destroy any tensors that were destined to be the input for a kernel + // that has not yet executed. + for (size_t j = 0; j < i; ++j) { + const KernelState& executed_kernel_state = kernels_[j]; + for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) { + for (size_t output_location : + executed_kernel_state.output_locations[k]) { + if (output_location >= input_start_index) { + // Only destroy an output location if it is an input to an + // operation that has not yet executed. + inputs[output_location].Destroy(); + } + } + } + } + done(ctx.status()); + return; + } + + // Free the inputs to the current kernel. + for (size_t j = 0; j < num_inputs; ++j) { + inputs[input_start_index + j].Destroy(); + } + + // Forward the outputs of the kernel to the inputs of subsequent kernels. + for (size_t j = 0; j < num_outputs; ++j) { + TensorValue val = ctx.release_output(j); + // TODO(mrry): Consider flattening the `output_locations` vector + // to improve the cache-friendliness of this loop. + for (size_t output_location : kernel_state.output_locations[j]) { + // TODO(mrry): Validate that the types match the expected values or + // ensure that the necessary validation has already happened. + inputs[output_location].Init(*val.tensor); + } + delete val.tensor; + } + } + done(Status::OK()); + } + + private: + const LocalExecutorParams params_; + + // All following members are read-only after Initialize(). + + // The sum of the number of inputs for each node in the graph. This determines + // the length of the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + size_t total_num_inputs_; + + // Represents cached graph structure state for each kernel. + struct KernelState { + // The kernel object. Not owned. + // + // This pointer is managed by `params_.create_kernel()` and + // `params_.delete_kernel()`. + OpKernel* kernel; + + // These fields determine the range of elements in `inputs` that corresponds + // to the inputs of `kernel`. + size_t input_start_index; + size_t num_inputs; + + size_t num_outputs; + + // For the `j`th output of `kernel`, `output_locations[j]` contains the + // locations in the flat `inputs` vector to which that output must be + // copied. See comment at the beginning of `RunAsync()` for details. + std::vector<std::vector<size_t>> + output_locations; // Length = `num_outputs`. + + // Memory space information for each output of `kernel`. + std::vector<AllocatorAttributes> + output_alloc_attrs; // Length = `num_outputs`. + }; + std::vector<KernelState> kernels_; + + // Memory space information for each input. This information is stored in the + // same order as the flat `inputs` vector. See comment at the beginning of + // `RunAsync()` for details. + std::vector<AllocatorAttributes> + input_alloc_attrs_; // Length = `total_num_inputs_`. +}; + +class SingleThreadedExecutorRegistrar { + public: + SingleThreadedExecutorRegistrar() { + ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory()); + } + + private: + class Factory : public ExecutorFactory { + Status NewExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + std::unique_ptr<Executor>* out_executor) override { + Executor* ret; + TF_RETURN_IF_ERROR( + NewSingleThreadedExecutor(params, std::move(graph), &ret)); + out_executor->reset(ret); + return Status::OK(); + } + }; +}; +static SingleThreadedExecutorRegistrar registrar; + +} // namespace + +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + Executor** executor) { + std::unique_ptr<SingleThreadedExecutorImpl> impl( + new SingleThreadedExecutorImpl(params)); + TF_RETURN_IF_ERROR(impl->Initialize(*graph)); + *executor = impl.release(); + return Status::OK(); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h new file mode 100644 index 0000000000..e934352a1d --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor.h @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ + +#include "tensorflow/core/common_runtime/executor.h" + +namespace tensorflow { +namespace data { + +// Creates a new `Executor` for executing `graph` synchronously on the caller +// thread. +// +// NOTE(mrry): The returned executor is optimized to impose low overhead on +// graphs that perform a small amount of work (e.g. <15us of work per graph on +// present architectures). It eschews concurrency, because issuing work to +// multiple threads can dominate the cost of executing small ops synchronously, +// and because contention in the executor data structures can reduce throughput +// (in terms of ops executed per unit time). +// +// However, the current implementation has the following limitations: +// +// 1. Reference-typed tensors are not supported and will not be supported in +// future. +// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not +// currently supported. The current plan is to extend support to "functional" +// control flow after the TensorFlow APIs transition to building graphs in +// that form (e.g. `tf.cond_v2()`). +// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported. +// The present implementation executes kernels one at a time in topological +// order, and cannot currently distinguish between disconnected subgraphs +// that are logically connected by subgraphs on a different device. +// 4. Memory logging is not currently supported. +// 5. Allocation forwarding is not currently supported. +// 6. Non-default device contexts are not currently supported. In effect, this +// limits the executor to CPU devices. +// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null +// are not currently supported. +// +// The single-threaded executor is primarily suitable for executing simple +// TensorFlow functions, such as one might find in a `tf.data` pipeline. +Status NewSingleThreadedExecutor(const LocalExecutorParams& params, + std::unique_ptr<const Graph> graph, + Executor** executor); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_ diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc new file mode 100644 index 0000000000..6244e287bb --- /dev/null +++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc @@ -0,0 +1,332 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/single_threaded_executor.h" + +#include <algorithm> + +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/common_runtime/process_util.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/framework/versions.pb.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/random/simple_philox.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace data { +namespace { + +class ExecutorTest : public ::testing::Test { + protected: + ExecutorTest() + : device_(DeviceFactory::NewDevice("CPU", {}, + "/job:localhost/replica:0/task:0")) {} + + ~ExecutorTest() override { + // There should always be exactly one Ref left on the Rendezvous + // when the test completes. + CHECK(rendez_->Unref()); + delete exec_; + delete device_; + } + + // Resets executor_ with a new executor based on a graph 'gdef'. + void Create(std::unique_ptr<const Graph> graph) { + const int version = graph->versions().producer(); + LocalExecutorParams params; + params.device = device_; + params.create_kernel = [this, version](const NodeDef& ndef, + OpKernel** kernel) { + return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel); + }; + params.delete_kernel = [](OpKernel* kernel) { + DeleteNonCachedKernel(kernel); + }; + delete exec_; + TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_)); + runner_ = [](std::function<void()> fn) { fn(); }; + rendez_ = NewLocalRendezvous(); + } + + Status Run(Rendezvous* rendez) { + Executor::Args args; + args.rendezvous = rendez; + args.runner = runner_; + return exec_->Run(args); + } + + Status Run(CallFrameInterface* call_frame) { + Executor::Args args; + args.call_frame = call_frame; + args.runner = runner_; + return exec_->Run(args); + } + + Device* device_ = nullptr; + Executor* exec_ = nullptr; + Executor::Args::Runner runner_; + Rendezvous* rendez_ = nullptr; +}; + +// A float val -> Tensor<float> +Tensor V(const float val) { + Tensor tensor(DT_FLOAT, TensorShape({})); + tensor.scalar<float>()() = val; + return tensor; +} + +// A int32 val -> Tensor<int32> +Tensor VI(const int32 val) { + Tensor tensor(DT_INT32, TensorShape({})); + tensor.scalar<int32>()() = val; + return tensor; +} + +// A bool val -> Tensor<bool> +Tensor VB(const bool val) { + Tensor tensor(DT_BOOL, TensorShape({})); + tensor.scalar<bool>()() = val; + return tensor; +} + +// A double val -> Tensor<double> +Tensor VD(const double val) { + Tensor tensor(DT_DOUBLE, TensorShape({})); + tensor.scalar<double>()() = val; + return tensor; +} + +// Tensor<float> -> a float val. +float V(const Tensor& tensor) { + CHECK_EQ(tensor.dtype(), DT_FLOAT); + CHECK(TensorShapeUtils::IsScalar(tensor.shape())); + return tensor.scalar<float>()(); +} + +Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation, + const string& receiver, const string& name) { + Rendezvous::ParsedKey result; + TF_CHECK_OK( + Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver, + name, FrameAndIter(0, 0)), + &result)); + return result; +} + +TEST_F(ExecutorTest, SimpleAdd) { + // c = a + b + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT); + auto tmp = test::graph::Add(g.get(), in0, in1); + test::graph::Retval(g.get(), 0, tmp); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0 +} + +TEST_F(ExecutorTest, SelfAdd) { + // v0 <- a + // v1 = v0 + v0 + // v2 = v1 + v1 + // ... ... + // v10 = v9 + v9 + // + // b <- v10 + // All nodes are executed by one thread. + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto v = test::graph::Arg(g.get(), 0, DT_FLOAT); + const int N = 10; + for (int i = 1; i <= N; ++i) { + v = test::graph::Add(g.get(), v, v); + } + // out <- v10 + test::graph::Retval(g.get(), 0, v); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + // a = 1.0 + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0 +} + +// Builds a graph which adds N copies of one variable "in". I.e., +// a + a + a + ... + a +// The returned graph is parenthesized ramdonly. I.e., +// a + ((a + a) + a) +// (a + a) + (a + a) +// ((a + a) + a) + a +// are all possibly generated. +void BuildTree(int N, Graph* g) { + CHECK_GT(N, 1); + // A single input node "in". + auto in = test::graph::Arg(g, 0, DT_FLOAT); + std::vector<Node*> nodes; + int i = 0; + // Duplicate "in" N times. Each copies is named as l0, l1, l2, .... + for (; i < N; ++i) { + nodes.push_back(test::graph::Identity(g, in, 0)); + } + random::PhiloxRandom philox(0, 17); + random::SimplePhilox rnd(&philox); + while (nodes.size() > 1) { + // Randomly pick two from nodes and add them. The resulting node + // is named lik n10, n11, .... and is put back into "nodes". + int x = rnd.Uniform(nodes.size()); + auto in0 = nodes[x]; + nodes[x] = nodes.back(); + nodes.resize(nodes.size() - 1); + x = rnd.Uniform(nodes.size()); + auto in1 = nodes[x]; + // node = in0 + in1. + nodes[x] = test::graph::Add(g, in0, in1); + } + // The final output node "out". + test::graph::Retval(g, 0, nodes.back()); + FixupSourceAndSinkEdges(g); +} + +TEST_F(ExecutorTest, RandomTree) { + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + BuildTree(4096, g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT}); + TF_ASSERT_OK(call_frame.SetArgs({V(1.0)})); + TF_ASSERT_OK(Run(&call_frame)); + std::vector<Tensor> retvals; + TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false)); + EXPECT_EQ(4096.0, V(retvals[0])); +} + +TEST_F(ExecutorTest, OpError) { + std::unique_ptr<Graph> g(new Graph(OpRegistry::Global())); + auto zero = test::graph::Constant(g.get(), V(0.0)); + auto inf = test::graph::Unary(g.get(), "Reciprocal", zero); + auto check = test::graph::CheckNumerics(g.get(), inf, "message"); + auto two = test::graph::Constant(g.get(), V(2.0)); + test::graph::Binary(g.get(), "Mul", check, two); + FixupSourceAndSinkEdges(g.get()); + Create(std::move(g)); + FunctionCallFrame call_frame({}, {}); + // Fails due to invalid dtype. + EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame))); +} + +static void BM_executor(int iters, int width, int depth) { +#ifdef PLATFORM_GOOGLE + BenchmarkUseRealTime(); +#endif // PLATFORM_GOOGLE + Graph* g = new Graph(OpRegistry::Global()); + random::PhiloxRandom philox(1729, 17); + random::SimplePhilox rand(&philox); + uint64 cur = 0; + uint32 r = 1 + rand.Rand32() % width; + std::vector<Node*> ready_nodes; + for (int i = 0; i < r; ++i) { + ready_nodes.push_back(test::graph::NoOp(g, {})); + ++cur; + } + for (int i = 0; i < depth; ++i) { + std::random_shuffle(ready_nodes.begin(), ready_nodes.end()); + r = 1 + rand.Rand32() % (ready_nodes.size()); + std::vector<Node*> control_inputs; + for (int j = 0; j < r; ++j) { + control_inputs.push_back(ready_nodes.back()); + ready_nodes.pop_back(); + } + Node* n = test::graph::NoOp(g, control_inputs); + ++cur; + r = 1 + rand.Rand32() % width; + for (int j = 0; j < r; ++j) { + ready_nodes.push_back(test::graph::NoOp(g, {n})); + ++cur; + } + } + FixupSourceAndSinkEdges(g); +#ifdef PLATFORM_GOOGLE + SetBenchmarkLabel(strings::StrCat("Nodes = ", cur)); + SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters)); +#endif // PLATFORM_GOOGLE + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .Run(iters); +} + +// Tall skinny graphs +BENCHMARK(BM_executor)->ArgPair(16, 1024); +BENCHMARK(BM_executor)->ArgPair(32, 8192); + +// Short fat graphs +BENCHMARK(BM_executor)->ArgPair(1024, 16); +BENCHMARK(BM_executor)->ArgPair(8192, 32); + +// Tall fat graph +BENCHMARK(BM_executor)->ArgPair(1024, 1024); + +// TODO(mrry): This benchmark currently crashes with a use-after free, because +// test::Benchmark::RunWithArgs() assumes that the executor will take ownership +// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the +// duration of the benchmark. Since the single threaded executor does not retain +// a copy of the graph, this fails. +// +// TODO(mrry): Add support for Arg/Retval "function call convention" in +// `test::Benchmark::RunWithArgs()`. +#if 0 +#define ALICE "/job:j/replica:0/task:0/cpu:0" +#define BOB "/job:j/replica:0/task:0/gpu:0" + +static void BM_FeedInputFetchOutput(int iters) { + Graph* g = new Graph(OpRegistry::Global()); + // z = x + y: x and y are provided as benchmark inputs. z is the + // output of the benchmark. Conceptually, the caller is ALICE, the + // benchmark is BOB. + Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB); + Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB); + Node* sum = test::graph::Add(g, x, y); + Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE); + FixupSourceAndSinkEdges(g); + Tensor val(DT_FLOAT, TensorShape({})); + val.scalar<float>()() = 3.14; + SetBenchmarkItemsProcessed(static_cast<int64>(iters)); + test::Benchmark("cpu", g, nullptr, nullptr, nullptr, + "SINGLE_THREADED_EXECUTOR") + .RunWithArgs({{x, val}, {y, val}}, {z}, iters); +} +BENCHMARK(BM_FeedInputFetchOutput); +#endif + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc index fe7ef38d5f..b8c7fb15f4 100644 --- a/tensorflow/core/kernels/data/skip_dataset_op.cc +++ b/tensorflow/core/kernels/data/skip_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -187,5 +187,5 @@ class SkipDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc index 14df3a6801..1e73cfc753 100644 --- a/tensorflow/core/kernels/data/slide_dataset_op.cc +++ b/tensorflow/core/kernels/data/slide_dataset_op.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -293,5 +293,5 @@ REGISTER_KERNEL_BUILDER(Name("SlideDataset").Device(DEVICE_CPU), SlideDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc index e526578701..85b1e50695 100644 --- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/core/util/sparse/sparse_tensor.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -274,5 +274,5 @@ TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL); #undef REGISTER_DATASET_KERNEL } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc index ffabda1a8a..783d1e6cb2 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.cc +++ b/tensorflow/core/kernels/data/sql/driver_manager.cc @@ -16,7 +16,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h" namespace tensorflow { - +namespace data { namespace sql { std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection( @@ -30,5 +30,5 @@ std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection( } } // namespace sql - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h index a34691b5a2..c5428f396b 100644 --- a/tensorflow/core/kernels/data/sql/driver_manager.h +++ b/tensorflow/core/kernels/data/sql/driver_manager.h @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/sql/query_connection.h" namespace tensorflow { - +namespace data { namespace sql { // A factory class for creating `QueryConnection` instances. @@ -35,7 +35,7 @@ class DriverManager { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_ diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h index e9ffca202f..2fd229a9bf 100644 --- a/tensorflow/core/kernels/data/sql/query_connection.h +++ b/tensorflow/core/kernels/data/sql/query_connection.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" namespace tensorflow { +namespace data { class IteratorContext; @@ -63,7 +64,7 @@ class QueryConnection { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc index 7cd07bd8ec..5108e83976 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { - +namespace data { namespace sql { SqliteQueryConnection::SqliteQueryConnection() {} @@ -115,5 +115,5 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry( } } // namespace sql - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h index 81b19530b7..175492c49d 100644 --- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h +++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace tensorflow { - +namespace data { namespace sql { class SqliteQueryConnection : public QueryConnection { @@ -50,7 +50,7 @@ class SqliteQueryConnection : public QueryConnection { }; } // namespace sql - +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_ diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc index 2aa153fcfa..6bbe459332 100644 --- a/tensorflow/core/kernels/data/sql_dataset_ops.cc +++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc @@ -24,8 +24,9 @@ limitations under the License. #include "tensorflow/core/lib/strings/stringprintf.h" namespace tensorflow { - +namespace data { namespace { + // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following ops. @@ -211,5 +212,5 @@ class SqlDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index 75af73df54..f5314f7a75 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { @@ -135,4 +136,5 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU), SetStatsAggregatorDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index b133cfab54..a7ded67876 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" namespace tensorflow { +namespace data { namespace { static mutex* get_counters_map_lock() { @@ -145,4 +146,5 @@ REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU), StatsAggregatorSummaryOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc index 8957f5d997..e9e42f05a1 100644 --- a/tensorflow/core/kernels/data/stats_dataset_ops.cc +++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" namespace tensorflow { +namespace data { namespace { // This op defines a `Dataset` that passes through its input elements and @@ -248,4 +249,5 @@ REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU), BytesProducedStatsDatasetOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc index e5c237dfaa..e5cdfdd732 100644 --- a/tensorflow/core/kernels/data/take_dataset_op.cc +++ b/tensorflow/core/kernels/data/take_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -174,5 +174,5 @@ class TakeDatasetOp : public UnaryDatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc index fc21c3235a..ca4ea25b89 100644 --- a/tensorflow/core/kernels/data/tensor_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc @@ -14,10 +14,11 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -28,17 +29,11 @@ class TensorDatasetOp : public DatasetOpKernel { explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {} void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - // Create a new TensorDatasetOp::Dataset, insert it in the step - // container, and return it as the output. OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs)); // TODO(mrry): Validate that the shapes of the "components" tensors match // the "shapes" attr.; - std::vector<Tensor> components; - components.reserve(inputs.size()); - for (const Tensor& t : inputs) { - components.push_back(t); - } + std::vector<Tensor> components(inputs.begin(), inputs.end()); *output = new Dataset(ctx, std::move(components)); } @@ -74,7 +69,13 @@ class TensorDatasetOp : public DatasetOpKernel { components.reserve(tensors_.size()); for (const Tensor& t : tensors_) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list(); + if (input_list) { + TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node)); + input_list->emplace_back(node->name(), t); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } components.emplace_back(node); } AttrValue dtypes; @@ -135,5 +136,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU), TensorDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc index ccd5e60acc..2ed636a400 100644 --- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a, @@ -648,5 +648,5 @@ REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU), EnqueueInQueueDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc index 5b051e0e08..7dc64b0a75 100644 --- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc +++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc @@ -14,11 +14,12 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -30,8 +31,6 @@ class TensorSliceDatasetOp : public DatasetOpKernel { : DatasetOpKernel(ctx) {} void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { - // Create a new TensorDatasetOp::Dataset, insert it in the step - // container, and return it as the output. OpInputList inputs; OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs)); std::vector<Tensor> components; @@ -93,7 +92,13 @@ class TensorSliceDatasetOp : public DatasetOpKernel { components.reserve(tensors_.size()); for (const Tensor& t : tensors_) { Node* node; - TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list(); + if (input_list) { + TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node)); + input_list->emplace_back(node->name(), t); + } else { + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + } components.emplace_back(node); } AttrValue dtypes; @@ -163,5 +168,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU), TensorSliceDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 1a79f72b28..81c432b938 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -18,7 +18,7 @@ limitations under the License. #include "tensorflow/core/util/batch_util.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -204,5 +204,5 @@ REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU), UnbatchDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc index 0ab6beabfc..2ad4711aab 100644 --- a/tensorflow/core/kernels/data/window_dataset.cc +++ b/tensorflow/core/kernels/data/window_dataset.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" namespace tensorflow { +namespace data { namespace { class WindowDataset : public DatasetBase { @@ -107,4 +108,5 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements, return Status::OK(); } +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h index 7bd31a0bc7..84cb3c7860 100644 --- a/tensorflow/core/kernels/data/window_dataset.h +++ b/tensorflow/core/kernels/data/window_dataset.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { +namespace data { // Creates a dataset representing an eagerly-collected window of elements. // @@ -43,6 +44,7 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements, std::vector<PartialTensorShape> output_shapes, DatasetBase** out_dataset); +} // namespace data } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_ diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc index 41bf9d43fe..ac44623ce2 100644 --- a/tensorflow/core/kernels/data/window_dataset_op.cc +++ b/tensorflow/core/kernels/data/window_dataset_op.cc @@ -19,7 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/window_dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -33,22 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { int64 window_size = 0; - OP_REQUIRES_OK( - ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size)); + OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size)); OP_REQUIRES( ctx, window_size > 0, errors::InvalidArgument("Window size must be greater than zero.")); - *output = new Dataset(ctx, window_size, input); + int64 window_shift = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "shift", &window_shift)); + OP_REQUIRES( + ctx, window_shift > 0, + errors::InvalidArgument("Window shift must be greater than zero.")); + + int64 window_stride = 0; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<int64>(ctx, "stride", &window_stride)); + OP_REQUIRES( + ctx, window_stride > 0, + errors::InvalidArgument("Window stride must be greater than zero.")); + + bool drop_remainder; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder)); + + *output = new Dataset(ctx, input, window_size, window_shift, window_stride, + drop_remainder); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input) + Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size, + int64 window_shift, int64 window_stride, bool drop_remainder) : DatasetBase(DatasetContext(ctx)), + input_(input), window_size_(window_size), - input_(input) { + window_shift_(window_shift), + window_stride_(window_stride), + drop_remainder_(drop_remainder) { input_->Ref(); } @@ -72,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } string DebugString() const override { - return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset"); + return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_, + window_stride_, drop_remainder_, ")::Dataset"); } protected: @@ -81,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { Node** output) const override { Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); - Node* window_size = nullptr; - TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size)); + Node* window_size_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node)); + Node* window_shift_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node)); + Node* window_stride_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node)); + Node* drop_remainder_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node)); TF_RETURN_IF_ERROR( - b->AddDataset(this, {input_graph_node, window_size}, output)); + b->AddDataset(this, + {input_graph_node, window_size_node, window_shift_node, + window_stride_node, drop_remainder_node}, + output)); return Status::OK(); } @@ -101,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { - // Each row of `window_elements` is a tuple of tensors from the - // input iterator. + const int64 window_size = dataset()->window_size_; + const int64 window_shift = dataset()->window_shift_; + const int64 window_stride = dataset()->window_stride_; std::vector<std::vector<Tensor>> window_elements; + Status status = Status::OK(); { mutex_lock l(mu_); - if (!input_impl_) { + if (!input_impl_ && buffer_.empty()) { *end_of_sequence = true; return Status::OK(); } - window_elements.reserve(dataset()->window_size_); - *end_of_sequence = false; - for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence; - ++i) { - std::vector<Tensor> window_element_tuple; - TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple, - end_of_sequence)); - if (!*end_of_sequence) { - window_elements.emplace_back(std::move(window_element_tuple)); - } else { - input_impl_.reset(); + + // Add elements to the buffer. + size_t target_size = TargetBufferSize(window_size, window_stride); + if (input_impl_) { + *end_of_sequence = false; + for (size_t i = buffer_.size(); + i < target_size && !*end_of_sequence; ++i) { + std::vector<Tensor> element; + Status status = + input_impl_->GetNext(ctx, &element, end_of_sequence); + if (!*end_of_sequence) { + buffer_.emplace_back(std::move(element), status); + } else { + input_impl_.reset(); + } + } + } + + // If there are not enough elements and `drop_remainder` is set, we do + // not wish to return a smaller window. + if (buffer_.empty() || + (dataset()->drop_remainder_ && buffer_.size() < target_size)) { + DCHECK(*end_of_sequence); + return Status::OK(); + } + + int num_elements = 1 + (buffer_.size() - 1) / window_stride; + window_elements.reserve(num_elements); + for (size_t i = 0; i < num_elements; ++i) { + status.Update(buffer_[window_stride * i].status); + if (!status.ok()) { + break; + } + window_elements.emplace_back(buffer_[window_stride * i].result); + } + + // Shift the window, discarding elements if necessary. + int buffer_size = buffer_.size(); + if (window_shift >= buffer_size) { + for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) { + bool end_of_input; + std::vector<Tensor> element; + // Ignore non-error status of discarded elements. + input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError(); + if (end_of_input) { + input_impl_.reset(); + } } + buffer_.clear(); + } else { + buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift); } } - if (window_elements.empty()) { - DCHECK(*end_of_sequence); - return Status::OK(); + if (!status.ok()) { + return status; } + // Construct output tensors. const size_t num_tuple_components = window_elements[0].size(); const int64 num_window_elements = window_elements.size(); + *end_of_sequence = false; for (size_t idx = 0; idx < num_tuple_components; ++idx) { DatasetBase* window_dataset; std::vector<std::vector<Tensor>> window_component_elements; @@ -154,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset, &out_tensors->back())); } - *end_of_sequence = false; return Status::OK(); } @@ -167,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } else { TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); } + // Save buffer. + TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"), + buffer_.size())); + for (int64 i = 0; i < buffer_.size(); i++) { + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status)); + TF_RETURN_IF_ERROR( + writer->WriteScalar(strings::StrCat("buffer[", i, "].size"), + buffer_[i].result.size())); + for (int64 j = 0; j < buffer_[i].result.size(); j++) { + TF_RETURN_IF_ERROR( + writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"), + buffer_[i].result[j])); + } + } return Status::OK(); } @@ -178,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel { } else { input_impl_.reset(); } + // Restore buffer. + int64 buffer_size; + TF_RETURN_IF_ERROR( + reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size)); + buffer_.resize(buffer_size); + for (int64 i = 0; i < buffer_size; i++) { + int64 vector_size; + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status)); + TF_RETURN_IF_ERROR(reader->ReadScalar( + strings::StrCat("buffer[", i, "].size"), &vector_size)); + buffer_[i].result.resize(vector_size); + for (int64 j = 0; j < vector_size; j++) { + TF_RETURN_IF_ERROR( + reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"), + &buffer_[i].result[j])); + } + } return Status::OK(); } private: + struct InvocationResult { + InvocationResult() = default; + InvocationResult(std::vector<Tensor>&& result, const Status& status) + : result(result), status(status) {} + + std::vector<Tensor> result; + Status status; + }; + + Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + TF_RETURN_IF_ERROR(writer->WriteScalar( + CodeKey(index), static_cast<int64>(status.code()))); + if (!status.ok()) { + TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index), + status.error_message())); + } + return Status::OK(); + } + + Status ReadStatusLocked(IteratorStateReader* reader, size_t index, + Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + int64 code_int; + TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); + error::Code code = static_cast<error::Code>(code_int); + + if (code != error::Code::OK) { + string error_message; + TF_RETURN_IF_ERROR( + reader->ReadScalar(ErrorMessageKey(index), &error_message)); + *status = Status(code, error_message); + } else { + *status = Status::OK(); + } + return Status::OK(); + } + + string CodeKey(size_t index) { + return full_name(strings::StrCat("buffer[", index, "].code")); + } + + string ErrorMessageKey(size_t index) { + return full_name(strings::StrCat("buffer[", index, "].error_message")); + } + + size_t TargetBufferSize(int64 window_size, int64 window_stride) { + return (window_size - 1) * window_stride + 1; + } + mutex mu_; + std::deque<InvocationResult> buffer_ GUARDED_BY(mu_); std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); }; - const int64 window_size_; const DatasetBase* const input_; + const int64 window_size_; + const int64 window_shift_; + const int64 window_stride_; + const bool drop_remainder_; }; }; REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU), WindowDatasetOp); - } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc index 1c49874a6a..3f76695bb1 100644 --- a/tensorflow/core/kernels/data/writer_ops.cc +++ b/tensorflow/core/kernels/data/writer_ops.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/core/platform/file_system.h" namespace tensorflow { - +namespace data { namespace { class ToTFRecordOp : public AsyncOpKernel { @@ -104,4 +104,5 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU), ToTFRecordOp); } // namespace +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc index e4306579ed..61a2078f46 100644 --- a/tensorflow/core/kernels/data/zip_dataset_op.cc +++ b/tensorflow/core/kernels/data/zip_dataset_op.cc @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset.h" namespace tensorflow { - +namespace data { namespace { // See documentation in ../ops/dataset_ops.cc for a high-level @@ -175,5 +175,5 @@ class ZipDatasetOp : public DatasetOpKernel { REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp); } // namespace - +} // namespace data } // namespace tensorflow diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h index 33ed5522d0..d705e82b0d 100644 --- a/tensorflow/core/kernels/debug_ops.h +++ b/tensorflow/core/kernels/debug_ops.h @@ -255,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp { TensorShape shape({1}); OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor)); output_tensor->vec<int64>()(0) = nan_count; - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } }; @@ -380,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp { bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 && positive_inf_count == 0; if (!mute) { - PublishTensor(*output_tensor); + OP_REQUIRES_OK(context, PublishTensor(*output_tensor)); } } diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc index b4dcf0a74b..ae451be7e2 100644 --- a/tensorflow/core/kernels/decode_bmp_op.cc +++ b/tensorflow/core/kernels/decode_bmp_op.cc @@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel { errors::InvalidArgument( "Number of channels must be 1, 3 or 4, was ", channels_)); - OP_REQUIRES(context, width > 0 && header_size >= 0, + OP_REQUIRES(context, width > 0, errors::InvalidArgument("Width must be positive")); + OP_REQUIRES(context, height != 0, + errors::InvalidArgument("Height must be nonzero")); OP_REQUIRES(context, header_size >= 0, errors::InvalidArgument("header size must be nonnegative")); @@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel { const int32 abs_height = abs(height); // there may be padding bytes when the width is not a multiple of 4 bytes - // 8 * channels == bits per pixel - const int row_size = (8 * channels_ * width + 31) / 32 * 4; + const int row_size = (channels_ * width + 3) / 4 * 4; const int64 last_pixel_offset = static_cast<int64>(header_size) + (abs_height - 1) * row_size + diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc index 3eed847c16..6bfb5bd5bc 100644 --- a/tensorflow/core/kernels/decode_csv_op.cc +++ b/tensorflow/core/kernels/decode_csv_op.cc @@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel { OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults)); for (int i = 0; i < record_defaults.size(); ++i) { + OP_REQUIRES(ctx, record_defaults[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2, errors::InvalidArgument( "There should only be 1 default per field but field ", i, diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index b01db91720..fb2a4cc8ef 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { data.shaped<T, 2>({indices_vec.dimension(0), slice_size}); if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { - T* merged_base = &merged_flat(0, 0); - const T* data_base = &data_flat(0, 0); + T* merged_base = merged_flat.data(); + const T* data_base = data_flat.data(); for (int i = 0; i < indices_vec.size(); i++) { int32 index = internal::SubtleMustCopy(indices_vec(i)); OP_REQUIRES( diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h index e13e548f86..8edf7d4a2c 100644 --- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h @@ -51,48 +51,55 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, - const TensorReshapingOp< + IndexPair<typename internal::traits<OutputBackward>::Index>, 1>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, - const TensorReverseOp<const array<bool, 5>, const Kernel> >, + 2>, + const TensorShufflingOp< + const array< + typename internal::traits<OutputBackward>::Index, 5>, + const TensorReverseOp<const Eigen::array<bool, 5>, + const Kernel>>>>, const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, + 2>, const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> > > >, + const OutputBackward>>>>, TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, internal::traits<OutputBackward>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<OutputBackward>::Index>, 2>, + IndexPair<typename internal::traits<OutputBackward>::Index>, 1>, const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, + 2>, const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> >, - const TensorReshapingOp< + const OutputBackward>>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const DSizes<typename internal::traits<OutputBackward>::Index, - 3>, - const TensorReverseOp<const array<bool, 5>, - const Kernel> > > > >::type + 2>, + const TensorShufflingOp< + const array< + typename internal::traits<OutputBackward>::Index, 5>, + const TensorReverseOp<const Eigen::array<bool, 5>, + const Kernel>>>>>>>::type CuboidConvolutionBackwardInput( const Kernel& kernel, const OutputBackward& output_backward, typename internal::traits<OutputBackward>::Index inputPlanes, typename internal::traits<OutputBackward>::Index inputRows, typename internal::traits<OutputBackward>::Index inputCols, - const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1, - const DenseIndex strideCols = 1) { + const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1, + const DenseIndex col_stride = 1) { typedef typename internal::traits<OutputBackward>::Index TensorIndex; const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar, internal::traits<Kernel>::NumDimensions, - internal::traits<Kernel>::Layout, TensorIndex> > + internal::traits<Kernel>::Layout, TensorIndex>> kern(kernel); const TensorRef< const Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, - internal::traits<OutputBackward>::Layout, TensorIndex> > + internal::traits<OutputBackward>::Layout, TensorIndex>> out(output_backward); EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout == @@ -125,58 +132,45 @@ CuboidConvolutionBackwardInput( const TensorIndex outputCols = isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4]; - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi<TensorIndex>( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi<TensorIndex>( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi<TensorIndex>( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; - } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; - } - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Computing the forward padding. + const TensorIndex forward_pad_top_z = numext::maxi<Index>( + 0, + ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2); + const TensorIndex forward_pad_top = numext::maxi<Index>( + 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2); + const TensorIndex forward_pad_left = numext::maxi<Index>( + 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2); + + const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z; + const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top; + const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left; + + const TensorIndex padding_bottom_z = inputPlanes - + (outputPlanes - 1) * plane_stride - 2 - + padding_top_z + kernelPlanesEff; + const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride - + 2 - padding_top + kernelRowsEff; + const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride - + 2 - padding_left + kernelColsEff; + + eigen_assert(padding_top_z >= 0); eigen_assert(padding_top >= 0); eigen_assert(padding_left >= 0); + eigen_assert(padding_bottom_z >= 0); eigen_assert(padding_bottom >= 0); eigen_assert(padding_right >= 0); - // The kernel has dimensions filters X channels X patch_planes X patch_rows X - // patch_cols. + // The kernel has dimensions : + // filters x channels x patch_planes x patch_rows x patch_cols. // We need to reverse the kernel along the spatial dimensions. - array<bool, 5> kernel_reverse; + Eigen::array<bool, 5> kernel_reverse; if (isColMajor) { kernel_reverse[0] = false; kernel_reverse[1] = false; @@ -191,15 +185,35 @@ CuboidConvolutionBackwardInput( kernel_reverse[4] = false; } - DSizes<TensorIndex, 3> kernel_dims; + // Reorder the dimensions to: + // filters x patch_planes x patch_rows x patch_cols x channels + array<TensorIndex, 5> kernel_shuffle; if (isColMajor) { - kernel_dims[0] = kernelFilters; - kernel_dims[1] = kernelChannels; - kernel_dims[2] = kernelRows * kernelCols * kernelPlanes; + // From: filters x channels x planes x rows x cols + // To: filters x planes x rows x cols x channels + kernel_shuffle[0] = 0; + kernel_shuffle[1] = 2; + kernel_shuffle[2] = 3; + kernel_shuffle[3] = 4; + kernel_shuffle[4] = 1; } else { - kernel_dims[0] = kernelRows * kernelCols * kernelPlanes; + // From: cols x rows x planes x channels x filters + // To: channels x cols x rows x planes x filters + kernel_shuffle[0] = 3; + kernel_shuffle[1] = 0; + kernel_shuffle[2] = 1; + kernel_shuffle[3] = 2; + kernel_shuffle[4] = 4; + } + + // Collapse the dims + DSizes<TensorIndex, 2> kernel_dims; + if (isColMajor) { + kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols; kernel_dims[1] = kernelChannels; - kernel_dims[2] = kernelFilters; + } else { + kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols; + kernel_dims[0] = kernelChannels; } // The output_backward has dimensions out_depth X out_planes X out_rows X @@ -208,36 +222,32 @@ CuboidConvolutionBackwardInput( // dimensions: // out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes * // input_rows * input_cols * OTHERS) - DSizes<TensorIndex, 3> pre_contract_dims; + DSizes<TensorIndex, 2> pre_contract_dims; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; + pre_contract_dims[0] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[1] = inputPlanes * inputRows * inputCols; for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[2] *= out.dimension(i); + pre_contract_dims[1] *= out.dimension(i); } } else { - pre_contract_dims[2] = kernelFilters; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = inputRows * inputCols * inputPlanes; + pre_contract_dims[1] = + kernelFilters * kernelPlanes * kernelRows * kernelCols; + pre_contract_dims[0] = inputPlanes * inputRows * inputCols; for (int i = 0; i < NumDims - 4; ++i) { pre_contract_dims[0] *= out.dimension(i); } } - // We will contract along dimensions (0, 2) in kernel and (0, 1) in - // output_backward, if this is col-major, and - // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this - // row-major. - array<IndexPair<TensorIndex>, 2> contract_dims; + // We will contract along the collapsed dimension that contains the + // kernelFilters, kernelPlanes, kernelRows and kernelCols. + array<IndexPair<TensorIndex>, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) contract_dims[0] = IndexPair<TensorIndex>(0, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 1); } else { // row-major: output.patches.contract(kernel) - contract_dims[0] = IndexPair<TensorIndex>(1, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 2); + contract_dims[0] = IndexPair<TensorIndex>(1, 1); } // Post contraction, the dimensions of the input_backprop is @@ -261,40 +271,31 @@ CuboidConvolutionBackwardInput( } } - DSizes<TensorIndex, NumDims> strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } - return choose( Cond<internal::traits<OutputBackward>::Layout == ColMajor>(), kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) .reshape(kernel_dims) + .eval() .contract(output_backward .extract_volume_patches( kernelPlanes, kernelRows, kernelCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, padding_ztop, - padding_zbottom, padding_top, padding_bottom, + plane_stride, row_stride, col_stride, padding_top_z, + padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims), contract_dims) .reshape(post_contract_dims), output_backward .extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, + plane_stride, row_stride, col_stride, + padding_top_z, padding_bottom_z, padding_top, padding_bottom, padding_left, padding_right) .reshape(pre_contract_dims) - .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims), + .contract(kernel.reverse(kernel_reverse) + .shuffle(kernel_shuffle) + .reshape(kernel_dims) + .eval(), contract_dims) .reshape(post_contract_dims)); } @@ -322,48 +323,69 @@ CuboidConvolutionBackwardInput( */ template <typename OutputBackward, typename Input> EIGEN_ALWAYS_INLINE static const typename internal::conditional< - internal::traits<OutputBackward>::Layout == ColMajor, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, 5>, - const TensorReverseOp< - const array<bool, 5>, - const TensorReshapingOp< - const DSizes<typename internal::traits<OutputBackward>::Index, - 5>, + internal::traits<Input>::Layout == ColMajor, + const TensorReverseOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorShufflingOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<Input>::Index>, 2>, - const TensorReshapingOp< + IndexPair<typename internal::traits<Input>::Index>, 1>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, - 3>, - const Input>, + 2>, + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const OutputBackward>>>, const TensorReshapingOp< - const DSizes< - typename internal::traits<OutputBackward>::Index, - 4>, + const DSizes<typename internal::traits<Input>::Index, + 2>, const TensorVolumePatchOp< Dynamic, Dynamic, Dynamic, - const OutputBackward> > > > > >, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, 5>, - const TensorReverseOp< - const array<bool, 5>, - const TensorReshapingOp< - const DSizes<typename internal::traits<OutputBackward>::Index, - 5>, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Input>>>>>>>>, + const TensorReverseOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorShufflingOp< + const Eigen::array<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Eigen::TensorReshapingOp< + const Eigen::DSizes<typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, const TensorContractionOp< const array< - IndexPair<typename internal::traits<Input>::Index>, 2>, - const TensorReshapingOp< - const DSizes< - typename internal::traits<OutputBackward>::Index, - 4>, - const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic, - const OutputBackward> >, + IndexPair<typename internal::traits<Input>::Index>, 1>, const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, - 3>, - const Input> > > > > >::type + 2>, + const TensorVolumePatchOp< + Dynamic, Dynamic, Dynamic, + const Eigen::TensorForcedEvalOp< + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const Input>>>>, + const Eigen::TensorForcedEvalOp<const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, + 2>, + const Eigen::TensorShufflingOp< + const Eigen::array< + typename internal::traits<Input>::Index, + internal::traits<Input>::NumDimensions>, + const OutputBackward>>>>>>>>::type CuboidConvolutionBackwardKernel( const Input& input, const OutputBackward& output_backward, typename internal::traits<Input>::Index kernelPlanes, @@ -374,11 +396,11 @@ CuboidConvolutionBackwardKernel( typedef typename internal::traits<Input>::Index TensorIndex; TensorRef<Tensor<typename internal::traits<Input>::Scalar, internal::traits<Input>::NumDimensions, - internal::traits<Input>::Layout, TensorIndex> > + internal::traits<Input>::Layout, TensorIndex>> in(input); TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar, internal::traits<OutputBackward>::NumDimensions, - internal::traits<OutputBackward>::Layout, TensorIndex> > + internal::traits<OutputBackward>::Layout, TensorIndex>> out(output_backward); EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout == @@ -392,6 +414,13 @@ CuboidConvolutionBackwardKernel( internal::traits<OutputBackward>::NumDimensions, YOU_MADE_A_PROGRAMMING_MISTAKE); + // We do not support higher dimensional backward convolutions, or convolutions + // without batch dimension. + // TODO(ezhulenev): Relax this constraint, and turn on tests without batch + // dimension in eigen_backward_cuboid_convolutions_test.cc. + EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5, + YOU_MADE_A_PROGRAMMING_MISTAKE); + const TensorIndex inputPlanes = isColMajor ? in.dimension(1) : in.dimension(NumDims - 2); const TensorIndex inputRows = @@ -406,213 +435,174 @@ CuboidConvolutionBackwardKernel( const TensorIndex outputCols = isColMajor ? out.dimension(3) : out.dimension(NumDims - 4); + // Number of filters. This is the same as the output depth. const TensorIndex kernelFilters = isColMajor ? out.dimension(0) : out.dimension(NumDims - 1); + // Number of channels. This is the same as the input depth. const TensorIndex kernelChannels = isColMajor ? in.dimension(0) : in.dimension(NumDims - 1); - TensorIndex forward_pad_z, forward_pad_y, forward_pad_x; - const TensorIndex size_z = - Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes)); - const TensorIndex size_y = - Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows)); - const TensorIndex size_x = - Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols)); - - // Infer padding type. - if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) { - // SAME padding. - const TensorIndex dz = numext::maxi<TensorIndex>( - 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes); - const TensorIndex dy = numext::maxi<TensorIndex>( - 0, (size_y - 1) * strideRows + kernelRows - inputRows); - const TensorIndex dx = numext::maxi<TensorIndex>( - 0, (size_x - 1) * strideCols + kernelCols - inputCols); - - forward_pad_z = dz / 2; - forward_pad_y = dy / 2; - forward_pad_x = dx / 2; + // Number of batches in the input tensor. + const TensorIndex batch = + isColMajor ? in.dimension(4) : in.dimension(NumDims - 5); + + // TODO(ezhulenev): Add support for inflated strides. Without inflated strides + // effective kernel planes/rows/cols are always the same as the kernel itself + // (see eigen_spatial_convolutions for details). + const TensorIndex kernelPlanesEff = kernelPlanes; + const TensorIndex kernelRowsEff = kernelRows; + const TensorIndex kernelColsEff = kernelCols; + + // Compute forward padding from input and output_backward dimensions. + const TensorIndex padPlanes = numext::maxi<Index>( + 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes); + const TensorIndex padRows = numext::maxi<Index>( + 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows); + const TensorIndex padCols = numext::maxi<Index>( + 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols); + + const TensorIndex padding_top_z = padPlanes / 2; + const TensorIndex padding_top = padRows / 2; + const TensorIndex padding_left = padCols / 2; + + // Compute paddings for output_backward before extracting patches. + const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1; + const auto expanded_out_rows = (outputRows - 1) * strideRows + 1; + const auto expanded_out_cols = (outputCols - 1) * strideCols + 1; + const auto padded_out_planes = inputPlanes + kernelPlanes - 1; + const auto padded_out_rows = inputRows + kernelRows - 1; + const auto padded_out_cols = inputCols + kernelCols - 1; + const auto top_pad_planes = kernelPlanes - 1 - padding_top_z; + const auto top_pad_rows = kernelRows - 1 - padding_top; + const auto left_pad_cols = kernelCols - 1 - padding_left; + const auto bottom_pad_planes = + padded_out_planes - expanded_out_planes - top_pad_planes; + const auto bottom_pad_rows = + padded_out_rows - expanded_out_rows - top_pad_rows; + const auto right_pad_cols = + padded_out_cols - expanded_out_cols - left_pad_cols; + + // Reorder output_backward dimensions. + array<TensorIndex, 5> output_backward_shuffle; + if (isColMajor) { + // From: [out_depth, out_planes, out_rows, out_cols, batch] + // To: [batch, out_planes, out_rows, out_cols, out_depth] + output_backward_shuffle = {4, 1, 2, 3, 0}; } else { - // VALID padding. - forward_pad_z = 0; - forward_pad_y = 0; - forward_pad_x = 0; + // From: [batch, out_cols, out_rows, out_planes, out_depth] + // To: [out_depth, out_cols, out_rows, out_planes, batch] + output_backward_shuffle = {4, 1, 2, 3, 0}; } - const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z; - const TensorIndex padding_top = kernelRows - 1 - forward_pad_y; - const TensorIndex padding_left = kernelCols - 1 - forward_pad_x; - - const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 - - (outputPlanes - 1) * stridePlanes - 1 - - padding_ztop; - const TensorIndex padding_bottom = inputRows + kernelRows - 1 - - (outputRows - 1) * strideRows - 1 - - padding_top; - const TensorIndex padding_right = inputCols + kernelCols - 1 - - (outputCols - 1) * strideCols - 1 - - padding_left; - - eigen_assert(padding_ztop >= 0); - eigen_assert(padding_zbottom >= 0); - eigen_assert(padding_top >= 0); - eigen_assert(padding_left >= 0); - eigen_assert(padding_bottom >= 0); - eigen_assert(padding_right >= 0); - - // The output_backward has dimensions out_depth X out_plaens X out_rows X - // out_cols X OTHERS - // When we extract the image patches from output_backward (with input as the - // kernel), it will have dimensions - // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes * - // kernel_rows * kernel_cols) X OTHERS - DSizes<TensorIndex, 4> pre_contract_dims; + // Reorder input dimensions. + array<TensorIndex, 5> input_shuffle; if (isColMajor) { - pre_contract_dims[0] = kernelFilters; - pre_contract_dims[1] = inputRows * inputCols * inputPlanes; - pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[3] = 1; - for (int i = 4; i < NumDims; ++i) { - pre_contract_dims[3] *= out.dimension(i); - } + // From: [in_depth, in_planes, in_rows, in_cols, batch] + // To: [in_depth, batch, in_planes, in_rows, in_cols] + input_shuffle = {0, 4, 1, 2, 3}; } else { - pre_contract_dims[3] = kernelFilters; - pre_contract_dims[2] = inputRows * inputCols * inputPlanes; - pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes; - pre_contract_dims[0] = 1; - for (int i = 0; i < NumDims - 4; ++i) { - pre_contract_dims[0] *= out.dimension(i); - } + // From: [batch, in_cols, in_rows, in_planes, in_depth] + // To: [in_cols, in_rows, in_planes, batch, in_depth] + input_shuffle = {1, 2, 3, 0, 4}; } - // The input has dimensions in_depth X (input_planes * input_rows * - // input_cols) X OTHERS - DSizes<TensorIndex, 3> input_dims; + // Input is playing the role of a "kernel" in this convolution. + DSizes<TensorIndex, 2> input_dims; if (isColMajor) { input_dims[0] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[2] = 1; - for (int i = 4; i < NumDims; ++i) { - input_dims[2] *= in.dimension(i); - } - eigen_assert(input_dims[2] == pre_contract_dims[3]); + input_dims[1] = batch * inputPlanes * inputRows * inputCols; } else { - input_dims[2] = kernelChannels; - input_dims[1] = inputRows * inputCols * inputPlanes; - input_dims[0] = 1; - for (int i = 0; i < NumDims - 4; ++i) { - input_dims[0] *= in.dimension(i); - } - eigen_assert(input_dims[0] == pre_contract_dims[0]); + input_dims[1] = kernelChannels; + input_dims[0] = inputCols * inputRows * inputPlanes * batch; } - // We will contract along dimensions (1, 2) in and (1, 3) in out, if - // this is col-major. - // For row-major, it's dimensions (0, 1) in and (0, 2) in out. - array<IndexPair<TensorIndex>, 2> contract_dims; + // Molds the output of the patch extraction result into a 2D tensor: + // - the first dimension (dims[0]): the patch values to be multiplied with the + // kernels + // - the second dimension (dims[1]): everything else + DSizes<TensorIndex, 2> pre_contract_dims; if (isColMajor) { - // col-major: in.contract(output.patches) - contract_dims[0] = IndexPair<TensorIndex>(1, 1); - contract_dims[1] = IndexPair<TensorIndex>(2, 3); + pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols; + pre_contract_dims[1] = + kernelPlanes * kernelRows * kernelCols * kernelFilters; } else { - // row-major: output.patches.contract(in) - contract_dims[0] = IndexPair<TensorIndex>(0, 0); - contract_dims[1] = IndexPair<TensorIndex>(2, 1); + pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch; + pre_contract_dims[0] = + kernelFilters * kernelCols * kernelRows * kernelPlanes; } - // After the contraction, the kernel will have dimension - // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols - // We will need to shuffle the first two dimensions and reverse the spatial - // dimensions. - // The end shape is: - // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols + // We will contract along the collapsed dimension that contains the + // batch, inputPlanes, inputRows and inputCols. + array<IndexPair<TensorIndex>, 1> contract_dims; + contract_dims[0] = IndexPair<TensorIndex>(1, 0); - // This is the shape of the kernel *before* the shuffling. - DSizes<TensorIndex, 5> kernel_dims; + // Dimensions after contraction. + DSizes<TensorIndex, NumDims> post_contract_dims; if (isColMajor) { - kernel_dims[0] = kernelChannels; - kernel_dims[1] = kernelFilters; - kernel_dims[2] = kernelPlanes; - kernel_dims[3] = kernelRows; - kernel_dims[4] = kernelCols; + post_contract_dims[0] = kernelChannels; + post_contract_dims[1] = kernelPlanes; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelCols; + post_contract_dims[4] = kernelFilters; } else { - kernel_dims[0] = kernelCols; - kernel_dims[1] = kernelRows; - kernel_dims[2] = kernelPlanes; - kernel_dims[3] = kernelFilters; - kernel_dims[4] = kernelChannels; + post_contract_dims[0] = kernelFilters; + post_contract_dims[1] = kernelCols; + post_contract_dims[2] = kernelRows; + post_contract_dims[3] = kernelPlanes; + post_contract_dims[4] = kernelChannels; } - // Flip filters and channels. + // Reorder output of contraction to valid filter shape. array<TensorIndex, 5> kernel_shuffle; if (isColMajor) { - kernel_shuffle[0] = 1; - kernel_shuffle[1] = 0; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 3; - kernel_shuffle[4] = 4; + // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth] + // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols] + kernel_shuffle = {4, 0, 1, 2, 3}; } else { - kernel_shuffle[0] = 0; - kernel_shuffle[1] = 1; - kernel_shuffle[2] = 2; - kernel_shuffle[3] = 4; - kernel_shuffle[4] = 3; + // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth] + // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth] + kernel_shuffle = {1, 2, 3, 4, 0}; } - // Reverse the spatial dimensions. - array<bool, 5> kernel_reverse; + // Reverse kernel backprop dimensions. + array<TensorIndex, 5> kernel_reverse; if (isColMajor) { - kernel_reverse[0] = false; - kernel_reverse[1] = false; - kernel_reverse[2] = true; - kernel_reverse[3] = true; - kernel_reverse[4] = true; + kernel_reverse = {false, false, true, true, true}; } else { - kernel_reverse[0] = true; - kernel_reverse[1] = true; - kernel_reverse[2] = true; - kernel_reverse[3] = false; - kernel_reverse[4] = false; + kernel_reverse = {true, true, true, false, false}; } - DSizes<TensorIndex, NumDims> strides; - for (int i = 0; i < NumDims; i++) { - strides[i] = 1; - } - if (isColMajor) { - strides[1] = stridePlanes; - strides[2] = strideRows; - strides[3] = strideCols; - } else { - strides[NumDims - 2] = stridePlanes; - strides[NumDims - 3] = strideRows; - strides[NumDims - 4] = strideCols; - } - return choose( - Cond<internal::traits<Input>::Layout == ColMajor>(), - input.reshape(input_dims) - .contract(output_backward + // Create convolution input (aka source of patches) from output backward + // tensor by shuffling dimensions. + const auto the_input = + output_backward.shuffle(output_backward_shuffle).eval(); + + // Create convolution kernel (aka filter) from input by shuffling and + // reshaping. + const auto the_kernel = + input.shuffle(input_shuffle).reshape(input_dims).eval(); + + return choose(Cond<internal::traits<Input>::Layout == ColMajor>(), + the_kernel.contract( + the_input .extract_volume_patches( inputPlanes, inputRows, inputCols, 1, 1, 1, stridePlanes, strideRows, strideCols, - - padding_ztop, padding_zbottom, padding_top, - padding_bottom, padding_left, padding_right) + top_pad_planes, bottom_pad_planes, top_pad_rows, + bottom_pad_rows, left_pad_cols, right_pad_cols) .reshape(pre_contract_dims), - contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle), - output_backward - .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1, - stridePlanes, strideRows, strideCols, - padding_ztop, padding_zbottom, padding_top, - padding_bottom, padding_left, padding_right) - .reshape(pre_contract_dims) - .contract(input.reshape(input_dims), contract_dims) - .reshape(kernel_dims) - .reverse(kernel_reverse) - .shuffle(kernel_shuffle)); + contract_dims), + the_input + .extract_volume_patches( + inputPlanes, inputRows, inputCols, 1, 1, 1, + stridePlanes, strideRows, strideCols, top_pad_planes, + bottom_pad_planes, top_pad_rows, bottom_pad_rows, + left_pad_cols, right_pad_cols) + .reshape(pre_contract_dims) + .contract(the_kernel, contract_dims)) + .reshape(post_contract_dims) + .shuffle(kernel_shuffle) + .reverse(kernel_reverse); } } // end namespace Eigen diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h index cb0a76dac4..960920c55b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h @@ -189,14 +189,19 @@ SpatialConvolutionBackwardInput( } #endif - // Reorder the dimensions to filters X patch_rows X patch_cols X channels + // Reorder the dimensions to: + // filters x patch_rows x patch_cols x channels array<TensorIndex, 4> kernel_shuffle; if (isColMajor) { + // From: filters x channels x rows x cols + // To: filters x rows x cols x channels kernel_shuffle[0] = 0; kernel_shuffle[1] = 2; kernel_shuffle[2] = 3; kernel_shuffle[3] = 1; } else { + // From: cols x rows x channels x filters + // To: channels x cols x rows x filters kernel_shuffle[0] = 2; kernel_shuffle[1] = 0; kernel_shuffle[2] = 1; @@ -233,8 +238,8 @@ SpatialConvolutionBackwardInput( } } - // We will contract along the fused dimension that contains the kernelFilters, - // the kernelRows and the kernelCols. + // We will contract along the collapsed dimension that contains the + // kernelFilters, the kernelRows and the kernelCols. array<IndexPair<TensorIndex>, 1> contract_dims; if (isColMajor) { // col-major: kernel.contract(output.patches) @@ -327,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional< const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>, const OutputBackward>, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, - 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, 2>, - const TensorImagePatchOp<Dynamic, Dynamic, - const Input> > > > >, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >, TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 4>, const TensorContractionOp< const array<IndexPair<typename internal::traits<Input>::Index>, 1>, - const TensorShufflingOp< - const array<typename internal::traits<OutputBackward>::Index, - 2>, - const TensorReshapingOp< - const DSizes<typename internal::traits<Input>::Index, 2>, - const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >, + const TensorReshapingOp< + const DSizes<typename internal::traits<Input>::Index, 2>, + const TensorImagePatchOp<Dynamic, Dynamic, const Input> >, const TensorReshapingOp< const DSizes<typename internal::traits<Input>::Index, 2>, const OutputBackward> > > >::type @@ -451,12 +449,16 @@ SpatialConvolutionBackwardKernel( eigen_assert(output_dims[0] == pre_contract_dims[0]); } - array<TensorIndex, 2> shuffle_dims; - shuffle_dims[0] = 1; - shuffle_dims[1] = 0; - + // We will contract along the collapsed dimension that contains the + // outputCols, outputRows and OTHERS. array<IndexPair<TensorIndex>, 1> contract_dims; - contract_dims[0] = IndexPair<TensorIndex>(1, 0); + if (isColMajor) { + // col-major: output_backward.contract(input.patches) + contract_dims[0] = IndexPair<TensorIndex>(1, 1); + } else { + // row-major: input.patches.contract(output_backward) + contract_dims[0] = IndexPair<TensorIndex>(0, 0); + } // After the contraction, the kernel will have the desired shape // out_depth X in_shape X kernel_rows X kernel_cols @@ -482,8 +484,7 @@ SpatialConvolutionBackwardKernel( kernelRows, kernelCols, row_stride, col_stride, row_in_stride, col_in_stride, 1, 1, padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) - .reshape(pre_contract_dims) - .shuffle(shuffle_dims), + .reshape(pre_contract_dims), contract_dims) .reshape(kernel_dims), input @@ -492,7 +493,6 @@ SpatialConvolutionBackwardKernel( padding_top, padding_bottom, padding_left, padding_right, OutScalar(0)) .reshape(pre_contract_dims) - .shuffle(shuffle_dims) .contract(output_backward.reshape(output_dims), contract_dims) .reshape(kernel_dims)); } diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc index 2229ec9659..673ec1458b 100644 --- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc +++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc @@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest, const int output_cols = input_cols - patch_cols + 1; const int output_planes = input_planes - patch_planes + 1; - Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols); + // TODO(ezhulenev): Support backward kernel convolution without batch + // dimension. + Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols, + /*num_batches*/ 1); Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows, patch_cols); - Tensor<float, 4> output_backward(output_depth, output_planes, output_rows, - output_cols); + Tensor<float, 5> output_backward(output_depth, output_planes, output_rows, + output_cols, /*num_batches*/ 1); output_backward = output_backward.constant(11.0f) + output_backward.random(); input = input.constant(2.0f) + input.random(); @@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest, if (output_i >= 0 && output_i < output_planes && output_j >= 0 && output_j < output_rows && output_k >= 0 && output_k < output_cols) { - expected += - input(id, i, j, k) * - output_backward(od, output_i, output_j, output_k); + expected += input(id, i, j, k, /*batch*/ 0) * + output_backward(od, output_i, output_j, + output_k, /*batch*/ 0); } } } @@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest, const int output_cols = input_cols - patch_cols + 1; const int output_planes = input_planes - patch_planes + 1; - Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes, - input_depth); + // TODO(ezhulenev): Support backward kernel convolution without batch + // dimension. + Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows, + input_planes, input_depth); Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes, input_depth, output_depth); - Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows, - output_planes, output_depth); + Tensor<float, 5, RowMajor> output_backward( + /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth); output_backward = output_backward.constant(11.0f) + output_backward.random(); input = input.constant(2.0f) + input.random(); @@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest, if (output_i >= 0 && output_i < output_planes && output_j >= 0 && output_j < output_rows && output_k >= 0 && output_k < output_cols) { - expected += - input(k, j, i, id) * - output_backward(output_k, output_j, output_i, od); + expected += input(/*batch*/ 0, k, j, i, id) * + output_backward(/*batch*/ 0, output_k, output_j, + output_i, od); } } } diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h new file mode 100644 index 0000000000..87e41b89b3 --- /dev/null +++ b/tensorflow/core/kernels/eigen_benchmark.h @@ -0,0 +1,304 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_ +#define TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_ + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h" +#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h" +#include "tensorflow/core/kernels/eigen_cuboid_convolution.h" +#include "tensorflow/core/kernels/eigen_spatial_convolutions.h" +#include "tensorflow/core/platform/test_benchmark.h" + +using ::tensorflow::TTypes; + +template <typename Scalar, typename Device> +class SpatialConvolutionBenchmarksSuite { + public: + using Input = TTypes<float, 4>::ConstTensor; + using Filter = TTypes<float, 4>::ConstTensor; + using Output = TTypes<float, 4>::Tensor; + + using Dimensions = Eigen::DSizes<Eigen::Index, 4>; + + SpatialConvolutionBenchmarksSuite(int iters, Device& device) + : iters_(iters), device_(device) {} + + Eigen::Index BufferSize(const Dimensions& dims) { + return dims.TotalSize() * sizeof(Scalar); + } + + void SpatialConvolution(Dimensions input_dims, Dimensions filter_dims) { + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + filter_dims[3]); // filter_count + + Scalar* input_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); + Scalar* filter_data = + static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); + Scalar* output_data = + static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(filter_data, 123, BufferSize(filter_dims)); + + Input input(input_data, input_dims); + Filter filter(filter_data, filter_dims); + Output output(output_data, output_dims); + + ::tensorflow::testing::StartTiming(); + for (int i = 0; i < iters_; ++i) { + output.device(device_) = Eigen::SpatialConvolution(input, filter); + tensorflow::testing::DoNotOptimize(output); + } + ::tensorflow::testing::StopTiming(); + + device_.deallocate(input_data); + device_.deallocate(filter_data); + device_.deallocate(output_data); + } + + void SpatialConvolutionBackwardInput(Dimensions input_dims, + Dimensions filter_dims) { + using OutputBackward = TTypes<float, 4>::ConstTensor; + using InputBackward = TTypes<float, 4>::Tensor; + + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + filter_dims[3]); // filter_count + + // Assuming that the convolution had SAME padding. + Eigen::Index input_rows = input_dims[1]; + Eigen::Index input_cols = input_dims[2]; + + Scalar* filter_data = + static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); + Scalar* output_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); + + device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Filter filter(filter_data, filter_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); + + ::tensorflow::testing::StartTiming(); + for (int i = 0; i < iters_; ++i) { + input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput( + filter, output_backward, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); + } + ::tensorflow::testing::StopTiming(); + + device_.deallocate(filter_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); + } + + void SpatialConvolutionBackwardKernel(Dimensions input_dims, + Dimensions filter_dims) { + using OutputBackward = TTypes<float, 4>::ConstTensor; + using FilterBackward = TTypes<float, 4>::Tensor; + + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + filter_dims[3]); // filter_count + + // Assuming that the convolution had SAME padding. + Eigen::Index filter_rows = filter_dims[0]; + Eigen::Index filter_cols = filter_dims[1]; + + Scalar* input_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); + Scalar* output_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* filter_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Input input(input_data, input_dims); + OutputBackward output_backward(output_backward_data, input_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); + + ::tensorflow::testing::StartTiming(); + for (int i = 0; i < iters_; ++i) { + filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel( + input, output_backward, filter_rows, filter_cols); + tensorflow::testing::DoNotOptimize(filter_backward); + } + ::tensorflow::testing::StopTiming(); + + device_.deallocate(input_data); + device_.deallocate(output_backward_data); + device_.deallocate(filter_backward_data); + } + + private: + int iters_; + Device& device_; +}; + +template <typename Scalar, typename Device> +class CuboidConvolutionBenchmarksSuite { + public: + using Input = TTypes<float, 5>::ConstTensor; + using Filter = TTypes<float, 5>::ConstTensor; + using Output = TTypes<float, 5>::Tensor; + + using Dimensions = Eigen::DSizes<Eigen::Index, 5>; + + CuboidConvolutionBenchmarksSuite(int iters, Device& device) + : iters_(iters), device_(device) {} + + Eigen::Index BufferSize(const Dimensions& dims) { + return dims.TotalSize() * sizeof(Scalar); + } + + void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) { + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + input_dims[3], // input_planes + filter_dims[4]); // filter_count + + Scalar* input_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); + Scalar* filter_data = + static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); + Scalar* output_data = + static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(filter_data, 123, BufferSize(filter_dims)); + + Input input(input_data, input_dims); + Filter filter(filter_data, filter_dims); + Output output(output_data, output_dims); + + ::tensorflow::testing::StartTiming(); + for (int i = 0; i < iters_; ++i) { + output.device(device_) = Eigen::CuboidConvolution(input, filter); + tensorflow::testing::DoNotOptimize(output); + } + ::tensorflow::testing::StopTiming(); + + device_.deallocate(input_data); + device_.deallocate(filter_data); + device_.deallocate(output_data); + } + + void CuboidConvolutionBackwardInput(Dimensions input_dims, + Dimensions filter_dims) { + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + input_dims[3], // input_planes + filter_dims[4]); // filter_count + + using OutputBackward = TTypes<float, 5>::ConstTensor; + using InputBackward = TTypes<float, 5>::Tensor; + + // Assuming that the convolution had SAME padding. + Eigen::Index input_rows = input_dims[1]; + Eigen::Index input_cols = input_dims[2]; + Eigen::Index input_planes = input_dims[3]; + + Scalar* filter_data = + static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); + Scalar* output_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* input_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); + + device_.memset(filter_data, 123, BufferSize(filter_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Filter filter(filter_data, filter_dims); + OutputBackward output_backward(output_backward_data, output_dims); + InputBackward input_backward(input_backward_data, input_dims); + + ::tensorflow::testing::StartTiming(); + for (int i = 0; i < iters_; ++i) { + input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput( + filter, output_backward, input_planes, input_rows, input_cols); + tensorflow::testing::DoNotOptimize(input_backward); + } + ::tensorflow::testing::StopTiming(); + + device_.deallocate(filter_data); + device_.deallocate(output_backward_data); + device_.deallocate(input_backward_data); + } + + void CuboidConvolutionBackwardKernel(Dimensions input_dims, + Dimensions filter_dims) { + using OutputBackward = TTypes<float, 5>::ConstTensor; + using FilterBackward = TTypes<float, 5>::Tensor; + + Dimensions output_dims(input_dims[0], // batch + input_dims[1], // input_height + input_dims[2], // input_width + input_dims[3], // input_planes + filter_dims[4]); // filter_count + + // Assuming that the convolution had SAME padding. + Eigen::Index filter_rows = filter_dims[0]; + Eigen::Index filter_cols = filter_dims[1]; + Eigen::Index filter_planes = filter_dims[2]; + + Scalar* input_data = + static_cast<Scalar*>(device_.allocate(BufferSize(input_dims))); + Scalar* output_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(output_dims))); + Scalar* filter_backward_data = + static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims))); + + device_.memset(input_data, 123, BufferSize(input_dims)); + device_.memset(output_backward_data, 123, BufferSize(output_dims)); + + Input input(input_data, input_dims); + OutputBackward output_backward(output_backward_data, output_dims); + FilterBackward filter_backward(filter_backward_data, filter_dims); + + ::tensorflow::testing::StartTiming(); + for (int i = 0; i < iters_; ++i) { + filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel( + input, output_backward, filter_planes, filter_rows, filter_cols); + tensorflow::testing::DoNotOptimize(filter_backward); + } + ::tensorflow::testing::StopTiming(); + + device_.deallocate(input_data); + device_.deallocate(output_backward_data); + device_.deallocate(filter_backward_data); + } + + private: + int iters_; + Device& device_; +}; + +#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_ diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc new file mode 100644 index 0000000000..ec949ddc84 --- /dev/null +++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc @@ -0,0 +1,422 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENTE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONT OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#define EIGEN_USE_CUSTOM_THREAD_POOL +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/eigen_benchmark.h" +#include "tensorflow/core/platform/test_benchmark.h" + +#define CREATE_THREAD_POOL(threads) \ + Eigen::ThreadPool tp(threads); \ + Eigen::ThreadPoolDevice device(&tp, threads) + +// -------------------------------------------------------------------------- // +// Spatial Convolutions // +// -------------------------------------------------------------------------- // + +void SpatialConvolution(int iters, int num_threads, + /* Input dimensions: */ + int input_batches, int input_height, int input_width, + int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_height, int filter_width) { + ::tensorflow::testing::StopTiming(); + + CREATE_THREAD_POOL(num_threads); + + using Benchmark = + SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>; + auto benchmark = Benchmark(iters, device); + + typename Benchmark::Dimensions input_dims(input_batches, input_height, + input_width, input_depth); + typename Benchmark::Dimensions filter_dims(filter_height, filter_width, + input_depth, filter_count); + + benchmark.SpatialConvolution(input_dims, filter_dims); + + auto num_computed_elements = + (input_dims.TotalSize() / input_depth) * filter_count; + auto flops = + num_computed_elements * (input_depth * filter_height * filter_width); + ::tensorflow::testing::ItemsProcessed(flops * iters); +} + +void SpatialConvolutionBackwardInput(int iters, int num_threads, + /* Input dimensions: */ + int input_batches, int input_height, + int input_width, int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_height, + int filter_width) { + ::tensorflow::testing::StopTiming(); + + CREATE_THREAD_POOL(num_threads); + + using Benchmark = + SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>; + auto benchmark = Benchmark(iters, device); + + typename Benchmark::Dimensions input_dims(input_batches, input_height, + input_width, input_depth); + typename Benchmark::Dimensions filter_dims(filter_height, filter_width, + input_depth, filter_count); + + benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims); + + auto num_computed_elements = input_dims.TotalSize(); + auto flops = + num_computed_elements * (input_depth * filter_height * filter_width); + ::tensorflow::testing::ItemsProcessed(flops * iters); +} + +void SpatialConvolutionBackwardKernel(int iters, int num_threads, + /* Input dimensions: */ + int input_batches, int input_height, + int input_width, int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_height, + int filter_width) { + ::tensorflow::testing::StopTiming(); + + CREATE_THREAD_POOL(num_threads); + + using Benchmark = + SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>; + auto benchmark = Benchmark(iters, device); + + typename Benchmark::Dimensions input_dims(input_batches, input_height, + input_width, input_depth); + typename Benchmark::Dimensions filter_dims(filter_height, filter_width, + input_depth, filter_count); + + benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims); + + auto num_computed_elements = filter_dims.TotalSize(); + auto flops = + num_computed_elements * (input_batches * input_height * input_width); + ::tensorflow::testing::ItemsProcessed(flops * iters); +} + +// Macro arguments names: --------------------------------------------------- // +// NT: num threads +// N: batch size +// H: height +// W: width +// C: channels +// FC: filter count +// FH: filter height +// FW: filter width + +#define BM_SPATIAL_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \ + BM_##prefix##_CPU_##NT##T_in_##N##_##H##_##W##_##C##_f_##FC##_##FH##_##FW + +#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \ + static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ + SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \ + } \ + BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW)) + +#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \ + static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \ + FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ + SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \ + } \ + BENCHMARK( \ + BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW)) + +#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \ + static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ + FH, FW)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ + SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \ + } \ + BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \ + FH, FW)) + +#define BM_SpatialConvolutions(N, H, W, C, FC, FH, FW, LABEL) \ + BM_SpatialConvolution(2, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolution(4, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolution(8, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolution(16, N, H, W, C, FC, FH, FW, LABEL); + +#define BM_SpatialConvolutionsBwdInput(N, H, W, C, FC, FH, FW, LABEL) \ + BM_SpatialConvolutionBwdInput(2, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolutionBwdInput(4, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolutionBwdInput(8, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolutionBwdInput(16, N, H, W, C, FC, FH, FW, LABEL); + +#define BM_SpatialConvolutionsBwdKernel(N, H, W, C, FC, FH, FW, LABEL) \ + BM_SpatialConvolutionBwdKernel(2, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolutionBwdKernel(4, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolutionBwdKernel(8, N, H, W, C, FC, FH, FW, LABEL); \ + BM_SpatialConvolutionBwdKernel(16, N, H, W, C, FC, FH, FW, LABEL); + +// ImageNet Forward Convolutions -------------------------------------------- // + +BM_SpatialConvolutions(32, // batch size + 56, 56, 64, // input: height, width, depth + 192, 3, 3, // filter: count, height, width + "conv2_00"); + +BM_SpatialConvolutions(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3"); +BM_SpatialConvolutions(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5"); +BM_SpatialConvolutions(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3"); +BM_SpatialConvolutions(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5"); +BM_SpatialConvolutions(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3"); +BM_SpatialConvolutions(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5"); +BM_SpatialConvolutions(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3"); +BM_SpatialConvolutions(32, 14, 14, 24, 64, 5, 5, + "conv4b_00_5x5 / conv4c_00_5x5"); +BM_SpatialConvolutions(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3"); +BM_SpatialConvolutions(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3"); +BM_SpatialConvolutions(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5"); +BM_SpatialConvolutions(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3"); +BM_SpatialConvolutions(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5"); +BM_SpatialConvolutions(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3"); +BM_SpatialConvolutions(32, 7, 7, 48, 128, 5, 5, "conv5a_00_5x5 / conv5_00_5x5"); +BM_SpatialConvolutions(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3"); + +// Benchmarks from https://github.com/soumith/convnet-benchmarks +BM_SpatialConvolutions(128, 128, 128, 3, 96, 11, 11, "convnet-layer1"); +BM_SpatialConvolutions(128, 64, 64, 64, 128, 9, 9, "convnet-layer2"); +BM_SpatialConvolutions(128, 32, 32, 128, 128, 9, 9, "convnet-layer3"); +BM_SpatialConvolutions(128, 16, 16, 128, 128, 7, 7, "convnet-layer4"); +BM_SpatialConvolutions(128, 13, 13, 384, 384, 3, 3, "convnet-layer5"); + +// ImageNet BackwardInput Convolutions -------------------------------------- // + +BM_SpatialConvolutionsBwdInput(32, 56, 56, 64, 192, 3, 3, "conv2_00"); +BM_SpatialConvolutionsBwdInput(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 24, 64, 5, 5, + "conv4b_00_5x5 / conv4c_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3"); +BM_SpatialConvolutionsBwdInput(32, 7, 7, 48, 128, 5, 5, + "conv5a_00_5x5 / conv5_00_5x5"); +BM_SpatialConvolutionsBwdInput(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3"); + +// ImageNet BackwardKernel Convolutions ------------------------------------- // + +BM_SpatialConvolutionsBwdKernel(32, 56, 56, 64, 192, 3, 3, "conv2_00"); +BM_SpatialConvolutionsBwdKernel(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 24, 64, 5, 5, + "conv4b_00_5x5 / conv4c_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3"); +BM_SpatialConvolutionsBwdKernel(32, 7, 7, 48, 128, 5, 5, + "conv5a_00_5x5 / conv5_00_5x5"); +BM_SpatialConvolutionsBwdKernel(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3"); + +// -------------------------------------------------------------------------- // +// Cuboid Convolutions // +// -------------------------------------------------------------------------- // + +void CuboidConvolution(int iters, int num_threads, + /* Input dimensions: */ + int input_batches, int input_height, int input_width, + int input_planes, int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_height, int filter_width, + int filter_planes) { + ::tensorflow::testing::StopTiming(); + + CREATE_THREAD_POOL(num_threads); + + using Benchmark = + CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>; + auto benchmark = Benchmark(iters, device); + + typename Benchmark::Dimensions input_dims( + input_batches, input_height, input_width, input_planes, input_depth); + typename Benchmark::Dimensions filter_dims( + filter_height, filter_width, filter_planes, input_depth, filter_count); + + benchmark.CuboidConvolution(input_dims, filter_dims); + + auto num_computed_elements = + (input_dims.TotalSize() / input_depth) * filter_count; + auto flops = num_computed_elements * + (input_depth * filter_height * filter_width * filter_planes); + ::tensorflow::testing::ItemsProcessed(flops * iters); +} + +void CuboidConvolutionBackwardInput(int iters, int num_threads, + /* Input dimensions: */ + int input_batches, int input_height, + int input_width, int input_planes, + int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_height, + int filter_width, int filter_planes) { + ::tensorflow::testing::StopTiming(); + + CREATE_THREAD_POOL(num_threads); + + using Benchmark = + CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>; + auto benchmark = Benchmark(iters, device); + + typename Benchmark::Dimensions input_dims( + input_batches, input_height, input_width, input_planes, input_depth); + typename Benchmark::Dimensions filter_dims( + filter_height, filter_width, filter_planes, input_depth, filter_count); + + benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims); + + auto num_computed_elements = input_dims.TotalSize(); + auto flops = num_computed_elements * + (input_depth * filter_height * filter_width * filter_planes); + ::tensorflow::testing::ItemsProcessed(flops * iters); +} + +void CuboidConvolutionBackwardKernel(int iters, int num_threads, + /* Input dimensions: */ + int input_batches, int input_height, + int input_width, int input_planes, + int input_depth, + /* Filter (kernel) dimensions: */ + int filter_count, int filter_height, + int filter_width, int filter_planes) { + ::tensorflow::testing::StopTiming(); + + CREATE_THREAD_POOL(num_threads); + + using Benchmark = + CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>; + auto benchmark = Benchmark(iters, device); + + typename Benchmark::Dimensions input_dims( + input_batches, input_height, input_width, input_planes, input_depth); + typename Benchmark::Dimensions filter_dims( + filter_height, filter_width, filter_planes, input_depth, filter_count); + + benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims); + + auto num_computed_elements = filter_dims.TotalSize(); + auto flops = num_computed_elements * + (input_batches * input_height * input_width * input_planes); + ::tensorflow::testing::ItemsProcessed(flops * iters); +} + +// Macro arguments names: --------------------------------------------------- // +// NT: num threads +// N: batch size +// H: height +// W: width +// P: panes +// C: channels +// FC: filter count +// FH: filter height +// FW: filter width +// FP: filter panes + +#define BM_CONCAT(a, b) a##b + +#define BM_CUBOID_NAME(p, NT, N, H, W, P, C, FC, FH, FW, FP) \ + BM_CONCAT(BM_##p##_CPU_##NT##T_in_##N##_##H##_##W##_##P##_##C, \ + _f_##FC##_##FH##_##FW##_##FP) + +#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ + static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \ + FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ + CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ + } \ + BENCHMARK( \ + BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP)) + +#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \ + static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ + FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ + CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ + } \ + BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \ + FH, FW, FP)) + +#define BM_CuboidConvolutionBwdKernel(NT, N, H, W, P, C, FC, FH, FW, FP, \ + LABEL) \ + static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \ + FC, FH, FW, FP)(int iters) { \ + ::tensorflow::testing::SetLabel(LABEL); \ + CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \ + } \ + BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \ + FH, FW, FP)) + +#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \ + BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolution(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolution(16, N, H, W, P, C, FC, FH, FW, FP, LABEL); + +#define BM_CuboidConvolutionsBwdInput(N, H, W, P, C, FC, FH, FW, FP, LABEL) \ + BM_CuboidConvolutionBwdInput(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolutionBwdInput(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolutionBwdInput(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolutionBwdInput(16, N, H, W, P, C, FC, FH, FW, FP, LABEL); + +#define BM_CuboidConvolutionsBwdKernel(N, H, W, P, C, FC, FH, FW, FP, LABEL) \ + BM_CuboidConvolutionBwdKernel(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolutionBwdKernel(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolutionBwdKernel(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \ + BM_CuboidConvolutionBwdKernel(16, N, H, W, P, C, FC, FH, FW, FP, LABEL); + +// Random Cuboid Convolutions ----------------------------------------------- // +// TODO(ezhulenev): find representative dims for cuboid convolutions (find +// models using Conv3D ops). + +BM_CuboidConvolutions(8, // batch size + 25, 25, 25, 4, // input: height, width, panes, depth + 16, 5, 5, 5, // filter: count, height, width, panes + "conv3d_depth4"); +BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); + +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); + +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4"); +BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8"); +BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1"); +BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2"); diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h index 62e9f9123d..c41fbc42d3 100644 --- a/tensorflow/core/kernels/eigen_cuboid_convolution.h +++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h @@ -21,6 +21,1362 @@ limitations under the License. namespace Eigen { +namespace internal { + +// WARNING: Most of the code here implicitly assumes that the matrix is in +// ColMajor layout. This is guaranteed by the tensor contraction (see +// TensorContraction.h). +// +// Inside Eigen a tensor contraction is represented by a matrix multiplication. +// We don't want to actually extract volume patches and reshape the result into +// a matrix (this involves allocating huge extra memory), so the patch +// extraction and reshape operations are implicit. +// +// TensorContractionInputMapper takes a matrix index and returns the coefficient +// (or the packet) of the "virtual tensor", that would be at that index if we +// were to actually reshape the result of patch extraction. +// +// TensorContractionSubMapper provides a similar view into the "virtual matrix" +// at the given vertical and horizontal offsets. +// +// "Virtual matrix" dimensions: +// *0: kernelChannels * kernelDepth * kernelRows * kernelCols; +// 1: out_depth * out_height * out_width; * OTHERS (e.g batches, etc...) +// +// *) extracted patches are continuous in memory (innermost dimension assuming +// col major layout) +// +// With this dimensions: +// row - offset within a single patch (in code: patchId) +// col - index of the extracted patch (in code: patchIndex) +// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions) +// +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar_, + typename Index, typename nocontract_t, typename contract_t, int Side, + int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, + int Alignment> +class TensorContractionInputMapper< + Scalar_, Index, Side, + TensorEvaluator<const TensorReshapingOp<NewDimension, + const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef Scalar_ Scalar; + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper VectorMapper; + typedef SubMapper LinearMapper; + typedef typename packet_traits<Scalar>::type Packet; + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper( + const TensorEvaluator< + const TensorReshapingOp< + NewDimension, + const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >, + Device>& tensor, + const nocontract_t&, const nocontract_t&, const contract_t&, + const contract_t&) + : m_impl(tensor.impl().impl()) { + if (internal::traits<ArgType>::Layout == ColMajor) { + m_patch_depth = tensor.impl().dimensions()[0]; + m_patch_planes = tensor.impl().dimensions()[1]; + m_patch_rows = tensor.impl().dimensions()[2]; + m_patch_cols = tensor.impl().dimensions()[3]; + m_num_patches = tensor.impl().dimensions()[4]; + } else { + const int NumDims = tensor.impl().dimensions().size(); + m_patch_depth = tensor.impl().dimensions()[NumDims - 1]; + m_patch_planes = tensor.impl().dimensions()[NumDims - 2]; + m_patch_rows = tensor.impl().dimensions()[NumDims - 3]; + m_patch_cols = tensor.impl().dimensions()[NumDims - 4]; + m_num_patches = tensor.impl().dimensions()[NumDims - 5]; + } + + // Strides for the output tensor. + // IMPORTANT: These strides are used to locate an element in a patch at a + // depth zero (channel), which is not quite the same as "traditional" + // stride. + m_rowStride = m_patch_planes; + m_colStride = m_patch_rows * m_rowStride; + m_patchStride = m_colStride * m_patch_cols * m_patch_depth; + m_otherStride = m_patchStride * m_num_patches; + + m_outputPlanes = tensor.impl().outputPlanes(); + m_outputRows = tensor.impl().outputRows(); + m_outputCols = tensor.impl().outputCols(); + + m_outputPlanesRows = m_outputPlanes * m_outputRows; + + m_plane_strides = tensor.impl().userPlaneStride(); + m_row_strides = tensor.impl().userRowStride(); + m_col_strides = tensor.impl().userColStride(); + + m_in_plane_strides = tensor.impl().userInPlaneStride(); + m_in_row_strides = tensor.impl().userInRowStride(); + m_in_col_strides = tensor.impl().userInColStride(); + + m_patch_plane_inflate_strides = tensor.impl().planeInflateStride(); + m_patch_row_inflate_strides = tensor.impl().rowInflateStride(); + m_patch_col_inflate_strides = tensor.impl().colInflateStride(); + + if (internal::traits<ArgType>::Layout == ColMajor) { + m_inputDepth = tensor.impl().impl().dimensions()[0]; + m_inputPlanes = tensor.impl().impl().dimensions()[1]; + m_inputRows = tensor.impl().impl().dimensions()[2]; + m_inputCols = tensor.impl().impl().dimensions()[3]; + } else { + const int NumDims = tensor.impl().impl().dimensions().size(); + m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1]; + m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2]; + m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3]; + m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4]; + } + + // Strides for navigating through the input tensor. + m_planeInputStride = m_inputDepth; + m_rowInputStride = m_inputDepth * m_inputPlanes; + m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes; + m_patchInputStride = + m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes; + + m_planePaddingTop = tensor.impl().planePaddingTop(); + m_rowPaddingTop = tensor.impl().rowPaddingTop(); + m_colPaddingLeft = tensor.impl().colPaddingLeft(); + + m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches); + + m_fastInputPlaneStride = + internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides); + m_fastInputRowStride = + internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides); + m_fastInputColStride = + internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides); + + m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride); + m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride); + + m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth); + m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); + m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes); + m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows); + m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols); + + m_fastOutputPlanesRows = + internal::TensorIntDivisor<Index>(m_outputPlanesRows); + } + + EIGEN_DEVICE_FUNC + TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper) + : m_impl(base_mapper.m_impl) { + m_patch_depth = base_mapper.m_patch_depth; + m_patch_planes = base_mapper.m_patch_planes; + m_patch_rows = base_mapper.m_patch_rows; + m_patch_cols = base_mapper.m_patch_cols; + m_num_patches = base_mapper.m_num_patches; + + m_rowStride = base_mapper.m_rowStride; + m_colStride = base_mapper.m_colStride; + m_patchStride = base_mapper.m_patchStride; + m_otherStride = base_mapper.m_otherStride; + + m_planeInputStride = base_mapper.m_planeInputStride; + m_rowInputStride = base_mapper.m_rowInputStride; + m_colInputStride = base_mapper.m_colInputStride; + m_patchInputStride = base_mapper.m_patchInputStride; + m_otherInputStride = base_mapper.m_otherInputStride; + + m_inputDepth = base_mapper.m_inputDepth; + m_inputPlanes = base_mapper.m_inputPlanes; + m_inputRows = base_mapper.m_inputRows; + m_inputCols = base_mapper.m_inputCols; + + m_outputPlanes = base_mapper.m_outputPlanes; + m_outputRows = base_mapper.m_outputRows; + m_outputCols = base_mapper.m_outputCols; + + m_plane_strides = base_mapper.m_plane_strides; + m_row_strides = base_mapper.m_row_strides; + m_col_strides = base_mapper.m_col_strides; + + m_in_plane_strides = base_mapper.m_in_plane_strides; + m_in_row_strides = base_mapper.m_in_row_strides; + m_in_col_strides = base_mapper.m_in_col_strides; + + m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides; + m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides; + m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides; + + m_planePaddingTop = base_mapper.m_planePaddingTop; + m_rowPaddingTop = base_mapper.m_rowPaddingTop; + m_colPaddingLeft = base_mapper.m_colPaddingLeft; + + m_outputPlanesRows = base_mapper.m_outputPlanesRows; + + m_fastNumPatches = base_mapper.m_fastNumPatches; + m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride; + m_fastInputRowStride = base_mapper.m_fastInputRowStride; + m_fastInputColStride = base_mapper.m_fastInputColStride; + m_fastRowStride = base_mapper.m_fastRowStride; + m_fastColStride = base_mapper.m_fastColStride; + m_fastOutputPlanes = base_mapper.m_fastOutputPlanes; + m_fastOutputRows = base_mapper.m_fastOutputRows; + m_fastOutputCols = base_mapper.m_fastOutputCols; + m_fastDimZero = base_mapper.m_fastDimZero; + m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows; + } + + // If true, turns off some optimizations for loading packets since the image + // patches are "non-standard" such as there are non-trivial strides or + // inflations in the input. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_in_plane_strides != 1 || m_in_row_strides != 1 || + m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 || + m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1; + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const { + return SubMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const { + return LinearMapper(*this, i, j); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the coefficient at the patchIndex location instead of the usual + // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the + // gpu code. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + // Load the packet at the patchIndex location instead of the usual m_rowIndex, + // m_colIndex, m_otherIndex. This is currently only used by the gpu code. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const { + Index planeIndex, rowIndex, colIndex, otherIndex; + computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex); + return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const { + return m_impl; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_patch_depth; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_patch_planes; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_impl.template packet<Unaligned>(inputIndex); + } + + private: + friend class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>; + + // Load coefficient from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset * m_in_col_strides; + const Index origInputCol = + (m_patch_col_inflate_strides == 1) + ? inputCol + : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0); + + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset * m_in_row_strides; + const Index origInputRow = + (m_patch_row_inflate_strides == 1) + ? inputRow + : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0); + + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides; + const Index origInputPlane = + (m_patch_plane_inflate_strides == 1) + ? inputPlane + : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0); + + if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 || + origInputCol >= m_inputCols || origInputRow >= m_inputRows || + origInputPlane >= m_inputPlanes || + (inputCol != origInputCol * m_patch_col_inflate_strides) || + (inputRow != origInputRow * m_patch_row_inflate_strides) || + (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + origInputPlane * m_planeInputStride + + origInputRow * m_rowInputStride + + origInputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // This is the same as loadCoeff(...), but optimized for all `inflate_strides` + // and `in_strides` equal to 1 (template specialization without templates). + EIGEN_DEVICE_FUNC + EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + eigen_assert(!nonStandardPatches()); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset; + + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset; + + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 || + inputRow >= m_inputRows || inputPlane < 0 || + inputPlane >= m_inputPlanes) { + return Scalar(0); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + + return m_impl.coeff(inputIndex); + } + + // Load packet from a patch specified by the "within patch offset" + // (patchId) and the precomputed indices of the first element of the patch. + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits<Packet>::size; + + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + if (nonStandardPatches()) { + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits<Packet>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + eigen_assert(!nonStandardPatches()); + + if ((patchDepth() % packetSize) == 0) { + return loadPacketFast(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } else { + // Offsets and input calculation here are identical to + // loadCoeffStandard(...), but repeated twice. + + const Index patchOffsets[2] = { + patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero}; + + const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride, + patchOffsets[1] / m_fastColStride}; + eigen_assert(colOffsets[0] <= colOffsets[1]); + + const Index inputCols[2] = {colIndex + colOffsets[0], + colIndex + colOffsets[1]}; + if (inputCols[0] >= m_inputCols || inputCols[1] < 0) { + return internal::pset1<Packet>(Scalar(0)); + } + + if (inputCols[0] == inputCols[1]) { + const Index rowOffsets[2] = { + (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride, + (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride}; + eigen_assert(rowOffsets[0] <= rowOffsets[1]); + const Index inputRows[2] = {rowIndex + rowOffsets[0], + rowIndex + rowOffsets[1]}; + + if (inputRows[0] >= m_inputRows || inputRows[1] < 0) { + return internal::pset1<Packet>(Scalar(0)); + } + + if (inputRows[0] == inputRows[1]) { + const Index planeOffsets[2] = { + patchOffsets[0] - colOffsets[0] * m_colStride - + rowOffsets[0] * m_rowStride, + patchOffsets[1] - colOffsets[1] * m_colStride - + rowOffsets[1] * m_rowStride}; + eigen_assert(planeOffsets[0] <= planeOffsets[1]); + const Index inputPlanes[2] = {planeIndex + planeOffsets[0], + planeIndex + planeOffsets[1]}; + + if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) { + return internal::pset1<Packet>(Scalar(0)); + } + + if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) { + const Index depth = patchId - patchOffsets[0] * patchDepth(); + const Index inputIndex = + depth + inputPlanes[0] * m_planeInputStride + + inputRows[0] * m_rowInputStride + + inputCols[0] * m_colInputStride + otherIndex; + return m_impl.template packet<Unaligned>(inputIndex); + } + } + } + } + + return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex, + otherIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex, + Index rowIndex, Index colIndex, + Index otherIndex) const { + const Index packetSize = internal::unpacket_traits<Packet>::size; + EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE) + eigen_assert(patchId < + patchDepth() * patchPlanes() * patchRows() * patchCols()); + + eigen_assert(!nonStandardPatches()); + eigen_assert((patchDepth() % packetSize) == 0); + + // Find the offset of the element wrt the location of the first element. + const Index patchOffset = patchId / m_fastDimZero; + eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset); + + const Index colOffset = patchOffset / m_fastColStride; + const Index inputCol = colIndex + colOffset; + const Index rowOffset = + (patchOffset - colOffset * m_colStride) / m_fastRowStride; + const Index inputRow = rowIndex + rowOffset; + const Index planeOffset = + patchOffset - colOffset * m_colStride - rowOffset * m_rowStride; + const Index inputPlane = planeIndex + planeOffset; + + if (inputCol < 0 || inputRow < 0 || inputPlane < 0 || + inputCol >= m_inputCols || inputRow >= m_inputRows || + inputPlane >= m_inputPlanes) { + return internal::pset1<Packet>(Scalar(0)); + } + + const Index depth = patchId - patchOffset * patchDepth(); + const Index inputIndex = depth + inputPlane * m_planeInputStride + + inputRow * m_rowInputStride + + inputCol * m_colInputStride + otherIndex; + return m_impl.template packet<Unaligned>(inputIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex, + Index colIndex, Index otherIndex) const { + const int packetSize = internal::unpacket_traits<Packet>::size; + EIGEN_ALIGN_MAX + typename internal::remove_const<Scalar>::type values[packetSize]; + for (int i = 0; i < packetSize; ++i) { + values[i] = + loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex); + } + Packet rslt = internal::pload<Packet>(values); + return rslt; + } + + // Precompute the indices (plane, row, col, other) of the first element of + // the given patch index, within the output tensor of the TensorVolumePatchOp. + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices( + Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex, + Index& otherIndex) const { + const int NumInputDims = array_size< + typename TensorEvaluator<ArgType, Device>::Dimensions>::value; + + // Check if patchIndex might contain batch and other dimensions. + otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches; + + // Compute index of the patch within the batch (and other dimensions). + const Index patch3DIndex = (NumInputDims == 4) + ? patchIndex + : (patchIndex - otherIndex * m_num_patches); + + otherIndex *= m_patchInputStride; + + colIndex = patch3DIndex / m_fastOutputPlanesRows; + rowIndex = + (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes; + planeIndex = + patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes; + + colIndex = colIndex * m_col_strides - m_colPaddingLeft; + rowIndex = rowIndex * m_row_strides - m_rowPaddingTop; + planeIndex = planeIndex * m_plane_strides - m_planePaddingTop; + } + + Index m_patch_depth; // number of channels in the patch + Index m_patch_planes; // number of planes in the patch + Index m_patch_rows; // number of rows in the patch + Index m_patch_cols; // number of columns in the patch + Index m_num_patches; // number of patches to extract + + // Strides for the output tensor. + Index m_rowStride; + Index m_colStride; + Index m_patchStride; + Index m_otherStride; + + Index m_planeInputStride; // Plane stride in the input tensor + Index m_rowInputStride; // Row stride in the input tensor + Index m_colInputStride; // Col stride in the input tensor + Index m_patchInputStride; // Patch stride in the input tensor + Index m_otherInputStride; + + Index m_inputDepth; // Depth of the input tensor + Index m_inputPlanes; // Number of planes in the input tensor + Index m_inputRows; // Number of rows in the input tensor + Index m_inputCols; // Number of cols in the input tensor + + Index m_outputPlanes; // Number of output planes + Index m_outputRows; // Number of output rows + Index m_outputCols; // Number of output cols + Index m_outputPlanesRows; // Cached outputPlanes * outputRows. + + Index m_plane_strides; // User specified plane stride + Index m_row_strides; // User specified row stride + Index m_col_strides; // User specified col stride + + // User specified plane/row/col atrous convolution strides. + Index m_in_plane_strides; + Index m_in_row_strides; + Index m_in_col_strides; + + // User specified plane/row/col inflation strides in the image patch. + Index m_patch_plane_inflate_strides; + Index m_patch_row_inflate_strides; + Index m_patch_col_inflate_strides; + + Index m_planePaddingTop; // Plane padding + Index m_rowPaddingTop; // Row padding + Index m_colPaddingLeft; // Column padding + + // Fast representation of various divisors. + internal::TensorIntDivisor<Index> m_fastNumPatches; + + internal::TensorIntDivisor<Index> m_fastInputPlaneStride; + internal::TensorIntDivisor<Index> m_fastInputRowStride; + internal::TensorIntDivisor<Index> m_fastInputColStride; + + internal::TensorIntDivisor<Index> m_fastRowStride; + internal::TensorIntDivisor<Index> m_fastColStride; + + internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth + internal::TensorIntDivisor<Index> m_fastOutputPlanes; + internal::TensorIntDivisor<Index> m_fastOutputRows; + internal::TensorIntDivisor<Index> m_fastOutputCols; + internal::TensorIntDivisor<Index> m_fastOutputPlanesRows; + + const TensorEvaluator<ArgType, Device> m_impl; +}; + +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, int Side, + int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, + int Alignment> +class TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp<NewDimension, + const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> { + public: + typedef typename packet_traits<Scalar>::type Packet; + typedef typename packet_traits<Scalar>::half HalfPacket; + + typedef TensorContractionInputMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + ParentMapper; + typedef TensorContractionSubMapper< + Scalar, Index, Side, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + Self; + typedef Self LinearMapper; + + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper), + m_depth_offset(vert_offset), + m_col_offset(horiz_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper( + const Self& base_mapper, Index vert_offset, Index horiz_offset) + : m_base_mapper(base_mapper.m_base_mapper), + m_depth_offset(vert_offset + base_mapper.m_depth_offset), + m_col_offset(horiz_offset + base_mapper.m_col_offset) { + m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const { + return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex, + m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i, + Index j) const { + return m_base_mapper(i + m_depth_offset, j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const { + return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i, + Index j) const { + return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset, + j + m_col_offset); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar + loadCoeffStandard(Index i) const { + return m_base_mapper.loadCoeffStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const { + return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex, + m_rowIndex, m_colIndex, m_otherIndex); + } + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet + loadPacketStandard(Index i) const { + return m_base_mapper.loadPacketStandard( + i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex); + } + template <typename Packet> + EIGEN_DEVICE_FUNC bool aligned(Index) const { + return false; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool nonStandardPatches() const { + return m_base_mapper.nonStandardPatches(); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchDepth() const { + return m_base_mapper.m_patch_depth; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchPlanes() const { + return m_base_mapper.m_patch_planes; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchRows() const { + return m_base_mapper.m_patch_rows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index patchCols() const { + return m_base_mapper.m_patch_cols; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth, + const Index baseIndex) const { + const Index inputIndex = depth + baseIndex; + return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex); + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const { + const Index p = m_planeIndex + plane; + return p < 0 || p >= m_base_mapper.m_inputPlanes; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padRow(const Index row) const { + const Index r = m_rowIndex + row; + return r < 0 || r >= m_base_mapper.m_inputRows; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE bool padCol(const Index col) const { + const Index c = m_colIndex + col; + return c < 0 || c >= m_base_mapper.m_inputCols; + } + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row, + const Index col) const { + const Index p = m_planeIndex + plane; + const Index r = m_rowIndex + row; + const Index c = m_colIndex + col; + return p * m_base_mapper.m_planeInputStride + + r * m_base_mapper.m_rowInputStride + + c * m_base_mapper.m_colInputStride + m_otherIndex; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index planeOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + const Index planeOffset = patchOffset - + colOffset * m_base_mapper.m_colStride - + rowOffset * m_base_mapper.m_rowStride; + return planeOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index rowOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + const Index rowOffset = + (patchOffset - colOffset * m_base_mapper.m_colStride) / + m_base_mapper.m_fastRowStride; + return rowOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index colOffset() const { + const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero; + const Index colOffset = patchOffset / m_base_mapper.m_fastColStride; + return colOffset; + } + + EIGEN_DEVICE_FUNC + EIGEN_ALWAYS_INLINE Index depthOffset() const { + const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth(); + return patchOffset; + } + + EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper + getLinearMapper(Index i, Index j) const { + return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset); + } + + private: + const ParentMapper& m_base_mapper; + Index m_depth_offset; // First row in the input matrix + Index m_col_offset; // First col in the input matrix + + // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base + // indices for the first element in a patch specified by col_offset + // (see computeBaseIndices(...) for details). + Index m_planeIndex; + Index m_rowIndex; + Index m_colIndex; + Index m_otherIndex; +}; + +// Arrange a block of the right input matrix (in our case it's always a "virtual +// matrix" constructed from extracted volume patches) in contiguous memory. +// +// Given column major input (A0 beside A1 in memory): +// A0 B0 C0 D0 E0 F0 G0 H0 ... +// A1 B1 C1 D1 E1 F1 G1 H1 ... +// A2 B2 C2 D2 E2 F2 G2 H2 ... +// A3 B3 C3 D3 E3 F3 G3 H3 ... +// A4 B4 C4 D4 E4 F4 G4 H4 ... +// A5 B5 C5 D5 E5 F5 G5 H5 ... +// A6 B6 C6 D6 E6 F6 G6 H6 ... +// A7 B7 C7 D7 E7 F7 G7 H7 ... +// A8 ... +// ... +// +// Packing yields row major output (A0 beside A1 in memory): +// A0 A1 A2 A3 A4 A5 A6 A7 +// B0 B1 B2 B3 B4 B5 B6 B7 +// C0 ... +// ... +// +// *) A, B, C, ... - patches extracted from the original input. +// *) nr - number of registers along the 'n' dimension. +// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix +// Multiplication" paper. +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, + int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered, + int Alignment, int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, packet_size, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + typedef typename packet_traits<Scalar>::type Packet; + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if ((packet_size % 4) == 0 && !non_standard_patches) { + const Index patch_depth = rhs.patchDepth(); + + if ((patch_depth % packet_size) == 0) { + const Index patch_cols = rhs.patchCols(); + const Index patch_rows = rhs.patchRows(); + const Index patch_planes = rhs.patchPlanes(); + + const Index startCol = rhs.colOffset(); + const Index max_cols = std::min<Index>( + Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) + + startCol, + patch_cols); + + for (Index c = startCol; c < max_cols; ++c) { + eigen_assert(k < peeled_k); + + const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; + const Index max_rows = std::min<Index>( + Eigen::divup( + peeled_k - c * patch_rows * patch_planes * patch_depth, + patch_planes * patch_depth) + + startRow, + patch_rows); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = startRow; r < max_rows; ++r) { + eigen_assert(k < peeled_k); + + const Index startPlane = + ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0; + const Index max_planes = std::min<Index>( + Eigen::divup( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth, // row + patch_depth) + + startPlane, + patch_planes); + + const bool pad_row0 = dm0.padRow(r); + const bool pad_row1 = dm1.padRow(r); + const bool pad_row2 = dm2.padRow(r); + const bool pad_row3 = dm3.padRow(r); + + for (Index p = startPlane; p < max_planes; ++p) { + eigen_assert(k < peeled_k); + + const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index startDepth = + ((c == startCol) && (r == startRow) && (p == startPlane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = std::min<Index>( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth - // row + p * patch_depth + // plane + startDepth, + patch_depth); + eigen_assert((max_depth - startDepth) % packet_size == 0); + + for (Index d = startDepth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock<Packet, 4> kernel; + kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 4> kernel; + kernel.packet[0] = dm0.loadPacketFast(k); + kernel.packet[1] = dm1.loadPacketFast(k); + kernel.packet[2] = dm2.loadPacketFast(k); + kernel.packet[3] = dm3.loadPacketFast(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 4> kernel; + kernel.packet[0] = dm0.loadPacketStandard(k); + kernel.packet[1] = dm1.loadPacketStandard(k); + kernel.packet[2] = dm2.loadPacketStandard(k); + kernel.packet[3] = dm3.loadPacketStandard(k); + ptranspose(kernel); + pstoreu(block + 0 * packet_size, kernel.packet[0]); + pstoreu(block + 1 * packet_size, kernel.packet[1]); + pstoreu(block + 2 * packet_size, kernel.packet[2]); + pstoreu(block + 3 * packet_size, kernel.packet[3]); + block += 4 * packet_size; + } + } + } + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Template specialization for packet_size = 2. We must special-case packet +// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>. +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, + int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous, + inner_dim_reordered, Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + typedef typename packet_traits<Scalar>::type Packet; + + const int packet_size = 2; + + const Index packet_cols4 = (cols / 4) * 4; + const Index peeled_k = (depth / packet_size) * packet_size; + const bool non_standard_patches = rhs.nonStandardPatches(); + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + Index k = 0; + if (!non_standard_patches) { + const Index patch_depth = rhs.patchDepth(); + + if ((patch_depth % packet_size) == 0) { + const Index patch_cols = rhs.patchCols(); + const Index patch_rows = rhs.patchRows(); + const Index patch_planes = rhs.patchPlanes(); + + const Index startCol = rhs.colOffset(); + const Index max_cols = std::min<Index>( + Eigen::divup(peeled_k, patch_rows * patch_planes * patch_depth) + + startCol, + patch_cols); + + for (Index c = startCol; c < max_cols; ++c) { + eigen_assert(k < peeled_k); + + const Index startRow = (c == startCol) ? rhs.rowOffset() : 0; + const Index max_rows = std::min<Index>( + Eigen::divup( + peeled_k - c * patch_rows * patch_planes * patch_depth, + patch_planes * patch_depth) + + startRow, + patch_rows); + + const bool pad_col0 = dm0.padCol(c); + const bool pad_col1 = dm1.padCol(c); + const bool pad_col2 = dm2.padCol(c); + const bool pad_col3 = dm3.padCol(c); + + for (Index r = startRow; r < max_rows; ++r) { + eigen_assert(k < peeled_k); + + const Index startPlane = + ((c == startCol) && (r == startRow)) ? rhs.planeOffset() : 0; + const Index max_planes = std::min<Index>( + Eigen::divup( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth, // row + patch_depth) + + startPlane, + patch_planes); + + const bool pad_row0 = dm0.padRow(r); + const bool pad_row1 = dm1.padRow(r); + const bool pad_row2 = dm2.padRow(r); + const bool pad_row3 = dm3.padRow(r); + + for (Index p = startPlane; p < max_planes; ++p) { + eigen_assert(k < peeled_k); + + const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p); + const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p); + const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p); + const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p); + + const Index idx0 = dm0.baseIndex(p, r, c); + const Index idx1 = dm1.baseIndex(p, r, c); + const Index idx2 = dm2.baseIndex(p, r, c); + const Index idx3 = dm3.baseIndex(p, r, c); + + const Index startDepth = + ((c == startCol) && (r == startRow) && (p == startPlane)) + ? rhs.depthOffset() + : 0; + const Index max_depth = std::min<Index>( + peeled_k - + c * patch_rows * patch_planes * patch_depth - // col + r * patch_planes * patch_depth - // row + p * patch_depth + // plane + startDepth, + patch_depth); + eigen_assert((max_depth - startDepth) % packet_size == 0); + + for (Index d = startDepth; d < max_depth; d += packet_size) { + eigen_assert(k < peeled_k); + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx0); + kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx1); + kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx2); + kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0)) + : rhs.packetNoPadding(d, idx3); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + k += packet_size; + } + } + } + } + + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = dm0.loadPacketFast(k); + kernel0.packet[1] = dm1.loadPacketFast(k); + kernel1.packet[0] = dm2.loadPacketFast(k); + kernel1.packet[1] = dm3.loadPacketFast(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } else { + for (; k < peeled_k; k += packet_size) { + PacketBlock<Packet, 2> kernel0; + PacketBlock<Packet, 2> kernel1; + kernel0.packet[0] = dm0.loadPacketStandard(k); + kernel0.packet[1] = dm1.loadPacketStandard(k); + kernel1.packet[0] = dm2.loadPacketStandard(k); + kernel1.packet[1] = dm3.loadPacketStandard(k); + ptranspose(kernel0); + ptranspose(kernel1); + pstoreu(block + 0 * packet_size, kernel0.packet[0]); + pstoreu(block + 1 * packet_size, kernel1.packet[0]); + pstoreu(block + 2 * packet_size, kernel0.packet[1]); + pstoreu(block + 3 * packet_size, kernel1.packet[1]); + block += 4 * packet_size; + } + } + } + if (!rhs.nonStandardPatches()) { + for (; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +// Special case for non-vectorized types such as float16 (packet_size = 1). +template <typename NewDimension, DenseIndex Planes, DenseIndex Rows, + DenseIndex Cols, typename ArgType, typename Device, typename Scalar, + typename Index, typename nocontract_t, typename contract_t, + bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment, + int nr> +struct gemm_pack_rhs< + Scalar, Index, + TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous, + inner_dim_reordered, Alignment>, + nr, ColMajor, false, false> { + typedef TensorContractionSubMapper< + Scalar, Index, Rhs, + TensorEvaluator<const TensorReshapingOp< + NewDimension, const TensorVolumePatchOp< + Planes, Rows, Cols, ArgType> >, + Device>, + nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered, + Alignment> + SubMapper; + typedef SubMapper DataMapper; + + EIGEN_DEVICE_FUNC + EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs, + Index depth, Index cols, Index stride = 0, + Index offset = 0) const { + eigen_assert(stride == 0); + eigen_assert(offset == 0); + + EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE); + + const Index packet_cols4 = (cols / 4) * 4; + + for (Index j2 = 0; j2 < packet_cols4; j2 += 4) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0); + const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1); + const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2); + const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3); + + if (!rhs.nonStandardPatches()) { + for (Index k = 0; k < depth; k++) { + block[0] = dm0.loadCoeffStandard(k); + block[1] = dm1.loadCoeffStandard(k); + block[2] = dm2.loadCoeffStandard(k); + block[3] = dm3.loadCoeffStandard(k); + block += 4; + } + } else { + for (Index k = 0; k < depth; k++) { + block[0] = dm0(k); + block[1] = dm1(k); + block[2] = dm2(k); + block[3] = dm3(k); + block += 4; + } + } + } + + // copy the remaining columns one at a time (nr==1) + for (Index j2 = packet_cols4; j2 < cols; ++j2) { + const SubMapper dm0 = rhs.getLinearMapper(0, j2); + for (Index k = 0; k < depth; k++) { + *block = dm0(k); + block += 1; + } + } + } +}; + +} // namespace internal + /** CuboidConvolution * \ingroup CXX11_NeuralNetworks_Module * diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h index a3d795813d..80ab745bfe 100644 --- a/tensorflow/core/kernels/eigen_volume_patch.h +++ b/tensorflow/core/kernels/eigen_volume_patch.h @@ -43,6 +43,7 @@ struct CustomTensorEvaluator { IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = false, + PreferBlockAccess = false, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = NumDims == 6, RawAccess = false diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD index 8bfa40304e..f2e0b2558f 100644 --- a/tensorflow/core/kernels/fuzzing/BUILD +++ b/tensorflow/core/kernels/fuzzing/BUILD @@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing") tf_ops_fuzz_target_lib("parse_tensor_op") +tf_ops_fuzz_target_lib("decode_compressed") + tf_ops_fuzz_target_lib("decode_json_example") diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc new file mode 100644 index 0000000000..0a56f4b63f --- /dev/null +++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc @@ -0,0 +1,45 @@ +/* Copyright 2018 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/kernels/fuzzing/fuzz_session.h" + +namespace tensorflow { +namespace fuzzing { + +class FuzzDecodeCompressed : public FuzzStringInputOp { + void BuildGraph(const Scope& scope) override { + auto input = + tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING); + auto d1 = tensorflow::ops::DecodeCompressed( + scope.WithOpName("d1"), input, + tensorflow::ops::DecodeCompressed::CompressionType("")); + auto d2 = tensorflow::ops::DecodeCompressed( + scope.WithOpName("d2"), input, + tensorflow::ops::DecodeCompressed::CompressionType("ZLIB")); + auto d3 = tensorflow::ops::DecodeCompressed( + scope.WithOpName("d3"), input, + tensorflow::ops::DecodeCompressed::CompressionType("GZIP")); + Scope grouper = + scope.WithControlDependencies(std::vector<tensorflow::Operation>{ + d1.output.op(), d2.output.op(), d3.output.op()}); + (void)tensorflow::ops::NoOp(grouper.WithOpName("output")); + } +}; + +STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed); + +} // namespace fuzzing +} // namespace tensorflow diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h index cd2873bdca..7710cf93d6 100644 --- a/tensorflow/core/kernels/gather_functor.h +++ b/tensorflow/core/kernels/gather_functor.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/type_traits.h" +#include "tensorflow/core/framework/variant.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/platform/prefetch.h" #include "tensorflow/core/platform/types.h" diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index ad0112e6cb..277ee2be02 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -113,10 +113,25 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { #endif generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator( slice_size, Tindices, Tparams, Tout, &error_loc); + +#ifdef INTEL_MKL +// Eigen implementation below is not highly performant. gather_nd_generator +// does not seem to be called in parallel, leading to very poor performance. +// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it +// needs to go through redundant operations like 'reshape', 'broadcast' and +// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but +// is considerably more efficient. +#pragma omp parallel for + for (Eigen::DenseIndex i = 0; i < batch_size; i++) { + const Eigen::array<Eigen::DenseIndex, 1> loc{i}; + gather_nd_generator(loc); + } +#else // INTEL_MKL Tscratch.device(d) = Tscratch.reshape(reshape_dims) .broadcast(broadcast_dims) .generate(gather_nd_generator) .sum(); +#endif // error_loc() returns -1 if there's no out-of-bounds index, // otherwise it returns the location of an OOB index in Tindices. diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index c7dbefa0b4..86146f75f4 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -123,8 +123,7 @@ class AutoTuneMap { string GetActionSummary(StringPiece action, const Parameters& params, const Config& config) { return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(), - std::string(action).c_str(), - params.ToString().c_str(), + string(action).c_str(), params.ToString().c_str(), config.ToString().c_str()); } diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc index bca1cff41c..2088c13586 100644 --- a/tensorflow/core/kernels/list_kernels.cc +++ b/tensorflow/core/kernels/list_kernels.cc @@ -77,9 +77,9 @@ static Status TensorListDeviceCopy( return Status::OK(); } -#define REGISTER_LIST_COPY(DIRECTION) \ - INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \ - TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy) +#define REGISTER_LIST_COPY(DIRECTION) \ + INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \ + TensorListDeviceCopy) REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE); REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST); @@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) { return Status::OK(); } -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName, - TensorListShape); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape); bool TensorList::Decode(const VariantTensorData& data) { tensors = data.tensors(); @@ -625,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16); #undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, - TensorList, TensorList::kTypeName, + TensorList, TensorListBinaryAdd<CPUDevice>); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, TensorList, - TensorList::kTypeName, TensorListZerosLike<CPUDevice>); } // namespace tensorflow diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc index c591226b76..a00bf700ca 100644 --- a/tensorflow/core/kernels/list_kernels.cu.cc +++ b/tensorflow/core/kernels/list_kernels.cu.cc @@ -94,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool); #undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU, - TensorList, TensorList::kTypeName, + TensorList, TensorListBinaryAdd<GPUDevice>); REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, TensorList, - TensorList::kTypeName, TensorListZerosLike<GPUDevice>); } // namespace tensorflow diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h index 066a1d603b..72581c9293 100644 --- a/tensorflow/core/kernels/list_kernels.h +++ b/tensorflow/core/kernels/list_kernels.h @@ -374,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, y->tensors.reserve(x.tensors.size()); for (const Tensor& t : x.tensors) { Tensor out_tensor; - TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor)); + AllocatorAttributes attr; + if (t.dtype() == DT_VARIANT) { + attr.set_on_host(true); + } + TF_RETURN_IF_ERROR( + c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr)); switch (out_tensor.dtype()) { #define DTYPE_CASE(dtype) \ case DataTypeToEnum<dtype>::value: \ @@ -385,6 +390,20 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x, TF_CALL_POD_TYPES(DTYPE_CASE) #undef DTYPE_CASE + + case DataTypeToEnum<Variant>::value: { + const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>(); + if (inner_x == nullptr) { + return errors::InvalidArgument("Input handle is not a list. Saw: '", + t.scalar<Variant>()().DebugString(), + "'"); + } + TensorList inner_y; + TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y)); + out_tensor.scalar<Variant>()() = std::move(inner_y); + break; + } + default: return errors::InvalidArgument( "Trying to compute zeros_like for unsupported dtype ", diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h index b43902e0b9..9198a98e47 100644 --- a/tensorflow/core/kernels/logistic-loss.h +++ b/tensorflow/core/kernels/logistic-loss.h @@ -86,7 +86,7 @@ class LogisticLossUpdater : public DualLossUpdater { } else { inverse_exp_term = 1 / (1 + exp(label * wx)); } - return inverse_exp_term * label * example_weight; + return -inverse_exp_term * label * example_weight; } // The smoothness constant is 4 since the derivative of logistic loss, which diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc index 2e8d9c623c..a495758861 100644 --- a/tensorflow/core/kernels/lookup_table_op.cc +++ b/tensorflow/core/kernels/lookup_table_op.cc @@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface { MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {} size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface { const auto key_values = key.flat<K>(); auto value_values = value->flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { value_values(i) = gtl::FindWithDefault( table_, SubtleMustCopyIfIntegral(key_values(i)), default_val); @@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); Tensor* keys; @@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface { } private: - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; std::unordered_map<K, V> table_ GUARDED_BY(mu_); }; @@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } size_t size() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return table_.size(); } @@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface { auto value_values = value->flat_inner_dims<V, 2>(); int64 value_dim = value_shape_.dim_size(0); - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (int64 i = 0; i < key_values.size(); ++i) { ValueArray* value_vec = gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i))); @@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); int64 size = table_.size(); int64 value_dim = value_shape_.dim_size(0); @@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface { int64 MemoryUsed() const override { int64 ret = 0; - mutex_lock l(mu_); + tf_shared_lock l(mu_); for (unsigned i = 0; i < table_.bucket_count(); ++i) { size_t bucket_size = table_.bucket_size(i); if (bucket_size == 0) { @@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface { private: TensorShape value_shape_; - // TODO(andreasst): consider using a read/write lock or a concurrent map mutable mutex mu_; typedef gtl::InlinedVector<V, 4> ValueArray; std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_); @@ -335,7 +333,7 @@ class MutableDenseHashTable final : public LookupInterface { } size_t size() const override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return num_entries_; } @@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface { auto value_matrix = value->shaped<V, 2>({num_elements, value_size}); const auto default_flat = default_value.flat<V>(); - mutex_lock l(mu_); + tf_shared_lock l(mu_); const auto key_buckets_matrix = key_buckets_.AccessTensor(ctx)->template matrix<K>(); const auto value_buckets_matrix = @@ -451,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface { } Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + tf_shared_lock l(mu_); Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx); Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx); TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor)); @@ -493,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface { TensorShape value_shape() const override { return value_shape_; } int64 MemoryUsed() const override { - mutex_lock l(mu_); + tf_shared_lock l(mu_); return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() + value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes(); } diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc index 460d65c5c2..9209ed2ab7 100644 --- a/tensorflow/core/kernels/loss_test.cc +++ b/tensorflow/core/kernels/loss_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/kernels/hinge-loss.h" #include "tensorflow/core/kernels/logistic-loss.h" +#include "tensorflow/core/kernels/poisson-loss.h" #include "tensorflow/core/kernels/smooth-hinge-loss.h" #include "tensorflow/core/kernels/squared-loss.h" #include "tensorflow/core/lib/core/errors.h" @@ -30,6 +31,24 @@ namespace { // TODO(sibyl-Aix6ihai): add a test to show the improvements of the Newton // modification detailed in readme.md +// This test checks that the dual value after update is optimal. +// At the optimum the dual value should be the opposite of the primal gradient. +// This does not hold at a point where the primal is not differentiable. +void TestComputeUpdatedDual(const DualLossUpdater &loss_updater, + const int num_loss_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) { + double new_dual = loss_updater.ComputeUpdatedDual( + num_loss_partitions, label, example_weight, current_dual, wx, + weighted_example_norm); + // The primal gradient needs to be computed after the weight update. + double new_wx = wx + (new_dual - current_dual) * num_loss_partitions * + weighted_example_norm * example_weight; + EXPECT_NEAR(new_dual, -loss_updater.PrimalLossDerivative(new_wx, label, 1.0), + 1e-5); +} + TEST(LogisticLoss, ComputePrimalLoss) { LogisticLossUpdater loss_updater; EXPECT_NEAR(0.693147, @@ -65,19 +84,12 @@ TEST(LogisticLoss, ComputeDualLoss) { TEST(LogisticLoss, ComputeUpdatedDual) { LogisticLossUpdater loss_updater; - EXPECT_NEAR(0.479, - loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.5 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); - - EXPECT_NEAR(-0.031, - loss_updater.ComputeUpdatedDual( - 2 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, 0.1 /* current_dual */, - -0.8 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.5 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); + TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, 0.1 /* current_dual */, + -0.8 /* wx */, 10.0 /* weighted_example_norm */); } TEST(SquaredLoss, ComputePrimalLoss) { @@ -126,19 +138,12 @@ TEST(SquaredLoss, ComputeDualLoss) { TEST(SquaredLoss, ComputeUpdatedDual) { SquaredLossUpdater loss_updater; - EXPECT_NEAR(0.336, - loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.3 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); - - EXPECT_NEAR(-0.427, - loss_updater.ComputeUpdatedDual( - 5 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, -0.4 /* current_dual */, - 0.8 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.3 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); + TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, -0.4 /* current_dual */, + 0.8 /* wx */, 10.0 /* weighted_example_norm */); } TEST(HingeLoss, ComputePrimalLoss) { @@ -207,48 +212,27 @@ TEST(HingeLoss, ConvertLabel) { TEST(HingeLoss, ComputeUpdatedDual) { HingeLossUpdater loss_updater; - // When label=1.0, example_weight=1.0, current_dual=0.5, wx=0.3 and - // weighted_example_norm=100.0, it turns out that the optimal value to update - // the dual to is 0.507 which is within the permitted range and thus should be - // the value returned. + // For the two tests belows, y*wx=1 after the update which is a + // non-differetiable point of the hinge loss and TestComputeUpdatedDual + // cannot be used. Check value of the dual variable instead. EXPECT_NEAR(0.507, loss_updater.ComputeUpdatedDual( 1 /* num partitions */, 1.0 /* label */, 1.0 /* example weight */, 0.5 /* current_dual */, 0.3 /* wx */, 100.0 /* weighted_example_norm */), 1e-3); - // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6, - // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that - // the optimal value to update the dual to is 0.384 which is within the - // permitted range and thus should be the value returned. EXPECT_NEAR(-0.416, loss_updater.ComputeUpdatedDual( 10 /* num partitions */, -1.0 /* label */, 1.0 /* example weight */, -0.4 /* current_dual */, 0.6 /* wx */, 10.0 /* weighted_example_norm */), 1e-3); - // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and - // weighted_example_norm=10.0, it turns out that the optimal value to update - // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range - // and hence the closest permitted value (0.0) should be returned instead. - EXPECT_NEAR(0.0, - loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, -0.5 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); - - // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and - // weighted_example_norm=10.0, it turns out that the optimal value to update - // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0] - // range and hence the closest permitted value (-1.0) should be returned - // instead. - EXPECT_NEAR(-1.0, - loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, -1.0 /* label */, - 2.0 /* example weight */, -1.0 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, -0.5 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, -1.0 /* label */, + 2.0 /* example weight */, -1.0 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); } TEST(SmoothHingeLoss, ComputePrimalLoss) { @@ -297,19 +281,75 @@ TEST(SmoothHingeLoss, ComputeDualLoss) { TEST(SmoothHingeLoss, ComputeUpdatedDual) { SmoothHingeLossUpdater loss_updater; - EXPECT_NEAR(0.336, - loss_updater.ComputeUpdatedDual( - 1 /* num partitions */, 1.0 /* label */, - 1.0 /* example weight */, 0.3 /* current_dual */, - 0.3 /* wx */, 10.0 /* weighted_example_norm */), - 1e-3); + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */, + 1.0 /* example weight */, 0.3 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); + TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */, + 1.0 /* example weight */, -0.4 /* current_dual */, + 0.8 /* wx */, 10.0 /* weighted_example_norm */); +} - EXPECT_NEAR(-0.427, - loss_updater.ComputeUpdatedDual( - 5 /* num partitions */, -1.0 /* label */, - 1.0 /* example weight */, -0.4 /* current_dual */, - 0.8 /* wx */, 10.0 /* weighted_example_norm */), +TEST(PoissonLoss, ComputePrimalLoss) { + PoissonLossUpdater loss_updater; + EXPECT_NEAR(1.0, + loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */, + 1.0 /* example weight */), 1e-3); + EXPECT_NEAR(21996.0, + loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */, + 1.0 /* example weight */), + 1.0); + EXPECT_NEAR(0.606, + loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR(6.64, + loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */, + 2.0 /* example weight */), + 1e-2); +} + +TEST(PoissonLoss, ComputeDualLoss) { + PoissonLossUpdater loss_updater; + // Dual is undefined. + EXPECT_NEAR( + std::numeric_limits<double>::max(), + loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + 0.0, + loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */, + 3.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + -0.847, + loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */, + 1.0 /* example weight */), + 1e-3); + EXPECT_NEAR( + -2.675, + loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */, + 3.0 /* example weight */), + 1e-3); +} + +TEST(PoissonLoss, ConvertLabel) { + PoissonLossUpdater loss_updater; + float example_label = -1.0; + // Negative label should throw an error. + Status status = loss_updater.ConvertLabel(&example_label); + EXPECT_FALSE(status.ok()); +} + +TEST(PoissonLoss, ComputeUpdatedDual) { + PoissonLossUpdater loss_updater; + TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */, + 1.0 /* example weight */, 0.5 /* current_dual */, + 0.3 /* wx */, 10.0 /* weighted_example_norm */); + TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */, + 1.0 /* example weight */, 0.0 /* current_dual */, + -0.8 /* wx */, 10.0 /* weighted_example_norm */); } } // namespace diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc index bdc3b5778f..dd89597369 100644 --- a/tensorflow/core/kernels/map_stage_op.cc +++ b/tensorflow/core/kernels/map_stage_op.cc @@ -410,8 +410,9 @@ class StagingMap : public ResourceBase { copy_or_move_tensors(&it->second, *key, *indices, tuple)); // Remove entry if all the values have been consumed - if (!std::any_of(it->second.begin(), it->second.end(), - std::mem_fn(&OptionalTensor::has_value))) { + if (!std::any_of( + it->second.begin(), it->second.end(), + [](const OptionalTensor& tensor) { return tensor.has_value(); })) { map_.erase(it); } @@ -444,8 +445,9 @@ class StagingMap : public ResourceBase { *key = it->first; // Remove entry if all the values have been consumed - if (!std::any_of(it->second.begin(), it->second.end(), - std::mem_fn(&OptionalTensor::has_value))) { + if (!std::any_of( + it->second.begin(), it->second.end(), + [](const OptionalTensor& tensor) { return tensor.has_value(); })) { map_.erase(it); } diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc index 10e468ce46..693ed8a8f0 100644 --- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc +++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc @@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase { // Exercises "delete_old_dirs". for (int i = 0; i < 2; ++i) { int directory_found = - Env::Default() - ->IsDirectory(std::string(io::Dirname(prefixes[i]))) - .code(); + Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code(); if (delete_old_dirs) { EXPECT_EQ(error::NOT_FOUND, directory_found); } else { diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h index cc4b6941b9..62aa7d5c29 100644 --- a/tensorflow/core/kernels/mirror_pad_op.h +++ b/tensorflow/core/kernels/mirror_pad_op.h @@ -103,6 +103,7 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>, IsAligned = false, PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess, BlockAccess = false, + PreferBlockAccess = false, Layout = TensorEvaluator<ArgType, Device>::Layout, CoordAccess = true, RawAccess = false diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index afbfaa83f3..52157ed5fb 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -300,19 +300,24 @@ template <typename T> class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> { public: static MklConvBwdFilterPrimitive<T>* Get( - const MklConvBwdFilterParams& convBwdFilterDims) { + const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) { MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr; - // look into the pool for reusable primitive - conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*>( + if (do_not_cache) { /* Create new primitive always */ + conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims); + } else { + // look into the pool for reusable primitive + conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*> ( MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter( convBwdFilterDims)); - if (conv_bwd_filter == nullptr) { - conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims); - MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter( - convBwdFilterDims, conv_bwd_filter); + if (conv_bwd_filter == nullptr) { + conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims); + MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter( + convBwdFilterDims, conv_bwd_filter); + } } + return conv_bwd_filter; } @@ -845,8 +850,13 @@ class MklConvCustomBackpropFilterOp MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims, diff_bias_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - conv_bwd_filter = - MklConvBwdFilterPrimitiveFactory<T>::Get(convBwdFilterDims); + + // MKL DNN allocates large buffers when a conv gradient filter primtive is + // created. So we don't cache conv backward primitives when the env + // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true. + bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled(); + conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get( + convBwdFilterDims, do_not_cache); auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc(); // allocate output tensors: diff_fitler and diff_bias (w bias) @@ -938,6 +948,9 @@ class MklConvCustomBackpropFilterOp if (diff_filter_reorder_required) { diff_filter.InsertReorderToUserMem(); } + + // delete primitive since it is not cached. + if (do_not_cache) delete conv_bwd_filter; } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index b5a98301e2..c38c9cc27c 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -174,7 +174,6 @@ class MklConvBwdInputPrimitive : public MklPrimitive { } }; - void Setup(const MklConvBwdInputParams& convBwdInputDims) { // create memory descriptors for convolution data w/ no specified format context_.diff_src_md.reset(new memory::desc( @@ -242,19 +241,23 @@ class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> { public: static MklConvBwdInputPrimitive<T>* Get( - const MklConvBwdInputParams& convBwdInputDims) { + const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) { MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr; - // look into the pool for reusable primitive - conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>( - MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput( - convBwdInputDims)); - - if (conv_bwd_input == nullptr) { + if (do_not_cache) { /* Always allocate primitive */ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims); - MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput( - convBwdInputDims, conv_bwd_input); + } else { + // look into the pool for reusable primitive + conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>( + MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput( + convBwdInputDims)); + if (conv_bwd_input == nullptr) { + conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims); + MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput( + convBwdInputDims, conv_bwd_input); + } } + return conv_bwd_input; } @@ -708,8 +711,18 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> { MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims, diff_dst_dims, strides, dilations, padding_left, padding_right, TFPaddingToMklDnnPadding(this->padding_)); - conv_bwd_input = - MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims); + + // We don't cache those primitves if the env variable + // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor + // includes potentialy large buffers. MKL DNN allocates buffers + // in the following cases + // 1. Legacy CPU without AVX512/AVX2, or + // 2. 1x1 convolution with stride != 1 + bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() && + (MklPrimitiveFactory<T>::IsLegacyPlatform() || + IsConv1x1StrideNot1(fwd_filter_dims, strides)); + conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims, + do_not_cache); auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc(); // allocate output tensor @@ -755,6 +768,11 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> { // execute convolution input bwd conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data); + + // delete primitive since it is not cached. + if (do_not_cache) { + delete conv_bwd_input; + } } catch (mkldnn::error& e) { string error_msg = "Status: " + std::to_string(e.status) + ", message: " + string(e.message) + ", in file " + diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index c6295c7280..184e0cb003 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -271,18 +271,23 @@ class MklConvFwdPrimitive : public MklPrimitive { template <typename T> class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> { public: - static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) { + static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims, + bool do_not_cache) { MklConvFwdPrimitive<T>* conv_fwd = nullptr; - // try to find a suitable one in pool - conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>( - MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims)); - - if (conv_fwd == nullptr) { + if (do_not_cache) { /* Always create new primitive */ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims); - MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims, - conv_fwd); + } else { + // try to find a suitable one in pool + conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>( + MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims)); + if (conv_fwd == nullptr) { + conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims); + MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims, + conv_fwd); + } } + return conv_fwd; } @@ -894,6 +899,17 @@ class MklConvOp : public OpKernel { // MKLDNN dilation starts from 0. for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1; + // In some cases, primitve descriptor includes potentialy large buffers, + // we don't cache those primitves if the env variable + // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers + // in the following cases + // 1. Legacy CPU without AVX512/AVX2, or + // 2. 1x1 convolution with stride != 1 + bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() && + (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) && + (MklPrimitiveFactory<T>::IsLegacyPlatform() || + IsConv1x1StrideNot1(filter_dims, strides)); + // get a conv2d fwd from primitive pool MklConvFwdPrimitive<T>* conv_fwd = nullptr; if (biasEnabled) { @@ -902,12 +918,14 @@ class MklConvOp : public OpKernel { MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims, dst_dims_mkl_order, strides, dilations, padding_left, padding_right); - conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims); + conv_fwd = MklConvFwdPrimitiveFactory<T>::Get( + convFwdDims, do_not_cache); } else { MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS, dst_dims_mkl_order, strides, dilations, padding_left, padding_right); - conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(convFwdDims); + conv_fwd = MklConvFwdPrimitiveFactory<T>::Get( + convFwdDims, do_not_cache); } // allocate output tensors output_tensor and filter_out_tensor @@ -952,6 +970,9 @@ class MklConvOp : public OpKernel { } else { conv_fwd->Execute(src_data, filter_data, dst_data); } + + // delete primitive since it is not cached. + if (do_not_cache) delete conv_fwd; } catch (mkldnn::error &e) { string error_msg = tensorflow::strings::StrCat( "Status: ", e.status, ", message: ", string(e.message), ", in file ", @@ -1062,7 +1083,7 @@ class MklConvOp : public OpKernel { #endif // Register 2D operations -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_CPU_2D(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ @@ -1079,16 +1100,16 @@ class MklConvOp : public OpKernel { .Label(mkl_op_registry::kMklOpLabel), \ MklDummyOp<CPUDevice, T>); -TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_float(REGISTER_MKL_CPU_2D); // Register 3D operations -#define REGISTER_MKL_CPU(T) \ +#define REGISTER_MKL_CPU_3D(T) \ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \ .Device(DEVICE_CPU) \ .TypeConstraint<T>("T") \ .Label(mkl_op_registry::kMklOpLabel), \ MklConvOp<CPUDevice, T, false>); -TF_CALL_float(REGISTER_MKL_CPU); +TF_CALL_float(REGISTER_MKL_CPU_3D); } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc new file mode 100644 index 0000000000..a055351337 --- /dev/null +++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc @@ -0,0 +1,407 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/ops/const_op.h" +#include "tensorflow/cc/ops/nn_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/public/session.h" + +#if defined(INTEL_MKL_DNN_ONLY) +#include "third_party/intel_mkl_dnn/include/mkldnn.h" +#include "tensorflow/core/util/mkl_util.h" +#endif + +// TODO(ezhulenev): Add numerical tests that will compare results of default +// (aka Eigen) convolutions with MKL convolutions. + +// -------------------------------------------------------------------------- // +// Performance Benchmarks. // +// -------------------------------------------------------------------------- // + +// Compare performance of default Tensorflow convolution kernels (Eigen) with +// MKL kernels on CPU. + +// Before running these benchmarks configure OpenMP environment variables: +// export KMP_BLOCKTIME=0 +// export OMP_NUM_THREADS=${num_threads} + +namespace tensorflow { + +struct Conv2DDimensions { + Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw) + : input_batches(n), + input_height(h), + input_width(w), + input_depth(c), + filter_count(fc), + filter_height(fh), + filter_width(fw) {} + + int input_batches; + int input_height; + int input_width; + int input_depth; + int filter_count; + int filter_height; + int filter_width; +}; + +static Tensor GetRandomTensor(const TensorShape& shape) { + Tensor tensor(DT_FLOAT, TensorShape(shape)); + tensor.flat<float>() = tensor.flat<float>().setRandom(); + return tensor; +} + +// Get a random Tensor for the Conv2D input. +static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) { + return GetRandomTensor({dims.input_batches, dims.input_height, + dims.input_width, dims.input_depth}); +} + +// Get a random Tensor for the Conv2D filter. +static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) { + return GetRandomTensor({dims.filter_height, dims.filter_width, + dims.input_depth, dims.filter_count}); +} + +// Get a random Tensor for the Conv2D output (assuming SAME padding). +static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) { + return GetRandomTensor({dims.input_batches, dims.input_height, + dims.input_width, dims.filter_count}); +} + +// Get a Tensor encoding Conv2D input shape. +static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) { + return test::AsTensor<int32>({dims.input_batches, dims.input_height, + dims.input_width, dims.input_depth}); +} + +// Get a Tensor encoding Conv2D filter shape. +static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) { + return test::AsTensor<int32>({dims.filter_height, dims.filter_width, + dims.input_depth, dims.filter_count}); +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Tensor NonMklTensor() { + MklDnnShape non_mkl_shape; + non_mkl_shape.SetMklTensor(false); + + auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize()); + Tensor tensor(DT_UINT8, {size}); + + non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(), + size * sizeof(uint8)); + return tensor; +} +#endif + +static Graph* DefaultConv2D(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + + Node* conv2d; + TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D") + .Input(input) + .Input(filter) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Finalize(graph, &conv2d)); + + return graph; +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Graph* MklConv2D(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + + Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl"); + + Node* conv2d; + TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D") + .Input(input) + .Input(filter) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "MklOp") + .Finalize(graph, &conv2d)); + + return graph; +} +#endif + +static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_sizes_t = GetInputSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input_sizes = + test::graph::Constant(graph, input_sizes_t, "input_sizes"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* conv2d_bwd_input; + TF_CHECK_OK( + NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput") + .Input(input_sizes) + .Input(filter) + .Input(out_backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Finalize(graph, &conv2d_bwd_input)); + + return graph; +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_sizes_t = GetInputSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input_sizes = + test::graph::Constant(graph, input_sizes_t, "input_sizes"); + Node* filter = test::graph::Constant(graph, filter_t, "filter"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl"); + + Node* conv2d_bwd_input; + TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"), + "_MklConv2DBackpropInput") + .Input(input_sizes) + .Input(filter) + .Input(out_backprop) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "MklOp") + .Finalize(graph, &conv2d_bwd_input)); + + return graph; +} +#endif + +static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) { + auto* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_sizes_t = GetFilterSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter_sizes = + test::graph::Constant(graph, filter_sizes_t, "filter_sizes"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* conv2d_bwd_filter; + TF_CHECK_OK( + NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter") + .Input(input) + .Input(filter_sizes) + .Input(out_backprop) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Finalize(graph, &conv2d_bwd_filter)); + + return graph; +} + +#if defined(INTEL_MKL_DNN_ONLY) +static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) { + Graph* graph = new Graph(OpRegistry::Global()); + + Tensor input_t = GetRandomInputTensor(dims); + Tensor filter_sizes_t = GetFilterSizesTensor(dims); + Tensor filter_t = GetRandomFilterTensor(dims); + Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding + + Node* input = test::graph::Constant(graph, input_t, "input"); + Node* filter_sizes = + test::graph::Constant(graph, filter_sizes_t, "filter_sizes"); + Node* out_backprop = + test::graph::Constant(graph, out_backprop_t, "out_backprop"); + + Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl"); + + Node* conv2d_bwd_filter; + TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"), + "_MklConv2DBackpropFilter") + .Input(input) + .Input(filter_sizes) + .Input(out_backprop) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Input(not_mkl_shape) + .Attr("T", DT_FLOAT) + .Attr("strides", {1, 1, 1, 1}) + .Attr("padding", "SAME") + .Attr("_kernel", "MklOp") + .Finalize(graph, &conv2d_bwd_filter)); + + return graph; +} +#endif + +// Macro arguments names: --------------------------------------------------- // +// N: batch size +// H: height +// W: width +// C: channels +// FC: filter count +// FH: filter height +// FW: filter width + +#define BM_CONCAT(a, b) a##b + +#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \ + BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW) + +// Flops computation in these benchmarks are the same as in +// eigen_benchmark_cpu_test.cc. + +#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (N) * (H) * (W) * (FC); \ + int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ + testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW)) + +#if defined(INTEL_MKL_DNN_ONLY) +#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ + BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); +#else +#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); +#endif + +#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (N) * (H) * (W) * (C); \ + int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \ + testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW)) + +#if defined(INTEL_MKL_DNN_ONLY) +#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ + BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); +#else +#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); +#endif + +#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \ + static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \ + FW)(int iters) { \ + testing::SetLabel(LABEL); \ + \ + int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \ + int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \ + testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \ + \ + Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \ + test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \ + } \ + BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW)) + +#if defined(INTEL_MKL_DNN_ONLY) +#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \ + BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL); +#else +#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \ + BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); +#endif + +// ImageNet Convolutions ---------------------------------------------------- // + +BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3"); +BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5"); +BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3"); +BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5"); +BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3"); +BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5"); +BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3"); + +BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3"); +BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5"); +BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3"); +BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5"); +BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3"); +BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5"); +BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3"); + +BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3"); +BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5"); +BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3"); +BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5"); +BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3"); +BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5"); +BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3"); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc index ec6d241e17..5398e6113f 100644 --- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc +++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc @@ -34,11 +34,11 @@ using mkldnn::prop_kind; template <typename T> void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) { - if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg && - fwdParams.alg_kind != pooling_avg_include_padding && - fwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(fwdParams.alg_kind == pooling_max || + fwdParams.alg_kind == pooling_avg || + fwdParams.alg_kind == pooling_avg_include_padding || + fwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = fwdParams.alg_kind; // create memory desc @@ -102,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, static_cast<void*>(const_cast<T*>(src_data))); context_.dst_mem->set_data_handle(static_cast<void*>(dst_data)); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(ws_data); } context_.fwd_stream->submit(context_.fwd_primitives); @@ -111,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data, context_.src_mem->set_data_handle(DummyData); context_.dst_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { // max pooling must have ws - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } @@ -120,11 +120,11 @@ template class MklPoolingFwdPrimitive<float>; template <typename T> void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) { - if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg && - bwdParams.alg_kind != pooling_avg_include_padding && - bwdParams.alg_kind != pooling_avg_exclude_padding) { - assert("Pooling algorithm kind is not supported\n"); - } + DCHECK(bwdParams.alg_kind == pooling_max || + bwdParams.alg_kind == pooling_avg || + bwdParams.alg_kind == pooling_avg_include_padding || + bwdParams.alg_kind == pooling_avg_exclude_padding) + << "Pooling algorithm kind is not supported"; context_.alg_kind = bwdParams.alg_kind; // check whether it is 2d or 3d @@ -190,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, static_cast<void*>(const_cast<T*>(diff_dst_data))); context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data)); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(const_cast<void*>(ws_data)); } @@ -199,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data, context_.diff_dst_mem->set_data_handle(DummyData); context_.diff_src_mem->set_data_handle(DummyData); if (context_.alg_kind == pooling_max) { - assert(ws_data != nullptr); + DCHECK(ws_data != nullptr); context_.ws_mem->set_data_handle(DummyData); } } diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc index f4cfc48af5..84385356e1 100644 --- a/tensorflow/core/kernels/mkl_relu_op.cc +++ b/tensorflow/core/kernels/mkl_relu_op.cc @@ -40,7 +40,6 @@ using mkldnn::memory; #include "mkl_dnn.h" #include "mkl_dnn_types.h" #endif -#include "tensorflow/core/platform/default/logging.h" #include "tensorflow/core/util/mkl_util.h" namespace tensorflow { diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc index 8bde966be9..cfab529662 100644 --- a/tensorflow/core/kernels/mkl_softmax_op.cc +++ b/tensorflow/core/kernels/mkl_softmax_op.cc @@ -50,6 +50,7 @@ class MklSoftmaxOp : public OpKernel { // src_tensor now points to the 0-th input of global data struct "context" size_t src_idx = 0; const Tensor& src_tensor = MklGetInput(context, src_idx); + const int input_dims = src_tensor.dims(); // Add: get MklShape MklDnnShape src_mkl_shape; @@ -62,7 +63,33 @@ class MklSoftmaxOp : public OpKernel { : src_tensor.shape(); auto src_dims = TFShapeToMklDnnDims(src_tf_shape); auto output_dims = src_dims; - + memory::format layout_type; + // In MKL, data format passed to mkl softmax op depends on dimension of the input tensor. + // Here "x" data format in MKL is used for 1 dim tensor, "nc" for 2 dim tensor, + // "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, and "ncdhw" for 5 dim tensor. + // Each of the simbols has the following meaning: + // n = batch, c = channels, t = sequence lenght, h = height, + // w = width, d = depth + switch (input_dims) { + case 1: + layout_type = memory::format::x; + break; + case 2: + layout_type = memory::format::nc; + break; + case 3: + layout_type = memory::format::tnc; + break; + case 4: + layout_type = memory::format::nchw; + break; + case 5: + layout_type = memory::format::ncdhw; + break; + default: + OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1")); + return; + } // Create softmax memory for src, dst: both are defined in mkl_util.h, // they are wrapper MklDnnData<T> src(&cpu_engine); @@ -75,7 +102,7 @@ class MklSoftmaxOp : public OpKernel { auto src_md = src_mkl_shape.IsMklTensor() ? src_mkl_shape.GetMklLayout() - : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc); + : memory::desc(src_dims, MklDnnType<T>(), layout_type); // src: setting memory descriptor and op memory descriptor // Basically following two functions maps the TF "src_tensor" to mkl @@ -84,10 +111,11 @@ class MklSoftmaxOp : public OpKernel { // data format is "nc" for src and dst; since the src and dst buffer is // always in 2D shape src.SetUsrMem(src_md, &src_tensor); - src.SetOpMemDesc(src_dims, memory::format::nc); + src.SetOpMemDesc(src_dims, layout_type); // creating a memory descriptor - int axis = 1; // axis to which softmax will be applied + // passing outermost dim as default axis, where the softmax is applied + int axis = input_dims - 1; auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring, src.GetOpMemDesc(), axis); auto softmax_fwd_pd = @@ -107,7 +135,7 @@ class MklSoftmaxOp : public OpKernel { output_mkl_shape.SetMklLayout(&dst_pd); output_mkl_shape.SetElemType(MklDnnType<T>()); output_mkl_shape.SetTfLayout(output_dims.size(), output_dims, - memory::format::nc); + layout_type); output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T))); } else { // then output is also TF shape output_mkl_shape.SetMklTensor(false); diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc index 5d9257e20b..81ce6d6e95 100644 --- a/tensorflow/core/kernels/non_max_suppression_op.cc +++ b/tensorflow/core/kernels/non_max_suppression_op.cc @@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context, } // Return intersection-over-union overlap between boxes i and j -static inline float IOUGreaterThanThreshold( - typename TTypes<float, 2>::ConstTensor boxes, int i, int j, - float iou_threshold) { - const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2)); - const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3)); - const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2)); - const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3)); - const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2)); - const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3)); - const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2)); - const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3)); - const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); - const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); - if (area_i <= 0 || area_j <= 0) return 0.0; - const float intersection_ymin = std::max<float>(ymin_i, ymin_j); - const float intersection_xmin = std::max<float>(xmin_i, xmin_j); - const float intersection_ymax = std::min<float>(ymax_i, ymax_j); - const float intersection_xmax = std::min<float>(xmax_i, xmax_j); - const float intersection_area = - std::max<float>(intersection_ymax - intersection_ymin, 0.0) * - std::max<float>(intersection_xmax - intersection_xmin, 0.0); - const float iou = intersection_area / (area_i + area_j - intersection_area); +template <typename T> +static inline bool IOUGreaterThanThreshold( + typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) { + const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2)); + const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3)); + const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2)); + const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3)); + const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2)); + const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3)); + const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2)); + const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3)); + const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i); + const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j); + if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0; + const T intersection_ymin = std::max<T>(ymin_i, ymin_j); + const T intersection_xmin = std::max<T>(xmin_i, xmin_j); + const T intersection_ymax = std::min<T>(ymax_i, ymax_j); + const T intersection_xmax = std::min<T>(xmax_i, xmax_j); + const T intersection_area = + std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) * + std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0)); + const T iou = intersection_area / (area_i + area_j - intersection_area); return iou > iou_threshold; } @@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold( return overlaps(i, j) > overlap_threshold; } +template <typename T> static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn( const Tensor& boxes, float threshold) { - typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>(); - return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1, - std::placeholders::_2, threshold); + typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>(); + return std::bind(&IOUGreaterThanThreshold<T>, boxes_data, + std::placeholders::_1, std::placeholders::_2, + static_cast<T>(threshold)); } static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( @@ -121,6 +123,7 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn( std::placeholders::_1, std::placeholders::_2, threshold); } +template <typename T> void DoNonMaxSuppressionOp( OpKernelContext* context, const Tensor& scores, int num_boxes, const Tensor& max_output_size, const float score_threshold, @@ -128,13 +131,13 @@ void DoNonMaxSuppressionOp( bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) { const int output_size = max_output_size.scalar<int>()(); - std::vector<float> scores_data(num_boxes); - std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin()); + std::vector<T> scores_data(num_boxes); + std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin()); // Data structure for selection candidate in NMS. struct Candidate { int box_index; - float score; + T score; }; auto cmp = [](const Candidate bs_i, const Candidate bs_j) { @@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp( std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)> candidate_priority_queue(cmp); for (int i = 0; i < scores_data.size(); ++i) { - if (scores_data[i] > score_threshold) { + if (static_cast<float>(scores_data[i]) > score_threshold) { candidate_priority_queue.emplace(Candidate({i, scores_data[i]})); } } std::vector<int> selected; - std::vector<float> selected_scores; + std::vector<T> selected_scores; Candidate next_candidate; while (selected.size() < output_size && !candidate_priority_queue.empty()) { @@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp( int num_valid_outputs = selected.size(); if (pad_to_max_output_size) { selected.resize(output_size, 0); - selected_scores.resize(output_size, 0); + selected_scores.resize(output_size, static_cast<T>(0)); } if (ptr_num_valid_outputs) { *ptr_num_valid_outputs = num_valid_outputs; @@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_); + auto suppress_check_fn = + CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_); const float score_threshold_val = std::numeric_limits<float>::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } private: float iou_threshold_; }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV2Op : public OpKernel { public: explicit NonMaxSuppressionV2Op(OpKernelConstruction* context) @@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel { if (!context->status().ok()) { return; } - auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val); + auto suppress_check_fn = + CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val); const float score_threshold_val = std::numeric_limits<float>::lowest(); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } }; @@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel { float score_threshold_val_; }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV3Op(OpKernelConstruction* context) @@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_); - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn); + DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn); } }; -template <typename Device> +template <typename Device, typename T> class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { public: explicit NonMaxSuppressionV4Op(OpKernelConstruction* context) @@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base { protected: void DoComputeAndPostProcess(OpKernelContext* context) override { auto suppress_check_fn = - CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_); + CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_); int num_valid_outputs; - DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_, - score_threshold_val_, suppress_check_fn, - pad_to_max_output_size_, &num_valid_outputs); + DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_, + score_threshold_val_, suppress_check_fn, + pad_to_max_output_size_, &num_valid_outputs); // Allocate scalar output tensor for number of indices computed. Tensor* num_outputs_t = nullptr; @@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel { auto suppress_check_fn = CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val); - DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size, - score_threshold_val, suppress_check_fn); + DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size, + score_threshold_val, suppress_check_fn); } }; REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU), NonMaxSuppressionOp<CPUDevice>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU), - NonMaxSuppressionV2Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV2Op<CPUDevice, Eigen::half>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU), - NonMaxSuppressionV3Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV3Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV3Op<CPUDevice, Eigen::half>); -REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU), - NonMaxSuppressionV4Op<CPUDevice>); +REGISTER_KERNEL_BUILDER( + Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4") + .TypeConstraint<Eigen::half>("T") + .Device(DEVICE_CPU), + NonMaxSuppressionV4Op<CPUDevice, Eigen::half>); REGISTER_KERNEL_BUILDER( Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index 876a1704c7..fc1c9003aa 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/optimization_registry.h" #include "tensorflow/core/common_runtime/placer.h" #include "tensorflow/core/common_runtime/rendezvous_mgr.h" #include "tensorflow/core/framework/function.h" @@ -104,13 +105,6 @@ class PartitionedCallOp : public AsyncOpKernel { for (auto d : lib->device_mgr()->ListDevices()) { device_set.AddDevice(d); } - Placer placer(graph.get(), &device_set); - OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); - - std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; - OP_REQUIRES_OK_ASYNC( - ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), - done); // The FunctionLibraryRuntime's library cannot be mutated from within // an OpKernel, so functions are instantiated in an overlay library. @@ -124,6 +118,47 @@ class PartitionedCallOp : public AsyncOpKernel { new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition()); overlay_libs_.emplace(lib, overlay_lib); + GraphOptimizationPassOptions optimization_options; + // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or + // make it possible to specify the relevant options via attributes. + SessionOptions session_options; + session_options.env = ctx->env(); + optimization_options.session_options = &session_options; + optimization_options.graph = &graph; + optimization_options.flib_def = overlay_lib; + optimization_options.device_set = &device_set; + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::PRE_PLACEMENT, optimization_options), + done); + Placer placer(graph.get(), &device_set); + OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PLACEMENT, optimization_options), + done); + OP_REQUIRES_OK_ASYNC( + ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, + optimization_options), + done); + + std::unordered_map<string, std::unique_ptr<Graph>> subgraphs; + OP_REQUIRES_OK_ASYNC( + ctx, PartitionHelper(device_set, std::move(graph), &subgraphs), + done); + optimization_options.graph = nullptr; + optimization_options.device_set = nullptr; + optimization_options.partition_graphs = &subgraphs; + OP_REQUIRES_OK_ASYNC(ctx, + OptimizationPassRegistry::Global()->RunGrouping( + OptimizationPassRegistry::POST_PARTITIONING, + optimization_options), + done); + auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>(); for (const auto& pair : subgraphs) { // TODO(akshayka): Fail gracefully if the set of devices corresponds @@ -175,7 +210,7 @@ class PartitionedCallOp : public AsyncOpKernel { TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value)); DataType dtype = attr_value->type(); if (dtype == DT_RESOURCE) { - ResourceHandle handle = args[index].flat<ResourceHandle>()(0); + const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0); node->set_assigned_device_name(handle.device()); } } diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h new file mode 100644 index 0000000000..f91244454e --- /dev/null +++ b/tensorflow/core/kernels/poisson-loss.h @@ -0,0 +1,109 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_ +#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_ + +#include <cmath> + +#include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +class PoissonLossUpdater : public DualLossUpdater { + public: + // Update is found by a Newton algorithm (see readme.md). + double ComputeUpdatedDual(const int num_loss_partitions, const double label, + const double example_weight, + const double current_dual, const double wx, + const double weighted_example_norm) const final { + // Newton algorithm converges quadratically so 10 steps will be largely + // enough to achieve a very good precision + static const int newton_total_steps = 10; + // Initialize the Newton optimization at x such that + // exp(x) = label - current_dual + const double y_minus_a = label - current_dual; + double x = (y_minus_a > 0) ? log(y_minus_a) : 0; + for (int i = 0; i < newton_total_steps; ++i) { + x = NewtonStep(x, num_loss_partitions, label, wx, example_weight, + weighted_example_norm, current_dual); + } + return label - exp(x); + } + + // Dual of poisson loss function. + // https://en.wikipedia.org/wiki/Convex_conjugate + double ComputeDualLoss(const double current_dual, const double example_label, + const double example_weight) const final { + // Dual of the poisson loss function is + // (y-a)*(log(y-a)-1), where a is the dual variable. + // It is defined only for a<y. + const double y_minus_a = example_label - current_dual; + if (y_minus_a == 0.0) { + // (y-a)*(log(y-a)-1) approaches 0 as y-a approaches 0. + return 0.0; + } + if (y_minus_a < 0.0) { + return std::numeric_limits<double>::max(); + } + return y_minus_a * (log(y_minus_a) - 1) * example_weight; + } + + double ComputePrimalLoss(const double wx, const double example_label, + const double example_weight) const final { + return (exp(wx) - wx * example_label) * example_weight; + } + + double PrimalLossDerivative(const double wx, const double label, + const double example_weight) const final { + return (exp(wx) - label) * example_weight; + } + + // TODO(chapelle): We need to introduce a maximum_prediction parameter, + // expose that parameter to the user and have this method return + // 1.0/maximum_prediction. + // Setting this at 1 for now, it only impacts the adaptive sampling. + double SmoothnessConstant() const final { return 1; } + + Status ConvertLabel(float* const example_label) const final { + if (*example_label < 0.0) { + return errors::InvalidArgument( + "Only non-negative labels can be used with the Poisson log loss. " + "Found example with label: ", *example_label); + } + return Status::OK(); + } + + private: + // One Newton step (see readme.md). + double NewtonStep(const double x, const int num_loss_partitions, + const double label, const double wx, + const double example_weight, + const double weighted_example_norm, + const double current_dual) const { + const double expx = exp(x); + const double numerator = + x - wx - num_loss_partitions * weighted_example_norm * + example_weight * (label - current_dual - expx); + const double denominator = + 1 + num_loss_partitions * weighted_example_norm * example_weight * expx; + return x - numerator / denominator; + } +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_ diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc index d8d589f5aa..8a3e3dc0a9 100644 --- a/tensorflow/core/kernels/qr_op_complex128.cc +++ b/tensorflow/core/kernels/qr_op_complex128.cc @@ -24,7 +24,13 @@ REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128); // cuSolver affecting older hardware. The cuSolver team is tracking the issue // (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable // this feature when a fix is available. -// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128); +REGISTER_KERNEL_BUILDER(Name("Qr") + .Device(DEVICE_GPU) + .TypeConstraint<complex128>("T") + .HostMemory("input") + .HostMemory("q") + .HostMemory("r"), + QrOp<complex128>); #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc index 63f2e03b3b..05537a0eaa 100644 --- a/tensorflow/core/kernels/qr_op_double.cc +++ b/tensorflow/core/kernels/qr_op_double.cc @@ -24,7 +24,13 @@ REGISTER_LINALG_OP("Qr", (QrOp<double>), double); // cuSolver affecting older hardware. The cuSolver team is tracking the issue // (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable // this feature when a fix is available. -// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double); +REGISTER_KERNEL_BUILDER(Name("Qr") + .Device(DEVICE_GPU) + .TypeConstraint<double>("T") + .HostMemory("input") + .HostMemory("q") + .HostMemory("r"), + QrOp<double>); #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc index 0b1a0aaa76..6aebd98186 100644 --- a/tensorflow/core/kernels/qr_op_float.cc +++ b/tensorflow/core/kernels/qr_op_float.cc @@ -24,7 +24,13 @@ REGISTER_LINALG_OP("Qr", (QrOp<float>), float); // cuSolver affecting older hardware. The cuSolver team is tracking the issue // (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable // this feature when a fix is available. -// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float); +REGISTER_KERNEL_BUILDER(Name("Qr") + .Device(DEVICE_GPU) + .TypeConstraint<float>("T") + .HostMemory("input") + .HostMemory("q") + .HostMemory("r"), + QrOp<float>); #endif } // namespace tensorflow diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc index c4d404259b..97ddc852f7 100644 --- a/tensorflow/core/kernels/queue_ops.cc +++ b/tensorflow/core/kernels/queue_ops.cc @@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel { } void Compute(OpKernelContext* context) override { - ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0); + const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0); handle_.AccessTensor(context)->flat<string>()(0) = ref.container(); handle_.AccessTensor(context)->flat<string>()(1) = ref.name(); context->set_output_ref(0, &mu_, handle_.AccessTensor(context)); diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc index 9cf953f4bf..8bfa44b2d0 100644 --- a/tensorflow/core/kernels/reduction_ops_max.cc +++ b/tensorflow/core/kernels/reduction_ops_max.cc @@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS); .TypeConstraint<int64>("Tidx") \ .HostMemory("reduction_indices"), \ ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>); + +REGISTER_GPU_KERNELS(Eigen::half); REGISTER_GPU_KERNELS(float); REGISTER_GPU_KERNELS(double); REGISTER_GPU_KERNELS(int64); diff --git a/tensorflow/core/kernels/reduction_ops_sum.cc b/tensorflow/core/kernels/reduction_ops_sum.cc index 5318d8c133..e4ca89eca3 100644 --- a/tensorflow/core/kernels/reduction_ops_sum.cc +++ b/tensorflow/core/kernels/reduction_ops_sum.cc @@ -76,7 +76,15 @@ REGISTER_KERNEL_BUILDER( .HostMemory("output") .HostMemory("reduction_indices"), ReductionOp<CPUDevice, int32, int64, Eigen::internal::SumReducer<int32>>); - +REGISTER_KERNEL_BUILDER( + Name("Sum") + .Device(DEVICE_GPU) + .TypeConstraint<int64>("T") + .TypeConstraint<int32>("Tidx") + .HostMemory("input") + .HostMemory("output") + .HostMemory("reduction_indices"), + ReductionOp<CPUDevice, int64, int32, Eigen::internal::SumReducer<int64>>); #endif #ifdef TENSORFLOW_USE_SYCL diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc index 5863a2c8e4..7edaaad8f7 100644 --- a/tensorflow/core/kernels/regex_full_match_op.cc +++ b/tensorflow/core/kernels/regex_full_match_op.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU), RegexFullMatchOp); +class StaticRegexFullMatchOp : public OpKernel { + public: + explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string pattern; + OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern)); + re_ = MakeUnique<RE2>(pattern); + OP_REQUIRES(ctx, re_->ok(), + errors::InvalidArgument("Invalid pattern: ", pattern, + ", error: ", re_->error())); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat<string>(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<bool>(); + for (size_t i = 0; i < input_flat.size(); ++i) { + output_flat(i) = RE2::FullMatch(input_flat(i), *re_); + } + } + + private: + std::unique_ptr<RE2> re_; +}; + +REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU), + StaticRegexFullMatchOp); + } // namespace tensorflow diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc index 194a711d98..26f107f940 100644 --- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc +++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc @@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts( std::unordered_set<string> retval; for (const string& node_name_and_port : node_names_and_ports) { const TensorId tid = ParseTensorName(node_name_and_port); - retval.emplace(std::string(tid.first)); + retval.emplace(tid.first); } return retval; } @@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) { const NodeDef* FindNodeDefByName(const string& input, const GraphDef& graph_def) { const TensorId tid = ParseTensorName(input); - const string name = std::string(tid.first); + const string name = string(tid.first); for (const NodeDef& node_def : graph_def.node()) { if (node_def.name() == name) { return &node_def; @@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap( std::vector<DataType> data_types; std::vector<TensorShape> shapes; const TensorId tid = ParseTensorName(name_and_port); - const string node_name = std::string(tid.first); + const string node_name(tid.first); const int port = tid.second; const NodeDef* node_def = FindNodeDefByName(node_name, graph_def); CHECK_NOTNULL(node_def); @@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType( const TensorShapeMap& tensor_shape_map, const string& node_name) { if (node_name.find(':') != string::npos) { const TensorId tid = ParseTensorName(node_name); - return GetTensorShapeType(tensor_shape_map, std::string(tid.first), - tid.second); + return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second); } else { return GetTensorShapeType(tensor_shape_map, node_name, 0); } @@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto( const TensorId tid = ParseTensorName(name); CHECK_EQ(tensor_shape_map->count(name), 0); tensor_shape_map->emplace( - std::string(tid.first), + string(tid.first), std::make_pair(tid.second, std::make_pair(tensor.dtype(), tensor.shape()))); } @@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::vector<NodeBuilder::NodeOut> node_out_list; for (const string& input : inputs) { const TensorId tid = ParseTensorName(input); - Node* node = FindMutableNodeByName(std::string(tid.first), graph); + Node* node = FindMutableNodeByName(string(tid.first), graph); CHECK_NOTNULL(node); node_out_list.emplace_back(node, tid.second); } @@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (const string& subgraph_input : std::get<1>(cluster)) { const TensorId tid = ParseTensorName(subgraph_input); - const string subgraph_input_name = std::string(tid.first); + const string subgraph_input_name(tid.first); const int subgraph_input_port = tid.second; const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def); CHECK_NOTNULL(node_def); @@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( std::deque<const Node*> queue; for (const string& output : border_outputs) { const TensorId tid = ParseTensorName(output); - const string& output_node_name = std::string(tid.first); + const string output_node_name(tid.first); for (const Node* node : graph.nodes()) { if (output_node_name == node->name()) { queue.push_back(node); @@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( for (int j = 0; j < border_outputs.size(); ++j) { const string& output = border_outputs.at(j); const TensorId tid = ParseTensorName(output); - const string output_name = std::string(tid.first); + const string output_name(tid.first); Node* src_node = edge->src(); if (src_node != nullptr && src_node->name() == output_name && edge->src_output() == tid.second) { @@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode( // RemoteFusedGraphExecuteOpNode for (const string& output : outputs) { const TensorId output_tid = ParseTensorName(output); - const string output_name = std::string(output_tid.first); + const string output_name(output_tid.first); for (size_t i = 0; i < border_outputs.size(); ++i) { const TensorId subgraph_output_tid = ParseTensorName(border_outputs.at(i)); - const string& subgraph_output_name = - std::string(subgraph_output_tid.first); + const string subgraph_output_name(subgraph_output_tid.first); if (output_name == subgraph_output_name) { LOG(INFO) << "As graph output and subgraph output are same, " << "the graph output node is replaced by identity node"; @@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions( GraphDef* graph_def) { const TensorId tid = ParseTensorName(input); CHECK_EQ(0, tid.second); - const string node_name = std::string(tid.first); + const string node_name(tid.first); for (NodeDef& node : *graph_def->mutable_node()) { if (node.name() != node_name) { continue; diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index ebcfb673d1..26705a8d34 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -79,7 +79,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { void ReadVariableOp::Compute(OpKernelContext* ctx) { Var* variable = nullptr; - ResourceHandle handle = HandleFromInput(ctx, 0); + const ResourceHandle& handle = HandleFromInput(ctx, 0); const auto status = LookupResource(ctx, handle, &variable); OP_REQUIRES(ctx, status.ok(), errors::FailedPrecondition( diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc index 15a707a9c6..cded417986 100644 --- a/tensorflow/core/kernels/reverse_sequence_op.cc +++ b/tensorflow/core/kernels/reverse_sequence_op.cc @@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) { OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, "), ", "(", seq_lens.NumElements(), - " vs. ", input.dim_size(batch_dim))); + " vs. ", input.dim_size(batch_dim), ")")); for (size_t d = 0; d < seq_lens_vec.size(); ++d) { OP_REQUIRES(context, seq_lens_vec[d] >= 0, @@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) { OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim), errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim, "), ", "(", seq_lens.NumElements(), - " vs. ", input.dim_size(batch_dim))); + " vs. ", input.dim_size(batch_dim), ")")); } template <> @@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel { auto seq_lens_t = seq_lens.vec<Tlen>(); CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_); + if (!context->status().ok()) return; const int input_dims = input.dims(); diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc index e335e38bdc..82546d581a 100644 --- a/tensorflow/core/kernels/save_restore_tensor.cc +++ b/tensorflow/core/kernels/save_restore_tensor.cc @@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context, // If we cannot find a cached reader we will allocate our own. std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader; - const checkpoint::TensorSliceReader* reader = - context->slice_reader_cache()->GetReader(file_pattern, open_func, - preferred_shard); + const checkpoint::TensorSliceReader* reader = nullptr; + + if (context->slice_reader_cache()) { + reader = context->slice_reader_cache()->GetReader(file_pattern, open_func, + preferred_shard); + } if (!reader) { allocated_reader.reset(new checkpoint::TensorSliceReader( file_pattern, open_func, preferred_shard)); diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc index ab4de6c815..180eb3ca34 100644 --- a/tensorflow/core/kernels/save_restore_v2_ops.cc +++ b/tensorflow/core/kernels/save_restore_v2_ops.cc @@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel { context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix)); if (delete_old_dirs_) { - const string& merged_dir = std::string(io::Dirname(merged_prefix)); + const string merged_dir(io::Dirname(merged_prefix)); for (const string& input_prefix : input_prefixes) { - const string& dirname = std::string(io::Dirname(input_prefix)); + const string dirname(io::Dirname(input_prefix)); if (dirname == merged_dir) continue; Status status = env->DeleteDir(dirname); // For sharded save, only the first delete will go through and all diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc index 1c071d3d41..a8e9b3261c 100644 --- a/tensorflow/core/kernels/sdca_internal.cc +++ b/tensorflow/core/kernels/sdca_internal.cc @@ -251,7 +251,7 @@ Status Examples::SampleAdaptiveProbabilities( num_weight_vectors); const double kappa = example_state_data(example_id, 0) + loss_updater->PrimalLossDerivative( - example_statistics.wx[0], label, example_weight); + example_statistics.wx[0], label, 1.0); probabilities_[example_id] = example_weight * sqrt(examples_[example_id].squared_norm_ + regularization.symmetric_l2() * diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc index 05c835ebc4..3bd4168dc7 100644 --- a/tensorflow/core/kernels/sdca_ops.cc +++ b/tensorflow/core/kernels/sdca_ops.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/kernels/hinge-loss.h" #include "tensorflow/core/kernels/logistic-loss.h" #include "tensorflow/core/kernels/loss.h" +#include "tensorflow/core/kernels/poisson-loss.h" #include "tensorflow/core/kernels/sdca_internal.h" #include "tensorflow/core/kernels/smooth-hinge-loss.h" #include "tensorflow/core/kernels/squared-loss.h" @@ -75,6 +76,8 @@ struct ComputeOptions { loss_updater.reset(new HingeLossUpdater); } else if (loss_type == "smooth_hinge_loss") { loss_updater.reset(new SmoothHingeLossUpdater); + } else if (loss_type == "poisson_loss") { + loss_updater.reset(new PoissonLossUpdater); } else { OP_REQUIRES( context, false, diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc index 9cd590ae61..30cb1e0a7f 100644 --- a/tensorflow/core/kernels/shape_op_test.cc +++ b/tensorflow/core/kernels/shape_op_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/abi.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) { REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE"); -REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE", - GetShapeFromKnownVecSize); +REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize); static void ExpectHasError(const Status& s, StringPiece substr) { EXPECT_TRUE(str_util::StrContains(s.ToString(), substr)) @@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) { Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs); EXPECT_FALSE(s.ok()); ExpectHasError( - s, - "No unary variant shape function found for Variant type_name: " - "NO KNOWN SHAPE"); + s, strings::StrCat( + "No unary variant shape function found for Variant type_index: ", + port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name()))); } { diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h index 11149c4d16..a4453bd7ab 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator.h +++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h @@ -50,10 +50,10 @@ class SparseConditionalAccumulator public: SparseConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape, - const string& name) + const string& name, const string& reduction_type) : TypedConditionalAccumulatorBase< std::tuple<const Tensor*, const Tensor*, const Tensor*>>( - dtype, shape, name) { + dtype, shape, name, reduction_type) { accum_idx_vec_ = nullptr; count_element_ = nullptr; accum_val_ = nullptr; diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc index 80bc1f1934..1e542a26a7 100644 --- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc +++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc @@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp { Creator GetCreator() const override { return [this](ConditionalAccumulatorBase** ret) { SparseConditionalAccumulator<Device, T>* accumulator = - new SparseConditionalAccumulator<Device, T>(dtype_, shape_, - cinfo_.name()); + new SparseConditionalAccumulator<Device, T>( + dtype_, shape_, cinfo_.name(), reduction_type_); *ret = accumulator; return Status::OK(); }; diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc index 7cc3c532c9..11db72bfa3 100644 --- a/tensorflow/core/kernels/split_op.cc +++ b/tensorflow/core/kernels/split_op.cc @@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel { void ComputeEasyCases(OpKernelContext* context, bool* done) { const Tensor& input = context->input(1); const TensorShape& input_shape = input.shape(); - const int32 split_dim_orig = context->input(0).flat<int32>()(0); + const Tensor& split_dim_tensor = context->input(0); + OP_REQUIRES( + context, split_dim_tensor.shape().dims() == 0, + errors::InvalidArgument("split_dim must be a scalar but has rank ", + split_dim_tensor.shape().dims())); + const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0); const int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; const int32 num_split = num_outputs(); diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc index 65296f61fd..add4afafc9 100644 --- a/tensorflow/core/kernels/stack_ops.cc +++ b/tensorflow/core/kernels/stack_ops.cc @@ -131,10 +131,8 @@ class Stack : public ResourceBase { }; Status GetStack(OpKernelContext* ctx, Stack** stack) { - string key; if (ctx->input_dtype(0) == DT_RESOURCE) { - auto resource = ctx->input(0).flat<ResourceHandle>()(0); - key = resource.name(); + return LookupResource(ctx, HandleFromInput(ctx, 0), stack); } else { Tensor Tstack_handle = ctx->mutable_input(0, false); if (Tstack_handle.NumElements() != 2) { @@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) { } const string& container = Tstack_handle.flat<string>()(0); const string& stack_name = Tstack_handle.flat<string>()(1); - key = strings::StrCat(container, stack_name); - } - ResourceMgr* rm = ctx->resource_manager(); - if (rm == nullptr) { - return errors::Internal("No resource manager."); - } - auto* step_container = ctx->step_container(); - if (step_container == nullptr) { - return errors::Internal("No step container."); + string key = strings::StrCat(container, stack_name); + ResourceMgr* rm = ctx->resource_manager(); + if (rm == nullptr) { + return errors::Internal("No resource manager."); + } + auto* step_container = ctx->step_container(); + if (step_container == nullptr) { + return errors::Internal("No step container."); + } + TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); + return Status::OK(); } - TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack)); - return Status::OK(); } std::atomic<int64> Stack::stack_counter{0}; diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc index 2aeafa28c4..544dca96ba 100644 --- a/tensorflow/core/kernels/string_strip_op.cc +++ b/tensorflow/core/kernels/string_strip_op.cc @@ -43,7 +43,7 @@ class StringStripOp : public OpKernel { for (int64 i = 0; i < input.size(); ++i) { StringPiece entry(input(i)); str_util::RemoveWhitespaceContext(&entry); - output(i) = std::string(entry); + output(i) = string(entry); } } }; diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 22e45918a0..07f1d6e767 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <cstddef> +#include <cstdlib> #include <string> #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" @@ -25,6 +27,8 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/bcast.h" namespace tensorflow { @@ -64,26 +68,28 @@ class SubstrOp : public OpKernel { const T len = tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); + StringPiece in(input(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size() + 1), + context, FastBoundsCheck(std::abs(pos), in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } } else { // Perform Op element-wise with tensor pos/len auto pos_flat = pos_tensor.flat<T>(); auto len_flat = len_tensor.flat<T>(); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { - string in = input(i); + StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, in.size() + 1), + context, FastBoundsCheck(std::abs(pos), in.size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } } } else { @@ -142,14 +148,16 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { - string in = input_bcast(i); + StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); OP_REQUIRES( - context, FastBoundsCheck(pos, input_bcast(i).size() + 1), + context, + FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), errors::InvalidArgument("pos ", pos, " out of range for string", "b'", in, "' at index ", i)); - output(i) = in.substr(pos, len); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i).assign(sub_in.data(), sub_in.size()); } break; } @@ -192,16 +200,18 @@ class SubstrOp : public OpKernel { // Iterate through broadcasted tensors and perform substr for (int i = 0; i < output_shape.dim_size(0); ++i) { for (int j = 0; j < output_shape.dim_size(1); ++j) { - string in = input_bcast(i, j); + StringPiece in(input_bcast(i, j)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1), - errors::InvalidArgument( - "pos ", pos, " out of range for ", "string b'", - in, "' at index (", i, ", ", j, ")")); - output(i, j) = in.substr(pos, len); + OP_REQUIRES( + context, FastBoundsCheck(std::abs(pos), in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", i, + ", ", j, ")")); + StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + output(i, j).assign(sub_in.data(), sub_in.size()); } } break; @@ -213,6 +223,16 @@ class SubstrOp : public OpKernel { } } } + + private: + // This adjusts the requested position. Note it does not perform any bound + // checks. + T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + if (pos_requested < 0) { + return s.size() + pos_requested; + } + return pos_requested; + } }; #define REGISTER_SUBSTR(type) \ diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc new file mode 100644 index 0000000000..2e07050260 --- /dev/null +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -0,0 +1,105 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <string> + +#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h" +#include "tensorflow/core/framework/allocator.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/graph/testlib.h" +#include "tensorflow/core/kernels/ops_testutil.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Test data from the TensorFlow README.md. +const char* lines[] = { + "**TensorFlow** is an open source software library for numerical " + "computation using data flow graphs.", + "The graph nodes represent mathematical operations, while the graph edges " + "represent the multidimensional data arrays (tensors) that flow between " + "them.", + "This flexible architecture enables you to deploy computation to one or " + "more CPUs or GPUs in a desktop, server, or mobile device without " + "rewriting code.", + "TensorFlow also includes " + "[TensorBoard](https://www.tensorflow.org/guide/" + "summaries_and_tensorboard), a data visualization toolkit.", + "TensorFlow was originally developed by researchers and engineers working " + "on the Google Brain team within Google's Machine Intelligence Research " + "organization for the purposes of conducting machine learning and deep " + "neural networks research.", + "The system is general enough to be applicable in a wide variety of other " + "domains, as well.", + "TensorFlow provides stable Python API and C APIs as well as without API " + "backwards compatibility guarantee like C++, Go, Java, JavaScript and " + "Swift."}; + +Tensor GetTestTensor(int batch) { + const int sz = TF_ARRAYSIZE(lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat<string>(); + for (int i = 0; i < batch; ++i) { + s(i) = lines[i % sz]; + } + return t; +} + +Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor position(DT_INT32, TensorShape({})); + position.flat<int32>().setConstant(pos); + Tensor length(DT_INT32, TensorShape({})); + length.flat<int32>().setConstant(len); + + TF_CHECK_OK(NodeBuilder("substr_op", "Substr") + .Input(test::graph::Constant(g, input)) + .Input(test::graph::Constant(g, position)) + .Input(test::graph::Constant(g, length)) + .Finalize(g, nullptr /* node */)); + return g; +} + +void BM_Substr(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters)); + testing::UseRealTime(); + Tensor input = GetTestTensor(batch_size); + Graph* g = SetupSubstrGraph(input, 3, 30); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg( + 256); + +} // end namespace tensorflow diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc index 632b65e9b6..fe93b91eb8 100644 --- a/tensorflow/core/kernels/tensor_array_ops.cc +++ b/tensorflow/core/kernels/tensor_array_ops.cc @@ -290,14 +290,14 @@ class TensorArrayGradOp : public TensorArrayCreationOp { } } else { container = "_tensor_arrays"; - auto resource = ctx->input(0).flat<ResourceHandle>()(0); + const auto& resource = ctx->input(0).flat<ResourceHandle>()(0); if (StringPiece(resource.name()).substr(0, container.size()) != container) { return errors::InvalidArgument("Wrong input container. ", resource.name()); } tensor_array_name = - std::string(StringPiece(resource.name()).substr(container.size())); + string(StringPiece(resource.name()).substr(container.size())); } auto output_handle = tensor_array_output_handle->flat<string>(); diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h index 9dedb618f9..ca341e511e 100644 --- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h +++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h @@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase { public: TypedConditionalAccumulatorBase(const DataType& dtype, const PartialTensorShape& shape, - const string& name) - : ConditionalAccumulatorBase(dtype, shape, name) {} + const string& name, + const string& reduction_type) + : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {} /** * Attempts to add a gradient to the accumulator. An ApplyGrad attempt is diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc index 62e814ff77..8d839ba85a 100644 --- a/tensorflow/core/kernels/unravel_index_op.cc +++ b/tensorflow/core/kernels/unravel_index_op.cc @@ -97,10 +97,12 @@ class UnravelIndexOp : public OpKernel { auto output = output_tensor->matrix<Tidx>(); - Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}}; - Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()}); - Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}}; - Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1}); + Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}}; + Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()}); + Eigen::array<Eigen::Index, 2> indices_reshape{ + {1, indices_tensor.NumElements()}}; + Eigen::array<Eigen::Index, 2> indices_bcast( + {dims_tensor.NumElements(), 1}); output = indices_tensor.vec<Tidx>() .reshape(indices_reshape) diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc index ed2bf3e8e2..1bf46b5e46 100644 --- a/tensorflow/core/kernels/whole_file_read_ops.cc +++ b/tensorflow/core/kernels/whole_file_read_ops.cc @@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel { "Contents tensor must be scalar, but had shape: ", contents_input->shape().DebugString())); const string& filename = filename_input->scalar<string>()(); - const string dir = std::string(io::Dirname(filename)); + const string dir(io::Dirname(filename)); if (!context->env()->FileExists(dir).ok()) { OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir)); } diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h index 49a8a4dbd4..d5cbe6c616 100644 --- a/tensorflow/core/lib/core/errors.h +++ b/tensorflow/core/lib/core/errors.h @@ -131,11 +131,23 @@ inline string FormatNodeNameForError(const string& name) { // LINT.ThenChange(//tensorflow/python/client/session.py) template <typename T> string FormatNodeNamesForError(const T& names) { - ::tensorflow::str_util::Formatter<string> f( - [](string* output, const string& s) { + return ::tensorflow::str_util::Join( + names, ", ", [](string* output, const string& s) { ::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s)); }); - return ::tensorflow::str_util::Join(names, ", ", f); +} +// LINT.IfChange +inline string FormatColocationNodeForError(const string& name) { + return strings::StrCat("{{colocation_node ", name, "}}"); +} +// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py) +template <typename T> +string FormatColocationNodeForError(const T& names) { + return ::tensorflow::str_util::Join( + names, ", ", [](string* output, const string& s) { + ::tensorflow::strings::StrAppend(output, + FormatColocationNodeForError(s)); + }); } // The CanonicalCode() for non-errors. diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h index 49f74ff47f..eb0ff555a5 100644 --- a/tensorflow/core/lib/core/status.h +++ b/tensorflow/core/lib/core/status.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h index e7b17c9b36..6edff139ae 100644 --- a/tensorflow/core/lib/core/stringpiece.h +++ b/tensorflow/core/lib/core/stringpiece.h @@ -26,13 +26,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ #define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_ -#include <assert.h> -#include <stddef.h> -#include <string.h> -#include <iosfwd> -#include <string> #include "absl/strings/string_view.h" -#include "tensorflow/core/platform/types.h" namespace tensorflow { diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h index c18dc9ad1a..2d622dc229 100644 --- a/tensorflow/core/lib/gtl/inlined_vector.h +++ b/tensorflow/core/lib/gtl/inlined_vector.h @@ -13,674 +13,19 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage -// for sequences of length <= N are provided inline without requiring -// any heap allocation. Typically N is very small (e.g., 4) so that -// sequences that are expected to be short do not require allocations. -// -// Only some of the std::vector<> operations are currently implemented. -// Other operations may be added as needed to facilitate migrating -// code that uses std::vector<> to InlinedVector<>. -// -// NOTE: If you want an inlined version to replace use of a -// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS> -// in util/bitmap/inlined_bitvector.h -// -// TODO(billydonahue): change size_t to size_type where appropriate. - #ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ #define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_ -#include <stddef.h> -#include <stdlib.h> -#include <string.h> -#include <sys/types.h> -#include <algorithm> -#include <cstddef> -#include <iterator> -#include <memory> -#include <type_traits> -#include <vector> - -#include "tensorflow/core/lib/gtl/manual_constructor.h" -#include "tensorflow/core/platform/byte_order.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/mem.h" +#include "absl/container/inlined_vector.h" +// TODO(kramerb): This is kept only because lots of targets transitively depend +// on it. Remove all targets' dependencies. +#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" -#include <initializer_list> // NOLINT(build/include_order) - namespace tensorflow { namespace gtl { -template <typename T, int N> -class InlinedVector { - public: - typedef T value_type; - typedef T* pointer; - typedef const T* const_pointer; - typedef T& reference; - typedef const T& const_reference; - typedef size_t size_type; - typedef std::ptrdiff_t difference_type; - typedef pointer iterator; - typedef const_pointer const_iterator; - - // Create an empty vector - InlinedVector(); - - // Create a vector with n copies of value_type(). - explicit InlinedVector(size_t n); - - // Create a vector with n copies of elem - InlinedVector(size_t n, const value_type& elem); - - // Create and initialize with the elements [range_start .. range_end). - // The unused enable_if argument restricts this constructor so that it is - // elided when value_type is an integral type. This prevents ambiguous - // interpretation between a call to this constructor with two integral - // arguments and a call to the preceding (n, elem) constructor. - template <typename InputIterator> - InlinedVector( - InputIterator range_start, InputIterator range_end, - typename std::enable_if<!std::is_integral<InputIterator>::value>::type* = - NULL) { - InitRep(); - AppendRange(range_start, range_end); - } - - InlinedVector(std::initializer_list<value_type> init) { - InitRep(); - AppendRange(init.begin(), init.end()); - } - - InlinedVector(const InlinedVector& v); - - ~InlinedVector() { clear(); } - - InlinedVector& operator=(const InlinedVector& v) { - // Optimized to avoid reallocation. - // Prefer reassignment to copy construction for elements. - const size_t s = size(); - const size_t vs = v.size(); - if (s < vs) { // grow - reserve(vs); - if (s) std::copy(v.begin(), v.begin() + s, begin()); - std::copy(v.begin() + s, v.end(), std::back_inserter(*this)); - } else { // maybe shrink - erase(begin() + vs, end()); - std::copy(v.begin(), v.end(), begin()); - } - return *this; - } - - size_t size() const { return size_internal(); } - - bool empty() const { return (size() == 0); } - - // Return number of elements that can be stored in vector - // without requiring a reallocation of underlying memory - size_t capacity() const { - if (is_inline()) { - return kFit; - } else { - return static_cast<size_t>(1) << u_.data[kSize - 2]; - } - } - - // Return a pointer to the underlying array. - // Only result[0,size()-1] are defined. - pointer data() { - if (is_inline()) { - return reinterpret_cast<T*>(u_.data); - } else { - return outofline_pointer(); - } - } - const_pointer data() const { - return const_cast<InlinedVector<T, N>*>(this)->data(); - } - - // Remove all elements - void clear() { - DiscardStorage(); - u_.data[kSize - 1] = 0; - } - - // Return the ith element - // REQUIRES: 0 <= i < size() - const value_type& at(size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - const value_type& operator[](size_t i) const { - DCHECK_LT(i, size()); - return data()[i]; - } - - // Return a non-const reference to the ith element - // REQUIRES: 0 <= i < size() - value_type& at(size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - value_type& operator[](size_t i) { - DCHECK_LT(i, size()); - return data()[i]; - } - - value_type& back() { - DCHECK(!empty()); - return at(size() - 1); - } - - const value_type& back() const { - DCHECK(!empty()); - return at(size() - 1); - } - - value_type& front() { - DCHECK(!empty()); - return at(0); - } - - const value_type& front() const { - DCHECK(!empty()); - return at(0); - } - - // Append a T constructed with args to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - template <typename... Args> - void emplace_back(Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - if (s < capacity()) { - new (data() + s) T(std::forward<Args>(args)...); - set_size_internal(s + 1); - } else { - EmplaceBackSlow(std::forward<Args>(args)...); - } - } - - // Append t to the vector. - // Increases size() by one. - // Amortized complexity: O(1) - // Worst-case complexity: O(size()) - void push_back(const value_type& t) { emplace_back(t); } - void push_back(value_type&& t) { emplace_back(std::move(t)); } - - inline void pop_back() { - DCHECK(!empty()); - const size_t s = size(); - Destroy(data() + s - 1, 1); - set_size_internal(s - 1); - } - - // Resizes the vector to contain "n" elements. - // If "n" is smaller than the initial size, extra elements are destroyed. - // If "n" is larger than the initial size, enough copies of "elem" - // are appended to increase the size to "n". If "elem" is omitted, - // new elements are value-initialized. - void resize(size_t n) { Resize<ValueInit>(n, nullptr); } - void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); } - - iterator begin() { return data(); } - const_iterator begin() const { return data(); } - - iterator end() { return data() + size(); } - const_iterator end() const { return data() + size(); } - - iterator insert(iterator pos, const value_type& v); - - iterator erase(iterator pos) { - DCHECK_LT(pos, end()); - DCHECK_GE(pos, begin()); - std::copy(pos + 1, end(), pos); - pop_back(); - return pos; - } - - iterator erase(iterator first, iterator last); - - // Enlarges the underlying representation so it can hold at least - // "n" elements without reallocation. - // Does not change size() or the actual contents of the vector. - void reserve(size_t n) { - if (n > capacity()) { - // Make room for new elements - Grow<Move>(n); - } - } - - // Swap the contents of *this with other. - // REQUIRES: value_type is swappable and copyable. - void swap(InlinedVector& other); - - private: - // Representation can either be inlined or out-of-line. - // In either case, at least sizeof(void*) + 8 bytes are available. - // - // Inlined: - // Last byte holds the length. - // First (length*sizeof(T)) bytes stores the elements. - // Outlined: - // Last byte holds kSentinel. - // Second-last byte holds lg(capacity) - // Preceding 6 bytes hold size. - // First sizeof(T*) bytes hold pointer. - - // Compute rep size. - static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag - static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align - - // See how many fit T we can fit inside kSize, but no more than 254 - // since 255 is used as sentinel tag for out-of-line allocation. - static const unsigned int kSentinel = 255; - static const size_t kFit1 = (kSize - 1) / sizeof(T); - static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1; - - union { - unsigned char data[kSize]; - // Force data to be aligned enough for a pointer. - T* unused_aligner; - } u_; - - inline void InitRep() { u_.data[kSize - 1] = 0; } - inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; } - - inline T* outofline_pointer() const { - T* ptr; - memcpy(&ptr, &u_.data[0], sizeof(ptr)); - return ptr; - } - - inline void set_outofline_pointer(T* p) { - memcpy(&u_.data[0], &p, sizeof(p)); - } - - inline uint64_t outofline_word() const { - uint64_t word; - memcpy(&word, &u_.data[kSize - 8], sizeof(word)); - return word; - } - - inline void set_outofline_word(uint64_t w) { - memcpy(&u_.data[kSize - 8], &w, sizeof(w)); - } - - inline size_t size_internal() const { - uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]); - if (s != kSentinel) { - return static_cast<size_t>(s); - } else { - const uint64_t word = outofline_word(); - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - return static_cast<size_t>(word & 0xffffffffffffull); - } else { - // The sentinel and capacity bits are least-significant bits in word. - return static_cast<size_t>(word >> 16); - } - } - } - - void set_size_internal(size_t n) { - if (is_inline()) { - DCHECK_LT(n, kSentinel); - u_.data[kSize - 1] = static_cast<unsigned char>(n); - } else { - uint64_t word; - if (port::kLittleEndian) { - // The sentinel and capacity bits are most-significant bits in word. - word = (static_cast<uint64_t>(n) | - (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) | - (static_cast<uint64_t>(kSentinel) << 56)); - } else { - // The sentinel and capacity bits are least-significant bits in word. - word = ((static_cast<uint64_t>(n) << 16) | - (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) | - (static_cast<uint64_t>(kSentinel))); - } - set_outofline_word(word); - DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n; - } - } - - void DiscardStorage() { - T* base = data(); - size_t n = size(); - Destroy(base, n); - if (!is_inline()) { - port::Free(base); - } - } - - template <typename... Args> - void EmplaceBackSlow(Args&&... args) { - const size_t s = size(); - DCHECK_EQ(s, capacity()); - Grow<Move, Construct>(s + 1, std::forward<Args>(args)...); - set_size_internal(s + 1); - } - - // Movers for Grow - // Does nothing. - static void Nop(T* src, size_t n, T* dst) {} - - // Moves srcs[0,n-1] contents to dst[0,n-1]. - static void Move(T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(std::move(*(src + i))); - } - } - - // Initializers for Resize. - // Initializes dst[0,n-1] with empty constructor. - static void ValueInit(const T*, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(); - } - } - - // Initializes dst[0,n-1] with copies of *src. - static void Fill(const T* src, size_t n, T* dst) { - for (size_t i = 0; i < n; i++) { - new (dst + i) T(*src); - } - } - - void Destroy(T* src, int n) { - if (!std::is_trivially_destructible<T>::value) { - for (int i = 0; i < n; i++) { - (src + i)->~T(); - } - } - } - - // Initialization methods for Grow. - // 1) Leave uninitialized memory. - struct Uninitialized { - void operator()(T*) const {} - }; - // 2) Construct a T with args at not-yet-initialized memory pointed by dst. - struct Construct { - template <class... Args> - void operator()(T* dst, Args&&... args) const { - new (dst) T(std::forward<Args>(args)...); - } - }; - - // Grow so that capacity >= n. Uses Mover to move existing elements - // to new buffer, and possibly initialize the new element according - // to InitType. - // We pass the InitType and Mover as template arguments so that - // this code compiles even if T does not support copying or default - // construction. - template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized, - class... Args> - void Grow(size_t n, Args&&... args) { - size_t s = size(); - DCHECK_LE(s, capacity()); - - // Compute new capacity by repeatedly doubling current capacity - size_t target = 1; - size_t target_lg = 0; - while (target < kFit || target < n) { - // TODO(psrc): Check and avoid overflow? - target_lg++; - target <<= 1; - } - - T* src = data(); - T* dst = static_cast<T*>(port::Malloc(target * sizeof(T))); - - // Need to copy elem before discarding src since it might alias src. - InitType{}(dst + s, std::forward<Args>(args)...); - Mover(src, s, dst); - DiscardStorage(); - - u_.data[kSize - 1] = kSentinel; - u_.data[kSize - 2] = static_cast<unsigned char>(target_lg); - set_size_internal(s); - DCHECK_EQ(capacity(), target); - set_outofline_pointer(dst); - } - - // Resize to size n. Any new elements are initialized by passing - // elem and the destination to Initializer. We pass the Initializer - // as a template argument so that this code compiles even if T does - // not support copying. - template <void(Initializer)(const T*, size_t, T*)> - void Resize(size_t n, const T* elem) { - size_t s = size(); - if (n <= s) { - Destroy(data() + n, s - n); - set_size_internal(n); - return; - } - reserve(n); - DCHECK_GE(capacity(), n); - set_size_internal(n); - Initializer(elem, n - s, data() + s); - } - - template <typename Iter> - void AppendRange(Iter first, Iter last, std::input_iterator_tag); - - // Faster path for forward iterators. - template <typename Iter> - void AppendRange(Iter first, Iter last, std::forward_iterator_tag); - - template <typename Iter> - void AppendRange(Iter first, Iter last); -}; - -// Provide linkage for constants. -template <typename T, int N> -const size_t InlinedVector<T, N>::kSizeUnaligned; -template <typename T, int N> -const size_t InlinedVector<T, N>::kSize; -template <typename T, int N> -const unsigned int InlinedVector<T, N>::kSentinel; -template <typename T, int N> -const size_t InlinedVector<T, N>::kFit1; -template <typename T, int N> -const size_t InlinedVector<T, N>::kFit; - -template <typename T, int N> -inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) { - a.swap(b); -} - -template <typename T, int N> -inline bool operator==(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin()); -} - -template <typename T, int N> -inline bool operator!=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(a == b); -} - -template <typename T, int N> -inline bool operator<(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end()); -} - -template <typename T, int N> -inline bool operator>(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return b < a; -} - -template <typename T, int N> -inline bool operator<=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(b < a); -} - -template <typename T, int N> -inline bool operator>=(const InlinedVector<T, N>& a, - const InlinedVector<T, N>& b) { - return !(a < b); -} - -// ======================================== -// Implementation - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector() { - InitRep(); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(size_t n) { - InitRep(); - if (n > capacity()) { - Grow<Nop>(n); // Must use Nop in case T is not copyable - } - set_size_internal(n); - ValueInit(nullptr, n, data()); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) { - InitRep(); - if (n > capacity()) { - Grow<Nop>(n); // Can use Nop since we know we have nothing to copy - } - set_size_internal(n); - Fill(&elem, n, data()); -} - -template <typename T, int N> -inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) { - InitRep(); - *this = v; -} - -template <typename T, int N> -typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert( - iterator pos, const value_type& v) { - DCHECK_GE(pos, begin()); - DCHECK_LE(pos, end()); - if (pos == end()) { - push_back(v); - return end() - 1; - } - size_t s = size(); - size_t idx = std::distance(begin(), pos); - if (s == capacity()) { - Grow<Move>(s + 1); - } - CHECK_LT(s, capacity()); - pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator. - Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1] - std::copy_backward(pos, data() + s - 1, data() + s); - *pos = v; - - set_size_internal(s + 1); - return pos; -} - -template <typename T, int N> -typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase( - iterator first, iterator last) { - DCHECK_LE(begin(), first); - DCHECK_LE(first, last); - DCHECK_LE(last, end()); - - size_t s = size(); - ptrdiff_t erase_gap = std::distance(first, last); - std::copy(last, data() + s, first); - Destroy(data() + s - erase_gap, erase_gap); - set_size_internal(s - erase_gap); - return first; -} - -template <typename T, int N> -void InlinedVector<T, N>::swap(InlinedVector& other) { - using std::swap; // Augment ADL with std::swap. - if (&other == this) { - return; - } - - InlinedVector* a = this; - InlinedVector* b = &other; - - const bool a_inline = a->is_inline(); - const bool b_inline = b->is_inline(); - - if (!a_inline && !b_inline) { - // Just swap the top-level representations. - T* aptr = a->outofline_pointer(); - T* bptr = b->outofline_pointer(); - a->set_outofline_pointer(bptr); - b->set_outofline_pointer(aptr); - - uint64_t aword = a->outofline_word(); - uint64_t bword = b->outofline_word(); - a->set_outofline_word(bword); - b->set_outofline_word(aword); - return; - } - - // Make a the larger of the two to reduce number of cases. - size_t a_size = a->size(); - size_t b_size = b->size(); - if (a->size() < b->size()) { - swap(a, b); - swap(a_size, b_size); - } - DCHECK_GE(a_size, b_size); - - if (b->capacity() < a_size) { - b->Grow<Move>(a_size); - } - - // One is inline and one is not. - // 'a' is larger. Swap the elements up to the smaller array size. - std::swap_ranges(a->data(), a->data() + b_size, b->data()); - std::uninitialized_copy(a->data() + b_size, a->data() + a_size, - b->data() + b_size); - Destroy(a->data() + b_size, a_size - b_size); - a->set_size_internal(b_size); - b->set_size_internal(a_size); - DCHECK_EQ(b->size(), a_size); - DCHECK_EQ(a->size(), b_size); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last, - std::input_iterator_tag) { - std::copy(first, last, std::back_inserter(*this)); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last, - std::forward_iterator_tag) { - typedef typename std::iterator_traits<Iter>::difference_type Length; - Length length = std::distance(first, last); - size_t s = size(); - reserve(s + length); - std::uninitialized_copy_n(first, length, data() + s); - set_size_internal(s + length); -} - -template <typename T, int N> -template <typename Iter> -inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) { - typedef typename std::iterator_traits<Iter>::iterator_category IterTag; - AppendRange(first, last, IterTag()); -} +using absl::InlinedVector; } // namespace gtl } // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc deleted file mode 100644 index 2721885c4a..0000000000 --- a/tensorflow/core/lib/gtl/inlined_vector_test.cc +++ /dev/null @@ -1,898 +0,0 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/gtl/inlined_vector.h" - -#include <list> -#include <memory> -#include <string> -#include <vector> - -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/macros.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/test_benchmark.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -typedef tensorflow::gtl::InlinedVector<int, 8> IntVec; - -// A type that counts number of live occurrences of the type -static int64 instances = 0; -class Instance { - public: - int value_; - explicit Instance(int x) : value_(x) { instances++; } - Instance(const Instance& x) : value_(x.value_) { instances++; } - ~Instance() { instances--; } - - friend inline void swap(Instance& a, Instance& b) { - using std::swap; - swap(a.value_, b.value_); - } - - friend std::ostream& operator<<(std::ostream& o, const Instance& v) { - return o << "[value:" << v.value_ << "]"; - } -}; - -typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec; - -// A simple reference counted class to make sure that the proper elements are -// destroyed in the erase(begin, end) test. -class RefCounted { - public: - RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); } - - RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) { - VLOG(5) << "[RefCounted: copy" - << " from count @" << v.count_ << "]"; - Ref(); - } - - ~RefCounted() { - Unref(); - count_ = nullptr; - } - - friend void swap(RefCounted& a, RefCounted& b) { - using std::swap; - swap(a.value_, b.value_); - swap(a.count_, b.count_); - } - - RefCounted& operator=(RefCounted v) { - using std::swap; - swap(*this, v); - return *this; - } - - void Ref() const { - CHECK(count_ != nullptr); - ++(*count_); - VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - void Unref() const { - --(*count_); - CHECK_GE(*count_, 0); - VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]"; - } - - int count() const { return *count_; } - - friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) { - return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]"; - } - - int value_; - int* count_; -}; - -typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec; - -// A class with a vtable pointer -class Dynamic { - public: - virtual ~Dynamic() {} - - friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) { - return o << "[Dynamic]"; - } -}; - -typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec; - -// Append 0..len-1 to *v -static void Fill(IntVec* v, int len, int offset = 0) { - for (int i = 0; i < len; i++) { - v->push_back(i + offset); - } -} - -static IntVec Fill(int len, int offset = 0) { - IntVec v; - Fill(&v, len, offset); - return v; -} - -TEST(IntVec, SimpleOps) { - for (int len = 0; len < 20; len++) { - IntVec v; - const IntVec& cv = v; // const alias - - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - EXPECT_EQ(v.begin(), v.data()); - EXPECT_EQ(cv.begin(), cv.data()); - - int counter = 0; - for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - counter = 0; - for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) { - EXPECT_EQ(counter, *iter); - counter++; - } - EXPECT_EQ(counter, len); - - if (len > 0) { - EXPECT_EQ(0, v.front()); - EXPECT_EQ(len - 1, v.back()); - v.pop_back(); - EXPECT_EQ(len - 1, v.size()); - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, Erase) { - for (int len = 1; len < 20; len++) { - for (int i = 0; i < len; ++i) { - IntVec v; - Fill(&v, len); - v.erase(v.begin() + i); - EXPECT_EQ(len - 1, v.size()); - for (int j = 0; j < i; ++j) { - EXPECT_EQ(j, v[j]); - } - for (int j = i; j < len - 1; ++j) { - EXPECT_EQ(j + 1, v[j]); - } - } - } -} - -// At the end of this test loop, the elements between [erase_begin, erase_end) -// should have reference counts == 0, and all others elements should have -// reference counts == 1. -TEST(RefCountedVec, EraseBeginEnd) { - for (int len = 1; len < 20; ++len) { - for (int erase_begin = 0; erase_begin < len; ++erase_begin) { - for (int erase_end = erase_begin; erase_end <= len; ++erase_end) { - std::vector<int> counts(len, 0); - RefCountedVec v; - for (int i = 0; i < len; ++i) { - v.push_back(RefCounted(i, &counts[i])); - } - - int erase_len = erase_end - erase_begin; - - v.erase(v.begin() + erase_begin, v.begin() + erase_end); - - EXPECT_EQ(len - erase_len, v.size()); - - // Check the elements before the first element erased. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(i, v[i].value_); - } - - // Check the elements after the first element erased. - for (size_t i = erase_begin; i < v.size(); ++i) { - EXPECT_EQ(i + erase_len, v[i].value_); - } - - // Check that the elements at the beginning are preserved. - for (int i = 0; i < erase_begin; ++i) { - EXPECT_EQ(1, counts[i]); - } - - // Check that the erased elements are destroyed - for (int i = erase_begin; i < erase_end; ++i) { - EXPECT_EQ(0, counts[i]); - } - - // Check that the elements at the end are preserved. - for (int i = erase_end; i < len; ++i) { - EXPECT_EQ(1, counts[i]); - } - } - } - } -} - -struct NoDefaultCtor { - explicit NoDefaultCtor(int) {} -}; -struct NoCopy { - NoCopy() {} - NoCopy(const NoCopy&) = delete; -}; -struct NoAssign { - NoAssign() {} - NoAssign& operator=(const NoAssign&) = delete; -}; -struct MoveOnly { - MoveOnly() {} - MoveOnly(MoveOnly&&) = default; - MoveOnly& operator=(MoveOnly&&) = default; -}; -TEST(InlinedVectorTest, NoDefaultCtor) { - tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2)); - (void)v; -} -TEST(InlinedVectorTest, NoCopy) { - tensorflow::gtl::InlinedVector<NoCopy, 1> v(10); - (void)v; -} -TEST(InlinedVectorTest, NoAssign) { - tensorflow::gtl::InlinedVector<NoAssign, 1> v(10); - (void)v; -} -TEST(InlinedVectorTest, MoveOnly) { - gtl::InlinedVector<MoveOnly, 2> v; - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); - v.push_back(MoveOnly{}); -} - -TEST(IntVec, Insert) { - for (int len = 0; len < 20; len++) { - for (int pos = 0; pos <= len; pos++) { - IntVec v; - Fill(&v, len); - v.insert(v.begin() + pos, 9999); - EXPECT_EQ(v.size(), len + 1); - for (int i = 0; i < pos; i++) { - EXPECT_EQ(v[i], i); - } - EXPECT_EQ(v[pos], 9999); - for (size_t i = pos + 1; i < v.size(); i++) { - EXPECT_EQ(v[i], i - 1); - } - } - } -} - -TEST(RefCountedVec, InsertConstructorDestructor) { - // Make sure the proper construction/destruction happen during insert - // operations. - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - for (int pos = 0; pos <= len; pos++) { - SCOPED_TRACE(pos); - std::vector<int> counts(len, 0); - int inserted_count = 0; - RefCountedVec v; - for (int i = 0; i < len; ++i) { - SCOPED_TRACE(i); - v.push_back(RefCounted(i, &counts[i])); - } - - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - - RefCounted insert_element(9999, &inserted_count); - EXPECT_EQ(1, inserted_count); - v.insert(v.begin() + pos, insert_element); - EXPECT_EQ(2, inserted_count); - // Check that the elements at the end are preserved. - for (auto elem : counts) { - EXPECT_EQ(1, elem); - } - EXPECT_EQ(2, inserted_count); - } - } -} - -TEST(IntVec, Resize) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - // Try resizing up and down by k elements - static const int kResizeElem = 1000000; - for (int k = 0; k < 10; k++) { - // Enlarging resize - v.resize(len + k, kResizeElem); - EXPECT_EQ(len + k, v.size()); - EXPECT_LE(len + k, v.capacity()); - for (int i = 0; i < len + k; i++) { - if (i < len) { - EXPECT_EQ(i, v[i]); - } else { - EXPECT_EQ(kResizeElem, v[i]); - } - } - - // Shrinking resize - v.resize(len, kResizeElem); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(i, v[i]); - } - } - } -} - -TEST(IntVec, InitWithLength) { - for (int len = 0; len < 20; len++) { - IntVec v(len, 7); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - for (int i = 0; i < len; i++) { - EXPECT_EQ(7, v[i]); - } - } -} - -TEST(IntVec, CopyConstructorAndAssignment) { - for (int len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - EXPECT_EQ(len, v.size()); - EXPECT_LE(len, v.capacity()); - - IntVec v2(v); - EXPECT_EQ(v, v2); - - for (int start_len = 0; start_len < 20; start_len++) { - IntVec v3; - Fill(&v3, start_len, 99); // Add dummy elements that should go away - v3 = v; - EXPECT_EQ(v, v3); - } - } -} - -TEST(OverheadTest, Storage) { - // Check for size overhead. - using tensorflow::gtl::InlinedVector; - EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>)); - EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>)); - EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>)); - - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>)); - EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>)); - EXPECT_EQ(2 * sizeof(char*), - sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>)); - EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>)); -} - -TEST(IntVec, Clear) { - for (int len = 0; len < 20; len++) { - SCOPED_TRACE(len); - IntVec v; - Fill(&v, len); - v.clear(); - EXPECT_EQ(0, v.size()); - EXPECT_EQ(v.begin(), v.end()); - } -} - -TEST(IntVec, Reserve) { - for (size_t len = 0; len < 20; len++) { - IntVec v; - Fill(&v, len); - - for (size_t newlen = 0; newlen < 100; newlen++) { - const int* start_rep = v.data(); - v.reserve(newlen); - const int* final_rep = v.data(); - if (newlen <= len) { - EXPECT_EQ(start_rep, final_rep); - } - EXPECT_LE(newlen, v.capacity()); - - // Filling up to newlen should not change rep - while (v.size() < newlen) { - v.push_back(0); - } - EXPECT_EQ(final_rep, v.data()); - } - } -} - -template <typename T> -static std::vector<typename T::value_type> Vec(const T& src) { - std::vector<typename T::value_type> result; - for (const auto& elem : src) { - result.push_back(elem); - } - return result; -} - -TEST(IntVec, SelfRefPushBack) { - std::vector<string> std_v; - tensorflow::gtl::InlinedVector<string, 4> v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(std_v, Vec(v)); - - v.push_back(v.back()); - std_v.push_back(std_v.back()); - } - EXPECT_EQ(std_v, Vec(v)); -} - -TEST(IntVec, SelfRefPushBackWithMove) { - std::vector<string> std_v; - gtl::InlinedVector<string, 4> v; - const string s = "A quite long string to ensure heap."; - std_v.push_back(s); - v.push_back(s); - for (int i = 0; i < 20; ++i) { - EXPECT_EQ(v.back(), std_v.back()); - - v.push_back(std::move(v.back())); - std_v.push_back(std::move(std_v.back())); - } - EXPECT_EQ(v.back(), std_v.back()); -} - -TEST(IntVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - SCOPED_TRACE(l1); - for (int l2 = 0; l2 < 20; l2++) { - SCOPED_TRACE(l2); - IntVec a = Fill(l1, 0); - IntVec b = Fill(l2, 100); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(i, b[i]); - } - for (int i = 0; i < l2; i++) { - SCOPED_TRACE(i); - EXPECT_EQ(100 + i, a[i]); - } - } - } -} - -TEST(InstanceVec, Swap) { - for (int l1 = 0; l1 < 20; l1++) { - for (int l2 = 0; l2 < 20; l2++) { - InstanceVec a, b; - for (int i = 0; i < l1; i++) a.push_back(Instance(i)); - for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i)); - EXPECT_EQ(l1 + l2, instances); - { - using std::swap; - swap(a, b); - } - EXPECT_EQ(l1 + l2, instances); - EXPECT_EQ(l1, b.size()); - EXPECT_EQ(l2, a.size()); - for (int i = 0; i < l1; i++) { - EXPECT_EQ(i, b[i].value_); - } - for (int i = 0; i < l2; i++) { - EXPECT_EQ(100 + i, a[i].value_); - } - } - } -} - -TEST(IntVec, EqualAndNotEqual) { - IntVec a, b; - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - a.push_back(3); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b.push_back(3); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b.push_back(7); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.push_back(6); - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - a.clear(); - b.clear(); - for (int i = 0; i < 100; i++) { - a.push_back(i); - b.push_back(i); - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - - b[i] = b[i] + 1; - EXPECT_FALSE(a == b); - EXPECT_TRUE(a != b); - - b[i] = b[i] - 1; // Back to before - EXPECT_TRUE(a == b); - EXPECT_FALSE(a != b); - } -} - -TEST(IntVec, RelationalOps) { - IntVec a, b; - EXPECT_FALSE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_FALSE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_TRUE(b <= a); - EXPECT_TRUE(a >= b); - EXPECT_TRUE(b >= a); - b.push_back(3); - EXPECT_TRUE(a < b); - EXPECT_FALSE(b < a); - EXPECT_FALSE(a > b); - EXPECT_TRUE(b > a); - EXPECT_TRUE(a <= b); - EXPECT_FALSE(b <= a); - EXPECT_FALSE(a >= b); - EXPECT_TRUE(b >= a); -} - -TEST(InstanceVec, CountConstructorsDestructors) { - const int start = instances; - for (int len = 0; len < 20; len++) { - InstanceVec v; - for (int i = 0; i < len; i++) { - v.push_back(Instance(i)); - } - EXPECT_EQ(start + len, instances); - - { // Copy constructor should create 'len' more instances. - InstanceVec v_copy(v); - EXPECT_EQ(start + len + len, instances); - } - EXPECT_EQ(start + len, instances); - - // Enlarging resize() must construct some objects - v.resize(len + 10, Instance(100)); - EXPECT_EQ(start + len + 10, instances); - - // Shrinking resize() must destroy some objects - v.resize(len, Instance(100)); - EXPECT_EQ(start + len, instances); - - // reserve() must not increase the number of initialized objects - v.reserve(len + 1000); - EXPECT_EQ(start + len, instances); - - // pop_back() and erase() must destroy one object - if (len > 0) { - v.pop_back(); - EXPECT_EQ(start + len - 1, instances); - if (!v.empty()) { - v.erase(v.begin()); - EXPECT_EQ(start + len - 2, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) { - const int start = instances; - for (int len = 0; len < 20; len++) { - for (int longorshort = 0; longorshort <= 1; ++longorshort) { - InstanceVec longer, shorter; - for (int i = 0; i < len; i++) { - longer.push_back(Instance(i)); - shorter.push_back(Instance(i)); - } - longer.push_back(Instance(len)); - EXPECT_EQ(start + len + len + 1, instances); - - if (longorshort) { - shorter = longer; - EXPECT_EQ(start + (len + 1) + (len + 1), instances); - } else { - longer = shorter; - EXPECT_EQ(start + len + len, instances); - } - } - } - EXPECT_EQ(start, instances); -} - -TEST(RangedConstructor, SimpleType) { - std::vector<int> source_v = {4, 5, 6, 7}; - // First try to fit in inline backing - tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end()); - tensorflow::gtl::InlinedVector<int, 4> empty4; - EXPECT_EQ(4, v.size()); - EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(4, v[0]); - EXPECT_EQ(5, v[1]); - EXPECT_EQ(6, v[2]); - EXPECT_EQ(7, v[3]); - - // Now, force a re-allocate - tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector<int, 2> empty2; - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty2.capacity(), realloc_v.capacity()); - EXPECT_EQ(4, realloc_v[0]); - EXPECT_EQ(5, realloc_v[1]); - EXPECT_EQ(6, realloc_v[2]); - EXPECT_EQ(7, realloc_v[3]); -} - -TEST(RangedConstructor, ComplexType) { - // We also use a list here to pass a different flavor of iterator (e.g. not - // random-access). - std::list<Instance> source_v = {Instance(0)}; - - // First try to fit in inline backing - tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(), - source_v.end()); - tensorflow::gtl::InlinedVector<Instance, 1> empty1; - EXPECT_EQ(1, v.size()); - EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline - EXPECT_EQ(0, v[0].value_); - - std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2), - Instance(3)}; - // Now, force a re-allocate - tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(), - source_v2.end()); - EXPECT_EQ(4, realloc_v.size()); - EXPECT_LT(empty1.capacity(), realloc_v.capacity()); - EXPECT_EQ(0, realloc_v[0].value_); - EXPECT_EQ(1, realloc_v[1].value_); - EXPECT_EQ(2, realloc_v[2].value_); - EXPECT_EQ(3, realloc_v[3].value_); -} - -TEST(RangedConstructor, ElementsAreConstructed) { - std::vector<string> source_v = {"cat", "dog"}; - - // Force expansion and re-allocation of v. Ensures that when the vector is - // expanded that new elements are constructed. - tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end()); - EXPECT_EQ("cat", v[0]); - EXPECT_EQ("dog", v[1]); -} - -TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) { - auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_EQ(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) { - auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6}; - EXPECT_EQ(3, vec.size()); - EXPECT_LE(3, vec.capacity()); - EXPECT_EQ(4, vec[0]); - EXPECT_EQ(5, vec[1]); - EXPECT_EQ(6, vec[2]); -} - -TEST(InitializerListConstructor, DisparateTypesInList) { - EXPECT_EQ((std::vector<int>{-7, 8}), - Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL})); - - EXPECT_EQ( - (std::vector<string>{"foo", "bar"}), - Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")})); -} - -TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) { - tensorflow::gtl::InlinedVector<Instance, 1> empty; - auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)}; - EXPECT_EQ(1, vec.size()); - EXPECT_EQ(empty.capacity(), vec.capacity()); - EXPECT_EQ(0, vec[0].value_); -} - -TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) { - auto vec = - tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)}; - EXPECT_EQ(2, vec.size()); - EXPECT_LE(2, vec.capacity()); - EXPECT_EQ(0, vec[0].value_); - EXPECT_EQ(1, vec[1].value_); -} - -TEST(DynamicVec, DynamicVecCompiles) { - DynamicVec v; - (void)v; -} - -static void BM_InlinedVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - IntVec v; - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024); - -static void BM_InlinedVectorFillRange(int iters, int len) { - std::unique_ptr<int[]> ia(new int[len]); - for (int j = 0; j < len; j++) { - ia[j] = j; - } - for (int i = 0; i < iters; i++) { - IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len); - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024); - -static void BM_StdVectorFill(int iters, int len) { - for (int i = 0; i < iters; i++) { - std::vector<int> v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(j); - } - } - testing::BytesProcessed((int64{iters} * len) * sizeof(int)); -} -BENCHMARK(BM_StdVectorFill)->Range(0, 1024); - -bool StringRepresentedInline(string s) { - const char* chars = s.data(); - string s1 = std::move(s); - return s1.data() != chars; -} - -static void BM_InlinedVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - gtl::InlinedVector<string, 8> v; - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); -} -BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024); - -static void BM_StdVectorFillString(int iters, int len) { - string strings[4] = {"a quite long string", "another long string", - "012345678901234567", "to cause allocation"}; - for (int i = 0; i < iters; i++) { - std::vector<string> v; - v.reserve(len); - for (int j = 0; j < len; j++) { - v.push_back(strings[j & 3]); - } - } - testing::ItemsProcessed(int64{iters} * len); - // The purpose of the benchmark is to verify that inlined vector is - // efficient when moving is more efficient than copying. To do so, we - // use strings that are larger than the small string optimization. - CHECK(!StringRepresentedInline(strings[0])); -} -BENCHMARK(BM_StdVectorFillString)->Range(0, 1024); - -namespace { -struct Buffer { // some arbitrary structure for benchmarking. - char* base; - int length; - int capacity; - void* user_data; -}; -} // anonymous namespace - -static void BM_InlinedVectorTenAssignments(int iters, int len) { - typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec; - - BufferVec src; - src.resize(len); - - iters *= 10; - BufferVec dst; - for (int i = 0; i < iters; i++) { - dst = src; - } -} -BENCHMARK(BM_InlinedVectorTenAssignments) - ->Arg(0) - ->Arg(1) - ->Arg(2) - ->Arg(3) - ->Arg(4) - ->Arg(20); - -static void BM_CreateFromInitializerList(int iters) { - for (; iters > 0; iters--) { - tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3}; - (void)x[0]; - } -} -BENCHMARK(BM_CreateFromInitializerList); - -namespace { - -struct LargeSwappable { - LargeSwappable() : d_(1024, 17) {} - ~LargeSwappable() {} - LargeSwappable(const LargeSwappable& o) : d_(o.d_) {} - - friend void swap(LargeSwappable& a, LargeSwappable& b) { - using std::swap; - swap(a.d_, b.d_); - } - - LargeSwappable& operator=(LargeSwappable o) { - using std::swap; - swap(*this, o); - return *this; - } - - std::vector<int> d_; -}; - -} // namespace - -static void BM_LargeSwappableElements(int iters, int len) { - typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec; - Vec a(len); - Vec b; - while (--iters >= 0) { - using std::swap; - swap(a, b); - } -} -BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024); - -} // namespace tensorflow diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h index 7ad916ad3d..238aa18e1e 100644 --- a/tensorflow/core/lib/gtl/optional.h +++ b/tensorflow/core/lib/gtl/optional.h @@ -16,861 +16,18 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_ #define TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_ -#include <assert.h> -#include <functional> -#include <initializer_list> -#include <type_traits> -#include <utility> - -#include "tensorflow/core/platform/logging.h" +#include "absl/types/optional.h" namespace tensorflow { namespace gtl { -// A value of type gtl::optional<T> holds either a value of T or an -// "empty" value. When it holds a value of T, it stores it as a direct -// subobject, so sizeof(optional<T>) is approximately sizeof(T)+1. The interface -// is based on the upcoming std::optional<T>, and gtl::optional<T> is -// designed to be cheaply drop-in replaceable by std::optional<T>, once it is -// rolled out. -// -// This implementation is based on the specification in the latest draft as of -// 2017-01-05, section 20.6. -// -// Differences between gtl::optional<T> and std::optional<T> include: -// - constexpr not used for nonconst member functions. -// (dependency on some differences between C++11 and C++14.) -// - nullopt and in_place are not constexpr. We need the inline variable -// support in C++17 for external linkage. -// - CHECK instead of throwing std::bad_optional_access. -// - optional::swap() and swap() relies on std::is_(nothrow_)swappable -// which is introduced in C++17. So we assume is_swappable is always true -// and is_nothrow_swappable is same as std::is_trivial. -// - make_optional cannot be constexpr due to absence of guaranteed copy -// elision. -// -// Synopsis: -// -// #include "tensorflow/core/lib/gtl/optional.h" -// -// tensorflow::gtl::optional<string> f() { -// string result; -// if (...) { -// ... -// result = ...; -// return result; -// } else { -// ... -// return tensorflow::gtl::nullopt; -// } -// } -// -// int main() { -// tensorflow::gtl::optional<string> optstr = f(); -// if (optstr) { -// // non-empty -// print(optstr.value()); -// } else { -// // empty -// error(); -// } -// } -template <typename T> -class optional; - -// The tag constant `in_place` is used as the first parameter of an optional<T> -// constructor to indicate that the remaining arguments should be forwarded -// to the underlying T constructor. -struct in_place_t {}; -extern const in_place_t in_place; - -// The tag constant `nullopt` is used to indicate an empty optional<T> in -// certain functions, such as construction or assignment. -struct nullopt_t { - struct init_t {}; - static init_t init; - // It must not be default-constructible to avoid ambiguity for opt = {}. - // Note the non-const reference, it is to eliminate ambiguity for code like: - // struct S { int value; }; - // - // void Test() { - // optional<S> opt; - // opt = {{}}; - // } - explicit constexpr nullopt_t(init_t& /*unused*/) {} // NOLINT -}; -extern const nullopt_t nullopt; - -namespace internal_optional { - -// define forward locally because std::forward is not constexpr until C++14 -template <typename T> -constexpr T&& forward(typename std::remove_reference<T>::type& - t) noexcept { // NOLINT(runtime/references) - return static_cast<T&&>(t); -} - -struct empty_struct {}; -// This class stores the data in optional<T>. -// It is specialized based on whether T is trivially destructible. -// This is the specialization for non trivially destructible type. -template <typename T, bool = std::is_trivially_destructible<T>::value> -class optional_data_dtor_base { - protected: - // Whether there is data or not. - bool engaged_; - // data storage - union { - empty_struct dummy_; - T data_; - }; - - void destruct() noexcept { - if (engaged_) { - data_.~T(); - engaged_ = false; - } - } - - // dummy_ must be initialized for constexpr constructor - constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {} - - template <typename... Args> - constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args) - : engaged_(true), data_(internal_optional::forward<Args>(args)...) {} - - ~optional_data_dtor_base() { destruct(); } -}; - -// Specialization for trivially destructible type. -template <typename T> -class optional_data_dtor_base<T, true> { - protected: - // Whether there is data or not. - bool engaged_; - // data storage - union { - empty_struct dummy_; - T data_; - }; - void destruct() noexcept { engaged_ = false; } - - // dummy_ must be initialized for constexpr constructor - constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {} - - template <typename... Args> - constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args) - : engaged_(true), data_(internal_optional::forward<Args>(args)...) {} - - ~optional_data_dtor_base() = default; -}; - -template <typename T> -class optional_data : public optional_data_dtor_base<T> { - protected: - using base = optional_data_dtor_base<T>; - using base::base; - - T* pointer() { return &this->data_; } - - constexpr const T* pointer() const { return &this->data_; } - - template <typename... Args> - void construct(Args&&... args) { - new (pointer()) T(std::forward<Args>(args)...); - this->engaged_ = true; - } - - template <typename U> - void assign(U&& u) { - if (this->engaged_) { - this->data_ = std::forward<U>(u); - } else { - construct(std::forward<U>(u)); - } - } - - optional_data() = default; - - optional_data(const optional_data& rhs) { - if (rhs.engaged_) { - construct(rhs.data_); - } - } - - optional_data(optional_data&& rhs) noexcept( - std::is_nothrow_move_constructible<T>::value) { - if (rhs.engaged_) { - construct(std::move(rhs.data_)); - } - } - - optional_data& operator=(const optional_data& rhs) { - if (rhs.engaged_) { - assign(rhs.data_); - } else { - this->destruct(); - } - return *this; - } - - optional_data& operator=(optional_data&& rhs) noexcept( - std::is_nothrow_move_assignable<T>::value&& - std::is_nothrow_move_constructible<T>::value) { - if (rhs.engaged_) { - assign(std::move(rhs.data_)); - } else { - this->destruct(); - } - return *this; - } -}; - -// ordered by level of restriction, from low to high. -// copyable implies movable. -enum class copy_traits { copyable = 0, movable = 1, non_movable = 2 }; - -// base class for enabling/disabling copy/move constructor. -template <copy_traits> -class optional_ctor_base; - -template <> -class optional_ctor_base<copy_traits::copyable> { - public: - constexpr optional_ctor_base() = default; - optional_ctor_base(const optional_ctor_base&) = default; - optional_ctor_base(optional_ctor_base&&) = default; - optional_ctor_base& operator=(const optional_ctor_base&) = default; - optional_ctor_base& operator=(optional_ctor_base&&) = default; -}; - -template <> -class optional_ctor_base<copy_traits::movable> { - public: - constexpr optional_ctor_base() = default; - optional_ctor_base(const optional_ctor_base&) = delete; - optional_ctor_base(optional_ctor_base&&) = default; - optional_ctor_base& operator=(const optional_ctor_base&) = default; - optional_ctor_base& operator=(optional_ctor_base&&) = default; -}; - -template <> -class optional_ctor_base<copy_traits::non_movable> { - public: - constexpr optional_ctor_base() = default; - optional_ctor_base(const optional_ctor_base&) = delete; - optional_ctor_base(optional_ctor_base&&) = delete; - optional_ctor_base& operator=(const optional_ctor_base&) = default; - optional_ctor_base& operator=(optional_ctor_base&&) = default; -}; - -// base class for enabling/disabling copy/move assignment. -template <copy_traits> -class optional_assign_base; - -template <> -class optional_assign_base<copy_traits::copyable> { - public: - constexpr optional_assign_base() = default; - optional_assign_base(const optional_assign_base&) = default; - optional_assign_base(optional_assign_base&&) = default; - optional_assign_base& operator=(const optional_assign_base&) = default; - optional_assign_base& operator=(optional_assign_base&&) = default; -}; - -template <> -class optional_assign_base<copy_traits::movable> { - public: - constexpr optional_assign_base() = default; - optional_assign_base(const optional_assign_base&) = default; - optional_assign_base(optional_assign_base&&) = default; - optional_assign_base& operator=(const optional_assign_base&) = delete; - optional_assign_base& operator=(optional_assign_base&&) = default; -}; - -template <> -class optional_assign_base<copy_traits::non_movable> { - public: - constexpr optional_assign_base() = default; - optional_assign_base(const optional_assign_base&) = default; - optional_assign_base(optional_assign_base&&) = default; - optional_assign_base& operator=(const optional_assign_base&) = delete; - optional_assign_base& operator=(optional_assign_base&&) = delete; -}; - +// Deprecated: please use absl::optional directly. +using absl::make_optional; +using absl::nullopt; template <typename T> -constexpr copy_traits get_ctor_copy_traits() { - return std::is_copy_constructible<T>::value - ? copy_traits::copyable - : std::is_move_constructible<T>::value ? copy_traits::movable - : copy_traits::non_movable; -} - -template <typename T> -constexpr copy_traits get_assign_copy_traits() { - return std::is_copy_assignable<T>::value && - std::is_copy_constructible<T>::value - ? copy_traits::copyable - : std::is_move_assignable<T>::value && - std::is_move_constructible<T>::value - ? copy_traits::movable - : copy_traits::non_movable; -} - -// Whether T is constructible or convertible from optional<U>. -template <typename T, typename U> -struct is_constructible_convertible_from_optional - : std::integral_constant< - bool, std::is_constructible<T, optional<U>&>::value || - std::is_constructible<T, optional<U>&&>::value || - std::is_constructible<T, const optional<U>&>::value || - std::is_constructible<T, const optional<U>&&>::value || - std::is_convertible<optional<U>&, T>::value || - std::is_convertible<optional<U>&&, T>::value || - std::is_convertible<const optional<U>&, T>::value || - std::is_convertible<const optional<U>&&, T>::value> {}; - -// Whether T is constructible or convertible or assignable from optional<U>. -template <typename T, typename U> -struct is_constructible_convertible_assignable_from_optional - : std::integral_constant< - bool, is_constructible_convertible_from_optional<T, U>::value || - std::is_assignable<T&, optional<U>&>::value || - std::is_assignable<T&, optional<U>&&>::value || - std::is_assignable<T&, const optional<U>&>::value || - std::is_assignable<T&, const optional<U>&&>::value> {}; - -} // namespace internal_optional - -template <typename T> -class optional : private internal_optional::optional_data<T>, - private internal_optional::optional_ctor_base< - internal_optional::get_ctor_copy_traits<T>()>, - private internal_optional::optional_assign_base< - internal_optional::get_assign_copy_traits<T>()> { - using data_base = internal_optional::optional_data<T>; - - public: - typedef T value_type; - - // [optional.ctor], constructors - - // A default constructed optional holds the empty value, NOT a default - // constructed T. - constexpr optional() noexcept {} - - // An optional initialized with `nullopt` holds the empty value. - constexpr optional(nullopt_t) noexcept {} // NOLINT(runtime/explicit) - - // Copy constructor, standard semantics. - optional(const optional& src) = default; - - // Move constructor, standard semantics. - optional(optional&& src) = default; - - // optional<T>(in_place, arg1, arg2, arg3) constructs a non-empty optional - // with an in-place constructed value of T(arg1,arg2,arg3). - // TODO(b/34201852): Add std::is_constructible<T, Args&&...> SFINAE. - template <typename... Args> - constexpr explicit optional(in_place_t, Args&&... args) - : data_base(in_place_t(), internal_optional::forward<Args>(args)...) {} - - // optional<T>(in_place, {arg1, arg2, arg3}) constructs a non-empty optional - // with an in-place list-initialized value of T({arg1, arg2, arg3}). - template <typename U, typename... Args, - typename = typename std::enable_if<std::is_constructible< - T, std::initializer_list<U>&, Args&&...>::value>::type> - constexpr explicit optional(in_place_t, std::initializer_list<U> il, - Args&&... args) - : data_base(in_place_t(), il, internal_optional::forward<Args>(args)...) { - } - - template < - typename U = T, - typename std::enable_if< - std::is_constructible<T, U&&>::value && - !std::is_same<in_place_t, typename std::decay<U>::type>::value && - !std::is_same<optional<T>, typename std::decay<U>::type>::value && - std::is_convertible<U&&, T>::value, - bool>::type = false> - constexpr optional(U&& v) // NOLINT - : data_base(in_place_t(), internal_optional::forward<U>(v)) {} - - template < - typename U = T, - typename std::enable_if< - std::is_constructible<T, U&&>::value && - !std::is_same<in_place_t, typename std::decay<U>::type>::value && - !std::is_same<optional<T>, typename std::decay<U>::type>::value && - !std::is_convertible<U&&, T>::value, - bool>::type = false> - explicit constexpr optional(U&& v) - : data_base(in_place_t(), internal_optional::forward<U>(v)) {} - - // Converting copy constructor (implicit) - template < - typename U, - typename std::enable_if< - std::is_constructible<T, const U&>::value && - !internal_optional::is_constructible_convertible_from_optional< - T, U>::value && - std::is_convertible<const U&, T>::value, - bool>::type = false> - optional(const optional<U>& rhs) { // NOLINT - if (rhs) { - this->construct(*rhs); - } - } - - // Converting copy constructor (explicit) - template < - typename U, - typename std::enable_if< - std::is_constructible<T, const U&>::value && - !internal_optional::is_constructible_convertible_from_optional< - T, U>::value && - !std::is_convertible<const U&, T>::value, - bool>::type = false> - explicit optional(const optional<U>& rhs) { - if (rhs) { - this->construct(*rhs); - } - } - - // Converting move constructor (implicit) - template < - typename U, - typename std::enable_if< - std::is_constructible<T, U&&>::value && - !internal_optional::is_constructible_convertible_from_optional< - T, U>::value && - std::is_convertible<U&&, T>::value, - bool>::type = false> - optional(optional<U>&& rhs) { // NOLINT - if (rhs) { - this->construct(std::move(*rhs)); - } - } - - // Converting move constructor (explicit) - template < - typename U, - typename std::enable_if< - std::is_constructible<T, U&&>::value && - !internal_optional::is_constructible_convertible_from_optional< - T, U>::value && - !std::is_convertible<U&&, T>::value, - bool>::type = false> - explicit optional(optional<U>&& rhs) { - if (rhs) { - this->construct(std::move(*rhs)); - } - } - - // [optional.dtor], destructor, trivial if T is trivially destructible. - ~optional() = default; - - // [optional.assign], assignment - - // Assignment from nullopt: opt = nullopt - optional& operator=(nullopt_t) noexcept { - this->destruct(); - return *this; - } - - // Copy assignment, standard semantics. - optional& operator=(const optional& src) = default; - - // Move assignment, standard semantics. - optional& operator=(optional&& src) = default; - - // Value assignment - template < - typename U = T, - typename = typename std::enable_if< - !std::is_same<optional<T>, typename std::decay<U>::type>::value && - (!std::is_scalar<T>::value || - !std::is_same<T, typename std::decay<U>::type>::value) && - std::is_constructible<T, U>::value && - std::is_assignable<T&, U>::value>::type> - optional& operator=(U&& v) { - this->assign(std::forward<U>(v)); - return *this; - } - - template <typename U, - typename = typename std::enable_if< - std::is_constructible<T, const U&>::value && - std::is_assignable<T&, const U&>::value && - !internal_optional:: - is_constructible_convertible_assignable_from_optional< - T, U>::value>::type> - optional& operator=(const optional<U>& rhs) { - if (rhs) { - this->assign(*rhs); - } else { - this->destruct(); - } - return *this; - } - - template <typename U, - typename = typename std::enable_if< - std::is_constructible<T, U>::value && - std::is_assignable<T&, U>::value && - !internal_optional:: - is_constructible_convertible_assignable_from_optional< - T, U>::value>::type> - optional& operator=(optional<U>&& rhs) { - if (rhs) { - this->assign(std::move(*rhs)); - } else { - this->destruct(); - } - return *this; - } - - // [optional.mod], modifiers - // Destroys the inner T value if one is present. - void reset() noexcept { this->destruct(); } - - // Emplace reconstruction. (Re)constructs the underlying T in-place with the - // given arguments forwarded: - // - // optional<Foo> opt; - // opt.emplace(arg1,arg2,arg3); (Constructs Foo(arg1,arg2,arg3)) - // - // If the optional is non-empty, and the `args` refer to subobjects of the - // current object, then behavior is undefined. This is because the current - // object will be destructed before the new object is constructed with `args`. - // - template <typename... Args, - typename = typename std::enable_if< - std::is_constructible<T, Args&&...>::value>::type> - void emplace(Args&&... args) { - this->destruct(); - this->construct(std::forward<Args>(args)...); - } - - // Emplace reconstruction with initializer-list. See immediately above. - template <class U, class... Args, - typename = typename std::enable_if<std::is_constructible< - T, std::initializer_list<U>&, Args&&...>::value>::type> - void emplace(std::initializer_list<U> il, Args&&... args) { - this->destruct(); - this->construct(il, std::forward<Args>(args)...); - } - - // [optional.swap], swap - // Swap, standard semantics. - void swap(optional& rhs) noexcept( - std::is_nothrow_move_constructible<T>::value&& - std::is_trivial<T>::value) { - if (*this) { - if (rhs) { - using std::swap; - swap(**this, *rhs); - } else { - rhs.construct(std::move(**this)); - this->destruct(); - } - } else { - if (rhs) { - this->construct(std::move(*rhs)); - rhs.destruct(); - } else { - // no effect (swap(disengaged, disengaged)) - } - } - } - - // [optional.observe], observers - // You may use `*opt`, and `opt->m`, to access the underlying T value and T's - // member `m`, respectively. If the optional is empty, behavior is - // undefined. - constexpr const T* operator->() const { return this->pointer(); } - T* operator->() { - assert(this->engaged_); - return this->pointer(); - } - constexpr const T& operator*() const& { return reference(); } - T& operator*() & { - assert(this->engaged_); - return reference(); - } - constexpr const T&& operator*() const&& { return std::move(reference()); } - T&& operator*() && { - assert(this->engaged_); - return std::move(reference()); - } - - // In a bool context an optional<T> will return false if and only if it is - // empty. - // - // if (opt) { - // // do something with opt.value(); - // } else { - // // opt is empty - // } - // - constexpr explicit operator bool() const noexcept { return this->engaged_; } - - // Returns false if and only if *this is empty. - constexpr bool has_value() const noexcept { return this->engaged_; } - - // Use `opt.value()` to get a reference to underlying value. The constness - // and lvalue/rvalue-ness of `opt` is preserved to the view of the T - // subobject. - const T& value() const& { - CHECK(*this) << "Bad optional access"; - return reference(); - } - T& value() & { - CHECK(*this) << "Bad optional access"; - return reference(); - } - T&& value() && { // NOLINT(build/c++11) - CHECK(*this) << "Bad optional access"; - return std::move(reference()); - } - const T&& value() const&& { // NOLINT(build/c++11) - CHECK(*this) << "Bad optional access"; - return std::move(reference()); - } - - // Use `opt.value_or(val)` to get either the value of T or the given default - // `val` in the empty case. - template <class U> - constexpr T value_or(U&& v) const& { - return static_cast<bool>(*this) ? **this - : static_cast<T>(std::forward<U>(v)); - } - template <class U> - T value_or(U&& v) && { // NOLINT(build/c++11) - return static_cast<bool>(*this) ? std::move(**this) - : static_cast<T>(std::forward<U>(v)); - } - - private: - // Private accessors for internal storage viewed as reference to T. - constexpr const T& reference() const { return *this->pointer(); } - T& reference() { return *(this->pointer()); } - - // T constraint checks. You can't have an optional of nullopt_t, in_place_t - // or a reference. - static_assert( - !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value, - "optional<nullopt_t> is not allowed."); - static_assert( - !std::is_same<in_place_t, typename std::remove_cv<T>::type>::value, - "optional<in_place_t> is not allowed."); - static_assert(!std::is_reference<T>::value, - "optional<reference> is not allowed."); -}; - -// [optional.specalg] -// Swap, standard semantics. -// This function shall not participate in overload resolution unless -// is_move_constructible_v<T> is true and is_swappable_v<T> is true. -// NOTE: we assume is_swappable is always true. There will be a compiling error -// if T is actually not Swappable. -template <typename T, - typename std::enable_if<std::is_move_constructible<T>::value, - bool>::type = false> -void swap(optional<T>& a, optional<T>& b) noexcept(noexcept(a.swap(b))) { - a.swap(b); -} - -// NOTE: make_optional cannot be constexpr in C++11 because the copy/move -// constructor is not constexpr and we don't have guaranteed copy elision -// util C++17. But they are still declared constexpr for consistency with -// the standard. - -// make_optional(v) creates a non-empty optional<T> where the type T is deduced -// from v. Can also be explicitly instantiated as make_optional<T>(v). -template <typename T> -constexpr optional<typename std::decay<T>::type> make_optional(T&& v) { - return optional<typename std::decay<T>::type>(std::forward<T>(v)); -} - -template <typename T, typename... Args> -constexpr optional<T> make_optional(Args&&... args) { - return optional<T>(in_place_t(), internal_optional::forward<Args>(args)...); -} - -template <typename T, typename U, typename... Args> -constexpr optional<T> make_optional(std::initializer_list<U> il, - Args&&... args) { - return optional<T>(in_place_t(), il, - internal_optional::forward<Args>(args)...); -} - -// Relational operators. Empty optionals are considered equal to each -// other and less than non-empty optionals. Supports relations between -// optional<T> and optional<T>, between optional<T> and T, and between -// optional<T> and nullopt. -// Note: We're careful to support T having non-bool relationals. - -// Relational operators [optional.relops] -// The C++17 (N4606) "Returns:" statements are translated into code -// in an obvious way here, and the original text retained as function docs. -// Returns: If bool(x) != bool(y), false; otherwise if bool(x) == false, true; -// otherwise *x == *y. -template <class T> -constexpr bool operator==(const optional<T>& x, const optional<T>& y) { - return static_cast<bool>(x) != static_cast<bool>(y) - ? false - : static_cast<bool>(x) == false ? true : *x == *y; -} -// Returns: If bool(x) != bool(y), true; otherwise, if bool(x) == false, false; -// otherwise *x != *y. -template <class T> -constexpr bool operator!=(const optional<T>& x, const optional<T>& y) { - return static_cast<bool>(x) != static_cast<bool>(y) - ? true - : static_cast<bool>(x) == false ? false : *x != *y; -} -// Returns: If !y, false; otherwise, if !x, true; otherwise *x < *y. -template <class T> -constexpr bool operator<(const optional<T>& x, const optional<T>& y) { - return !y ? false : !x ? true : *x < *y; -} -// Returns: If !x, false; otherwise, if !y, true; otherwise *x > *y. -template <class T> -constexpr bool operator>(const optional<T>& x, const optional<T>& y) { - return !x ? false : !y ? true : *x > *y; -} -// Returns: If !x, true; otherwise, if !y, false; otherwise *x <= *y. -template <class T> -constexpr bool operator<=(const optional<T>& x, const optional<T>& y) { - return !x ? true : !y ? false : *x <= *y; -} -// Returns: If !y, true; otherwise, if !x, false; otherwise *x >= *y. -template <class T> -constexpr bool operator>=(const optional<T>& x, const optional<T>& y) { - return !y ? true : !x ? false : *x >= *y; -} - -// Comparison with nullopt [optional.nullops] -// The C++17 (N4606) "Returns:" statements are used directly here. -template <class T> -constexpr bool operator==(const optional<T>& x, nullopt_t) noexcept { - return !x; -} -template <class T> -constexpr bool operator==(nullopt_t, const optional<T>& x) noexcept { - return !x; -} -template <class T> -constexpr bool operator!=(const optional<T>& x, nullopt_t) noexcept { - return static_cast<bool>(x); -} -template <class T> -constexpr bool operator!=(nullopt_t, const optional<T>& x) noexcept { - return static_cast<bool>(x); -} -template <class T> -constexpr bool operator<(const optional<T>& x, nullopt_t) noexcept { - return false; -} -template <class T> -constexpr bool operator<(nullopt_t, const optional<T>& x) noexcept { - return static_cast<bool>(x); -} -template <class T> -constexpr bool operator<=(const optional<T>& x, nullopt_t) noexcept { - return !x; -} -template <class T> -constexpr bool operator<=(nullopt_t, const optional<T>& x) noexcept { - return true; -} -template <class T> -constexpr bool operator>(const optional<T>& x, nullopt_t) noexcept { - return static_cast<bool>(x); -} -template <class T> -constexpr bool operator>(nullopt_t, const optional<T>& x) noexcept { - return false; -} -template <class T> -constexpr bool operator>=(const optional<T>& x, nullopt_t) noexcept { - return true; -} -template <class T> -constexpr bool operator>=(nullopt_t, const optional<T>& x) noexcept { - return !x; -} - -// Comparison with T [optional.comp_with_t] -// The C++17 (N4606) "Equivalent to:" statements are used directly here. -template <class T> -constexpr bool operator==(const optional<T>& x, const T& v) { - return static_cast<bool>(x) ? *x == v : false; -} -template <class T> -constexpr bool operator==(const T& v, const optional<T>& x) { - return static_cast<bool>(x) ? v == *x : false; -} -template <class T> -constexpr bool operator!=(const optional<T>& x, const T& v) { - return static_cast<bool>(x) ? *x != v : true; -} -template <class T> -constexpr bool operator!=(const T& v, const optional<T>& x) { - return static_cast<bool>(x) ? v != *x : true; -} -template <class T> -constexpr bool operator<(const optional<T>& x, const T& v) { - return static_cast<bool>(x) ? *x < v : true; -} -template <class T> -constexpr bool operator<(const T& v, const optional<T>& x) { - return static_cast<bool>(x) ? v < *x : false; -} -template <class T> -constexpr bool operator<=(const optional<T>& x, const T& v) { - return static_cast<bool>(x) ? *x <= v : true; -} -template <class T> -constexpr bool operator<=(const T& v, const optional<T>& x) { - return static_cast<bool>(x) ? v <= *x : false; -} -template <class T> -constexpr bool operator>(const optional<T>& x, const T& v) { - return static_cast<bool>(x) ? *x > v : false; -} -template <class T> -constexpr bool operator>(const T& v, const optional<T>& x) { - return static_cast<bool>(x) ? v > *x : true; -} -template <class T> -constexpr bool operator>=(const optional<T>& x, const T& v) { - return static_cast<bool>(x) ? *x >= v : false; -} -template <class T> -constexpr bool operator>=(const T& v, const optional<T>& x) { - return static_cast<bool>(x) ? v >= *x : true; -} +using optional = absl::optional<T>; } // namespace gtl } // namespace tensorflow -namespace std { - -// Normally std::hash specializations are not recommended in tensorflow code, -// but we allow this as it is following a standard library component. -template <class T> -struct hash<::tensorflow::gtl::optional<T>> { - size_t operator()(const ::tensorflow::gtl::optional<T>& opt) const { - if (opt) { - return hash<T>()(*opt); - } else { - return static_cast<size_t>(0x297814aaad196e6dULL); - } - } -}; - -} // namespace std - #endif // TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_ diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc deleted file mode 100644 index 12b5bbc60b..0000000000 --- a/tensorflow/core/lib/gtl/optional_test.cc +++ /dev/null @@ -1,1098 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/lib/gtl/optional.h" - -#include <string> -#include <utility> - -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { -namespace { - -using tensorflow::gtl::in_place; -using tensorflow::gtl::in_place_t; -using tensorflow::gtl::make_optional; -using tensorflow::gtl::nullopt; -using tensorflow::gtl::nullopt_t; -using tensorflow::gtl::optional; - -template <typename T> -string TypeQuals(T&) { - return "&"; -} -template <typename T> -string TypeQuals(T&&) { - return "&&"; -} -template <typename T> -string TypeQuals(const T&) { - return "c&"; -} -template <typename T> -string TypeQuals(const T&&) { - return "c&&"; -} - -struct StructorListener { - int construct0 = 0; - int construct1 = 0; - int construct2 = 0; - int listinit = 0; - int copy = 0; - int move = 0; - int copy_assign = 0; - int move_assign = 0; - int destruct = 0; -}; - -struct Listenable { - static StructorListener* listener; - - Listenable() { ++listener->construct0; } - Listenable(int /*unused*/) { ++listener->construct1; } // NOLINT - Listenable(int /*unused*/, int /*unused*/) { ++listener->construct2; } - Listenable(std::initializer_list<int> /*unused*/) { ++listener->listinit; } - Listenable(const Listenable& /*unused*/) { ++listener->copy; } - Listenable(Listenable&& /*unused*/) { ++listener->move; } // NOLINT - Listenable& operator=(const Listenable& /*unused*/) { - ++listener->copy_assign; - return *this; - } - Listenable& operator=(Listenable&& /*unused*/) { // NOLINT - ++listener->move_assign; - return *this; - } - ~Listenable() { ++listener->destruct; } -}; - -StructorListener* Listenable::listener = nullptr; - -// clang on macos -- even the latest major version at time of writing (8.x) -- -// does not like much of our constexpr business. clang < 3.0 also has trouble. -#if defined(__clang__) && defined(__APPLE__) -#define SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG -#endif - -struct ConstexprType { - constexpr ConstexprType() : x(0) {} - constexpr explicit ConstexprType(int i) : x(i) {} -#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG - constexpr ConstexprType(std::initializer_list<int> il) : x(il.size()) {} -#endif - constexpr ConstexprType(const char* s) : x(-1) {} // NOLINT - int x; -}; - -struct Copyable { - Copyable() {} - Copyable(const Copyable&) {} - Copyable& operator=(const Copyable&) { return *this; } -}; - -struct MoveableThrow { - MoveableThrow() {} - MoveableThrow(MoveableThrow&&) {} - MoveableThrow& operator=(MoveableThrow&&) { return *this; } -}; - -struct MoveableNoThrow { - MoveableNoThrow() {} - MoveableNoThrow(MoveableNoThrow&&) noexcept {} - MoveableNoThrow& operator=(MoveableNoThrow&&) noexcept { return *this; } -}; - -struct NonMovable { - NonMovable() {} - NonMovable(const NonMovable&) = delete; - NonMovable& operator=(const NonMovable&) = delete; - NonMovable(NonMovable&&) = delete; - NonMovable& operator=(NonMovable&&) = delete; -}; - -TEST(optionalTest, DefaultConstructor) { - optional<int> empty; - EXPECT_FALSE(!!empty); - constexpr optional<int> cempty; - static_assert(!cempty.has_value(), ""); - EXPECT_TRUE(std::is_nothrow_default_constructible<optional<int>>::value); -} - -TEST(optionalTest, NullOptConstructor) { - optional<int> empty(nullopt); - EXPECT_FALSE(!!empty); - // Creating a temporary nullopt_t object instead of using nullopt because - // nullopt cannot be constexpr and have external linkage at the same time. - constexpr optional<int> cempty{nullopt_t(nullopt_t::init)}; - static_assert(!cempty.has_value(), ""); - EXPECT_TRUE((std::is_nothrow_constructible<optional<int>, nullopt_t>::value)); -} - -TEST(optionalTest, CopyConstructor) { - optional<int> empty, opt42 = 42; - optional<int> empty_copy(empty); - EXPECT_FALSE(!!empty_copy); - optional<int> opt42_copy(opt42); - EXPECT_TRUE(!!opt42_copy); - EXPECT_EQ(42, opt42_copy); - // test copyablility - EXPECT_TRUE(std::is_copy_constructible<optional<int>>::value); - EXPECT_TRUE(std::is_copy_constructible<optional<Copyable>>::value); - EXPECT_FALSE(std::is_copy_constructible<optional<MoveableThrow>>::value); - EXPECT_FALSE(std::is_copy_constructible<optional<MoveableNoThrow>>::value); - EXPECT_FALSE(std::is_copy_constructible<optional<NonMovable>>::value); -} - -TEST(optionalTest, MoveConstructor) { - optional<int> empty, opt42 = 42; - optional<int> empty_move(std::move(empty)); - EXPECT_FALSE(!!empty_move); - optional<int> opt42_move(std::move(opt42)); - EXPECT_TRUE(!!opt42_move); - EXPECT_EQ(42, opt42_move); - // test movability - EXPECT_TRUE(std::is_move_constructible<optional<int>>::value); - EXPECT_TRUE(std::is_move_constructible<optional<Copyable>>::value); - EXPECT_TRUE(std::is_move_constructible<optional<MoveableThrow>>::value); - EXPECT_TRUE(std::is_move_constructible<optional<MoveableNoThrow>>::value); - EXPECT_FALSE(std::is_move_constructible<optional<NonMovable>>::value); - // test noexcept - EXPECT_TRUE(std::is_nothrow_move_constructible<optional<int>>::value); - EXPECT_FALSE( - std::is_nothrow_move_constructible<optional<MoveableThrow>>::value); - EXPECT_TRUE( - std::is_nothrow_move_constructible<optional<MoveableNoThrow>>::value); -} - -TEST(optionalTest, Destructor) { - struct Trivial {}; - - struct NonTrivial { - ~NonTrivial() {} - }; - - EXPECT_TRUE(std::is_trivially_destructible<optional<int>>::value); - EXPECT_TRUE(std::is_trivially_destructible<optional<Trivial>>::value); - EXPECT_FALSE(std::is_trivially_destructible<optional<NonTrivial>>::value); -} - -TEST(optionalTest, InPlaceConstructor) { - constexpr optional<ConstexprType> opt0{in_place_t()}; - static_assert(opt0, ""); - static_assert(opt0->x == 0, ""); - constexpr optional<ConstexprType> opt1{in_place_t(), 1}; - static_assert(opt1, ""); - static_assert(opt1->x == 1, ""); -#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG - constexpr optional<ConstexprType> opt2{in_place_t(), {1, 2}}; - static_assert(opt2, ""); - static_assert(opt2->x == 2, ""); -#endif - - // TODO(b/34201852): uncomment these when std::is_constructible<T, Args&&...> - // SFINAE is added to optional::optional(in_place_t, Args&&...). - // struct I { - // I(in_place_t); - // }; - - // EXPECT_FALSE((std::is_constructible<optional<I>, in_place_t>::value)); - // EXPECT_FALSE((std::is_constructible<optional<I>, const - // in_place_t&>::value)); -} - -// template<U=T> optional(U&&); -TEST(optionalTest, ValueConstructor) { - constexpr optional<int> opt0(0); - static_assert(opt0, ""); - static_assert(*opt0 == 0, ""); - EXPECT_TRUE((std::is_convertible<int, optional<int>>::value)); - // Copy initialization ( = "abc") won't work due to optional(optional&&) - // is not constexpr. Use list initialization instead. This invokes - // optional<ConstexprType>::optional<U>(U&&), with U = const char (&) [4], - // which direct-initializes the ConstexprType value held by the optional - // via ConstexprType::ConstexprType(const char*). - constexpr optional<ConstexprType> opt1 = {"abc"}; - static_assert(opt1, ""); - static_assert(-1 == opt1->x, ""); - EXPECT_TRUE( - (std::is_convertible<const char*, optional<ConstexprType>>::value)); - // direct initialization - constexpr optional<ConstexprType> opt2{2}; - static_assert(opt2, ""); - static_assert(2 == opt2->x, ""); - EXPECT_FALSE((std::is_convertible<int, optional<ConstexprType>>::value)); - - // this invokes optional<int>::optional(int&&) - // NOTE: this has different behavior than assignment, e.g. - // "opt3 = {};" clears the optional rather than setting the value to 0 - constexpr optional<int> opt3({}); - static_assert(opt3, ""); - static_assert(*opt3 == 0, ""); - - // this invokes the move constructor with a default constructed optional - // because non-template function is a better match than template function. - optional<ConstexprType> opt4({}); - EXPECT_FALSE(!!opt4); -} - -struct Implicit {}; - -struct Explicit {}; - -struct Convert { - Convert(const Implicit&) // NOLINT(runtime/explicit) - : implicit(true), move(false) {} - Convert(Implicit&&) // NOLINT(runtime/explicit) - : implicit(true), move(true) {} - explicit Convert(const Explicit&) : implicit(false), move(false) {} - explicit Convert(Explicit&&) : implicit(false), move(true) {} - - bool implicit; - bool move; -}; - -struct ConvertFromOptional { - ConvertFromOptional(const Implicit&) // NOLINT(runtime/explicit) - : implicit(true), move(false), from_optional(false) {} - ConvertFromOptional(Implicit&&) // NOLINT(runtime/explicit) - : implicit(true), move(true), from_optional(false) {} - ConvertFromOptional(const optional<Implicit>&) // NOLINT(runtime/explicit) - : implicit(true), move(false), from_optional(true) {} - ConvertFromOptional(optional<Implicit>&&) // NOLINT(runtime/explicit) - : implicit(true), move(true), from_optional(true) {} - explicit ConvertFromOptional(const Explicit&) - : implicit(false), move(false), from_optional(false) {} - explicit ConvertFromOptional(Explicit&&) - : implicit(false), move(true), from_optional(false) {} - explicit ConvertFromOptional(const optional<Explicit>&) - : implicit(false), move(false), from_optional(true) {} - explicit ConvertFromOptional(optional<Explicit>&&) - : implicit(false), move(true), from_optional(true) {} - - bool implicit; - bool move; - bool from_optional; -}; - -TEST(optionalTest, ConvertingConstructor) { - optional<Implicit> i_empty; - optional<Implicit> i(in_place); - optional<Explicit> e_empty; - optional<Explicit> e(in_place); - { - // implicitly constructing optional<Convert> from optional<Implicit> - optional<Convert> empty = i_empty; - EXPECT_FALSE(!!empty); - optional<Convert> opt_copy = i; - EXPECT_TRUE(!!opt_copy); - EXPECT_TRUE(opt_copy->implicit); - EXPECT_FALSE(opt_copy->move); - optional<Convert> opt_move = optional<Implicit>(in_place); - EXPECT_TRUE(!!opt_move); - EXPECT_TRUE(opt_move->implicit); - EXPECT_TRUE(opt_move->move); - } - { - // explicitly constructing optional<Convert> from optional<Explicit> - optional<Convert> empty(e_empty); - EXPECT_FALSE(!!empty); - optional<Convert> opt_copy(e); - EXPECT_TRUE(!!opt_copy); - EXPECT_FALSE(opt_copy->implicit); - EXPECT_FALSE(opt_copy->move); - EXPECT_FALSE((std::is_convertible<const optional<Explicit>&, - optional<Convert>>::value)); - optional<Convert> opt_move{optional<Explicit>(in_place)}; - EXPECT_TRUE(!!opt_move); - EXPECT_FALSE(opt_move->implicit); - EXPECT_TRUE(opt_move->move); - EXPECT_FALSE( - (std::is_convertible<optional<Explicit>&&, optional<Convert>>::value)); - } - { - // implicitly constructing optional<ConvertFromOptional> from - // optional<Implicit> via ConvertFromOptional(optional<Implicit>&&) - // check that ConvertFromOptional(Implicit&&) is NOT called - static_assert( - gtl::internal_optional::is_constructible_convertible_from_optional< - ConvertFromOptional, Implicit>::value, - ""); - optional<ConvertFromOptional> opt0 = i_empty; - EXPECT_TRUE(!!opt0); - EXPECT_TRUE(opt0->implicit); - EXPECT_FALSE(opt0->move); - EXPECT_TRUE(opt0->from_optional); - optional<ConvertFromOptional> opt1 = optional<Implicit>(); - EXPECT_TRUE(!!opt1); - EXPECT_TRUE(opt1->implicit); - EXPECT_TRUE(opt1->move); - EXPECT_TRUE(opt1->from_optional); - } - { - // implicitly constructing optional<ConvertFromOptional> from - // optional<Explicit> via ConvertFromOptional(optional<Explicit>&&) - // check that ConvertFromOptional(Explicit&&) is NOT called - optional<ConvertFromOptional> opt0(e_empty); - EXPECT_TRUE(!!opt0); - EXPECT_FALSE(opt0->implicit); - EXPECT_FALSE(opt0->move); - EXPECT_TRUE(opt0->from_optional); - EXPECT_FALSE((std::is_convertible<const optional<Explicit>&, - optional<ConvertFromOptional>>::value)); - optional<ConvertFromOptional> opt1{optional<Explicit>()}; - EXPECT_TRUE(!!opt1); - EXPECT_FALSE(opt1->implicit); - EXPECT_TRUE(opt1->move); - EXPECT_TRUE(opt1->from_optional); - EXPECT_FALSE((std::is_convertible<optional<Explicit>&&, - optional<ConvertFromOptional>>::value)); - } -} - -TEST(optionalTest, StructorBasic) { - StructorListener listener; - Listenable::listener = &listener; - { - optional<Listenable> empty; - EXPECT_FALSE(!!empty); - optional<Listenable> opt0(in_place); - EXPECT_TRUE(!!opt0); - optional<Listenable> opt1(in_place, 1); - EXPECT_TRUE(!!opt1); - optional<Listenable> opt2(in_place, 1, 2); - EXPECT_TRUE(!!opt2); - } - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(1, listener.construct1); - EXPECT_EQ(1, listener.construct2); - EXPECT_EQ(3, listener.destruct); -} - -TEST(optionalTest, CopyMoveStructor) { - StructorListener listener; - Listenable::listener = &listener; - optional<Listenable> original(in_place); - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(0, listener.copy); - EXPECT_EQ(0, listener.move); - optional<Listenable> copy(original); - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(1, listener.copy); - EXPECT_EQ(0, listener.move); - optional<Listenable> move(std::move(original)); - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(1, listener.copy); - EXPECT_EQ(1, listener.move); -} - -TEST(optionalTest, ListInit) { - StructorListener listener; - Listenable::listener = &listener; - optional<Listenable> listinit1(in_place, {1}); - optional<Listenable> listinit2(in_place, {1, 2}); - EXPECT_EQ(2, listener.listinit); -} - -TEST(optionalTest, AssignFromNullopt) { - optional<int> opt(1); - opt = nullopt; - EXPECT_FALSE(!!opt); - - StructorListener listener; - Listenable::listener = &listener; - optional<Listenable> opt1(in_place); - opt1 = nullopt; - EXPECT_FALSE(opt1); - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(1, listener.destruct); - - EXPECT_TRUE((std::is_nothrow_assignable<optional<int>, nullopt_t>::value)); - EXPECT_TRUE( - (std::is_nothrow_assignable<optional<Listenable>, nullopt_t>::value)); -} - -TEST(optionalTest, CopyAssignment) { - const optional<int> empty, opt1 = 1, opt2 = 2; - optional<int> empty_to_opt1, opt1_to_opt2, opt2_to_empty; - - EXPECT_FALSE(!!empty_to_opt1); - empty_to_opt1 = empty; - EXPECT_FALSE(!!empty_to_opt1); - empty_to_opt1 = opt1; - EXPECT_TRUE(!!empty_to_opt1); - EXPECT_EQ(1, empty_to_opt1.value()); - - EXPECT_FALSE(!!opt1_to_opt2); - opt1_to_opt2 = opt1; - EXPECT_TRUE(!!opt1_to_opt2); - EXPECT_EQ(1, opt1_to_opt2.value()); - opt1_to_opt2 = opt2; - EXPECT_TRUE(!!opt1_to_opt2); - EXPECT_EQ(2, opt1_to_opt2.value()); - - EXPECT_FALSE(!!opt2_to_empty); - opt2_to_empty = opt2; - EXPECT_TRUE(!!opt2_to_empty); - EXPECT_EQ(2, opt2_to_empty.value()); - opt2_to_empty = empty; - EXPECT_FALSE(!!opt2_to_empty); - - EXPECT_TRUE(std::is_copy_assignable<optional<Copyable>>::value); - EXPECT_FALSE(std::is_copy_assignable<optional<MoveableThrow>>::value); - EXPECT_FALSE(std::is_copy_assignable<optional<MoveableNoThrow>>::value); - EXPECT_FALSE(std::is_copy_assignable<optional<NonMovable>>::value); -} - -TEST(optionalTest, MoveAssignment) { - StructorListener listener; - Listenable::listener = &listener; - - optional<Listenable> empty1, empty2, set1(in_place), set2(in_place); - EXPECT_EQ(2, listener.construct0); - optional<Listenable> empty_to_empty, empty_to_set, set_to_empty(in_place), - set_to_set(in_place); - EXPECT_EQ(4, listener.construct0); - empty_to_empty = std::move(empty1); - empty_to_set = std::move(set1); - set_to_empty = std::move(empty2); - set_to_set = std::move(set2); - EXPECT_EQ(0, listener.copy); - EXPECT_EQ(1, listener.move); - EXPECT_EQ(1, listener.destruct); - EXPECT_EQ(1, listener.move_assign); - - EXPECT_TRUE(std::is_move_assignable<optional<Copyable>>::value); - EXPECT_TRUE(std::is_move_assignable<optional<MoveableThrow>>::value); - EXPECT_TRUE(std::is_move_assignable<optional<MoveableNoThrow>>::value); - EXPECT_FALSE(std::is_move_assignable<optional<NonMovable>>::value); - - EXPECT_FALSE(std::is_nothrow_move_assignable<optional<MoveableThrow>>::value); - EXPECT_TRUE( - std::is_nothrow_move_assignable<optional<MoveableNoThrow>>::value); -} - -struct NoConvertToOptional { - // disable implicit conversion from const NoConvertToOptional& - // to optional<NoConvertToOptional>. - NoConvertToOptional(const NoConvertToOptional&) = delete; -}; - -struct CopyConvert { - CopyConvert(const NoConvertToOptional&); - CopyConvert& operator=(const CopyConvert&) = delete; - CopyConvert& operator=(const NoConvertToOptional&); -}; - -struct CopyConvertFromOptional { - CopyConvertFromOptional(const NoConvertToOptional&); - CopyConvertFromOptional(const optional<NoConvertToOptional>&); - CopyConvertFromOptional& operator=(const CopyConvertFromOptional&) = delete; - CopyConvertFromOptional& operator=(const NoConvertToOptional&); - CopyConvertFromOptional& operator=(const optional<NoConvertToOptional>&); -}; - -struct MoveConvert { - MoveConvert(NoConvertToOptional&&); - MoveConvert& operator=(const MoveConvert&) = delete; - MoveConvert& operator=(NoConvertToOptional&&); -}; - -struct MoveConvertFromOptional { - MoveConvertFromOptional(NoConvertToOptional&&); - MoveConvertFromOptional(optional<NoConvertToOptional>&&); - MoveConvertFromOptional& operator=(const MoveConvertFromOptional&) = delete; - MoveConvertFromOptional& operator=(NoConvertToOptional&&); - MoveConvertFromOptional& operator=(optional<NoConvertToOptional>&&); -}; - -// template <class U = T> optional<T>& operator=(U&& v); -TEST(optionalTest, ValueAssignment) { - optional<int> opt; - EXPECT_FALSE(!!opt); - opt = 42; - EXPECT_TRUE(!!opt); - EXPECT_EQ(42, opt.value()); - opt = nullopt; - EXPECT_FALSE(!!opt); - opt = 42; - EXPECT_TRUE(!!opt); - EXPECT_EQ(42, opt.value()); - opt = 43; - EXPECT_TRUE(!!opt); - EXPECT_EQ(43, opt.value()); - opt = {}; // this should clear optional - EXPECT_FALSE(!!opt); - - opt = {44}; - EXPECT_TRUE(!!opt); - EXPECT_EQ(44, opt.value()); - - // U = const NoConvertToOptional& - EXPECT_TRUE((std::is_assignable<optional<CopyConvert>&, - const NoConvertToOptional&>::value)); - // U = const optional<NoConvertToOptional>& - EXPECT_TRUE((std::is_assignable<optional<CopyConvertFromOptional>&, - const NoConvertToOptional&>::value)); - // U = const NoConvertToOptional& triggers SFINAE because - // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false - EXPECT_FALSE((std::is_assignable<optional<MoveConvert>&, - const NoConvertToOptional&>::value)); - // U = NoConvertToOptional - EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&, - NoConvertToOptional&&>::value)); - // U = const NoConvertToOptional& triggers SFINAE because - // std::is_constructible_v<MoveConvertFromOptional, const - // NoConvertToOptional&> is false - EXPECT_FALSE((std::is_assignable<optional<MoveConvertFromOptional>&, - const NoConvertToOptional&>::value)); - // U = NoConvertToOptional - EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&, - NoConvertToOptional&&>::value)); - // U = const optional<NoConvertToOptional>& - EXPECT_TRUE( - (std::is_assignable<optional<CopyConvertFromOptional>&, - const optional<NoConvertToOptional>&>::value)); - // U = optional<NoConvertToOptional> - EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&, - optional<NoConvertToOptional>&&>::value)); -} - -// template <class U> optional<T>& operator=(const optional<U>& rhs); -// template <class U> optional<T>& operator=(optional<U>&& rhs); -TEST(optionalTest, ConvertingAssignment) { - optional<int> opt_i; - optional<char> opt_c('c'); - opt_i = opt_c; - EXPECT_TRUE(!!opt_i); - EXPECT_EQ(*opt_c, *opt_i); - opt_i = optional<char>(); - EXPECT_FALSE(!!opt_i); - opt_i = optional<char>('d'); - EXPECT_TRUE(!!opt_i); - EXPECT_EQ('d', *opt_i); - - optional<string> opt_str; - optional<const char*> opt_cstr("abc"); - opt_str = opt_cstr; - EXPECT_TRUE(!!opt_str); - EXPECT_EQ(string("abc"), *opt_str); - opt_str = optional<const char*>(); - EXPECT_FALSE(!!opt_str); - opt_str = optional<const char*>("def"); - EXPECT_TRUE(!!opt_str); - EXPECT_EQ(string("def"), *opt_str); - - // operator=(const optional<U>&) with U = NoConvertToOptional - EXPECT_TRUE( - (std::is_assignable<optional<CopyConvert>, - const optional<NoConvertToOptional>&>::value)); - // operator=(const optional<U>&) with U = NoConvertToOptional - // triggers SFINAE because - // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false - EXPECT_FALSE( - (std::is_assignable<optional<MoveConvert>&, - const optional<NoConvertToOptional>&>::value)); - // operator=(optional<U>&&) with U = NoConvertToOptional - EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&, - optional<NoConvertToOptional>&&>::value)); - // operator=(const optional<U>&) with U = NoConvertToOptional triggers SFINAE - // because std::is_constructible_v<MoveConvertFromOptional, - // const NoConvertToOptional&> is false. - // operator=(U&&) with U = const optional<NoConverToOptional>& triggers SFINAE - // because std::is_constructible<MoveConvertFromOptional, - // optional<NoConvertToOptional>&&> is true. - EXPECT_FALSE( - (std::is_assignable<optional<MoveConvertFromOptional>&, - const optional<NoConvertToOptional>&>::value)); -} - -TEST(optionalTest, ResetAndHasValue) { - StructorListener listener; - Listenable::listener = &listener; - optional<Listenable> opt; - EXPECT_FALSE(!!opt); - EXPECT_FALSE(opt.has_value()); - opt.emplace(); - EXPECT_TRUE(!!opt); - EXPECT_TRUE(opt.has_value()); - opt.reset(); - EXPECT_FALSE(!!opt); - EXPECT_FALSE(opt.has_value()); - EXPECT_EQ(1, listener.destruct); - opt.reset(); - EXPECT_FALSE(!!opt); - EXPECT_FALSE(opt.has_value()); - - constexpr optional<int> empty; - static_assert(!empty.has_value(), ""); - constexpr optional<int> nonempty(1); - static_assert(nonempty.has_value(), ""); -} - -TEST(optionalTest, Emplace) { - StructorListener listener; - Listenable::listener = &listener; - optional<Listenable> opt; - EXPECT_FALSE(!!opt); - opt.emplace(1); - EXPECT_TRUE(!!opt); - opt.emplace(1, 2); - EXPECT_EQ(1, listener.construct1); - EXPECT_EQ(1, listener.construct2); - EXPECT_EQ(1, listener.destruct); -} - -TEST(optionalTest, ListEmplace) { - StructorListener listener; - Listenable::listener = &listener; - optional<Listenable> opt; - EXPECT_FALSE(!!opt); - opt.emplace({1}); - EXPECT_TRUE(!!opt); - opt.emplace({1, 2}); - EXPECT_EQ(2, listener.listinit); - EXPECT_EQ(1, listener.destruct); -} - -TEST(optionalTest, Swap) { - optional<int> opt_empty, opt1 = 1, opt2 = 2; - EXPECT_FALSE(!!opt_empty); - EXPECT_TRUE(!!opt1); - EXPECT_EQ(1, opt1.value()); - EXPECT_TRUE(!!opt2); - EXPECT_EQ(2, opt2.value()); - swap(opt_empty, opt1); - EXPECT_FALSE(!!opt1); - EXPECT_TRUE(!!opt_empty); - EXPECT_EQ(1, opt_empty.value()); - EXPECT_TRUE(!!opt2); - EXPECT_EQ(2, opt2.value()); - swap(opt_empty, opt1); - EXPECT_FALSE(!!opt_empty); - EXPECT_TRUE(!!opt1); - EXPECT_EQ(1, opt1.value()); - EXPECT_TRUE(!!opt2); - EXPECT_EQ(2, opt2.value()); - swap(opt1, opt2); - EXPECT_FALSE(!!opt_empty); - EXPECT_TRUE(!!opt1); - EXPECT_EQ(2, opt1.value()); - EXPECT_TRUE(!!opt2); - EXPECT_EQ(1, opt2.value()); - - EXPECT_TRUE(noexcept(opt1.swap(opt2))); - EXPECT_TRUE(noexcept(swap(opt1, opt2))); -} - -TEST(optionalTest, PointerStuff) { - optional<string> opt(in_place, "foo"); - EXPECT_EQ("foo", *opt); - const auto& opt_const = opt; - EXPECT_EQ("foo", *opt_const); - EXPECT_EQ(opt->size(), 3); - EXPECT_EQ(opt_const->size(), 3); - - constexpr optional<ConstexprType> opt1(1); - static_assert(opt1->x == 1, ""); -} - -// gcc has a bug pre 4.9 where it doesn't do correct overload resolution -// between rvalue reference qualified member methods. Skip that test to make -// the build green again when using the old compiler. -#if defined(__GNUC__) && !defined(__clang__) -#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 9) -#define SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG -#endif -#endif - -TEST(optionalTest, Value) { - using O = optional<string>; - using CO = const optional<string>; - O lvalue(in_place, "lvalue"); - CO clvalue(in_place, "clvalue"); - EXPECT_EQ("lvalue", lvalue.value()); - EXPECT_EQ("clvalue", clvalue.value()); - EXPECT_EQ("xvalue", O(in_place, "xvalue").value()); -#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG - EXPECT_EQ("cxvalue", CO(in_place, "cxvalue").value()); - EXPECT_EQ("&", TypeQuals(lvalue.value())); - EXPECT_EQ("c&", TypeQuals(clvalue.value())); - EXPECT_EQ("&&", TypeQuals(O(in_place, "xvalue").value())); - EXPECT_EQ("c&&", TypeQuals(CO(in_place, "cxvalue").value())); -#endif -} - -TEST(optionalTest, DerefOperator) { - using O = optional<string>; - using CO = const optional<string>; - O lvalue(in_place, "lvalue"); - CO clvalue(in_place, "clvalue"); - EXPECT_EQ("lvalue", *lvalue); - EXPECT_EQ("clvalue", *clvalue); - EXPECT_EQ("xvalue", *O(in_place, "xvalue")); -#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG - EXPECT_EQ("cxvalue", *CO(in_place, "cxvalue")); - EXPECT_EQ("&", TypeQuals(*lvalue)); - EXPECT_EQ("c&", TypeQuals(*clvalue)); - EXPECT_EQ("&&", TypeQuals(*O(in_place, "xvalue"))); - EXPECT_EQ("c&&", TypeQuals(*CO(in_place, "cxvalue"))); -#endif - - constexpr optional<int> opt1(1); - static_assert(*opt1 == 1, ""); - -#if !defined(SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG) && \ - !defined(SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG) - using COI = const optional<int>; - static_assert(*COI(2) == 2, ""); -#endif -} - -TEST(optionalTest, ValueOr) { - optional<double> opt_empty, opt_set = 1.2; - EXPECT_EQ(42.0, opt_empty.value_or(42)); - EXPECT_EQ(1.2, opt_set.value_or(42)); - EXPECT_EQ(42.0, optional<double>().value_or(42)); - EXPECT_EQ(1.2, optional<double>(1.2).value_or(42)); - -#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG - constexpr optional<double> copt_empty; - static_assert(42.0 == copt_empty.value_or(42), ""); - - constexpr optional<double> copt_set = {1.2}; - static_assert(1.2 == copt_set.value_or(42), ""); - - using COD = const optional<double>; - static_assert(42.0 == COD().value_or(42), ""); - static_assert(1.2 == COD(1.2).value_or(42), ""); -#endif -} - -// make_optional cannot be constexpr until C++17 -TEST(optionalTest, make_optional) { - auto opt_int = make_optional(42); - EXPECT_TRUE((std::is_same<decltype(opt_int), optional<int>>::value)); - EXPECT_EQ(42, opt_int); - - StructorListener listener; - Listenable::listener = &listener; - - optional<Listenable> opt0 = make_optional<Listenable>(); - EXPECT_EQ(1, listener.construct0); - optional<Listenable> opt1 = make_optional<Listenable>(1); - EXPECT_EQ(1, listener.construct1); - optional<Listenable> opt2 = make_optional<Listenable>(1, 2); - EXPECT_EQ(1, listener.construct2); - optional<Listenable> opt3 = make_optional<Listenable>({1}); - optional<Listenable> opt4 = make_optional<Listenable>({1, 2}); - EXPECT_EQ(2, listener.listinit); -} - -TEST(optionalTest, Comparisons) { - optional<int> ae, be, a2 = 2, b2 = 2, a4 = 4, b4 = 4; - -#define optionalTest_Comparisons_EXPECT_LESS(x, y) \ - EXPECT_FALSE((x) == (y)); \ - EXPECT_TRUE((x) != (y)); \ - EXPECT_TRUE((x) < (y)); \ - EXPECT_FALSE((x) > (y)); \ - EXPECT_TRUE((x) <= (y)); \ - EXPECT_FALSE((x) >= (y)); - -#define optionalTest_Comparisons_EXPECT_SAME(x, y) \ - EXPECT_TRUE((x) == (y)); \ - EXPECT_FALSE((x) != (y)); \ - EXPECT_FALSE((x) < (y)); \ - EXPECT_FALSE((x) > (y)); \ - EXPECT_TRUE((x) <= (y)); \ - EXPECT_TRUE((x) >= (y)); - -#define optionalTest_Comparisons_EXPECT_GREATER(x, y) \ - EXPECT_FALSE((x) == (y)); \ - EXPECT_TRUE((x) != (y)); \ - EXPECT_FALSE((x) < (y)); \ - EXPECT_TRUE((x) > (y)); \ - EXPECT_FALSE((x) <= (y)); \ - EXPECT_TRUE((x) >= (y)); - - // LHS: nullopt, ae, a2, 3, a4 - // RHS: nullopt, be, b2, 3, b4 - - // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,nullopt); - optionalTest_Comparisons_EXPECT_SAME(nullopt, be); - optionalTest_Comparisons_EXPECT_LESS(nullopt, b2); - // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,3); - optionalTest_Comparisons_EXPECT_LESS(nullopt, b4); - - optionalTest_Comparisons_EXPECT_SAME(ae, nullopt); - optionalTest_Comparisons_EXPECT_SAME(ae, be); - optionalTest_Comparisons_EXPECT_LESS(ae, b2); - optionalTest_Comparisons_EXPECT_LESS(ae, 3); - optionalTest_Comparisons_EXPECT_LESS(ae, b4); - - optionalTest_Comparisons_EXPECT_GREATER(a2, nullopt); - optionalTest_Comparisons_EXPECT_GREATER(a2, be); - optionalTest_Comparisons_EXPECT_SAME(a2, b2); - optionalTest_Comparisons_EXPECT_LESS(a2, 3); - optionalTest_Comparisons_EXPECT_LESS(a2, b4); - - // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(3,nullopt); - optionalTest_Comparisons_EXPECT_GREATER(3, be); - optionalTest_Comparisons_EXPECT_GREATER(3, b2); - optionalTest_Comparisons_EXPECT_SAME(3, 3); - optionalTest_Comparisons_EXPECT_LESS(3, b4); - - optionalTest_Comparisons_EXPECT_GREATER(a4, nullopt); - optionalTest_Comparisons_EXPECT_GREATER(a4, be); - optionalTest_Comparisons_EXPECT_GREATER(a4, b2); - optionalTest_Comparisons_EXPECT_GREATER(a4, 3); - optionalTest_Comparisons_EXPECT_SAME(a4, b4); -} - -TEST(optionalTest, SwapRegression) { - StructorListener listener; - Listenable::listener = &listener; - - { - optional<Listenable> a; - optional<Listenable> b(in_place); - a.swap(b); - } - - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(1, listener.move); - EXPECT_EQ(2, listener.destruct); - - { - optional<Listenable> a(in_place); - optional<Listenable> b; - a.swap(b); - } - - EXPECT_EQ(2, listener.construct0); - EXPECT_EQ(2, listener.move); - EXPECT_EQ(4, listener.destruct); -} - -TEST(optionalTest, BigStringLeakCheck) { - constexpr size_t n = 1 << 16; - - using OS = optional<string>; - - OS a; - OS b = nullopt; - OS c = string(n, 'c'); - string sd(n, 'd'); - OS d = sd; - OS e(in_place, n, 'e'); - OS f; - f.emplace(n, 'f'); - - OS ca(a); - OS cb(b); - OS cc(c); - OS cd(d); - OS ce(e); - - OS oa; - OS ob = nullopt; - OS oc = string(n, 'c'); - string sod(n, 'd'); - OS od = sod; - OS oe(in_place, n, 'e'); - OS of; - of.emplace(n, 'f'); - - OS ma(std::move(oa)); - OS mb(std::move(ob)); - OS mc(std::move(oc)); - OS md(std::move(od)); - OS me(std::move(oe)); - OS mf(std::move(of)); - - OS aa1; - OS ab1 = nullopt; - OS ac1 = string(n, 'c'); - string sad1(n, 'd'); - OS ad1 = sad1; - OS ae1(in_place, n, 'e'); - OS af1; - af1.emplace(n, 'f'); - - OS aa2; - OS ab2 = nullopt; - OS ac2 = string(n, 'c'); - string sad2(n, 'd'); - OS ad2 = sad2; - OS ae2(in_place, n, 'e'); - OS af2; - af2.emplace(n, 'f'); - - aa1 = af2; - ab1 = ae2; - ac1 = ad2; - ad1 = ac2; - ae1 = ab2; - af1 = aa2; - - OS aa3; - OS ab3 = nullopt; - OS ac3 = string(n, 'c'); - string sad3(n, 'd'); - OS ad3 = sad3; - OS ae3(in_place, n, 'e'); - OS af3; - af3.emplace(n, 'f'); - - aa3 = nullopt; - ab3 = nullopt; - ac3 = nullopt; - ad3 = nullopt; - ae3 = nullopt; - af3 = nullopt; - - OS aa4; - OS ab4 = nullopt; - OS ac4 = string(n, 'c'); - string sad4(n, 'd'); - OS ad4 = sad4; - OS ae4(in_place, n, 'e'); - OS af4; - af4.emplace(n, 'f'); - - aa4 = OS(in_place, n, 'a'); - ab4 = OS(in_place, n, 'b'); - ac4 = OS(in_place, n, 'c'); - ad4 = OS(in_place, n, 'd'); - ae4 = OS(in_place, n, 'e'); - af4 = OS(in_place, n, 'f'); - - OS aa5; - OS ab5 = nullopt; - OS ac5 = string(n, 'c'); - string sad5(n, 'd'); - OS ad5 = sad5; - OS ae5(in_place, n, 'e'); - OS af5; - af5.emplace(n, 'f'); - - string saa5(n, 'a'); - string sab5(n, 'a'); - string sac5(n, 'a'); - string sad52(n, 'a'); - string sae5(n, 'a'); - string saf5(n, 'a'); - - aa5 = saa5; - ab5 = sab5; - ac5 = sac5; - ad5 = sad52; - ae5 = sae5; - af5 = saf5; - - OS aa6; - OS ab6 = nullopt; - OS ac6 = string(n, 'c'); - string sad6(n, 'd'); - OS ad6 = sad6; - OS ae6(in_place, n, 'e'); - OS af6; - af6.emplace(n, 'f'); - - aa6 = string(n, 'a'); - ab6 = string(n, 'b'); - ac6 = string(n, 'c'); - ad6 = string(n, 'd'); - ae6 = string(n, 'e'); - af6 = string(n, 'f'); - - OS aa7; - OS ab7 = nullopt; - OS ac7 = string(n, 'c'); - string sad7(n, 'd'); - OS ad7 = sad7; - OS ae7(in_place, n, 'e'); - OS af7; - af7.emplace(n, 'f'); - - aa7.emplace(n, 'A'); - ab7.emplace(n, 'B'); - ac7.emplace(n, 'C'); - ad7.emplace(n, 'D'); - ae7.emplace(n, 'E'); - af7.emplace(n, 'F'); -} - -TEST(optionalTest, MoveAssignRegression) { - StructorListener listener; - Listenable::listener = &listener; - - { - optional<Listenable> a; - Listenable b; - a = std::move(b); - } - - EXPECT_EQ(1, listener.construct0); - EXPECT_EQ(1, listener.move); - EXPECT_EQ(2, listener.destruct); -} - -TEST(optionalTest, ValueType) { - EXPECT_TRUE((std::is_same<optional<int>::value_type, int>::value)); - EXPECT_TRUE((std::is_same<optional<string>::value_type, string>::value)); - EXPECT_FALSE((std::is_same<optional<int>::value_type, nullopt_t>::value)); -} - -TEST(optionalTest, Hash) { - std::hash<optional<int>> hash; - std::set<size_t> hashcodes; - hashcodes.insert(hash(nullopt)); - for (int i = 0; i < 100; ++i) { - hashcodes.insert(hash(i)); - } - EXPECT_GT(hashcodes.size(), 90); -} - -struct MoveMeNoThrow { - MoveMeNoThrow() : x(0) {} - MoveMeNoThrow(const MoveMeNoThrow& other) : x(other.x) { - LOG(FATAL) << "Should not be called."; - } - MoveMeNoThrow(MoveMeNoThrow&& other) noexcept : x(other.x) {} - int x; -}; - -struct MoveMeThrow { - MoveMeThrow() : x(0) {} - MoveMeThrow(const MoveMeThrow& other) : x(other.x) {} - MoveMeThrow(MoveMeThrow&& other) : x(other.x) {} - int x; -}; - -TEST(optionalTest, NoExcept) { - static_assert( - std::is_nothrow_move_constructible<optional<MoveMeNoThrow>>::value, ""); - static_assert( - !std::is_nothrow_move_constructible<optional<MoveMeThrow>>::value, ""); - std::vector<optional<MoveMeNoThrow>> v; - v.reserve(10); - for (int i = 0; i < 10; ++i) v.emplace_back(); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h index e2927689d2..117b6a0bb8 100644 --- a/tensorflow/core/lib/io/block_builder.h +++ b/tensorflow/core/lib/io/block_builder.h @@ -20,6 +20,7 @@ limitations under the License. #include <stdint.h> #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace table { diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h index e3649fd0c9..38fb0c5d86 100644 --- a/tensorflow/core/lib/io/path.h +++ b/tensorflow/core/lib/io/path.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_LIB_IO_PATH_H_ #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace io { diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc index c24628be57..f93ebea771 100644 --- a/tensorflow/core/lib/io/record_reader.cc +++ b/tensorflow/core/lib/io/record_reader.cc @@ -109,9 +109,6 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) { } Status RecordReader::ReadRecord(uint64* offset, string* record) { - static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); - static const size_t kFooterSize = sizeof(uint32); - // Position the input stream. int64 curr_pos = input_stream_->Tell(); int64 desired_pos = static_cast<int64>(*offset); diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h index c05f9e1b36..11af1366b0 100644 --- a/tensorflow/core/lib/io/record_reader.h +++ b/tensorflow/core/lib/io/record_reader.h @@ -58,6 +58,14 @@ class RecordReaderOptions { // Note: this class is not thread safe; external synchronization required. class RecordReader { public: + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + // Create a reader that will return log records from "*file". // "*file" must remain live while this Reader is in use. explicit RecordReader( diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc index 6e71d23e71..2c6db2487e 100644 --- a/tensorflow/core/lib/io/record_writer.cc +++ b/tensorflow/core/lib/io/record_writer.cc @@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() { } } -static uint32 MaskedCrc(const char* data, size_t n) { - return crc32c::Mask(crc32c::Value(data, n)); -} - Status RecordWriter::WriteRecord(StringPiece data) { if (dest_ == nullptr) { return Status(::tensorflow::error::FAILED_PRECONDITION, @@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) { // uint32 masked crc of length // byte data[length] // uint32 masked crc of data - char header[sizeof(uint64) + sizeof(uint32)]; - core::EncodeFixed64(header + 0, data.size()); - core::EncodeFixed32(header + sizeof(uint64), - MaskedCrc(header, sizeof(uint64))); - char footer[sizeof(uint32)]; - core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size())); - + char header[kHeaderSize]; + char footer[kFooterSize]; + PopulateHeader(header, data.data(), data.size()); + PopulateFooter(footer, data.data(), data.size()); TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header)))); TF_RETURN_IF_ERROR(dest_->Append(data)); return dest_->Append(StringPiece(footer, sizeof(footer))); diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h index 2f6afa5487..1212e1fafb 100644 --- a/tensorflow/core/lib/io/record_writer.h +++ b/tensorflow/core/lib/io/record_writer.h @@ -16,8 +16,10 @@ limitations under the License. #ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ #define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_ +#include "tensorflow/core/lib/core/coding.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/lib/hash/crc32c.h" #if !defined(IS_SLIM_BUILD) #include "tensorflow/core/lib/io/zlib_compression_options.h" #include "tensorflow/core/lib/io/zlib_outputbuffer.h" @@ -41,12 +43,20 @@ class RecordWriterOptions { // Options specific to zlib compression. #if !defined(IS_SLIM_BUILD) - ZlibCompressionOptions zlib_options; + tensorflow::io::ZlibCompressionOptions zlib_options; #endif // IS_SLIM_BUILD }; class RecordWriter { public: + // Format of a single record: + // uint64 length + // uint32 masked crc of length + // byte data[length] + // uint32 masked crc of data + static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32); + static const size_t kFooterSize = sizeof(uint32); + // Create a writer that will append data to "*dest". // "*dest" must be initially empty. // "*dest" must remain live while this Writer is in use. @@ -72,13 +82,35 @@ class RecordWriter { // are invalid. Status Close(); + // Utility method to populate TFRecord headers. Populates record-header in + // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1]. + inline static void PopulateHeader(char* header, const char* data, size_t n); + + // Utility method to populate TFRecord footers. Populates record-footer in + // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1]. + inline static void PopulateFooter(char* footer, const char* data, size_t n); + private: WritableFile* dest_; RecordWriterOptions options_; + inline static uint32 MaskedCrc(const char* data, size_t n) { + return crc32c::Mask(crc32c::Value(data, n)); + } + TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter); }; +void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) { + core::EncodeFixed64(header + 0, n); + core::EncodeFixed32(header + sizeof(uint64), + MaskedCrc(header, sizeof(uint64))); +} + +void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) { + core::EncodeFixed32(footer, MaskedCrc(data, n)); +} + } // namespace io } // namespace tensorflow diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc index da514bd21c..946d7188d3 100644 --- a/tensorflow/core/lib/io/recordio_test.cc +++ b/tensorflow/core/lib/io/recordio_test.cc @@ -58,7 +58,7 @@ class StringDest : public WritableFile { Status Close() override { return Status::OK(); } Status Flush() override { return Status::OK(); } Status Sync() override { return Status::OK(); } - Status Append(const StringPiece& slice) override { + Status Append(StringPiece slice) override { contents_->append(slice.data(), slice.size()); return Status::OK(); } diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc index 877ac40f1c..9cebbf40c6 100644 --- a/tensorflow/core/lib/io/table_test.cc +++ b/tensorflow/core/lib/io/table_test.cc @@ -98,7 +98,7 @@ class StringSink : public WritableFile { Status Flush() override { return Status::OK(); } Status Sync() override { return Status::OK(); } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { contents_.append(data.data(), data.size()); return Status::OK(); } diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc index 84b47c171f..cba139e6ad 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.cc +++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc @@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() { return Status::OK(); } -Status ZlibOutputBuffer::Append(const StringPiece& data) { +Status ZlibOutputBuffer::Append(StringPiece data) { // If there is sufficient free space in z_stream_input_ to fit data we // add it there and return. // If there isn't enough space we deflate the existing contents of diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h index 3d86d89a99..ccad2fda44 100644 --- a/tensorflow/core/lib/io/zlib_outputbuffer.h +++ b/tensorflow/core/lib/io/zlib_outputbuffer.h @@ -62,7 +62,7 @@ class ZlibOutputBuffer : public WritableFile { // to file when the buffer is full. // // To immediately write contents to file call `Flush()`. - Status Append(const StringPiece& data) override; + Status Append(StringPiece data) override; // Deflates any cached input and writes all output to file. Status Flush() override; diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h index c204d52cfe..9e4e1989dd 100644 --- a/tensorflow/core/lib/monitoring/collection_registry.h +++ b/tensorflow/core/lib/monitoring/collection_registry.h @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace monitoring { diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h index 756e5c2af8..bc4365e439 100644 --- a/tensorflow/core/lib/monitoring/metric_def.h +++ b/tensorflow/core/lib/monitoring/metric_def.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace monitoring { diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h index bb5d20fb68..c876c5156a 100644 --- a/tensorflow/core/lib/png/png_io.h +++ b/tensorflow/core/lib/png/png_io.h @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/png.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace png { diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc index 36d939e061..c536b5688e 100644 --- a/tensorflow/core/lib/wav/wav_io.cc +++ b/tensorflow/core/lib/wav/wav_io.cc @@ -232,6 +232,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string, "Bad audio format for WAV: Expected 1 (PCM), but got", audio_format); } TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset)); + if (*channel_count < 1) { + return errors::InvalidArgument( + "Bad number of channels for WAV: Expected at least 1, but got ", + *channel_count); + } TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset)); uint32 bytes_per_second; TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset)); diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc index 01452b3e85..7c4184bff4 100644 --- a/tensorflow/core/ops/boosted_trees_ops.cc +++ b/tensorflow/core/ops/boosted_trees_ops.cc @@ -22,6 +22,10 @@ limitations under the License. namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource); REGISTER_OP("IsBoostedTreesEnsembleInitialized") @@ -354,4 +358,125 @@ REGISTER_OP("BoostedTreesCenterBias") return Status::OK(); }); +REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource); + +REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized") + .Input("quantile_stream_resource_handle: resource") + .Output("is_initialized: bool") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + c->set_output(0, c->Scalar()); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesCreateQuantileStreamResource") + .Attr("max_elements: int = 1099511627776") // 1 << 40 + .Input("quantile_stream_resource_handle: resource") + .Input("epsilon: float") + .Input("num_streams: int64") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesMakeQuantileSummaries") + .Attr("num_features: int >= 0") + .Input("float_values: num_features * float") + .Input("example_weights: float") + .Input("epsilon: float") + .Output("summaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + ShapeHandle example_weights_shape; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(num_features), 1, &example_weights_shape)); + for (int i = 0; i < num_features; ++i) { + ShapeHandle feature_shape; + DimensionHandle unused_dim; + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(example_weights_shape, 0), + &unused_dim)); + // the columns are value, weight, min_rank, max_rank. + c->set_output(i, c->MakeShape({c->UnknownDim(), 4})); + } + // epsilon must be a scalar. + ShapeHandle unused_input; + TF_RETURN_IF_ERROR( + c->WithRank(c->input(num_features + 1), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries") + .Attr("num_features: int >= 0") + .Input("quantile_stream_resource_handle: resource") + .Input("summaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + // resource handle must be a scalar. + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + // each summary must be rank 2. + for (int i = 1; i < num_features + 1; i++) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input)); + } + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceFlush") + .Attr("generate_quantiles: bool = False") + .Input("quantile_stream_resource_handle: resource") + .Input("num_buckets: int64") + .SetShapeFn([](InferenceContext* c) { + // All the inputs are scalars. + shape_inference::ShapeHandle unused_input; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input)); + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries") + .Attr("num_features: int >= 0") + .Input("quantile_stream_resource_handle: resource") + .Output("bucket_boundaries: num_features * float") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + shape_inference::ShapeHandle unused_input; + // resource handle must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input)); + for (int i = 0; i < num_features; i++) { + c->set_output(i, c->Vector(c->UnknownDim())); + } + return Status::OK(); + }); + +REGISTER_OP("BoostedTreesBucketize") + .Attr("num_features: int >= 0") + .Input("float_values: num_features * float") + .Input("bucket_boundaries: num_features * float") + .Output("buckets: num_features * int32") + .SetShapeFn([](InferenceContext* c) { + int num_features; + TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features)); + ShapeHandle feature_shape; + DimensionHandle unused_dim; + for (int i = 0; i < num_features; i++) { + TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape)); + TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0), + c->Dim(c->input(0), 0), &unused_dim)); + } + // Bucketized result should have same dimension as input. + for (int i = 0; i < num_features; i++) { + c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1})); + } + return Status::OK(); + }); + } // namespace tensorflow diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 9e67662fa6..e59958749c 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -11360,6 +11360,29 @@ op { is_commutative: true } op { + name: "BoostedTreesBucketize" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + output_arg { + name: "buckets" + type: DT_INT32 + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesCalculateBestGainsPerFeature" input_arg { name: "node_id_range" @@ -11469,6 +11492,29 @@ op { is_stateful: true } op { + name: "BoostedTreesCreateQuantileStreamResource" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + input_arg { + name: "num_streams" + type: DT_INT64 + } + attr { + name: "max_elements" + type: "int" + default_value { + i: 1099511627776 + } + } + is_stateful: true +} +op { name: "BoostedTreesDeserializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -11562,6 +11608,32 @@ op { is_stateful: true } op { + name: "BoostedTreesMakeQuantileSummaries" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "example_weights" + type: DT_FLOAT + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + output_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesMakeStatsSummary" input_arg { name: "node_ids" @@ -11631,6 +11703,83 @@ op { is_stateful: true } op { + name: "BoostedTreesQuantileStreamResourceAddSummaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceFlush" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "num_buckets" + type: DT_INT64 + } + attr { + name: "generate_quantiles" + type: "bool" + default_value { + b: false + } + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceHandleOp" + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { name: "BoostedTreesSerializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -13070,6 +13219,71 @@ op { is_stateful: true } op { + name: "ConditionalAccumulator" + output_arg { + name: "handle" + type: DT_STRING + is_ref: true + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } + is_stateful: true +} +op { name: "Conj" input_arg { name: "input" @@ -27127,6 +27341,18 @@ op { is_stateful: true } op { + name: "IsBoostedTreesQuantileStreamResourceInitialized" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "is_initialized" + type: DT_BOOL + } + is_stateful: true +} +op { name: "IsFinite" input_arg { name: "x" @@ -29381,6 +29607,49 @@ op { } } op { + name: "MapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "MapDefun" input_arg { name: "arguments" @@ -34842,6 +35111,29 @@ op { } } op { + name: "ModelDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "Mul" input_arg { name: "x" @@ -35682,6 +35974,42 @@ op { } } op { + name: "NonMaxSuppressionV2" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } +} +op { name: "NonMaxSuppressionV3" input_arg { name: "boxes" @@ -35709,6 +36037,46 @@ op { } } op { + name: "NonMaxSuppressionV3" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } +} +op { name: "NonMaxSuppressionV4" input_arg { name: "boxes" @@ -35747,6 +36115,57 @@ op { } } op { + name: "NonMaxSuppressionV4" + input_arg { + name: "boxes" + type_attr: "T" + } + input_arg { + name: "scores" + type_attr: "T" + } + input_arg { + name: "max_output_size" + type: DT_INT32 + } + input_arg { + name: "iou_threshold" + type: DT_FLOAT + } + input_arg { + name: "score_threshold" + type: DT_FLOAT + } + output_arg { + name: "selected_indices" + type: DT_INT32 + } + output_arg { + name: "valid_outputs" + type: DT_INT32 + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { + name: "pad_to_max_output_size" + type: "bool" + default_value { + b: false + } + } +} +op { name: "NonMaxSuppressionWithOverlaps" input_arg { name: "overlaps" @@ -37037,6 +37456,54 @@ op { } } op { + name: "ParallelInterleaveDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -37118,6 +37585,53 @@ op { } } op { + name: "ParallelMapDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "num_parallel_calls" + type: DT_INT32 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "ParameterizedTruncatedNormal" input_arg { name: "shape" @@ -56665,6 +57179,125 @@ op { } } op { + name: "SdcaOptimizer" + input_arg { + name: "sparse_example_indices" + type: DT_INT64 + number_attr: "num_sparse_features" + } + input_arg { + name: "sparse_feature_indices" + type: DT_INT64 + number_attr: "num_sparse_features" + } + input_arg { + name: "sparse_feature_values" + type: DT_FLOAT + number_attr: "num_sparse_features_with_values" + } + input_arg { + name: "dense_features" + type: DT_FLOAT + number_attr: "num_dense_features" + } + input_arg { + name: "example_weights" + type: DT_FLOAT + } + input_arg { + name: "example_labels" + type: DT_FLOAT + } + input_arg { + name: "sparse_indices" + type: DT_INT64 + number_attr: "num_sparse_features" + } + input_arg { + name: "sparse_weights" + type: DT_FLOAT + number_attr: "num_sparse_features" + } + input_arg { + name: "dense_weights" + type: DT_FLOAT + number_attr: "num_dense_features" + } + input_arg { + name: "example_state_data" + type: DT_FLOAT + } + output_arg { + name: "out_example_state_data" + type: DT_FLOAT + } + output_arg { + name: "out_delta_sparse_weights" + type: DT_FLOAT + number_attr: "num_sparse_features" + } + output_arg { + name: "out_delta_dense_weights" + type: DT_FLOAT + number_attr: "num_dense_features" + } + attr { + name: "loss_type" + type: "string" + allowed_values { + list { + s: "logistic_loss" + s: "squared_loss" + s: "hinge_loss" + s: "smooth_hinge_loss" + s: "poisson_loss" + } + } + } + attr { + name: "adaptative" + type: "bool" + default_value { + b: false + } + } + attr { + name: "num_sparse_features" + type: "int" + has_minimum: true + } + attr { + name: "num_sparse_features_with_values" + type: "int" + has_minimum: true + } + attr { + name: "num_dense_features" + type: "int" + has_minimum: true + } + attr { + name: "l1" + type: "float" + } + attr { + name: "l2" + type: "float" + } + attr { + name: "num_loss_partitions" + type: "int" + has_minimum: true + minimum: 1 + } + attr { + name: "num_inner_iterations" + type: "int" + has_minimum: true + minimum: 1 + } +} +op { name: "SdcaShrinkL1" input_arg { name: "weights" @@ -64381,6 +65014,71 @@ op { is_stateful: true } op { + name: "SparseConditionalAccumulator" + output_arg { + name: "handle" + type: DT_STRING + is_ref: true + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "shape" + type: "shape" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } + is_stateful: true +} +op { name: "SparseCross" input_arg { name: "indices" @@ -69174,6 +69872,21 @@ op { } } op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } +} +op { name: "StaticRegexReplace" input_arg { name: "input" @@ -74889,9 +75602,21 @@ op { type: DT_VARIANT } input_arg { - name: "window_size" + name: "size" type: DT_INT64 } + input_arg { + name: "shift" + type: DT_INT64 + } + input_arg { + name: "stride" + type: DT_INT64 + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } output_arg { name: "handle" type: DT_VARIANT diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc index eed0bce174..ffab8ad661 100644 --- a/tensorflow/core/ops/data_flow_ops.cc +++ b/tensorflow/core/ops/data_flow_ops.cc @@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); @@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator") .Attr("shape: shape") .Attr("container: string = ''") .Attr("shared_name: string = ''") + .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { c->set_output(0, c->Vector(2)); diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index f03639e833..4d3f272c1b 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -198,6 +198,7 @@ REGISTER_OP("MapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("ParallelMapDataset") @@ -209,6 +210,7 @@ REGISTER_OP("ParallelMapDataset") .Attr("Targuments: list(type) >= 0") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("MapAndBatchDataset") @@ -325,6 +327,19 @@ REGISTER_OP("ParallelInterleaveDataset") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ParallelInterleaveDatasetV2") + .Input("input_dataset: variant") + .Input("other_arguments: Targuments") + .Input("cycle_length: int64") + .Input("block_length: int64") + .Input("num_parallel_calls: int64") + .Output("handle: variant") + .Attr("f: func") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("GroupByReducerDataset") .Input("input_dataset: variant") .Input("key_func_other_arguments: Tkey_func_other_arguments") @@ -381,14 +396,20 @@ REGISTER_OP("FilterByLastComponentDataset") REGISTER_OP("WindowDataset") .Input("input_dataset: variant") - .Input("window_size: int64") + .Input("size: int64") + .Input("shift: int64") + .Input("stride: int64") + .Input("drop_remainder: bool") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn([](shape_inference::InferenceContext* c) { shape_inference::ShapeHandle unused; - // batch_size should be a scalar. + // size, shift, stride, and drop_remainder should be scalars. TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); return shape_inference::ScalarShape(c); }); @@ -858,6 +879,13 @@ REGISTER_OP("IteratorGetNextAsOptional") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(shape_inference::ScalarShape); +REGISTER_OP("ModelDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + REGISTER_OP("MapDefun") .Input("arguments: Targuments") .Output("output: output_types") @@ -866,7 +894,7 @@ REGISTER_OP("MapDefun") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { - std::vector<TensorShape> output_shapes; + std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( @@ -876,6 +904,10 @@ REGISTER_OP("MapDefun") int64 dim_zero = -1; for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + if (c->Rank(c->input(i)) == 0) { + return errors::InvalidArgument( + "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { if (dim_zero == -1) { diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc index 11ca0bd259..5427275284 100644 --- a/tensorflow/core/ops/image_ops.cc +++ b/tensorflow/core/ops/image_ops.cc @@ -683,11 +683,12 @@ REGISTER_OP("NonMaxSuppression") }); REGISTER_OP("NonMaxSuppressionV2") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn([](InferenceContext* c) { // Get inputs and validate ranks. ShapeHandle boxes; @@ -711,22 +712,24 @@ REGISTER_OP("NonMaxSuppressionV2") }); REGISTER_OP("NonMaxSuppressionV3") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") + .Attr("T: {half, float} = DT_FLOAT") .SetShapeFn(NMSShapeFn); REGISTER_OP("NonMaxSuppressionV4") - .Input("boxes: float") - .Input("scores: float") + .Input("boxes: T") + .Input("scores: T") .Input("max_output_size: int32") .Input("iou_threshold: float") .Input("score_threshold: float") .Output("selected_indices: int32") .Output("valid_outputs: int32") + .Attr("T: {half, float} = DT_FLOAT") .Attr("pad_to_max_output_size: bool = false") .SetShapeFn([](InferenceContext* c) { TF_RETURN_IF_ERROR(NMSShapeFn(c)); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index c0376b5721..4ece1c8953 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -4272,6 +4272,29 @@ op { is_commutative: true } op { + name: "BoostedTreesBucketize" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + output_arg { + name: "buckets" + type: DT_INT32 + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesCalculateBestGainsPerFeature" input_arg { name: "node_id_range" @@ -4381,6 +4404,29 @@ op { is_stateful: true } op { + name: "BoostedTreesCreateQuantileStreamResource" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + input_arg { + name: "num_streams" + type: DT_INT64 + } + attr { + name: "max_elements" + type: "int" + default_value { + i: 1099511627776 + } + } + is_stateful: true +} +op { name: "BoostedTreesDeserializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -4474,6 +4520,32 @@ op { is_stateful: true } op { + name: "BoostedTreesMakeQuantileSummaries" + input_arg { + name: "float_values" + type: DT_FLOAT + number_attr: "num_features" + } + input_arg { + name: "example_weights" + type: DT_FLOAT + } + input_arg { + name: "epsilon" + type: DT_FLOAT + } + output_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } +} +op { name: "BoostedTreesMakeStatsSummary" input_arg { name: "node_ids" @@ -4543,6 +4615,83 @@ op { is_stateful: true } op { + name: "BoostedTreesQuantileStreamResourceAddSummaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "summaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceFlush" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + input_arg { + name: "num_buckets" + type: DT_INT64 + } + attr { + name: "generate_quantiles" + type: "bool" + default_value { + b: false + } + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "bucket_boundaries" + type: DT_FLOAT + number_attr: "num_features" + } + attr { + name: "num_features" + type: "int" + has_minimum: true + } + is_stateful: true +} +op { + name: "BoostedTreesQuantileStreamResourceHandleOp" + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { name: "BoostedTreesSerializeEnsemble" input_arg { name: "tree_ensemble_handle" @@ -5592,6 +5741,19 @@ op { s: "" } } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } is_stateful: true } op { @@ -13149,6 +13311,18 @@ op { is_stateful: true } op { + name: "IsBoostedTreesQuantileStreamResourceInitialized" + input_arg { + name: "quantile_stream_resource_handle" + type: DT_RESOURCE + } + output_arg { + name: "is_initialized" + type: DT_BOOL + } + is_stateful: true +} +op { name: "IsFinite" input_arg { name: "x" @@ -14542,6 +14716,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "MapDefun" @@ -16540,6 +16721,29 @@ op { } } op { + name: "ModelDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "Mul" input_arg { name: "x" @@ -17078,11 +17282,11 @@ op { name: "NonMaxSuppressionV2" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17096,16 +17300,29 @@ op { name: "selected_indices" type: DT_INT32 } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } } op { name: "NonMaxSuppressionV3" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17123,16 +17340,29 @@ op { name: "selected_indices" type: DT_INT32 } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } } op { name: "NonMaxSuppressionV4" input_arg { name: "boxes" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "scores" - type: DT_FLOAT + type_attr: "T" } input_arg { name: "max_output_size" @@ -17155,6 +17385,19 @@ op { type: DT_INT32 } attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { name: "pad_to_max_output_size" type: "bool" default_value { @@ -18192,6 +18435,54 @@ op { } } op { + name: "ParallelInterleaveDatasetV2" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "cycle_length" + type: DT_INT64 + } + input_arg { + name: "block_length" + type: DT_INT64 + } + input_arg { + name: "num_parallel_calls" + type: DT_INT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "f" + type: "func" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "ParallelMapDataset" input_arg { name: "input_dataset" @@ -18230,6 +18521,13 @@ op { has_minimum: true minimum: 1 } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } } op { name: "ParameterizedTruncatedNormal" @@ -26977,6 +27275,7 @@ op { s: "squared_loss" s: "hinge_loss" s: "smooth_hinge_loss" + s: "poisson_loss" } } } @@ -29609,6 +29908,19 @@ op { s: "" } } + attr { + name: "reduction_type" + type: "string" + default_value { + s: "MEAN" + } + allowed_values { + list { + s: "MEAN" + s: "SUM" + } + } + } is_stateful: true } op { @@ -32107,6 +32419,21 @@ op { } } op { + name: "StaticRegexFullMatch" + input_arg { + name: "input" + type: DT_STRING + } + output_arg { + name: "output" + type: DT_BOOL + } + attr { + name: "pattern" + type: "string" + } +} +op { name: "StaticRegexReplace" input_arg { name: "input" @@ -35872,9 +36199,21 @@ op { type: DT_VARIANT } input_arg { - name: "window_size" + name: "size" type: DT_INT64 } + input_arg { + name: "shift" + type: DT_INT64 + } + input_arg { + name: "stride" + type: DT_INT64 + } + input_arg { + name: "drop_remainder" + type: DT_BOOL + } output_arg { name: "handle" type: DT_VARIANT diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc index 79ca96d249..eff453241d 100644 --- a/tensorflow/core/ops/parsing_ops.cc +++ b/tensorflow/core/ops/parsing_ops.cc @@ -343,10 +343,11 @@ REGISTER_OP("DecodeCSV") // Validate the record_defaults inputs. for (int i = 1; i < c->num_inputs(); ++i) { ShapeHandle v; - TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v)); - if (c->Value(c->Dim(v, 0)) > 1) { + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { return errors::InvalidArgument( - "Shape of a default must be a length-0 or length-1 vector"); + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); } } diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc index c65e66d1a8..ba594e400c 100644 --- a/tensorflow/core/ops/parsing_ops_test.cc +++ b/tensorflow/core/ops/parsing_ops_test.cc @@ -52,9 +52,12 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) { INFER_OK(op, "[1,2,?,4];?;?", "in0;in0"); INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0"); + // Scalar defaults are ok + INFER_OK(op, "?;?;[]", "in0;in0"); + // Check errors in the record_defaults inputs. - INFER_ERROR("must be rank 1", op, "?;?;[]"); - INFER_ERROR("must be rank 1", op, "?;[];?"); + INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]"); + INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?"); INFER_ERROR("Shape of a default must be", op, "?;?;[2]"); INFER_ERROR("Shape of a default must be", op, "?;[2];?"); } diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc index 4025070adb..fdf53a55dd 100644 --- a/tensorflow/core/ops/sdca_ops.cc +++ b/tensorflow/core/ops/sdca_ops.cc @@ -41,7 +41,7 @@ static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) { REGISTER_OP("SdcaOptimizer") .Attr( "loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss'," - "'smooth_hinge_loss'}") + "'smooth_hinge_loss', 'poisson_loss'}") .Attr("adaptative : bool=false") .Attr("num_sparse_features: int >= 0") .Attr("num_sparse_features_with_values: int >= 0") diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index 7aa1e71809..ef8b15dc8a 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -56,6 +56,12 @@ REGISTER_OP("RegexFullMatch") return Status::OK(); }); +REGISTER_OP("StaticRegexFullMatch") + .Input("input: string") + .Attr("pattern: string") + .Output("output: bool") + .SetShapeFn(shape_inference::UnchangedShape); + REGISTER_OP("StringToHashBucketFast") .Input("input: string") .Output("output: int64") diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc index e597a490d6..d7a13a3528 100644 --- a/tensorflow/core/platform/abi.cc +++ b/tensorflow/core/platform/abi.cc @@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name, namespace tensorflow { namespace port { -std::string MaybeAbiDemangle(const char* name) { +string MaybeAbiDemangle(const char* name) { #if defined(_MSC_VER) std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc, std::free, static_cast<unsigned short>(0))}; - return std::string(demangled.get() != nullptr ? demangled.get() : name); + return string(demangled.get() != nullptr ? demangled.get() : name); #else int status = 0; std::unique_ptr<char, void (*)(void*)> res{ diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h index 591e83b0c4..d1498a6a64 100644 --- a/tensorflow/core/platform/abi.h +++ b/tensorflow/core/platform/abi.h @@ -17,11 +17,12 @@ limitations under the License. #define TENSORFLOW_CORE_PLATFORM_ABI_H_ #include <string> +#include "tensorflow/core/platform/types.h" namespace tensorflow { namespace port { -std::string MaybeAbiDemangle(const char* name); +string MaybeAbiDemangle(const char* name); } // namespace port } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc index a1be4aacce..5e1eabee5b 100644 --- a/tensorflow/core/platform/cloud/curl_http_request.cc +++ b/tensorflow/core/platform/cloud/curl_http_request.cc @@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size, .StopCapture() .OneLiteral(": ") .GetResult(&value, &name)) { - string str_value = std::string(value); + string str_value(value); str_util::StripTrailingWhitespace(&str_value); - that->response_headers_[std::string(name)] = str_value; + that->response_headers_[string(name)] = str_value; } return size * nmemb; } diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 9d33787bd5..83228fab6f 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -179,13 +179,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket, return errors::InvalidArgument("GCS path doesn't start with 'gs://': ", fname); } - *bucket = std::string(bucketp); + *bucket = string(bucketp); if (bucket->empty() || *bucket == ".") { return errors::InvalidArgument("GCS path doesn't contain a bucket name: ", fname); } str_util::ConsumePrefix(&objectp, "/"); - *object = std::string(objectp); + *object = string(objectp); if (!empty_object_ok && object->empty()) { return errors::InvalidArgument("GCS path doesn't contain an object name: ", fname); @@ -224,7 +224,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) { for (const string& path : paths) { StringPiece subpath = io::Dirname(path); while (!subpath.empty()) { - result.emplace(std::string(subpath)); + result.emplace(string(subpath)); subpath = io::Dirname(subpath); } } @@ -371,7 +371,7 @@ class GcsWritableFile : public WritableFile { ~GcsWritableFile() override { Close().IgnoreError(); } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { TF_RETURN_IF_ERROR(CheckWritable()); sync_needed_ = true; outfile_ << data; @@ -723,7 +723,7 @@ GcsFileSystem::GcsFileSystem() { if (!header_name.empty() && !header_value.empty()) { additional_header_.reset(new std::pair<const string, const string>( - std::string(header_name), std::string(header_value))); + string(header_name), string(header_value))); VLOG(1) << "GCS additional header ENABLED. " << "Name: " << additional_header_->first << ", " @@ -1229,7 +1229,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern, // Find the fixed prefix by looking for the first wildcard. const string& fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\")); - const string& dir = std::string(io::Dirname(fixed_prefix)); + const string dir(io::Dirname(fixed_prefix)); if (dir.empty()) { return errors::InvalidArgument( "A GCS pattern doesn't have a bucket name: ", pattern); @@ -1326,7 +1326,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, " doesn't match the prefix ", object_prefix)); } if (!relative_path.empty() || include_self_directory_marker) { - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); } if (++retrieved_results >= max_results) { return Status::OK(); @@ -1354,7 +1354,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname, "Unexpected response: the returned folder name ", prefix_str, " doesn't match the prefix ", object_prefix); } - result->emplace_back(std::string(relative_path)); + result->emplace_back(relative_path); if (++retrieved_results >= max_results) { return Status::OK(); } diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc index ee6ba7b041..9b85cae9b9 100644 --- a/tensorflow/core/platform/cloud/oauth_client.cc +++ b/tensorflow/core/platform/cloud/oauth_client.cc @@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson( // Send the request to the Google OAuth 2.0 server to get the token. std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); std::vector<char> response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); @@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson( std::unique_ptr<HttpRequest> request(http_request_factory_->Create()); std::vector<char> response_buffer; - request->SetUri(std::string(oauth_server_uri)); + request->SetUri(string(oauth_server_uri)); request->SetPostFromBuffer(request_body.c_str(), request_body.size()); request->SetResultBuffer(&response_buffer); TF_RETURN_IF_ERROR(request->Send()); diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc index 4ffa72288b..1cd0641cd3 100644 --- a/tensorflow/core/platform/cloud/oauth_client_test.cc +++ b/tensorflow/core/platform/cloud/oauth_client_test.cc @@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) { EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer", grant_type); - int last_dot = std::string(assertion).find_last_of("."); - string header_dot_claim = std::string(assertion.substr(0, last_dot)); - string signature_encoded = std::string(assertion.substr(last_dot + 1)); + int last_dot = assertion.rfind('.'); + string header_dot_claim(assertion.substr(0, last_dot)); + string signature_encoded(assertion.substr(last_dot + 1)); // Check that 'signature' signs 'header_dot_claim'. diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index 92aa72be89..941ab7ad65 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -177,7 +177,7 @@ class RetryingWritableFile : public WritableFile { Close().IgnoreError(); } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { return RetryingUtils::CallWithRetries( [this, &data]() { return base_file_->Append(data); }, initial_delay_microseconds_); diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index ec2c470db7..5910fef1d2 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -72,7 +72,7 @@ class MockRandomAccessFile : public RandomAccessFile { class MockWritableFile : public WritableFile { public: explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {} - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { return calls_.ConsumeNextCall("Append"); } Status Close() override { return calls_.ConsumeNextCall("Close"); } diff --git a/tensorflow/core/platform/cord.h b/tensorflow/core/platform/cord.h new file mode 100644 index 0000000000..7c5c6655be --- /dev/null +++ b/tensorflow/core/platform/cord.h @@ -0,0 +1,26 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_ +#define TENSORFLOW_CORE_PLATFORM_CORD_H_ + +// Include appropriate platform-dependent implementations +#if defined(PLATFORM_GOOGLE) +#include "tensorflow/core/platform/google/cord.h" +#else +#include "tensorflow/core/platform/default/cord.h" +#endif + +#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_ diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index 0411a8c4f9..bb841aeab7 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -625,7 +625,9 @@ def tf_additional_lib_deps(): """Additional dependencies needed to build TF libraries.""" return [ "@com_google_absl//absl/base:base", + "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/types:span", + "@com_google_absl//absl/types:optional", ] + if_static( ["@nsync//:nsync_cpp"], ["@nsync//:nsync_headers"], diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl index 3a012c23fd..37475feebe 100644 --- a/tensorflow/core/platform/default/build_config_root.bzl +++ b/tensorflow/core/platform/default/build_config_root.bzl @@ -3,64 +3,64 @@ # be separate to avoid cyclic references. def tf_cuda_tests_tags(): - return ["requires-gpu"] + return ["requires-gpu", "local", "gpu"] def tf_sycl_tests_tags(): - return ["requires-gpu"] + return ["requires-gpu", "local", "gpu"] def tf_additional_plugin_deps(): - return select({ - str(Label("//tensorflow:with_xla_support")): [ - str(Label("//tensorflow/compiler/jit")) - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_xla_support")): [ + str(Label("//tensorflow/compiler/jit")), + ], + "//conditions:default": [], + }) def tf_additional_xla_deps_py(): - return [] + return [] def tf_additional_grpc_deps_py(): - return [] + return [] def tf_additional_license_deps(): - return select({ - str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"], + "//conditions:default": [], + }) def tf_additional_verbs_deps(): - return select({ - str(Label("//tensorflow:with_verbs_support")): [ - str(Label("//tensorflow/contrib/verbs:verbs_server_lib")), - str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")), - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_verbs_support")): [ + str(Label("//tensorflow/contrib/verbs:verbs_server_lib")), + str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")), + ], + "//conditions:default": [], + }) def tf_additional_mpi_deps(): - return select({ - str(Label("//tensorflow:with_mpi_support")): [ - str(Label("//tensorflow/contrib/mpi:mpi_server_lib")), - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_mpi_support")): [ + str(Label("//tensorflow/contrib/mpi:mpi_server_lib")), + ], + "//conditions:default": [], + }) def tf_additional_gdr_deps(): - return select({ - str(Label("//tensorflow:with_gdr_support")): [ - str(Label("//tensorflow/contrib/gdr:gdr_server_lib")), - ], - "//conditions:default": [], - }) + return select({ + str(Label("//tensorflow:with_gdr_support")): [ + str(Label("//tensorflow/contrib/gdr:gdr_server_lib")), + ], + "//conditions:default": [], + }) -def if_static(extra_deps, otherwise=[]): - return select({ - str(Label("//tensorflow:framework_shared_object")): otherwise, - "//conditions:default": extra_deps, - }) +def if_static(extra_deps, otherwise = []): + return select({ + str(Label("//tensorflow:framework_shared_object")): otherwise, + "//conditions:default": extra_deps, + }) -def if_dynamic_kernels(extra_deps, otherwise=[]): - return select({ - str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps, - "//conditions:default": otherwise, - }) +def if_dynamic_kernels(extra_deps, otherwise = []): + return select({ + str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps, + "//conditions:default": otherwise, + }) diff --git a/tensorflow/core/lib/gtl/optional.cc b/tensorflow/core/platform/default/cord.h index 8dea073788..1ab682182c 100644 --- a/tensorflow/core/lib/gtl/optional.cc +++ b/tensorflow/core/platform/default/cord.h @@ -1,4 +1,4 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/core/lib/gtl/optional.h" +#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_ +#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_ -namespace tensorflow { -namespace gtl { +class Cord; +namespace absl { +using ::Cord; +} // namespace absl -nullopt_t::init_t nullopt_t::init; -extern const nullopt_t nullopt{nullopt_t::init}; - -} // namespace gtl -} // namespace tensorflow +#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_ diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc index ccddf1eafc..0389149469 100644 --- a/tensorflow/core/platform/default/device_tracer.cc +++ b/tensorflow/core/platform/default/device_tracer.cc @@ -321,6 +321,11 @@ class DeviceTracerImpl : public DeviceTracer, return nullptr; } + bool IsEnabled(bool is_expensive) const override { + // We don't do anything with 'Activities' so we are never 'enabled'. + return false; + } + protected: // This callback is used exclusively by CUPTIManager. friend class CUPTIManager; diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc index 305a9a682f..2e32abdffb 100644 --- a/tensorflow/core/platform/env_test.cc +++ b/tensorflow/core/platform/env_test.cc @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/null_file_system.h" #include "tensorflow/core/platform/protobuf.h" #include "tensorflow/core/platform/test.h" @@ -345,7 +346,13 @@ TEST_F(DefaultEnvTest, LocalTempFilename) { // Write something to the temporary file. std::unique_ptr<WritableFile> file_to_write; TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write)); +#if defined(PLATFORM_GOOGLE) + TF_CHECK_OK(file_to_write->Append("Nu")); + TF_CHECK_OK(file_to_write->Append(absl::Cord("ll"))); +#else + // TODO(ebrevdo): Remove this version. TF_CHECK_OK(file_to_write->Append("Null")); +#endif TF_CHECK_OK(file_to_write->Close()); TF_CHECK_OK(env->FileExists(filename)); diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h index 077b1d79cf..30059dc02e 100644 --- a/tensorflow/core/platform/file_system.h +++ b/tensorflow/core/platform/file_system.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/cord.h" #include "tensorflow/core/platform/file_statistics.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/platform.h" @@ -252,7 +253,12 @@ class WritableFile { virtual ~WritableFile(); /// \brief Append 'data' to the file. - virtual Status Append(const StringPiece& data) = 0; + virtual Status Append(StringPiece data) = 0; + + // \brief Append 'data' to the file. + virtual Status Append(const absl::Cord& cord) { + return errors::Unimplemented("Append(absl::Cord) is not implemented"); + } /// \brief Close the file. /// diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc index 8cdb08f51b..eb35531e9f 100644 --- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc +++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc @@ -282,7 +282,7 @@ class HDFSWritableFile : public WritableFile { } } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { if (hdfs_->hdfsWrite(fs_, file_, data.data(), static_cast<tSize>(data.size())) == -1) { return IOError(filename_, errno); diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc index 47bfa020ce..c7afab9583 100644 --- a/tensorflow/core/platform/posix/posix_file_system.cc +++ b/tensorflow/core/platform/posix/posix_file_system.cc @@ -91,7 +91,7 @@ class PosixWritableFile : public WritableFile { } } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { size_t r = fwrite(data.data(), 1, data.size(), file_); if (r != data.size()) { return IOError(filename_, errno); diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc index ce0f6cd741..e0b8e37745 100644 --- a/tensorflow/core/platform/s3/s3_file_system.cc +++ b/tensorflow/core/platform/s3/s3_file_system.cc @@ -211,7 +211,7 @@ class S3WritableFile : public WritableFile { std::ios_base::binary | std::ios_base::trunc | std::ios_base::in | std::ios_base::out)) {} - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { if (!outfile_) { return errors::FailedPrecondition( "The internal temporary file is not writable."); diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h index e5851f1dfe..9974bbbb4e 100644 --- a/tensorflow/core/platform/tracing.h +++ b/tensorflow/core/platform/tracing.h @@ -155,6 +155,10 @@ class TraceCollector { StringPiece name_part1, StringPiece name_part2, bool is_expensive) const = 0; + // Returns true if this activity handle tracking is enabled for an op of the + // given expensiveness. + virtual bool IsEnabled(bool is_expensive) const = 0; + protected: static string ConcatenateNames(StringPiece first, StringPiece second); diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc index 9079a5ccaa..6cf79634d7 100644 --- a/tensorflow/core/platform/windows/windows_file_system.cc +++ b/tensorflow/core/platform/windows/windows_file_system.cc @@ -150,7 +150,7 @@ class WindowsWritableFile : public WritableFile { } } - Status Append(const StringPiece& data) override { + Status Append(StringPiece data) override { DWORD bytes_written = 0; DWORD data_size = static_cast<DWORD>(data.size()); BOOL write_result = diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index c68504a272..85cd02350a 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -390,9 +390,12 @@ message ConfigProto { message Experimental { // Task name for group resolution. string collective_group_leader = 1; - // Whether the client will format templated errors. For example, the string: - // "The node was defined on ^^node:Foo:${file}:${line}^^". - bool client_handles_error_formatting = 2; + + // We removed the flag client_handles_error_formatting. Marking the tag + // number as reserved. + // TODO(shikharagarwal): Should we just remove this tag so that it can be + // used in future for other purpose? + reserved 2; // Which executor to use, the default executor will be used // if it is an empty string or "DEFAULT" diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 4129c93af5..b043a69431 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -19,12 +19,12 @@ limitations under the License. // TensorFlow uses semantic versioning, see http://semver.org/. #define TF_MAJOR_VERSION 1 -#define TF_MINOR_VERSION 10 +#define TF_MINOR_VERSION 11 #define TF_PATCH_VERSION 0 // TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1", // "-beta", "-rc", "-rc.1") -#define TF_VERSION_SUFFIX "" +#define TF_VERSION_SUFFIX "-rc1" #define TF_STR_HELPER(x) #x #define TF_STR(x) TF_STR_HELPER(x) diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h index 973e315f09..24002e72a0 100644 --- a/tensorflow/core/util/ctc/ctc_beam_entry.h +++ b/tensorflow/core/util/ctc/ctc_beam_entry.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_ diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h index 1a622babe1..1e45a8abd3 100644 --- a/tensorflow/core/util/ctc/ctc_beam_scorer.h +++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange // Collection of scoring classes that can be extended and provided to the // CTCBeamSearchDecoder to incorporate additional scoring logic (such as a diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h index 5e2aeb7830..6fbb1ed0da 100644 --- a/tensorflow/core/util/ctc/ctc_beam_search.h +++ b/tensorflow/core/util/ctc/ctc_beam_search.h @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_ diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h index 3be36822e5..b55d7d77ac 100644 --- a/tensorflow/core/util/ctc/ctc_decoder.h +++ b/tensorflow/core/util/ctc/ctc_decoder.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_ diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h index 36be9e92ef..054412d388 100644 --- a/tensorflow/core/util/ctc/ctc_loss_util.h +++ b/tensorflow/core/util/ctc/ctc_loss_util.h @@ -1,4 +1,3 @@ -// LINT.IfChange /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// LINT.IfChange #ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ #define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_ diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index 6474319370..680211edff 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -17,6 +17,7 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_ #ifdef INTEL_MKL +#include <string> #include <memory> #include <unordered_map> #include <utility> @@ -56,6 +57,7 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/util/padding.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/env_var.h" #ifndef INTEL_MKL_ML_ONLY #include "mkldnn.hpp" @@ -102,6 +104,8 @@ typedef enum { Dim3d_I = 1 } MklDnnDims3D; +static const int kSmallBatchSize = 32; + #ifdef INTEL_MKL_ML_ONLY class MklShape { public: @@ -2000,7 +2004,9 @@ const mkldnn::memory::dims NONE_DIMS = {}; template <typename T> class MklPrimitiveFactory { public: - MklPrimitiveFactory() {} + MklPrimitiveFactory() { + } + ~MklPrimitiveFactory() {} MklPrimitive* GetOp(const string& key) { @@ -2023,6 +2029,22 @@ class MklPrimitiveFactory { map[key] = op; } + /// Function to decide whether HW has AVX512 or AVX2 + /// For those legacy device(w/o AVX512 and AVX2), + /// MKL-DNN GEMM will be used. + static inline bool IsLegacyPlatform() { + return (!port::TestCPUFeature(port::CPUFeature::AVX512F) + && !port::TestCPUFeature(port::CPUFeature::AVX2)); + } + + /// Fuction to check whether primitive memory optimization is enabled + static inline bool IsPrimitiveMemOptEnabled() { + bool is_primitive_mem_opt_enabled = true; + TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true, + &is_primitive_mem_opt_enabled)); + return is_primitive_mem_opt_enabled; + } + private: static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() { static thread_local std::unordered_map<string, MklPrimitive*> map_; @@ -2060,7 +2082,7 @@ class FactoryKeyCreator { const char delimiter = 'x'; const int kMaxKeyLength = 256; void Append(StringPiece s) { - key_.append(s.ToString()); + key_.append(string(s)); key_.append(1, delimiter); } }; @@ -2099,7 +2121,7 @@ class MklReorderPrimitive : public MklPrimitive { context_.dst_mem->set_data_handle(to->get_data_handle()); } - private: + private: struct ReorderContext { std::shared_ptr<mkldnn::memory> src_mem; std::shared_ptr<mkldnn::memory> dst_mem; @@ -2141,7 +2163,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> { return instance_; } - private: + private: MklReorderPrimitiveFactory() {} ~MklReorderPrimitiveFactory() {} @@ -2186,6 +2208,15 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) { return *reorder_prim->GetPrimitive(); } +// utility function to determine if it is conv 1x1 and stride != 1 +// for purpose of temporarily disabling primitive reuse +inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) { + if (filter_dims.size() != 4 || strides.size() != 2) return false; + + return ((filter_dims[2] == 1) && (filter_dims[3] == 1) && + ((strides[0] != 1) || (strides[1] != 1))); +} + #endif // INTEL_MKL_DNN } // namespace tensorflow diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc index 204b933051..546b0a833c 100644 --- a/tensorflow/core/util/sparse/group_iterator.cc +++ b/tensorflow/core/util/sparse/group_iterator.cc @@ -21,8 +21,8 @@ namespace sparse { void GroupIterable::IteratorStep::UpdateEndOfGroup() { ++next_loc_; - int64 N = iter_->ix_.dim_size(0); - auto ix_t = iter_->ix_.template matrix<int64>(); + const auto& ix_t = iter_->ix_matrix_; + const int64 N = ix_t.dimension(0); while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) { ++next_loc_; } @@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++( std::vector<int64> Group::group() const { std::vector<int64> g; - auto ix_t = iter_->ix_.template matrix<int64>(); + const auto& ix_t = iter_->ix_matrix_; for (const int d : iter_->group_dims_) { g.push_back(ix_t(loc_, d)); } @@ -62,8 +62,8 @@ std::vector<int64> Group::group() const { } TTypes<int64>::UnalignedConstMatrix Group::indices() const { - return TTypes<int64>::UnalignedConstMatrix( - &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_); + return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)), + next_loc_ - loc_, iter_->dims_); } } // namespace sparse diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h index 3fa8cb6116..14610c61d9 100644 --- a/tensorflow/core/util/sparse/group_iterator.h +++ b/tensorflow/core/util/sparse/group_iterator.h @@ -79,6 +79,7 @@ class GroupIterable { GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims) : ix_(ix), + ix_matrix_(ix_.matrix<int64>()), vals_(vals), dims_(dims), group_dims_(group_dims.begin(), group_dims.end()) {} @@ -127,7 +128,8 @@ class GroupIterable { private: friend class Group; - Tensor ix_; + const Tensor ix_; + const TTypes<int64>::ConstMatrix ix_matrix_; Tensor vals_; const int dims_; const gtl::InlinedVector<int64, 8> group_dims_; diff --git a/tensorflow/core/util/status_util.h b/tensorflow/core/util/status_util.h deleted file mode 100644 index ea92f61dce..0000000000 --- a/tensorflow/core/util/status_util.h +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ -#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ - -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/lib/strings/strcat.h" - -namespace tensorflow { - -// Creates a tag to be used in an exception error message. This can be parsed by -// the Python layer and replaced with information about the node. -// -// For example, error_format_tag(node, "${file}") returns -// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as -// e.g. "file/where/node/was/created.py". -inline string error_format_tag(const Node& node, const string& format) { - return strings::StrCat("^^node:", node.name(), ":", format, "^^"); -} - -} // namespace tensorflow - -#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_ diff --git a/tensorflow/core/util/status_util_test.cc b/tensorflow/core/util/status_util_test.cc deleted file mode 100644 index 1f06004db2..0000000000 --- a/tensorflow/core/util/status_util_test.cc +++ /dev/null @@ -1,36 +0,0 @@ -/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/core/util/status_util.h" - -#include "tensorflow/core/graph/graph_constructor.h" -#include "tensorflow/core/graph/node_builder.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -TEST(TestStatusUtil, ErrorFormatTagForNode) { - Graph graph(OpRegistry::Global()); - Node* node; - TF_CHECK_OK(NodeBuilder("Foo", "NoOp").Finalize(&graph, &node)); - EXPECT_EQ(error_format_tag(*node, "${line}"), "^^node:Foo:${line}^^"); - EXPECT_EQ(error_format_tag(*node, "${file}:${line}"), - "^^node:Foo:${file}:${line}^^"); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h index 6539d565e2..7b101971a8 100644 --- a/tensorflow/core/util/tensor_bundle/naming.h +++ b/tensorflow/core/util/tensor_bundle/naming.h @@ -35,6 +35,7 @@ limitations under the License. #define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_ #include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { |