diff options
Diffstat (limited to 'tensorflow/core')
299 files changed, 14171 insertions, 3334 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index bc0bfb793c..900a0e11c4 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -149,6 +149,7 @@ load( "tf_cuda_tests_tags", ) load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured") load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library") load( "//third_party/mkl:build_defs.bzl", @@ -238,7 +239,6 @@ tf_proto_library( srcs = [], cc_api_version = 2, default_header = True, - java_api_version = 2, js_api_version = 2, protodeps = [ ":protos_all_proto", @@ -271,6 +271,12 @@ proto_library( visibility = ["//visibility:public"], ) +java_proto_library( + name = "example_java_proto", + visibility = ["//visibility:public"], + deps = [":example_protos"], +) + closure_proto_library( name = "example_protos_closure", visibility = ["//visibility:public"], @@ -707,14 +713,11 @@ cc_library( cc_library( name = "feature_util", srcs = ["example/feature_util.cc"], - hdrs = [ - "example/feature_util.h", - "platform/types.h", - ], + hdrs = ["example/feature_util.h"], visibility = ["//visibility:public"], deps = [ ":core_stringpiece", - ":platform_protobuf", + ":lib_proto_parsing", ":protos_all_cc", ], ) @@ -1041,6 +1044,7 @@ tf_gen_op_libs( "dataset_ops", "decode_proto_ops", "encode_proto_ops", + "experimental_dataset_ops", "function_ops", "functional_ops", "image_ops", @@ -1057,7 +1061,6 @@ tf_gen_op_libs( "random_grad", "random_ops", "remote_fused_graph_ops", - "resource_variable_ops", "rpc_ops", "scoped_allocator_ops", "sdca_ops", @@ -1099,6 +1102,14 @@ tf_gen_op_libs( deps = ["//tensorflow/core/kernels:debug_ops"], ) +tf_gen_op_libs( + is_external = False, + op_lib_names = [ + "resource_variable_ops", + ], + deps = [":lib"], +) + # And one for all user ops cc_library( name = "user_ops_op_lib", @@ -1164,6 +1175,7 @@ cc_library( ":dataset_ops_op_lib", ":decode_proto_ops_op_lib", ":encode_proto_ops_op_lib", + ":experimental_dataset_ops_op_lib", ":function_ops_op_lib", ":functional_ops_op_lib", ":image_ops_op_lib", @@ -1230,6 +1242,7 @@ cc_library( srcs = [ "ops/math_grad.cc", "ops/random_grad.cc", + "ops/stateless_random_grad.cc", ], linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669 visibility = ["//visibility:public"], @@ -1363,6 +1376,7 @@ cc_library( "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", + "//tensorflow/core/kernels:mkl_slice_op", "//tensorflow/core/kernels:mkl_softmax_op", "//tensorflow/core/kernels:mkl_transpose_op", "//tensorflow/core/kernels:mkl_tfconv_op", @@ -2377,7 +2391,6 @@ tf_proto_library( srcs = ERROR_CODES_PROTO_SRCS, cc_api_version = 2, default_header = True, - java_api_version = 2, js_api_version = 2, provide_cc_alias = True, ) @@ -2398,7 +2411,6 @@ tf_proto_library( srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS, cc_api_version = 2, default_header = True, - java_api_version = 2, js_api_version = 2, protodeps = [ ":error_codes_proto", @@ -2478,6 +2490,8 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [ "framework/op_segment.h", "framework/rendezvous.h", # only needed for tests "framework/resource_var.h", + "framework/run_handler.h", + "framework/run_handler_util.h", "framework/tensor_reference.h", "framework/tracking_allocator.h", # only needed for tests "framework/unique_tensor_references.h", @@ -2554,6 +2568,7 @@ tf_cuda_library( "**/*test*", "**/*main.cc", "example/example_parser_configuration.*", + "example/feature_util.cc", "util/reporter.cc", "framework/fake_input.*", "framework/op_gen_lib.*", @@ -2583,6 +2598,7 @@ tf_cuda_library( ], }), deps = [ + ":feature_util", ":lib", ":lib_internal", ":protos_all_proto_text", @@ -2962,6 +2978,7 @@ tf_cuda_library( ":core_cpu_internal", ":device_tracer", ":framework", + ":framework_internal", ":graph", ":lib", ":lib_internal", @@ -2999,7 +3016,7 @@ tf_cuda_library( "platform/device_tracer.h", ], copts = tf_copts(), - cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(), + cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()), visibility = ["//visibility:private"], deps = [ ":core_cpu_internal", @@ -3734,7 +3751,7 @@ tf_cc_tests_gpu( tf_cc_tests_gpu( name = "hierarchical_tree_broadcaster_test", - size = "small", + size = "medium", srcs = [ "common_runtime/hierarchical_tree_broadcaster_test.cc", ], @@ -3821,6 +3838,7 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_pooling_ops", "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", + "//tensorflow/core/kernels:mkl_slice_op", "//tensorflow/core/kernels:mkl_softmax_op", "//tensorflow/core/kernels:mkl_tfconv_op", ]), @@ -4108,6 +4126,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "framework_run_handler_util_test", + size = "small", + srcs = ["framework/run_handler_util_test.cc"], + linkstatic = tf_kernel_tests_linkstatic(), + deps = [ + ":framework_internal", + ":lib", + ":test", + ":test_main", + ], +) + tf_cuda_cc_test( name = "common_runtime_direct_session_test", size = "small", diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt new file mode 100644 index 0000000000..fa8fc96bb2 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalAssertNextDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt new file mode 100644 index 0000000000..5fd88e7a0c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalCSVDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt new file mode 100644 index 0000000000..ac1f9719fe --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt @@ -0,0 +1,21 @@ +op { + graph_op_name: "ExperimentalDirectedInterleaveDataset" + in_arg { + name: "selector_input_dataset" + description: <<END +A dataset of scalar `DT_INT64` elements that determines which of the +`N` data inputs should produce the next output element. +END + } + in_arg { + name: "data_input_datasets" + description: <<END +`N` datasets with the same type that will be interleaved according to +the values of `selector_input_dataset`. +END + } + summary: <<END +A substitute for `InterleaveDataset` on a fixed list of `N` datasets. +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt new file mode 100644 index 0000000000..66511eff60 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt @@ -0,0 +1,58 @@ +op { + graph_op_name: "ExperimentalFunctionBufferingResource" + in_arg { + name: "string_arg" + description: <<END +String argument to the function call. +END + } + in_arg { + name: "target_device" + description: <<END +Target device to execute the function on. +END + } + out_arg { + name: "resource" + description: <<END +Handle to the resource created. +END + } + attr { + name: "shared_name" + description: <<END +If non-empty, this resource will be shared under the given name across +multiple sessions. +END + } + attr { + name: "container" + description: <<END +If non-empty, this resource is placed in the given container. +Otherwise, a default container is used. +END + } + attr { + name: "f" + description: <<END +Function to be executed. +END + } + attr { + name: "buffer_size" + description: <<END +Size of the buffer. +END + } + attr { + name: "output_types" + description: <<END +The type list for the return values. +END + } + summary: <<END +Creates a resource that fills up a buffer by making function calls. +END + visibility: HIDDEN +} + diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt new file mode 100644 index 0000000000..bf4b66b22b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt @@ -0,0 +1,25 @@ +op { + graph_op_name: "ExperimentalFunctionBufferingResourceGetNext" + in_arg { + name: "function_buffer_resource" + description: <<END +The FunctionBufferingResource handle. +END + } + out_arg { + name: "output" + description: <<END +A list of return values. +END + } + attr { + name: "output_types" + description: <<END +The type list for the return values. +END + } + summary: <<END +Gets the next element from a FunctionBufferingResource. +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt new file mode 100644 index 0000000000..729718ddb3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "ExperimentalFunctionBufferingResourceReset" + in_arg { + name: "function_buffer_resource" + description: <<END +The FunctionBufferingResource handle. +END + } + summary: <<END +Resets the FunctionBufferingResource. +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt new file mode 100644 index 0000000000..fe266c111f --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalIdentityIndexedDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt new file mode 100644 index 0000000000..d42546516d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "ExperimentalIgnoreErrorsDataset" + summary: <<END +Creates a dataset that contains the elements of `input_dataset` ignoring errors. +END + visibility: HIDDEN +} + diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt new file mode 100644 index 0000000000..e285f87e10 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalIndexedDatasetGet" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt new file mode 100644 index 0000000000..60c32473b5 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalIndexedDatasetMaterialize" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt new file mode 100644 index 0000000000..b72b229e9a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "ExperimentalIteratorGetDevice" + summary: <<END +Returns the name of the device on which `resource` has been placed. +END + visibility: HIDDEN +} + diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt new file mode 100644 index 0000000000..b38b23a51d --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalLMDBDataset" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt new file mode 100644 index 0000000000..9676b9d284 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "ExperimentalMaterializedIndexDatasetHandle" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt new file mode 100644 index 0000000000..d73b5bfda3 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt @@ -0,0 +1,13 @@ +op { + graph_op_name: "ExperimentalThreadPoolDataset" + in_arg { + name: "thread_pool" + description: <<END +A resource produced by the ThreadPoolHandle op. +END + } + summary: <<END +Creates a dataset that uses a custom thread pool to compute `input_dataset`. +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt new file mode 100644 index 0000000000..48bf93406c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt @@ -0,0 +1,35 @@ +op { + graph_op_name: "ExperimentalThreadPoolHandle" + out_arg { + name: "handle" + description: <<END +A resource that can be consumed by one or more ExperimentalThreadPoolDataset +ops. +END + } + attr { + name: "num_threads" + description: <<END +The number of threads in the thread pool. +END + } + attr { + name: "max_intra_op_parallelism" + description: <<END +The maximum degree of parallelism to use within operations that execute on this +threadpool. +END + } + attr { + name: "display_name" + description: <<END +A human-readable name for the threads that may be visible in some +visualizations. +threadpool. +END + } + summary: <<END +Creates a dataset that uses a custom thread pool to compute `input_dataset`. +END + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt new file mode 100644 index 0000000000..68ed797a0c --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt @@ -0,0 +1,8 @@ +op { + graph_op_name: "ExperimentalUniqueDataset" + summary: <<END +Creates a dataset that contains the unique elements of `input_dataset`. +END + visibility: HIDDEN +} + diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt index 40d7d371ca..7142a0e3f2 100644 --- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt @@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as: where -\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\) +\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\) is the lower incomplete Gamma function. diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt index 4433693759..d158f4b502 100644 --- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt @@ -4,16 +4,23 @@ op { in_arg { name: "arguments" description: <<END - A list of tensors whose types are Targuments, corresponding to the inputs the - function should be mapped over. + A list of tensors whose types are `Targuments`, corresponding to the inputs + the function should be mapped over. +END + } + in_arg { + name: "captured_inputs" + description: <<END + A list of tensors whose types are `Tcaptured`, corresponding to the captured + inputs of the defun. END } out_arg { name: "output" description: <<END - A list of output tensors whose types are output_types and whose dimensions 0 - are the same as the dimensions 0 of the tensors in arguments, and whose - remaining dimensions correspond to those in output_shapes. + A list of output tensors whose types are `output_types` and whose dimensions + 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose + remaining dimensions correspond to those in `output_shapes`. END } attr { @@ -21,6 +28,10 @@ END description: "A list of types." } attr { + name: "Tcaptured" + description: "A list of types." + } + attr { name: "output_types" description: "A list of types." } @@ -29,6 +40,6 @@ END description: "A list of shapes." } summary: <<END - Maps a function on the list of tensors unpacked from inputs on dimension 0. + Maps a function on the list of tensors unpacked from arguments on dimension 0. END } diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt new file mode 100644 index 0000000000..08414b3e68 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt @@ -0,0 +1,26 @@ +op { + visibility: HIDDEN + graph_op_name: "ReduceDataset" + in_arg { + name: "input_dataset" + description: <<END +A variant tensor representing the input dataset. +END + } + in_arg { + name: "initial_state" + description: <<END +A nested structure of tensors, representing the initial state of the +transformation. +END + } + attr { + name: "f" + description: <<END +A function that maps `(old_state, input_element)` to `new_state`. It must take +two arguments and return a nested structures of tensors. The structure of +`new_state` must match the structure of `initial_state`. +END + } + summary: "Reduces the input dataset to a singleton using a reduce function." +} diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt new file mode 100644 index 0000000000..b6a6dbdf54 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt @@ -0,0 +1,46 @@ +op { + graph_op_name: "StatelessRandomUniformInt" + visibility: HIDDEN + in_arg { + name: "shape" + description: <<END +The shape of the output tensor. +END + } + in_arg { + name: "seed" + description: <<END +2 seeds (shape [2]). +END + } + in_arg { + name: "minval" + description: <<END +Minimum value (inclusive, scalar). +END + } + in_arg { + name: "maxval" + description: <<END +Maximum value (exclusive, scalar). +END + } + out_arg { + name: "output" + description: <<END +Random values with specified shape. +END + } + attr { + name: "dtype" + description: <<END +The type of the output. +END + } + summary: "Outputs deterministic pseudorandom random integers from a uniform distribution." + description: <<END +The generated values follow a uniform distribution in the range `[minval, maxval)`. + +The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt index 5246090ab3..fe0fcc9508 100644 --- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt @@ -18,6 +18,16 @@ END Scalar defining the number of characters to include in each substring END } + attr { + name: "unit" + description: <<END +The unit that is used to create the substring. One of: `"BYTE"` (for +defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8 +encoded Unicode code points). The default is `"BYTE"`. Results are undefined if +`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid +UTF-8. +END + } out_arg { name: "output" description: <<END diff --git a/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt new file mode 100644 index 0000000000..7898fe8d6b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt @@ -0,0 +1,28 @@ +op { + graph_op_name: "UnicodeScript" + endpoint { + name: "UnicodeScript" + } + in_arg { + name: "input" + description: <<END +A Tensor of int32 Unicode code points. +END + } + out_arg { + name: "output" + description: <<END +A Tensor of int32 script codes corresponding to each input code point. +END + } + summary: <<END +Determine the script codes of a given tensor of Unicode integer code points. +END + description: <<END +This operation converts Unicode code points to script codes corresponding to +each code point. Script codes correspond to International Components for +Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html. +Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will +match input shape. +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt new file mode 100644 index 0000000000..ca107abc6b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Xdivy" + summary: "Returns 0 if x == 0, and x / y otherwise, elementwise." +} diff --git a/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt new file mode 100644 index 0000000000..da625f7836 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "Xlogy" + summary: "Returns 0 if x == 0, and x * log(y) otherwise, elementwise." +} diff --git a/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt new file mode 100644 index 0000000000..3d937c745c --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListPushBackBatch" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt index 9552fc92e3..e395e333bf 100644 --- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "BatchToSpaceND" endpoint { - name: "manip.batch_to_space_nd" + name: "batch_to_space_nd" } endpoint { - name: "batch_to_space_nd" + name: "manip.batch_to_space_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt new file mode 100644 index 0000000000..44f25b5d93 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "EmptyTensorList" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt index 71257c8855..598f23bde3 100644 --- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "GatherNd" endpoint { - name: "manip.gather_nd" + name: "gather_nd" } endpoint { - name: "gather_nd" + name: "manip.gather_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt index b17806b338..5020844204 100644 --- a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "RegexReplace" - endpoint { - name: "strings.regex_replace" - } - endpoint { - name: "regex_replace" - deprecated: true - } + visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt index c469665b66..b3d596de7a 100644 --- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "Reshape" endpoint { - name: "manip.reshape" + name: "reshape" } endpoint { - name: "reshape" + name: "manip.reshape" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt index 77f595927b..51478b7c34 100644 --- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "ReverseV2" endpoint { - name: "manip.reverse" + name: "reverse" } endpoint { - name: "reverse" + name: "manip.reverse" deprecated: true } endpoint { diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt index a65a19b542..85888da45a 100644 --- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "ScatterNd" endpoint { - name: "manip.scatter_nd" + name: "scatter_nd" } endpoint { - name: "scatter_nd" + name: "manip.scatter_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt index af323a6cf3..146b97f444 100644 --- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "SpaceToBatchND" endpoint { - name: "manip.space_to_batch_nd" + name: "space_to_batch_nd" } endpoint { - name: "space_to_batch_nd" + name: "manip.space_to_batch_nd" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt new file mode 100644 index 0000000000..d3c70190dd --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StatelessMultinomial" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt new file mode 100644 index 0000000000..e294325fb8 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StatelessRandomNormal" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt new file mode 100644 index 0000000000..95d414c54a --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StatelessRandomUniform" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt new file mode 100644 index 0000000000..c72bdda94a --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "StatelessTruncatedNormal" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt index 4778d7927c..4fb9ee56e9 100644 --- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt @@ -1,10 +1,4 @@ op { graph_op_name: "Substr" - endpoint { - name: "strings.substr" - } - endpoint { - name: "substr" - deprecated: true - } + visibility: HIDDEN } diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt new file mode 100644 index 0000000000..45fc55e71e --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListConcatLists" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt new file mode 100644 index 0000000000..e1ad713e7f --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListElementShape" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt new file mode 100644 index 0000000000..4aaefba3c5 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListFromTensor" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt new file mode 100644 index 0000000000..aaf607d70e --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListGather" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt new file mode 100644 index 0000000000..3bb5f39cbc --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListGetItem" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt new file mode 100644 index 0000000000..a04c20bb8a --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListLength" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt new file mode 100644 index 0000000000..9287162f22 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListPopBack" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt new file mode 100644 index 0000000000..da2bc11721 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListPushBack" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt new file mode 100644 index 0000000000..77e63747d5 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListReserve" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt new file mode 100644 index 0000000000..0015189d7f --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListScatter" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt new file mode 100644 index 0000000000..4999ee7ad9 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListSetItem" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt new file mode 100644 index 0000000000..2dc7b2784b --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt @@ -0,0 +1,4 @@ +op { + graph_op_name: "TensorListStack" + visibility: HIDDEN +} diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt index c34061c941..1d8695f1fd 100644 --- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt +++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt @@ -1,10 +1,10 @@ op { graph_op_name: "Tile" endpoint { - name: "manip.tile" + name: "tile" } endpoint { - name: "tile" + name: "manip.tile" deprecated: true } } diff --git a/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt new file mode 100644 index 0000000000..a884a46143 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt @@ -0,0 +1,6 @@ +op { + graph_op_name: "UnicodeScript" + endpoint { + name: "strings.unicode_script" + } +} diff --git a/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt new file mode 100644 index 0000000000..984442ba2b --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt @@ -0,0 +1,6 @@ +op { + graph_op_name: "Xdivy" + endpoint { + name: "math.xdivy" + } +} diff --git a/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt new file mode 100644 index 0000000000..b4a5299256 --- /dev/null +++ b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt @@ -0,0 +1,6 @@ +op { + graph_op_name: "Xlogy" + endpoint { + name: "math.xlogy" + } +} diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc index 3b2dc6a050..7cb90de3c7 100644 --- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc +++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc @@ -522,7 +522,6 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams( InitInstanceSharedParams( gr, cp, ir, [this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) { - DCHECK(!ir->out_mu.try_lock()); DCHECK(ir->out_mu_available); ir->status.Update(s); ir->out_mu.unlock(); diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc index 99cb9ac6a0..e81e61b633 100644 --- a/tensorflow/core/common_runtime/constant_folding.cc +++ b/tensorflow/core/common_runtime/constant_folding.cc @@ -470,19 +470,19 @@ bool ReplaceTensorWithConstant( const ConstantFoldNameGenerator& generate_new_name) { // Be conservative when replacing a tensor with a constant, when not // running on CPU. - // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY + // 1) Do not replace another constant. + // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY // constraint, do not replace it. - // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY + // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY // constraint, do not replace it. - // 3) If the constant op created does not have a kernel implementation - // for the device, do not use it. // 4) If the size of the constant in bytes is too large (> // max_constant_in_bytes), do not replace it. This prevents the size of the // Graph from growing too large. + // 5) If the constant op created does not have a kernel implementation + // for the device, do not use it. // TODO(keveman): Consider adding a new constant op that has a kernel // implementation for all types, but with HostMemory constraint on it's // output. - // 5) Do not replace another constant. if (tensor.first->IsConstant()) { return false; } diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc index d800a86199..6e2eb66b94 100644 --- a/tensorflow/core/common_runtime/copy_tensor.cc +++ b/tensorflow/core/common_runtime/copy_tensor.cc @@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator, status_cb->Unref(); }; auto copier = std::bind( - [dst, recv_dev_context, out_allocator, status_cb]( - StatusCallback wrapped_done_, - // Begin unbound arguments - const Tensor& from, Tensor* to) { - if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( - "During Variant Host->Device Copy: " - "non-DMA-copy attempted of tensor type: ", - DataTypeString(from.dtype())); - status_cb->UpdateStatus(err); - return err; - } - if (status_cb->ok()) { + [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator, + edge_name](StatusCallback wrapped_done_, + // Begin unbound arguments + const Tensor& from, Tensor* to) { + if (from.dtype() == DT_VARIANT) { status_cb->Ref(); - *to = Tensor(out_allocator, from.dtype(), from.shape()); - recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, - wrapped_done_); + CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name, + dst, to, recv_dev_context, wrapped_done_); return Status::OK(); } else { - return status_cb->status(); + if (!DMAHelper::CanUseDMA(&from)) { + Status err = errors::InvalidArgument( + "During Variant Host->Device Copy: " + "non-DMA-copy attempted of tensor type: ", + DataTypeString(from.dtype())); + status_cb->UpdateStatus(err); + return err; + } + if (status_cb->ok()) { + status_cb->Ref(); + *to = Tensor(out_allocator, from.dtype(), from.shape()); + recv_dev_context->CopyCPUTensorToDevice(&from, dst, to, + wrapped_done_); + return Status::OK(); + } else { + return status_cb->status(); + } } }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); @@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator, status_cb->Unref(); }; auto copier = std::bind( - [edge_name, src, send_dev_context, out_allocator, status_cb]( - StatusCallback wrapped_done_, - // Begin unbound arguments - const Tensor& from, Tensor* to) { - if (!DMAHelper::CanUseDMA(&from)) { - Status err = errors::InvalidArgument( - "During Variant Device->Host Copy: " - "non-DMA-copy attempted of tensor type: ", - DataTypeString(from.dtype())); - status_cb->UpdateStatus(err); - return err; - } - if (status_cb->ok()) { + [edge_name, src, send_dev_context, out_allocator, status_cb, + cpu_allocator](StatusCallback wrapped_done_, + // Begin unbound arguments + const Tensor& from, Tensor* to) { + if (from.dtype() == DT_VARIANT) { status_cb->Ref(); - *to = Tensor(out_allocator, from.dtype(), from.shape()); - send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, - wrapped_done_); + CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name, + src, to, send_dev_context, wrapped_done_); return Status::OK(); } else { - return status_cb->status(); + if (!DMAHelper::CanUseDMA(&from)) { + Status err = errors::InvalidArgument( + "During Variant Device->Host Copy: " + "non-DMA-copy attempted of tensor type: ", + DataTypeString(from.dtype())); + status_cb->UpdateStatus(err); + return err; + } + if (status_cb->ok()) { + status_cb->Ref(); + *to = Tensor(out_allocator, from.dtype(), from.shape()); + send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to, + wrapped_done_); + return Status::OK(); + } else { + return status_cb->status(); + } } }, std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2); diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc index af5d5b17e7..458e133b68 100644 --- a/tensorflow/core/common_runtime/direct_session.cc +++ b/tensorflow/core/common_runtime/direct_session.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/log_memory.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/run_handler.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/algorithm.h" @@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool, #endif // __ANDROID__ } +static RunHandlerPool* GetOrCreateRunHandlerPool( + const SessionOptions& options) { + static RunHandlerPool* pool = + new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options)); + return pool; +} + +bool DirectSession::ShouldUseRunHandlerPool() const { + if (options_.config.session_inter_op_thread_pool_size() > 0 || + options_.config.use_per_session_threads()) { + return false; + } + return true; +} + DirectSession::DirectSession(const SessionOptions& options, const DeviceMgr* device_mgr, DirectSessionFactory* const factory) @@ -363,7 +379,7 @@ Status DirectSession::MaybeInitializeExecutionState( Status DirectSession::Create(const GraphDef& graph) { TF_RETURN_IF_ERROR(init_error_); if (graph.node_size() > 0) { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); if (graph_created_) { return errors::AlreadyExists( "A Graph has already been created for this session."); @@ -375,7 +391,7 @@ Status DirectSession::Create(const GraphDef& graph) { Status DirectSession::Extend(const GraphDef& graph) { TF_RETURN_IF_ERROR(CheckNotClosed()); - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); return ExtendLocked(graph); } @@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options, } } - Executor::Args::Runner default_runner = [this, - pool](Executor::Args::Closure c) { - SchedClosure(pool, std::move(c)); - }; + std::unique_ptr<RunHandler> handler; + if (ShouldUseRunHandlerPool() && + run_options.experimental().use_run_handler_pool()) { + // Non-null only when a global inter-op pool is used. + VLOG(1) << "Using RunHandler to scheduler inter-op closures."; + handler = GetOrCreateRunHandlerPool(options_)->Get(); + } + auto* handler_ptr = handler.get(); + + Executor::Args::Runner default_runner = nullptr; + + if (pool == nullptr) { + default_runner = [](Executor::Args::Closure c) { c(); }; + } else if (handler_ptr != nullptr) { + default_runner = [handler_ptr](Executor::Args::Closure c) { + handler_ptr->ScheduleInterOpClosure(std::move(c)); + }; + } else { + default_runner = [this, pool](Executor::Args::Closure c) { + SchedClosure(pool, std::move(c)); + }; + } + for (const auto& item : executors_and_keys->items) { - // TODO(zhengxq): support partial run. - // TODO(zhengxq): if the device picks its own threadpool, we need to assign + // TODO(azaks): support partial run. + // TODO(azaks): if the device picks its own threadpool, we need to assign // less threads to the main compute pool by default. thread::ThreadPool* device_thread_pool = item.device->tensorflow_device_thread_pool(); + // TODO(crk): Investigate usage of RunHandlerPool when using device specific + // thread pool(s). if (!device_thread_pool) { args.runner = default_runner; } else { @@ -1172,7 +1209,7 @@ Status DirectSession::CreateExecutors( int graph_def_version; { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); graph_def_version = execution_state_->original_graph_def().versions().producer(); } @@ -1400,7 +1437,7 @@ Status DirectSession::CreateGraphs( std::unique_ptr<FunctionLibraryDefinition>* flib_def, RunStateArgs* run_state_args, DataTypeVector* input_types, DataTypeVector* output_types, int64* collective_graph_key) { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); std::unique_ptr<ClientGraph> client_graph; std::unique_ptr<GraphExecutionState> temp_exec_state_holder; diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h index c2cf3c7fd7..3a168bbe3f 100644 --- a/tensorflow/core/common_runtime/direct_session.h +++ b/tensorflow/core/common_runtime/direct_session.h @@ -215,7 +215,7 @@ class DirectSession : public Session { // if not already initialized. Status MaybeInitializeExecutionState(const GraphDef& graph, bool* out_already_initialized) - EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); + EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); // Retrieves an already existing set of executors to run 'inputs' and // 'outputs', or creates and caches them for future use. @@ -247,8 +247,11 @@ class DirectSession : public Session { ExecutorsAndKeys* executors_and_keys, RunMetadata* run_metadata); + // Returns whether inter-op execution uses a global pool. + bool ShouldUseRunHandlerPool() const; + ::tensorflow::Status ExtendLocked(const GraphDef& graph) - EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_); + EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_); ::tensorflow::Status ResourceHandleToInputTensor( const Tensor& resource_tensor, Tensor* retrieved_tensor); @@ -289,7 +292,7 @@ class DirectSession : public Session { } ::tensorflow::Status CheckGraphCreated(const char* method) { - mutex_lock l(graph_def_lock_); + mutex_lock l(graph_state_lock_); if (!graph_created_) { return errors::InvalidArgument( "Session was not created with a graph before ", method, "!"); @@ -313,10 +316,8 @@ class DirectSession : public Session { DeviceSet device_set_; string session_handle_; - bool graph_created_ GUARDED_BY(graph_def_lock_) = false; - - mutex graph_def_lock_; - GraphDef graph_def_ GUARDED_BY(graph_def_lock_); + mutex graph_state_lock_; + bool graph_created_ GUARDED_BY(graph_state_lock_) = false; // The thread-pools to use for running ops, with a bool indicating if the pool // is owned. @@ -367,11 +368,11 @@ class DirectSession : public Session { // nodes can not be moved to a different device. Maps node names to // device names. std::unordered_map<string, string> stateful_placements_ - GUARDED_BY(graph_def_lock_); + GUARDED_BY(graph_state_lock_); // Execution_state; used when placing the entire graph. std::unique_ptr<GraphExecutionState> execution_state_ - GUARDED_BY(graph_def_lock_); + GUARDED_BY(graph_state_lock_); // The function library, before any rewrites or optimizations have been // performed. In particular, CreateGraphs() may need to modify the function @@ -386,7 +387,7 @@ class DirectSession : public Session { std::atomic<int64> edge_name_counter_ = {0}; std::atomic<int64> handle_name_counter_ = {0}; - // For generating step ids that are unique across all sessions. + // For generating step ids that are unique across this sessions. static std::atomic_int_fast64_t step_id_counter_; // Global timeout for all blocking operations in this session. @@ -395,8 +396,6 @@ class DirectSession : public Session { // Manages all the cost models for the graphs executed in this session. CostModelManager cost_model_manager_; - Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr; - // For testing collective graph key generation. mutex collective_graph_key_lock_; int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1; diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc index 65e816c202..a6440c55ad 100644 --- a/tensorflow/core/common_runtime/direct_session_test.cc +++ b/tensorflow/core/common_runtime/direct_session_test.cc @@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) { EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2); } +TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) { + Initialize({3, 2, -1, 0}); + auto session = CreateSession(); + ASSERT_TRUE(session != nullptr); + TF_ASSERT_OK(session->Create(def_)); + std::vector<std::pair<string, Tensor>> inputs; + + // Request two targets: one fetch output and one non-fetched output. + std::vector<string> output_names = {y_ + ":0"}; + std::vector<string> target_nodes = {y_neg_}; + std::vector<Tensor> outputs; + + // Prepares RunOptions and RunMetadata + RunOptions run_options; + run_options.mutable_experimental()->set_use_run_handler_pool(true); + + Status s = session->Run(run_options, inputs, output_names, target_nodes, + &outputs, nullptr); + TF_ASSERT_OK(s); + + ASSERT_EQ(1, outputs.size()); + // The first output should be initialized and have the correct + // output. + auto mat = outputs[0].matrix<float>(); + ASSERT_TRUE(outputs[0].IsInitialized()); + EXPECT_FLOAT_EQ(5.0, mat(0, 0)); +} + TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) { GraphDef def; Graph g(OpRegistry::Global()); @@ -2234,8 +2262,8 @@ class DirectSessionCollectiveTest : public ::testing::Test { TF_RETURN_IF_ERROR(session->Create(g)); std::vector<Tensor> outputs; TF_RETURN_IF_ERROR( - session->Run({{"input1:0", t1}, {"input2:0", t2}}, {}, - {"collective_call1:0", "collective_call2:0"}, &outputs)); + session->Run({{"input0:0", t1}, {"input1:0", t2}}, {}, + {"collective_call0:0", "collective_call1:0"}, &outputs)); DirectSession* direct_session = static_cast<DirectSession*>(session.get()); { mutex_lock l(direct_session->collective_graph_key_lock_); @@ -2273,6 +2301,26 @@ class DirectSessionCollectiveTest : public ::testing::Test { }}); } + NodeDef Input(int id) { + AttrValue dtype_attr; + SetAttrValue(DT_FLOAT, &dtype_attr); + NodeDef input; + input.set_name(strings::StrCat("input", id)); + input.set_op("Placeholder"); + input.mutable_attr()->insert({"dtype", dtype_attr}); + return input; + } + + NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) { + NodeDef collective_call; + collective_call.set_name(strings::StrCat("collective_call", cpu_id)); + collective_call.set_op(op); + collective_call.add_input(input); + collective_call.set_device( + strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id)); + return collective_call; + } + // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and // CPU1, with instance_key 1, and appropriate placeholder inputs. If // `add_unused_function` is true, adds another CollectiveFunction with @@ -2289,42 +2337,17 @@ class DirectSessionCollectiveTest : public ::testing::Test { *lib->add_function() = unused_function; } - // Inputs. - AttrValue dtype_attr; - SetAttrValue(DT_FLOAT, &dtype_attr); - NodeDef input1; - input1.set_name("input1"); - input1.set_op("Placeholder"); - input1.mutable_attr()->insert({"dtype", dtype_attr}); - NodeDef input2; - input2.set_name("input2"); - input2.set_op("Placeholder"); - input2.mutable_attr()->insert({"dtype", dtype_attr}); - + *g.add_node() = Input(0); + *g.add_node() = Input(1); // CollectiveReduce on CPU0 with instance_key 1. - NodeDef collective_call1; - collective_call1.set_name("collective_call1"); - collective_call1.set_op("CollectiveFunction1"); - collective_call1.add_input("input1"); - collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0"); + *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0); // CollectiveReduce on CPU1 with instance_key 1. - NodeDef collective_call2; - collective_call2.set_name("collective_call2"); - collective_call2.set_op("CollectiveFunction1"); - collective_call2.add_input("input2"); - collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1"); - - *g.add_node() = input1; - *g.add_node() = input2; - *g.add_node() = collective_call1; - *g.add_node() = collective_call2; + *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1); return g; } }; -#ifndef GOOGLE_CUDA -// TODO(ayushd): enable this test for GPU builds. TEST_F(DirectSessionCollectiveTest, TestCollectiveGraphKeyUsesOnlyCalledFunctions) { int64 key1; @@ -2333,6 +2356,5 @@ TEST_F(DirectSessionCollectiveTest, TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2)); ASSERT_EQ(key1, key2); } -#endif } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc index 2ed4f69f90..2c63b8704e 100644 --- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc +++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc @@ -108,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { EXPECT_EQ(2, shape.dim(0).size()); EXPECT_EQ(1, shape.dim(1).size()); if (node->name() == y->name()) { -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) // if MKL is used, it goes through various additional // graph rewrite pass. In TF, everytime a graph pass // happens, "constant" nodes are allocated @@ -117,16 +117,16 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) { // which increments the value of AllocationId. // Thus AllocationId becomes more than TF if MKL // is used. Now IDs for MKL are 8 more than TF. - EXPECT_EQ(29, cm->AllocationId(node, 0)); -#else EXPECT_EQ(21, cm->AllocationId(node, 0)); -#endif - } else { -#ifdef INTEL_MKL - EXPECT_EQ(30, cm->AllocationId(node, 0)); #else + EXPECT_EQ(13, cm->AllocationId(node, 0)); +#endif // INTEL_MKL && ENABLE_MKL + } else { +#if defined(INTEL_MKL) && defined(ENABLE_MKL) EXPECT_EQ(22, cm->AllocationId(node, 0)); -#endif +#else + EXPECT_EQ(14, cm->AllocationId(node, 0)); +#endif // INTEL_MKL && ENABLE_MKL } } EXPECT_LE(0, cm->MaxExecutionTime(node)); diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD index be5f3bae3a..7b74c67c85 100644 --- a/tensorflow/core/common_runtime/eager/BUILD +++ b/tensorflow/core/common_runtime/eager/BUILD @@ -147,10 +147,11 @@ tf_cuda_library( "kernel_and_device.h", ], visibility = ["//tensorflow:internal"], - deps = select({ + deps = [ + "@farmhash_archive//:farmhash", + ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", - "//util/hash:farmhash_fingerprint", ], "//conditions:default": [ "//tensorflow/core:core_cpu_lib", @@ -219,13 +220,13 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = [ ":kernel_and_device", + "@farmhash_archive//:farmhash", # Only the TF_AttrType enum is required, so pull in just the C headers. # TODO(b/113535673): Break this dependency and avoid the C header completely. "//tensorflow/c:c_api_headers", ] + select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", - "//util/hash:farmhash_fingerprint", ], "//conditions:default": [ "//tensorflow/core:core_cpu", diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index cf1cd4134e..5c8369de87 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -136,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m, m->insert(*it); } } + // For any attr-value pairs that exist in the op def (from op registry) but + // not `m`, fill them into `m`, so that we can run a TFE_Op without having to + // specify all the default attr values (e.g. for matmul, the `transpose_a` + // attr defaults to false). + const OpDef* op_def = nullptr; + Status s = OpDefForOp(op_name_.c_str(), &op_def); + // This is expected, if this op is a custom function, and is therefore not + // present in the op registry. + if (!s.ok()) return; + + DCHECK(op_def); + for (const auto& attr_def : op_def->attr()) { + if (attr_def.has_default_value() && !m->count(attr_def.name())) { + SetInAttrValueMap(m, attr_def.name(), attr_def.default_value()); + } + } } const NodeDef& AttrBuilder::BuildNodeDef() { diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h index cbe6a1cb50..c114ea4ba0 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.h +++ b/tensorflow/core/common_runtime/eager/attr_builder.h @@ -110,6 +110,12 @@ class AttrBuilder { using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>; void MayBeInitializeNodeDef(); + // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as + // well as any default attr-value pairs from the associated op_def, if there + // is one. + // + // If `include_those_in_node_def` is true, also include any attr-value pairs + // from `node_def_`. void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const; template <class T> diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 18420b60fd..f23cefb33d 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -70,7 +70,9 @@ EagerContext::EagerContext(const SessionOptions& opts, async_default_(async), log_memory_(LogMemory::IsEnabled()), env_(opts.env), - use_send_tensor_rpc_(false) { + use_send_tensor_rpc_(false), + pin_small_ops_to_cpu_(ReadBoolFromEnvVar( + "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", true)) { if (device_mgr_owned) { local_device_manager_.reset(device_mgr); local_unowned_device_manager_ = nullptr; diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h index 5ed6057ec6..15eeaa8066 100644 --- a/tensorflow/core/common_runtime/eager/context.h +++ b/tensorflow/core/common_runtime/eager/context.h @@ -202,6 +202,7 @@ class EagerContext { // EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used // instead (which in-turn use WorkerService.RecvTensor RPCs). bool UseSendTensorRPC() { return use_send_tensor_rpc_; } + bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; } private: void InitDeviceMapAndAsync(); @@ -293,6 +294,7 @@ class EagerContext { #endif bool use_send_tensor_rpc_; + const bool pin_small_ops_to_cpu_; }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc index 1bc63616d0..a52f933d75 100644 --- a/tensorflow/core/common_runtime/eager/execute.cc +++ b/tensorflow/core/common_runtime/eager/execute.cc @@ -579,19 +579,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals, return Status::OK(); #endif } -} // namespace -Status EagerExecute(EagerOperation* op, - gtl::InlinedVector<TensorHandle*, 2>* retvals, - int* num_retvals) { - // Ensure all resource-touching ops run in the device the resource is, - // regardless of anything else that has been specified. This is identical to - // the graph mode behavior. +// The Op device may be updated if: +// - A resource touching input is specified: all resource-touching ops run in +// the device the resource is, regardless of anything else that has been +// specified. This is identical to the graph mode behavior. +// +// - All op inputs are on the CPU, small (<64 elements) and integers +// (int32/int64). This can be disabled by setting the environment variable +// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false". +Status MaybeUpdateOpDevice(EagerOperation* op) { EagerContext* ctx = op->EagerContext(); + bool device_set_for_resource_variable = false; + bool all_inputs_eligible_for_cpu_pinning = ctx->PinSmallOpsToCPU(); + for (int i = 0; i < op->Inputs().size(); ++i) { Device* input_op_device = nullptr; - auto status = op->Inputs()[i]->OpDevice(&input_op_device); - if (!status.ok()) return status; + TF_RETURN_IF_ERROR(op->Inputs()[i]->OpDevice(&input_op_device)); VLOG(2) << "for op " << op->Name() << " input " << i << " " << DataTypeString(op->Inputs()[i]->dtype) << " " << (input_op_device == nullptr ? "cpu" : input_op_device->name()) @@ -603,8 +607,53 @@ Status EagerExecute(EagerOperation* op, << d->name() << " because input #" << i << " is a resource in this device."; op->SetDevice(d); + + device_set_for_resource_variable = true; + all_inputs_eligible_for_cpu_pinning = false; + } else if (all_inputs_eligible_for_cpu_pinning) { + TensorHandle* handle = op->Inputs()[i]; + + // Input is on CPU. + if (input_op_device != nullptr && input_op_device != ctx->HostCPU()) { + all_inputs_eligible_for_cpu_pinning = false; + continue; + } + + if (handle->dtype != DataType::DT_INT32 && + handle->dtype != DataType::DT_INT64) { + all_inputs_eligible_for_cpu_pinning = false; + continue; + } + + int64 num_elements; + TF_RETURN_IF_ERROR(handle->NumElements(&num_elements)); + if (num_elements > 64) { + all_inputs_eligible_for_cpu_pinning = false; + } } } + + // Ops without inputs are usually ops that generate a tensor in some way and + // usually require being present on whatever device they are scheduled on + // - for e.g. VarHandleOp or _Recv). + // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for + // an op, but there is a GPU kernel? + if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) { + VLOG(1) << "Forcing op " << op->Name() + << " to be on the CPU since all input tensors have an " + "int32/int64 dtype, and are small (less than 64 elements)."; + op->SetDevice(ctx->HostCPU()); + } + + return Status::OK(); +} +} // namespace + +Status EagerExecute(EagerOperation* op, + gtl::InlinedVector<TensorHandle*, 2>* retvals, + int* num_retvals) { + TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op)); + bool op_is_local = IsLocal(op->EagerContext(), op->Device()); if (op_is_local) { diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc index c1542f1f57..87749da7af 100644 --- a/tensorflow/core/common_runtime/eval_const_tensor.cc +++ b/tensorflow/core/common_runtime/eval_const_tensor.cc @@ -113,6 +113,13 @@ Status TryToInferTensorOutputFromInputShapes(const Edge& edge, return Status::OK(); } +// Returns true if 'node' has a registered CPU kernel. +bool HasCpuKernel(const Node& node) { + return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr, + /*kernel_class_name=*/nullptr) + .ok(); +} + // Extracts the subgraph ending at 'target_node' that is statically computable // and inserts into 'out_graph'. If statically computable, 'is_constant_graph' // will be set to true. @@ -136,6 +143,12 @@ Status ExtractConstantSubgraph( return Status::OK(); } + // Since constant-folding runs on the CPU, do not attempt to constant-fold + // operators that have no CPU kernel. + if (!HasCpuKernel(target_node)) { + return Status::OK(); + } + // TODO(skyewm): should more of the filtering applied in input nodes below be // applied to target_node here? @@ -201,6 +214,11 @@ Status ExtractConstantSubgraph( return Status::OK(); } + if (!HasCpuKernel(*current_node)) { + *is_constant_graph = false; + return Status::OK(); + } + // If there is nothing more to recurse down, see if // the generator node is a constant. if (current_node->num_inputs() == 0) { diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc index 2c48084cab..40ec1502da 100644 --- a/tensorflow/core/common_runtime/executor.cc +++ b/tensorflow/core/common_runtime/executor.cc @@ -54,6 +54,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/context.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/mutex.h" @@ -1240,6 +1241,7 @@ class ExecutorState { StepStatsCollectorInterface* const stats_collector_; const tracing::TraceCollector* const trace_collector_; const tracing::EventCollector* const event_collector_; + Context context_; // QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper // instead of a pointer? (avoids having to delete). @@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl) trace_collector_(tracing::GetTraceCollector()), event_collector_( tracing::GetEventCollector(tracing::EventCategory::kCompute)), + context_(ContextKind::kThread), slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper), call_frame_(args.call_frame), impl_(impl), @@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item, } void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) { + WithContext wc(context_); const GraphView& gview = impl_->gview_; TaggedNodeSeq ready; TaggedNodeReadyQueue inline_ready; diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h index 6cd4fd22ea..34bf73972f 100644 --- a/tensorflow/core/common_runtime/executor.h +++ b/tensorflow/core/common_runtime/executor.h @@ -97,12 +97,6 @@ class Executor { typedef std::function<void()> Closure; typedef std::function<void(Closure)> Runner; Runner runner = nullptr; - - // A callback that is invoked each time a node has finished executing. - typedef std::function<Status(const string& node_name, const int output_slot, - const Tensor* tensor, const bool is_ref, - OpKernelContext* ctx)> - NodeOutputsCallback; }; typedef std::function<void(const Status&)> DoneCallback; virtual void RunAsync(const Args& args, DoneCallback done) = 0; diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc index 96ecfb41d4..37a979a8f1 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.cc +++ b/tensorflow/core/common_runtime/graph_optimizer.cc @@ -38,7 +38,8 @@ void GraphOptimizer::Optimize( std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, - const std::function<bool(const Node*)>& cse_consider_fn) { + const std::function<bool(const Node*)>& cse_consider_fn, + const std::function<bool(const Node*)>& cf_consider_fn) { Graph* g = graph->get(); DumpGraph("Initial", g); @@ -62,6 +63,7 @@ void GraphOptimizer::Optimize( if (opts_.do_constant_folding()) { ConstantFoldingOptions cf_opts; cf_opts.shape_map = shape_map; + cf_opts.consider = cf_consider_fn; if (opts_.max_folded_constant_in_bytes() > 0) { cf_opts.max_constant_size_in_bytes = opts_.max_folded_constant_in_bytes(); diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h index 80246281cd..789cc56942 100644 --- a/tensorflow/core/common_runtime/graph_optimizer.h +++ b/tensorflow/core/common_runtime/graph_optimizer.h @@ -45,12 +45,15 @@ class GraphOptimizer { // // If cse_consider_fn is not null then only nodes for which cse_consider_fn // returns true will be considered for CSE. + // If cf_consider_fn is not null then only nodes for which cf_consider_fn + // returns true will be considered for CF. void Optimize( FunctionLibraryRuntime* runtime, Env* env, Device* device, std::unique_ptr<Graph>* graph, const std::unordered_map<string, std::vector<PartialTensorShape>>* shape_map, - const std::function<bool(const Node*)>& cse_consider_fn = nullptr); + const std::function<bool(const Node*)>& cse_consider_fn = nullptr, + const std::function<bool(const Node*)>& cf_consider_fn = nullptr); const OptimizerOptions& options() { return opts_; } diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc index dfce7c23e7..9306386117 100644 --- a/tensorflow/core/common_runtime/lower_if_op.cc +++ b/tensorflow/core/common_runtime/lower_if_op.cc @@ -38,11 +38,12 @@ class CondBuilder { public: enum Branch { kElseBranch = 0, kThenBranch = 1 }; - // Create a CondBuilder to create the lowering of If op. that has then and + // Create a CondBuilder to create the lowered form of `if_op` with then and // else functions named `then_fn_name` and `else_fn_name` respectively in the - // given graph. + // `graph`. The functions should be available in `flib`. CondBuilder(Node* if_op, const string& then_fn_name, - const string& else_fn_name, Graph* graph); + const string& else_fn_name, const FunctionLibraryDefinition& flib, + Graph* graph); // Constructs the basic conditional control flow using switch and merge nodes. Status CreatePivotNodes(); @@ -89,6 +90,7 @@ class CondBuilder { Node* then_call_node_; Node* else_call_node_; Graph* graph_; + const FunctionLibraryDefinition& flib_; string name_; NodeBuilder then_call_builder_; @@ -96,13 +98,17 @@ class CondBuilder { }; CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name, - const string& else_fn_name, Graph* graph) + const string& else_fn_name, + const FunctionLibraryDefinition& flib, Graph* graph) : if_op_(if_op), graph_(graph), + flib_(flib), name_(if_op->name()), then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()), else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) { TF_CHECK_OK(if_op_->input_node(0, &pred_)); + then_call_builder_.Device(if_op_->requested_device()); + else_call_builder_.Device(if_op_->requested_device()); } Status CondBuilder::CreatePivotNodes() { @@ -113,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() { NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry()) .Input(NodeOut(pred_, 0)) .Input(NodeOut(pred_, 0)) + .Device(if_op_->requested_device()) .Finalize(graph_, &switch_pred)); control_predecessor_ = switch_pred; TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry()) .Input(switch_pred, kElseBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_f_)); TF_RETURN_IF_ERROR( NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry()) .Input(switch_pred, kThenBranch) + .Device(if_op_->requested_device()) .Finalize(graph_, &pivot_t_)); return Status::OK(); } @@ -136,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) { NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry()) .Input(src, src_output) .Input(pred_, 0) + .Device(if_op_->requested_device()) .Finalize(graph_, &input)); then_call_builder_.Input(input, kThenBranch); else_call_builder_.Input(input, kElseBranch); @@ -174,6 +184,7 @@ Status CondBuilder::AddOutputs() { TF_RETURN_IF_ERROR( NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry()) .Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)}) + .Device(if_op_->requested_device()) .Finalize(graph_, &merges[i])); outputs_[i] = NodeOut(merges[i], 0); } @@ -193,15 +204,15 @@ Status CondBuilder::AddOutputs() { return Status::OK(); } -Status InlineCallInGraph(Node* n, Graph* g) { - const auto& lib = g->flib_def(); - const FunctionDef* fdef = lib.Find(n->type_string()); +Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib, + Graph* g) { + const FunctionDef* fdef = flib.Find(n->type_string()); CHECK(fdef != nullptr); FunctionBody* fbody; TF_RETURN_IF_ERROR( - FunctionDefToBodyHelper(*fdef, n->attrs(), &lib, - [&lib](const string& op, const OpDef** sig) { - return lib.LookUpOpDef(op, sig); + FunctionDefToBodyHelper(*fdef, n->attrs(), &flib, + [&flib](const string& op, const OpDef** sig) { + return flib.LookUpOpDef(op, sig); }, &fbody)); // TODO(jpienaar): Improve this interface to make the need to delete it @@ -214,13 +225,13 @@ Status InlineCallInGraph(Node* n, Graph* g) { Status CondBuilder::BuildLoweredIfOutput() { // Build the identity node output. NodeBuilder ib(name_, "IdentityN"); - ib.Input(outputs_); + ib.Input(outputs_).Device(if_op_->requested_device()); return ib.Finalize(graph_, &lowered_if_output_); } Status CondBuilder::InlineCallNodes() { - TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_)); - TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_)); + TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_)); + TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_)); return Status::OK(); } @@ -240,6 +251,12 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) { return errors::Internal("Lowering If op requires a graph to be available."); } + FunctionLibraryDefinition* flib = options.flib_def; + if (flib == nullptr) { + return errors::Internal( + "Lowering If op requires a FunctionLibraryDefinition to be available."); + } + // Match all the nodes that need to be rewritten. gtl::InlinedVector<Node*, 2> matches; for (Node* n : g->op_nodes()) { @@ -251,12 +268,14 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) { } } for (Node* n : matches) { - TF_RETURN_IF_ERROR(RewriteNode(n, g)); + TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g)); } return Status::OK(); } -Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) { +Status LowerIfOpPass::RewriteNode(Node* n, + const FunctionLibraryDefinition& flib, + Graph* g) { const AttrValue* then_attr = n->attrs().Find("then_branch"); if (then_attr == nullptr) { return errors::InvalidArgument("Then branch function missing"); @@ -266,7 +285,8 @@ Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) { return errors::InvalidArgument("Else branch function missing"); } - CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g); + CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib, + g); TF_RETURN_IF_ERROR(cb.CreatePivotNodes()); TF_RETURN_IF_ERROR(cb.AddInputs()); TF_RETURN_IF_ERROR(cb.AddOutputs()); diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h index a9ef39ae5c..5ab1123e3f 100644 --- a/tensorflow/core/common_runtime/lower_if_op.h +++ b/tensorflow/core/common_runtime/lower_if_op.h @@ -29,8 +29,9 @@ class LowerIfOpPass : public GraphOptimizationPass { Status Run(const GraphOptimizationPassOptions& options) override; private: - // Rewrite the given If node `n` in graph `g` to use the switch-merge form. - Status RewriteNode(Node* n, Graph* g); + // Rewrite the given If node `n` in graph `g` to use the switch-merge + // form. `flib` should contain the branch functions referenced by `n`. + Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g); }; } // namespace tensorflow diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc index 319a617b32..044a355d06 100644 --- a/tensorflow/core/common_runtime/lower_if_op_test.cc +++ b/tensorflow/core/common_runtime/lower_if_op_test.cc @@ -36,9 +36,7 @@ namespace tensorflow { namespace { Status Rewrite(std::unique_ptr<Graph>* graph) { - FunctionDefLibrary flib; - FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib); - + FunctionLibraryDefinition flib_def((*graph)->flib_def()); GraphOptimizationPassOptions opt_options; opt_options.graph = graph; opt_options.flib_def = &flib_def; diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h index 538a70668a..429b19599b 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h @@ -251,6 +251,7 @@ class MklCPUAllocator : public Allocator { // max_alloc_size from large_size_allocator would be the maximum // size allocated by MklCPUAllocator. stats->max_alloc_size = l_stats.max_alloc_size; + stats->bytes_limit = std::max(s_stats.bytes_limit, l_stats.bytes_limit); } void ClearStats() override { diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc index a67411cd2e..e08ab57638 100644 --- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc +++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/common_runtime/mkl_cpu_allocator.h" @@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) { } // namespace tensorflow -#endif // INTEL_MKL +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc index a5d31b75c7..e1dc08d645 100644 --- a/tensorflow/core/common_runtime/process_util.cc +++ b/tensorflow/core/common_runtime/process_util.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/util.h" namespace tensorflow { @@ -56,24 +57,26 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) { const int32 inter_op = options.config.inter_op_parallelism_threads(); if (inter_op != 0) return inter_op; #ifdef INTEL_MKL - // MKL library executes ops in parallel using OMP threads - // Set inter_op conservatively to avoid thread oversubscription that could - // lead to severe perf degradations and OMP resource exhaustion - int mkl_intra_op = 1; + if (!DisableMKL()) { + // MKL library executes ops in parallel using OMP threads + // Set inter_op conservatively to avoid thread oversubscription that could + // lead to severe perf degradations and OMP resource exhaustion + int mkl_intra_op = 1; #ifdef _OPENMP - mkl_intra_op = omp_get_max_threads(); + mkl_intra_op = omp_get_max_threads(); #endif // _OPENMP - CHECK_GE(mkl_intra_op, 1); - const int32 mkl_inter_op = std::max( - (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2); - VLOG(0) << "Creating new thread pool with default inter op setting: " - << mkl_inter_op - << ". Tune using inter_op_parallelism_threads for best performance."; - return mkl_inter_op; -#else + DCHECK_GE(mkl_intra_op, 1); + const int32 mkl_inter_op = std::max( + (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2); + VLOG(0) + << "Creating new thread pool with default inter op setting: " + << mkl_inter_op + << ". Tune using inter_op_parallelism_threads for best performance."; + return mkl_inter_op; + } +#endif // INTEL_MKL // Default to using the number of cores available in the process. return port::NumSchedulableCPUs(); -#endif // INTEL_MKL } thread::ThreadPool* NewThreadPoolFromSessionOptions( diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc index a81f8650bf..b1fe928ba7 100644 --- a/tensorflow/core/common_runtime/ring_reducer.cc +++ b/tensorflow/core/common_runtime/ring_reducer.cc @@ -41,6 +41,16 @@ limitations under the License. // Set true for greater intelligibility of debug mode log messages. #define READABLE_KEYS false +// RingReduce algorithm exchanges chunks of tensor between devices. The chunk +// size depends on the number of subdivisions specified in the algorithm. If +// the user does not specify the number of subdivisions, we infer the number +// dynamically so that the resulting chunk size does not exceed +// kMaxChunkSizeBytes, empirically set at 4 MiB. +constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024); +// kMaxSubdivsPerDev is used to give an upper bound on the number of +// subdivisions dynamically generated. A reasonable value would be a small +// multiple of the number of NICs adjacent to each device. +constexpr int kMaxSubdivsPerDevice = 2; namespace tensorflow { namespace { @@ -92,7 +102,62 @@ RingReducer::RingReducer() RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); } +Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { + if (col_params->instance.shape.num_elements() == 0) { + return errors::Internal("shape in CollectiveParams should be non-empty"); + } + const int kAvgDevPerTask = + col_params->group.group_size / col_params->group.num_tasks; + const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask; + if (kMaxNumSubdivs <= 0) { + return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs, + " in RingReducer"); + } + // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add + // as many offsets as needed so that the size of tensor chunks <= + // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large + // lead to worse performance. + int num_subdivs = 0; + const size_t tensor_size = col_params->instance.shape.num_elements() * + DataTypeSize(col_params->instance.data_type); + size_t chunk_size; + do { + ++num_subdivs; + int num_chunks = col_params->group.group_size * num_subdivs; + chunk_size = tensor_size / num_chunks; + VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks + << " chunk_size " << chunk_size; + } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs); + if (num_subdivs <= 0) { + return errors::Internal("Unexpected num_subdivs ", num_subdivs, + " in RingReducer"); + } + + int subdiv_stride = kAvgDevPerTask / num_subdivs; + if (subdiv_stride == 0) subdiv_stride = 1; + col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs); + for (int sdi = 0; sdi < num_subdivs; ++sdi) { + int subdiv_offset = subdiv_stride * sdi; + if (sdi % 2 == 1) subdiv_offset *= -1; + col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset); + } + + if (VLOG_IS_ON(2)) { + string subdiv_buf; + for (const int subdiv_offset : + col_params->instance.impl_details.subdiv_offsets) { + strings::StrAppend(&subdiv_buf, " ", subdiv_offset); + } + VLOG(2) << "Dynamically generated " << num_subdivs + << " subdiv_offsets:" << subdiv_buf << " tensor_size " + << tensor_size << " chunk_size " << chunk_size; + } + + return Status::OK(); +} + Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) { + // TODO(b/113171733): change CHECKs to return errors. CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE); CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce"); const string& device_name = @@ -123,12 +188,11 @@ Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) { dev_per_task.push_back(dev_count); CHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); - // Generate a ring permutation for each requested offset. if (col_params->instance.impl_details.subdiv_offsets.empty()) { - return errors::Internal( - "Subdiv offsets should be non-empty for ring reducer, size=", - col_params->instance.impl_details.subdiv_offsets.size()); + TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params)); } + + // Generate a ring permutation for requested offset. VLOG(2) << "Setting up perms for col_params " << col_params << " subdiv_permutations " << &col_params->instance.impl_details.subdiv_permutations; @@ -646,7 +710,8 @@ bool RingReducer::RunAsyncParts() { case RF_SEND: --send_pending_count; break; - default: {} // Ignore any other actions + default: { + } // Ignore any other actions } } } diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc index 28df85399e..75aba43572 100644 --- a/tensorflow/core/common_runtime/ring_reducer_test.cc +++ b/tensorflow/core/common_runtime/ring_reducer_test.cc @@ -549,37 +549,38 @@ class RingReducerTest : public ::testing::Test { int32 reduce_counter_ GUARDED_BY(mu_) = 0; }; -TEST_F(RingReducerTest, InitializeParams) { - static const int kNumDevsPerTask = 8; - static const int kNumTasks = 3; - static const int kNumDevs = kNumDevsPerTask * kNumTasks; +CollectiveParams SetUpCollectiveParams(const int num_devs_per_task, + const int num_tasks) { CollectiveParams cp; - std::vector<string> device_names; - std::vector<string> task_names; + const int kNumDevs = num_devs_per_task * num_tasks; cp.group.group_key = 1; cp.group.group_size = kNumDevs; cp.group.device_type = DeviceType("GPU"); - cp.group.num_tasks = kNumTasks; + cp.group.num_tasks = num_tasks; cp.instance.instance_key = 3; cp.instance.type = REDUCTION_COLLECTIVE; cp.instance.data_type = DataType(DT_FLOAT); - cp.instance.shape = TensorShape({5}); + cp.instance.shape = TensorShape({kNumDevs}); cp.instance.impl_details.collective_name = "RingReduce"; cp.instance.impl_details.subdiv_offsets.push_back(0); cp.is_source = false; for (int i = 0; i < kNumDevs; ++i) { - int task_id = i / kNumDevsPerTask; - int dev_id = i % kNumDevsPerTask; + int task_id = i / num_devs_per_task; + int dev_id = i % num_devs_per_task; string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id); - task_names.push_back(task_name); string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id); - device_names.push_back(device_name); cp.instance.task_names.push_back(task_name); cp.instance.device_names.push_back(device_name); } + return cp; +} - int test_rank = 0; - cp.default_rank = test_rank; +TEST_F(RingReducerTest, InitializeParams) { + const int kNumDevsPerTask = 8; + const int kNumTasks = 3; + CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + + cp.default_rank = 0; cp.instance.impl_details.subdiv_offsets = {0, 4}; RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, @@ -588,8 +589,15 @@ TEST_F(RingReducerTest, InitializeParams) { 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}}, {0, 4}); - test_rank = 3; - cp.default_rank = test_rank; + cp.instance.impl_details.subdiv_offsets = {0, -4}; + RunSubdivPermsTest(&cp, + {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, + 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}}, + {0, 3}); + + cp.default_rank = 3; cp.instance.impl_details.subdiv_offsets = {3, -3}; RunSubdivPermsTest(&cp, {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14, @@ -599,6 +607,49 @@ TEST_F(RingReducerTest, InitializeParams) { {0, 1}); } +TEST_F(RingReducerTest, AutomaticSubdivs) { + const int kNumDevsPerTask = 8; + const int kNumTasks = 3; + const int kNumDevs = kNumDevsPerTask * kNumTasks; + CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + + // Test automatic generation of subdiv offsets. + cp.default_rank = 0; + cp.instance.impl_details.subdiv_offsets.clear(); + RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}}, + {0}); + + // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2 + // offsets, {0, -4}, to be generated. + { + int num_subdivs = 2; + int num_chunks = kNumDevs * num_subdivs; + size_t chunk_size = 3 * 1048576; // 3 MB + size_t tensor_size = chunk_size * num_chunks; + cp.instance.shape = + TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))}); + } + cp.instance.impl_details.subdiv_offsets.clear(); + RunSubdivPermsTest(&cp, + {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}, + {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8, + 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}}, + {0, 3}); +} + +TEST_F(RingReducerTest, AutomaticSubdivUpperBound) { + const int kNumDevsPerTask = 1; + const int kNumTasks = 4; + CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks); + + cp.default_rank = 0; + cp.instance.impl_details.subdiv_offsets.clear(); + cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)}); + RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0}); +} + // TODO(b/113171733): change to use TEST_P. #define DEF_TEST(B, T, W, D, S, L, A) \ TEST_F(RingReducerTest, \ diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc index 0fbc20b34b..6404d8bc6a 100644 --- a/tensorflow/core/common_runtime/threadpool_device.cc +++ b/tensorflow/core/common_runtime/threadpool_device.cc @@ -29,6 +29,7 @@ limitations under the License. #include "tensorflow/core/platform/tracing.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/public/session_options.h" +#include "tensorflow/core/util/util.h" #ifdef INTEL_MKL #ifdef _OPENMP @@ -49,6 +50,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options, allocator_(allocator), scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) { #ifdef INTEL_MKL + // Early return when MKL is disabled + if (DisableMKL()) return; #ifdef _OPENMP const char* user_omp_threads = getenv("OMP_NUM_THREADS"); if (user_omp_threads == nullptr) { @@ -113,8 +116,12 @@ class MklCPUAllocatorFactory : public AllocatorFactory { } }; -REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory); +#ifdef ENABLE_MKL +REGISTER_MEM_ALLOCATOR("MklCPUAllocator", (DisableMKL() ? 50 : 200), + MklCPUAllocatorFactory); +#endif // ENABLE_MKL + } // namespace -#endif +#endif // INTEL_MKL } // namespace tensorflow diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 20a07d86a2..50403b4004 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) { return Status::OK(); } +namespace { + +// This SliceHelper processes the output shape of the `slice` +// when the tensor of `sizes` is available. +template <typename T> +Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, + const Tensor* sizes_value, + std::vector<DimensionHandle>* dims) { + auto sizes_vec = sizes_value->vec<T>(); + for (int i = 0; i < sizes_value->NumElements(); ++i) { + DimensionHandle dim = c->Dim(c->input(0), i); + if (sizes_vec(i) != -1) { + auto dim_val = c->Value(dim); + if (sizes_vec(i) < 0) { + return errors::InvalidArgument( + "Out of bounds slicing on dimension ", i, " of length ", dim_val, + ": sizes vector cannot be < -1, but was ", sizes_vec(i)); + } + + dims->emplace_back(c->MakeDim(sizes_vec(i))); + } else { + DimensionHandle result; + TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); + dims->emplace_back(result); + } + } + + return Status::OK(); +} +} // namespace + +Status SliceShape(InferenceContext* c) { + ShapeHandle input = c->input(0); + ShapeHandle begin_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); + ShapeHandle sizes_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); + + // Merge to check compatibility of begin and sizes tensors. + TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); + + DimensionHandle ndims = c->Dim(begin_shape, 0); + if (c->ValueKnown(ndims)) { + TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); + } + + // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known + // values, even though the `begin` value does not represent a shape. + ShapeHandle begin_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); + + // We check the tensor value here and will only use + // `MakeShapeFromShapeTensor` when `sizes_value` is null. + // The reason is that `sizes` might contain -1, which can't + // be represented (-1 in the ShapeHandle would mean "unknown"). + const Tensor* sizes_value = c->input_tensor(2); + + if (sizes_value != nullptr) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); + std::vector<DimensionHandle> dims; + // If the begin and sizes tensors are available, then + // we can be precise about the shape of the output. + if (sizes_value->dtype() == DT_INT64) { + TF_RETURN_IF_ERROR( + SliceHelper<int64>(c, begin_value, sizes_value, &dims)); + } else { + TF_RETURN_IF_ERROR( + SliceHelper<int32>(c, begin_value, sizes_value, &dims)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } else { + // In case `sizes` is not available (`sizes_value` is null), + // we could try to use `MakeShapeFromShapeTensor` here. + // If sizes contain -1, we will simply consider it as `Unknown`. + // This is less than ideal but still an improvement of shape inference. + // The following is an example that returns [None, 1, None] with this + // code path: + // z = tf.zeros((1, 2, 3)) + // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) + // m.get_shape().as_list() + ShapeHandle sizes_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); + if (c->RankKnown(sizes_value)) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); + std::vector<DimensionHandle> dims; + dims.reserve(c->Rank(sizes_value)); + for (int i = 0; i < c->Rank(sizes_value); ++i) { + dims.emplace_back(c->Dim(sizes_value, i)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } + // We might know the rank of the input. + if (c->RankKnown(input)) { + c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); + return Status::OK(); + } else { + return shape_inference::UnknownShape(c); + } + } + + return Status::OK(); +} + Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, ShapeHandle values_shape, ShapeHandle shape_shape) { // Validate ranks. diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index e6f9f935f9..3a496e06ae 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) { // Shape function for random operations. Status RandomShape(shape_inference::InferenceContext* c); +// Shape function for Slice opertaions. +Status SliceShape(shape_inference::InferenceContext* c); + // Validates the 3 component tensors of a sparse tensor have the proper // shapes. This mimics SparseTensor.__init__ in python/framework/ops.py. Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape, diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h index 697e0604bf..964a7d5f8c 100644 --- a/tensorflow/core/framework/dataset.h +++ b/tensorflow/core/framework/dataset.h @@ -278,15 +278,8 @@ class IteratorContext { // Function call support. std::function<void(std::function<void()>)> runner = nullptr; - // A function that returns the current `StatsAggregator` instance to be - // used when recording statistics about the iterator. - // - // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator` - // is a property of the `IteratorResource` (which this class does not know - // about), and (ii) it can change after the `IteratorContext` has been - // created. Better suggestions are welcome! - std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter = - nullptr; + // The `StatsAggregator` object to record statistics about the iterator. + std::shared_ptr<StatsAggregator> stats_aggregator = nullptr; // The FunctionLibraryRuntime object to be used to make function calls. FunctionLibraryRuntime* lib = nullptr; @@ -320,13 +313,6 @@ class IteratorContext { return ¶ms_.runner; } - std::shared_ptr<StatsAggregator> stats_aggregator() { - if (params_.stats_aggregator_getter) { - return params_.stats_aggregator_getter(); - } else { - return nullptr; - } - } std::shared_ptr<const FunctionLibraryDefinition> function_library() { return params_.function_library; @@ -344,8 +330,8 @@ class IteratorContext { return params_.allocator_getter; } - std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() { - return params_.stats_aggregator_getter; + std::shared_ptr<StatsAggregator> stats_aggregator() { + return params_.stats_aggregator; } std::shared_ptr<model::Model> model() { return params_.model; } @@ -657,15 +643,15 @@ class DatasetBaseIterator : public IteratorBase { // When performance modeling is enabled, this method adds a tunable parameter // to the model node corresponding to this iterator. // - // The performance modeling logic may use `value` to set the value of the + // The performance modeling logic may use `state` to set the value of the // tunable parameter at any point during the lifetime of this iterator. When - // it does, it notifies `cond_var`. + // it does, it acquires `state->mu` and notifies `state->cond_var`. void AddTunableParameter(IteratorContext* ctx, const string& name, - std::atomic<int64>* value, int64 min, int64 max, - condition_variable* cond_var) { + std::shared_ptr<model::SharedState> state, int64 min, + int64 max) { if (ctx->model()) { - ctx->model()->AddTunableParameter(prefix(), name, value, min, max, - cond_var); + ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min, + max); } } diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc index a17959a448..20f957190b 100644 --- a/tensorflow/core/framework/function.cc +++ b/tensorflow/core/framework/function.cc @@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func, return Status::OK(); } +Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { + mutex_lock l(mu_); + bool added; + TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); + TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added)); + return Status::OK(); +} + Status FunctionLibraryDefinition::RemoveFunction(const string& func) { const auto& i = function_defs_.find(func); if (i == function_defs_.end()) { diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h index e01eb7503d..4d6d68e214 100644 --- a/tensorflow/core/framework/function.h +++ b/tensorflow/core/framework/function.h @@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface { // a non-OK status if "func" was not found in the library, OK otherwise. Status ReplaceFunction(const string& func, const FunctionDef& fdef); + // Replaces the gradient corresponding to `grad.function_name()`. Returns + // a non-OK status if "grad.function_name()" was not found in the library, OK + // otherwise. + Status ReplaceGradient(const GradientDef& grad); + // Adds the functions and gradients in 'other' to this function library. // Duplicate functions and gradients are ignored. // This operation is atomic. diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc index d5c203d276..0445c242e9 100644 --- a/tensorflow/core/framework/function_testlib.cc +++ b/tensorflow/core/framework/function_testlib.cc @@ -93,7 +93,6 @@ FunctionDef IsZero() { FunctionDef RandomUniform() { const Tensor kZero = test::AsScalar<int64>(0); - const Tensor kTen = test::AsScalar<int64>(10); return FDH::Define( // Name @@ -108,19 +107,11 @@ FunctionDef RandomUniform() { "Const", {}, {{"value", kZero}, {"dtype", DT_INT64}}}, - {{"random_uniform/min"}, - "Const", - {}, - {{"value", kZero}, {"dtype", DT_INT64}}}, - {{"random_uniform/max"}, - "Const", - {}, - {{"value", kTen}, {"dtype", DT_INT64}}}, {{"random_uniform"}, - "RandomUniformInt", - {}, - {{"T", DT_INT64}, - {"Tout", DT_INT64}, + "RandomUniform", + {"random_uniform/shape"}, + {{"T", DT_INT32}, + {"Tout", DT_FLOAT}, {"seed", 87654321}, {"seed2", 42}}}}); } diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc index b0330ec990..bfdb3a6658 100644 --- a/tensorflow/core/framework/model.cc +++ b/tensorflow/core/framework/model.cc @@ -296,12 +296,12 @@ void Model::AddProcessingTime(const string& name, int64 delta) { void Model::AddTunableParameter(const string& node_name, const string& parameter_name, - std::atomic<int64>* value, int64 min, int64 max, - condition_variable* cond_var) { + std::shared_ptr<SharedState> state, int64 min, + int64 max) { tf_shared_lock l(mu_); auto node = *gtl::FindOrNull(lookup_table_, node_name); DCHECK(node); - node->add_tunable_param(parameter_name, value, min, max, cond_var); + node->add_tunable_param(parameter_name, std::move(state), min, max); } // The optimization algorithm starts by setting all tunable parallelism @@ -311,54 +311,55 @@ void Model::AddTunableParameter(const string& node_name, // is less than or equal to the processing time needed to produce an element // divided by CPU budget. void Model::Optimize(int64 cpu_budget) { - tf_shared_lock lock(mu_); std::vector<std::shared_ptr<Model::Node::Tunable>> tunables; - const int64 processing_time = ProcessingTime(); - tunables = CollectTunables(); - for (auto tunable : tunables) { - tunable->value = 1; - } - while (true) { - const int64 output_time = OutputTime(); - bool all_tunables = true; - for (auto& tunable : tunables) { - if (tunable->value < tunable->max) { - all_tunables = false; + { + tf_shared_lock lock(mu_); + const int64 processing_time = ProcessingTime(); + tunables = CollectTunables(); + for (auto tunable : tunables) { + tunable->value = 1; + } + while (true) { + const int64 output_time = OutputTime(); + bool all_tunables = true; + for (auto& tunable : tunables) { + if (tunable->value < tunable->max) { + all_tunables = false; + break; + } + } + if (output_time < processing_time / cpu_budget || all_tunables) { break; } - } - if (output_time < processing_time / cpu_budget || all_tunables) { - break; - } - int64 best_delta = -1; - Model::Node::Tunable* best_tunable = nullptr; - for (auto& tunable : tunables) { - if (tunable->value == tunable->max) { - continue; + int64 best_delta = -1; + Model::Node::Tunable* best_tunable = nullptr; + for (auto& tunable : tunables) { + if (tunable->value == tunable->max) { + continue; + } + tunable->value++; + int64 delta = output_time - OutputTime(); + if (delta > best_delta) { + best_delta = delta; + best_tunable = tunable.get(); + } + tunable->value--; } - tunable->value++; - int64 delta = output_time - OutputTime(); - if (delta > best_delta) { - best_delta = delta; - best_tunable = tunable.get(); + if (!best_tunable) { + // NOTE: This can happen because we are performing the optimization + // while the model data is changing. If this becomes an issue, we should + // look into performing the optimization using a model snapshot. + break; } - tunable->value--; + best_tunable->value++; } - if (!best_tunable) { - // NOTE: This can happen because we are performing the optimization - // while the model data is changing. If this becomes an issue, we should - // look into performing the optimization using a model snapshot. - break; - } - best_tunable->value++; } VLOG(2) << "Number of knobs: " << tunables.size(); for (auto& tunable : tunables) { VLOG(2) << "Setting tunable parameter: " << tunable->value; - tunable->value_ptr->store(tunable->value); - if (tunable->cond_var) { - tunable->cond_var->notify_all(); - } + mutex_lock l(*tunable->state->mu); + tunable->state->value = tunable->value; + tunable->state->cond_var->notify_all(); } } diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h index 26402f5cd3..eae0fa70e8 100644 --- a/tensorflow/core/framework/model.h +++ b/tensorflow/core/framework/model.h @@ -33,6 +33,19 @@ namespace tensorflow { namespace data { namespace model { +// Represents thread-safe state that can be shared between an input pipeline and +// the performance model. +struct SharedState { + public: + explicit SharedState(int64 value, std::shared_ptr<mutex> mu, + std::shared_ptr<condition_variable> cond_var) + : value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {} + + std::shared_ptr<mutex> mu; + std::shared_ptr<condition_variable> cond_var; + int64 value; +}; + // Abstract representation of a TensorFlow input pipeline that can be used // for collecting runtime information and optimizing performance. It collects // runtime information about execution of the input pipeline that is used to @@ -62,8 +75,8 @@ class Model { // Adds a tunable parameter for the given node. void AddTunableParameter(const string& node_name, const string& parameter_name, - std::atomic<int64>* value, int64 min, int64 max, - condition_variable* cond_var) LOCKS_EXCLUDED(mu_); + std::shared_ptr<SharedState> value, int64 min, + int64 max) LOCKS_EXCLUDED(mu_); // Runs optimization. void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_); @@ -109,13 +122,8 @@ class Model { public: // Represents a tunable parameter. struct Tunable { - Tunable(std::atomic<int64>* value, int64 min, int64 max, - condition_variable* cond_var) - : value(*value), - min(min), - max(max), - value_ptr(value), - cond_var(cond_var) {} + Tunable(std::shared_ptr<SharedState> state, int64 min, int64 max) + : value(state->value), min(min), max(max), state(std::move(state)) {} // Identifies the model value of the parameter. This can be different from // the actual value (e.g. during optimization search). @@ -127,12 +135,8 @@ class Model { // Identifies the maximum value of the parameter. int64 max; - // Points to the actual value of the parameter. Not owned. - std::atomic<int64>* value_ptr; - - // If non-null, this condition variable is notified when the model updates - // the actual value of the parameter (via `value_ptr`). Not owned. - condition_variable* cond_var; + // Shared state of the parameter. + std::shared_ptr<SharedState> state; }; Node(int64 id, const string& name, std::shared_ptr<Node> output) @@ -158,12 +162,12 @@ class Model { } // Adds a tunable parameter. - void add_tunable_param(const string& name, std::atomic<int64>* value, - int64 min, int64 max, condition_variable* cond_var) - LOCKS_EXCLUDED(mu_) { + void add_tunable_param(const string& name, + std::shared_ptr<SharedState> state, int64 min, + int64 max) LOCKS_EXCLUDED(mu_) { mutex_lock l(mu_); tunable_params_[name] = - std::make_shared<Tunable>(value, min, max, cond_var); + std::make_shared<Tunable>(std::move(state), min, max); } // Returns the unique node ID. diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h index 187bfa2c88..0ff67554eb 100644 --- a/tensorflow/core/framework/node_def_util.h +++ b/tensorflow/core/framework/node_def_util.h @@ -17,7 +17,6 @@ limitations under the License. #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ #include <string> -#include <unordered_map> #include <vector> #include "tensorflow/core/framework/attr_value_util.h" diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h index 25f8de8dcc..81ed5f95f0 100644 --- a/tensorflow/core/framework/op.h +++ b/tensorflow/core/framework/op.h @@ -209,16 +209,16 @@ template <> class OpDefBuilderWrapper<true> { public: OpDefBuilderWrapper(const char name[]) : builder_(name) {} - OpDefBuilderWrapper<true>& Attr(StringPiece spec) { - builder_.Attr(spec); + OpDefBuilderWrapper<true>& Attr(string spec) { + builder_.Attr(std::move(spec)); return *this; } - OpDefBuilderWrapper<true>& Input(StringPiece spec) { - builder_.Input(spec); + OpDefBuilderWrapper<true>& Input(string spec) { + builder_.Input(std::move(spec)); return *this; } - OpDefBuilderWrapper<true>& Output(StringPiece spec) { - builder_.Output(spec); + OpDefBuilderWrapper<true>& Output(string spec) { + builder_.Output(std::move(spec)); return *this; } OpDefBuilderWrapper<true>& SetIsCommutative() { @@ -237,12 +237,12 @@ class OpDefBuilderWrapper<true> { builder_.SetAllowsUninitializedInput(); return *this; } - OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) { - builder_.Deprecated(version, explanation); + OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) { + builder_.Deprecated(version, std::move(explanation)); return *this; } - OpDefBuilderWrapper<true>& Doc(StringPiece text) { - builder_.Doc(text); + OpDefBuilderWrapper<true>& Doc(string text) { + builder_.Doc(std::move(text)); return *this; } OpDefBuilderWrapper<true>& SetShapeFn( diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc index 34a7a43d38..8a9bb63182 100644 --- a/tensorflow/core/framework/op_def_builder.cc +++ b/tensorflow/core/framework/op_def_builder.cc @@ -526,32 +526,32 @@ void FinalizeDoc(const string& text, OpDef* op_def, } // namespace -OpDefBuilder::OpDefBuilder(StringPiece op_name) { - op_def()->set_name(string(op_name)); // NOLINT +OpDefBuilder::OpDefBuilder(string op_name) { + op_def()->set_name(std::move(op_name)); } -OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) { - attrs_.emplace_back(spec.data(), spec.size()); +OpDefBuilder& OpDefBuilder::Attr(string spec) { + attrs_.push_back(std::move(spec)); return *this; } -OpDefBuilder& OpDefBuilder::Input(StringPiece spec) { - inputs_.emplace_back(spec.data(), spec.size()); +OpDefBuilder& OpDefBuilder::Input(string spec) { + inputs_.push_back(std::move(spec)); return *this; } -OpDefBuilder& OpDefBuilder::Output(StringPiece spec) { - outputs_.emplace_back(spec.data(), spec.size()); +OpDefBuilder& OpDefBuilder::Output(string spec) { + outputs_.push_back(std::move(spec)); return *this; } #ifndef TF_LEAN_BINARY -OpDefBuilder& OpDefBuilder::Doc(StringPiece text) { +OpDefBuilder& OpDefBuilder::Doc(string text) { if (!doc_.empty()) { errors_.push_back( strings::StrCat("Extra call to Doc() for Op ", op_def()->name())); } else { - doc_.assign(text.data(), text.size()); + doc_ = std::move(text); } return *this; } @@ -577,14 +577,14 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() { return *this; } -OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) { +OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) { if (op_def()->has_deprecation()) { errors_.push_back( strings::StrCat("Deprecated called twice for Op ", op_def()->name())); } else { OpDeprecation* deprecation = op_def()->mutable_deprecation(); deprecation->set_version(version); - deprecation->set_explanation(string(explanation)); + deprecation->set_explanation(std::move(explanation)); } return *this; } diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h index 0b39d6e848..8077b20598 100644 --- a/tensorflow/core/framework/op_def_builder.h +++ b/tensorflow/core/framework/op_def_builder.h @@ -51,7 +51,7 @@ struct OpRegistrationData { class OpDefBuilder { public: // Constructs an OpDef with just the name field set. - explicit OpDefBuilder(StringPiece op_name); + explicit OpDefBuilder(string op_name); // Adds an attr to this OpDefBuilder (and returns *this). The spec has // format "<name>:<type>" or "<name>:<type>=<default>" @@ -84,7 +84,7 @@ class OpDefBuilder { // * Ability to restrict the type of the tensor like the existing // restrictions for type attrs. // Perhaps by linking the type of the tensor to a type attr? - OpDefBuilder& Attr(StringPiece spec); + OpDefBuilder& Attr(string spec); // Adds an input or output to this OpDefBuilder (and returns *this). // The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)" @@ -101,8 +101,8 @@ class OpDefBuilder { // in the spec? // TODO(josh11b): SparseInput() and SparseOutput() matching the Python // handling? - OpDefBuilder& Input(StringPiece spec); - OpDefBuilder& Output(StringPiece spec); + OpDefBuilder& Input(string spec); + OpDefBuilder& Output(string spec); // Turns on the indicated boolean flag in this OpDefBuilder (and // returns *this). @@ -112,7 +112,7 @@ class OpDefBuilder { OpDefBuilder& SetAllowsUninitializedInput(); // Deprecate the op at a certain GraphDef version. - OpDefBuilder& Deprecated(int version, StringPiece explanation); + OpDefBuilder& Deprecated(int version, string explanation); // Adds docs to this OpDefBuilder (and returns *this). // Docs have the format: @@ -128,9 +128,9 @@ class OpDefBuilder { // to suppress the automatically-generated type documentation in // generated output. #ifndef TF_LEAN_BINARY - OpDefBuilder& Doc(StringPiece text); + OpDefBuilder& Doc(string text); #else - OpDefBuilder& Doc(StringPiece text) { return *this; } + OpDefBuilder& Doc(string text) { return *this; } #endif // Sets the shape function to be used for shape inference. diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc index ebdaaec153..508a8d3149 100644 --- a/tensorflow/core/framework/resource_mgr.cc +++ b/tensorflow/core/framework/resource_mgr.cc @@ -288,4 +288,13 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) { return ctx->resource_manager()->Delete(p); } +Status ResourceHandlesShape(shape_inference::InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + for (int i = 0; i < n; ++i) { + c->set_output(i, c->Scalar()); + } + return Status::OK(); +} + } // end namespace tensorflow diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index d58deaa3fc..4a531648d9 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ #define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ +#include <memory> #include <string> #include <typeindex> #include <typeinfo> @@ -127,6 +128,14 @@ class ResourceMgr { Status Lookup(const string& container, const string& name, T** resource) const TF_MUST_USE_RESULT; + // Similar to Lookup, but looks up multiple resources at once, with only a + // single lock acquisition. + template <typename T> + Status LookupMany(absl::Span<std::pair<const string*, const string*> const> + containers_and_names, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* + resource) const TF_MUST_USE_RESULT; + // If "container" has a resource "name", returns it in // "*resource". Otherwise, invokes creator() to create the resource. // The caller takes the ownership of one ref on "*resource". @@ -239,14 +248,31 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input, ResourceHandle* handle); // Create a resource pointed by a given resource handle. +// +// If successful, the caller transfers the ownership of one ref on `resource` to +// `ctx->resource_mgr()`. template <typename T> Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value); // Looks up a resource pointed by a given resource handle. +// +// If the lookup is successful, the caller takes the ownership of one ref on +// `*value`, and must call its `Unref()` method when it has finished using it. template <typename T> Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value); +// Looks up multiple resources pointed by a sequence of resource handles. +template <typename T> +Status LookupResources( + OpKernelContext* ctx, absl::Span<ResourceHandle const> p, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values); + // Looks up or creates a resource. +// +// If successful, the caller takes the ownership of one ref on `*value`, and +// must call its `Unref()` method when it has finished using it. If the +// `creator` is invoked, its reference on the created resource is transferred +// to `ctx->resource_mgr()`. template <typename T> Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, T** value, std::function<Status(T**)> creator); @@ -358,6 +384,26 @@ class ResourceHandleOp : public OpKernel { std::atomic<bool> initialized_{false}; }; +// Utility op kernel to produce a handle to a resource of type T. +template <typename T> +class ResourceHandlesOp : public OpKernel { + public: + explicit ResourceHandlesOp(OpKernelConstruction* context); + + void Compute(OpKernelContext* ctx) override; + + bool IsExpensive() override { return false; } + + private: + std::vector<string> containers_; + std::vector<string> names_; + mutex mutex_; + std::vector<Tensor> resources_; + std::atomic<bool> initialized_{false}; +}; + +Status ResourceHandlesShape(shape_inference::InferenceContext* c); + // Registers a kernel for an op which produces a handle to a resource of the // specified type. #define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \ @@ -390,6 +436,24 @@ Status ResourceMgr::Lookup(const string& container, const string& name, } template <typename T> +Status ResourceMgr::LookupMany( + absl::Span<std::pair<const string*, const string*> const> + containers_and_names, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const { + CheckDeriveFromResourceBase<T>(); + tf_shared_lock l(mu_); + resources->resize(containers_and_names.size()); + for (size_t i = 0; i < containers_and_names.size(); ++i) { + T* resource; + TF_RETURN_IF_ERROR(LookupInternal(*containers_and_names[i].first, + *containers_and_names[i].second, + &resource)); + (*resources)[i].reset(resource); + } + return Status::OK(); +} + +template <typename T> Status ResourceMgr::LookupInternal(const string& container, const string& name, T** resource) const { ResourceBase* found = nullptr; @@ -499,6 +563,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, } template <typename T> +Status LookupResources( + OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p, + std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) { + std::vector<std::pair<const string*, const string*>> containers_and_names( + p.size()); + for (size_t i = 0; i < p.size(); ++i) { + TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i])); + containers_and_names[i] = {&p[i]->container(), &p[i]->name()}; + } + return ctx->resource_manager()->LookupMany(containers_and_names, values); +} + +template <typename T> Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p, T** value, std::function<Status(T**)> creator) { TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p)); @@ -555,6 +632,46 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) { ctx->set_output(0, resource_); } +template <typename T> +ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context) + : OpKernel(context) { + int n; + OP_REQUIRES_OK(context, context->GetAttr("N", &n)); + OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_)); + OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_)); + OP_REQUIRES( + context, containers_.size() == n, + errors::InvalidArgument("Number of containers (", containers_.size(), + ") must be equal to N (", n, ")")); + OP_REQUIRES(context, names_.size() == n, + errors::InvalidArgument("Number of names (", containers_.size(), + ") must be equal to N (", n, ")")); + resources_.resize(n); +} + +template <typename T> +void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) { + if (!initialized_.load()) { + mutex_lock ml(mutex_); + // Checking again to see if another thread has initialized the resource. + if (!initialized_.load()) { + AllocatorAttributes attr; + attr.set_on_host(true); + for (size_t i = 0; i < resources_.size(); ++i) { + OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), + &resources_[i], attr)); + ResourceHandle h = + MakeResourceHandle<T>(ctx, containers_[i], names_[i]); + resources_[i].template scalar<ResourceHandle>()() = h; + } + initialized_.store(true); + } + } + for (size_t i = 0; i < resources_.size(); ++i) { + ctx->set_output(i, resources_[i]); + } +} + } // end namespace tensorflow #endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_ diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc new file mode 100644 index 0000000000..0c4007eafc --- /dev/null +++ b/tensorflow/core/framework/run_handler.cc @@ -0,0 +1,249 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#define EIGEN_USE_THREADS + +#include "tensorflow/core/framework/run_handler.h" + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/framework/run_handler_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/util/ptr_util.h" + +namespace tensorflow { + +// Contains the concrete implementation of the RunHandler. +// Externally visible RunHandler class simply forwards the work to this one. +class RunHandler::Impl { + public: + explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) { + Reset(); + } + + ~Impl() {} + + void set_inter_op_scheduling_range(std::uint_fast32_t start, + std::uint_fast32_t limit) { + inter_op_scheduling_range_.store(EncodePartition(start, limit), + std::memory_order_release); + } + + std::uint_fast32_t inter_op_scheduling_range() const { + return inter_op_scheduling_range_.load(std::memory_order_acquire); + } + + // Stores now time (in microseconds) since unix epoch when the handler is + // requested via RunHandlerPool::Get(). + uint64 start_time_us() const { return start_time_us_; } + + void ScheduleInterOpClosure(std::function<void()> fn); + + void Reset(); + + RunHandlerPool::Impl* pool_impl() { return pool_impl_; } + + private: + // Encoding/decoding logic for storing [start, limit) into a single + // uint_fast32_t int. We assume that pool_num_threads < (1 << 16). + const int kMaxPartitionBits = 16; + const int kMaxThreads = 1 << kMaxPartitionBits; + + std::uint_fast32_t EncodePartition(std::uint_fast32_t start, + std::uint_fast32_t limit) { + return (start << kMaxPartitionBits) | limit; + } + + void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start, + std::uint_fast32_t* limit) { + *limit = val & (kMaxThreads - 1); + val >>= kMaxPartitionBits; + *start = val; + } + + std::atomic_uint_fast32_t inter_op_scheduling_range_; + RunHandlerPool::Impl* pool_impl_; // NOT OWNED. + uint64 start_time_us_; +}; + +// Contains shared state across all run handlers present in the pool. Also +// responsible for pool management decisions. +// This class is thread safe. +class RunHandlerPool::Impl { + public: + explicit Impl(int num_inter_op_threads) + : max_handlers_(128), + inter_op_thread_pool_(new thread::ThreadPool( + Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)), + iterations_(0) { + VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_; + for (int i = 0; i < max_handlers_; ++i) { + handlers_.emplace_back(new RunHandler::Impl(this)); + free_handlers_.push_back(handlers_.back().get()); + } + } + + ~Impl() { + // Sanity check that all handlers have been returned back to the pool before + // destruction. + DCHECK_EQ(handlers_.size(), max_handlers_); + DCHECK_EQ(free_handlers_.size(), handlers_.size()); + DCHECK_EQ(sorted_active_handlers_.size(), 0); + } + + thread::ThreadPool* inter_op_thread_pool() const { + return inter_op_thread_pool_.get(); + } + + std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + while (free_handlers_.empty()) { + one_handler_free_.wait(l); + } + // Remove the last entry from free_handlers_ and add to the end of + // sorted_active_handlers_. + auto* handler_impl = free_handlers_.back(); + handler_impl->Reset(); + // Sortedness isn't violated if we simply add at the end of the list, since + // handlers are expected to be obtained in increasing order of time. + sorted_active_handlers_.push_back(handler_impl); + DCHECK_LE(sorted_active_handlers_.size(), max_handlers_); + free_handlers_.pop_back(); + + RecomputePoolStatsLocked(); + return WrapUnique<RunHandler>(new RunHandler(handler_impl)); + } + + void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + DCHECK_GT(sorted_active_handlers_.size(), 0); + + uint64 now = tensorflow::Env::Default()->NowMicros(); + double elapsed = (now - handler->start_time_us()) / 1000.0; + time_hist_.Add(elapsed); + + // Erase from and update sorted_active_handlers_. Add it to the end of + // free_handlers_. + auto iter = std::find(sorted_active_handlers_.begin(), + sorted_active_handlers_.end(), handler); + DCHECK(iter != sorted_active_handlers_.end()) + << "Unexpected handler: " << handler + << " is being requested for release"; + + // Remove this handler from this list and add it to the list of free + // handlers. + sorted_active_handlers_.erase(iter); + free_handlers_.push_back(handler); + DCHECK_LE(free_handlers_.size(), max_handlers_); + + RecomputePoolStatsLocked(); + } + one_handler_free_.notify_one(); + } + + private: + void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_); + + // Maximum number of handlers pre-created during pool construction time. The + // number has been chosen expecting each handler might at least want 1 + // inter-op thread for execution (during compute intensive workloads like + // inference). + const int max_handlers_; + + // Thread safe part. + const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_; + + // Thread compatible part used only by lock under RunHandlerPool. + // Handlers are sorted by start time. + std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_); + std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_); + std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_); + // Histogram of elapsed runtime of every handler (in ms). + histogram::Histogram time_hist_ GUARDED_BY(mu_); + std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_); + std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_); + int64 iterations_ GUARDED_BY(mu_); + condition_variable one_handler_free_; + mutex mu_; +}; + +void RunHandlerPool::Impl::RecomputePoolStatsLocked() { + int num_active_requests = sorted_active_handlers_.size(); + if (num_active_requests == 0) return; + + int num_threads = inter_op_thread_pool_->NumThreads(); + + inter_op_start_.resize(num_active_requests); + inter_op_limit_.resize(num_active_requests); + + const int kMinThreadsPerRequest = 3; + ComputeInterOpSchedulingRanges(num_active_requests, num_threads, + kMinThreadsPerRequest, &inter_op_start_, + &inter_op_limit_); + + for (int i = 0; i < num_active_requests; ++i) { + sorted_active_handlers_[i]->set_inter_op_scheduling_range( + inter_op_start_[i], inter_op_limit_[i]); + } + + if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) { + VLOG(1) << "Printing time histogram: " << time_hist_.ToString(); + VLOG(1) << "Active session runs: " << num_active_requests; + uint64 now = tensorflow::Env::Default()->NowMicros(); + string ranges_str = ""; + string times_str = ""; + for (int i = 0; i < num_active_requests; ++i) { + if (i > 0) { + times_str += " "; + ranges_str += " "; + } + + times_str += strings::StrCat( + (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms."); + ranges_str += strings::StrCat("[", inter_op_start_[i], ", ", + inter_op_limit_[i], ")"); + } + VLOG(1) << "Elapsed times are: " << times_str; + VLOG(1) << "Ranges are: " << ranges_str; + } +} + +void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) { + std::uint_fast32_t start = 0, limit = 0; + DecodePartition(inter_op_scheduling_range(), &start, &limit); + pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn)); +} + +void RunHandler::Impl::Reset() { + set_inter_op_scheduling_range( + 0, pool_impl_->inter_op_thread_pool()->NumThreads()); + start_time_us_ = tensorflow::Env::Default()->NowMicros(); +} + +RunHandlerPool::RunHandlerPool(int num_inter_op_threads) + : impl_(new Impl(num_inter_op_threads)) {} + +RunHandlerPool::~RunHandlerPool() {} + +std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); } + +RunHandler::RunHandler(Impl* impl) : impl_(impl) {} + +void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) { + impl_->ScheduleInterOpClosure(std::move(fn)); +} + +RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); } +} // namespace tensorflow diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h new file mode 100644 index 0000000000..72fa6301b4 --- /dev/null +++ b/tensorflow/core/framework/run_handler.h @@ -0,0 +1,95 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/histogram/histogram.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/protobuf/config.pb.h" + +namespace tensorflow { + +class RunHandler; + +// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers +// that can be used for tracking inter-op work for a given Session::Run(). +// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes +// 'active' when its unique_ptr is returned by Get() and is being used by a +// client. It becomes 'inactive' once more when its unique_ptr gets destroyed. +// +// Expected usage: +// +// * Create a single RunHandlerPool (say run_handler_pool_). +// +// * When a Session::Run() is invoked, obtain a handler by: +// auto handler = run_handler_pool_->Get(); +// +// * Use handler for scheduling all inter-op work by: +// handler->ScheduleInterOpClosure(closure); +// +// This class is thread safe. +class RunHandlerPool { + public: + explicit RunHandlerPool(int num_inter_op_threads); + ~RunHandlerPool(); + + // Returns an inactive RunHandler from the pool. + // + // RunHandlers in RunHandlerPool are initially 'inactive'. + // A RunHandler becomes 'active' when its unique_ptr its returned by Get() + // and is being used by a client. It becomes 'inactive' once more when the + // unique_ptr is destroyed. + // + // Will block unless there is an inactive handler. + std::unique_ptr<RunHandler> Get(); + + private: + class Impl; + friend class RunHandler; + + std::unique_ptr<Impl> impl_; +}; + +// RunHandler can be used to schedule inter-op closures to run on a global pool +// shared across all Session::Run(s). +// +// It can only be created via RunHandlerPool::Get(). +// +// This class can be used instead of directly scheduling closures on a global +// pool since it maintains a global view across all sessions and optimizes pool +// scheduling to improve (median and tail) latency. +// +// This class is thread safe. +class RunHandler { + public: + void ScheduleInterOpClosure(std::function<void()> fn); + + ~RunHandler(); + + private: + class Impl; + friend class RunHandlerPool::Impl; + + explicit RunHandler(Impl* impl); + + Impl* impl_; // NOT OWNED. +}; + +} // end namespace tensorflow. + +#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_ diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc new file mode 100644 index 0000000000..3087998c69 --- /dev/null +++ b/tensorflow/core/framework/run_handler_util.cc @@ -0,0 +1,57 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/run_handler_util.h" + +#include <algorithm> +#include <cmath> +#include "tensorflow/core/platform/logging.h" + +namespace tensorflow { + +void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, + int min_threads_per_request, + std::vector<std::uint_fast32_t>* start_vec, + std::vector<std::uint_fast32_t>* end_vec) { + // Each request is expected to have weight W[i] = num_active_requests - i. + // Therefore, total_weight = sum of all request weights. + float total_weight = 0.5f * num_active_requests * (num_active_requests + 1); + float demand_factor = static_cast<float>(num_threads) / total_weight; + float last_cumulative_weight = 0.0; + min_threads_per_request = std::max(1, min_threads_per_request); + for (int i = 0; i != num_active_requests; i++) { + float cumulative_weight = + static_cast<float>(i + 1) * + (num_active_requests - static_cast<float>(i) * 0.5f); + float weight = cumulative_weight - last_cumulative_weight; + // Quantize thread_demand by rounding up, and also satisfying + // `min_threads_per_request` constraint. + // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from + // rounding weights like 4.0 to 5. + int demand = + std::max(min_threads_per_request, + static_cast<int>(ceil(weight * demand_factor - 0.00001f))); + // For the quantized range [start, end); compute the floor of real start, + // and expand downwards from there with length `demand` and adjust for + // boundary conditions. + int start = last_cumulative_weight * demand_factor; + int end = std::min(num_threads, start + demand); + start = std::max(0, std::min(start, end - demand)); + start_vec->at(i) = start; + end_vec->at(i) = end; + last_cumulative_weight = cumulative_weight; + } +} +} // namespace tensorflow diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h new file mode 100644 index 0000000000..c0c36aeccb --- /dev/null +++ b/tensorflow/core/framework/run_handler_util.h @@ -0,0 +1,43 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ +#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ + +#include <cstdint> +#include <vector> + +namespace tensorflow { + +// Assign thread ranges to requests. +// Requests are numbered 0...num_active_requests-1, and +// threads are numbered 0...num_threads-1. +// On return, the range start_vec->at(i)...end_vec->at(i)-1 +// indicates the subrange of the threads available to request i. +// The ranges given to different requests may overlap. +// Lower numbered requests will tend to be assigned more threads. +// Thus, a client might associate older requests with lower +// array indices so they receive access to more threads. +// However, the routine ensures that each request is given access +// to at least min(min_threads_per_request, num_threads) threads. +// Every thread will be assigned to at least one request range, +// assuming there is at least one request. +void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads, + int min_threads_per_request, + std::vector<std::uint_fast32_t>* start_vec, + std::vector<std::uint_fast32_t>* end_vec); + +} // end namespace tensorflow +#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_ diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc new file mode 100644 index 0000000000..a1928c132b --- /dev/null +++ b/tensorflow/core/framework/run_handler_util_test.cc @@ -0,0 +1,93 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/run_handler_util.h" + +#include <vector> +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" +namespace tensorflow { +namespace { + +void VerifyFunction(int num_active_requests, int num_threads, + int min_threads_per_request, bool print_stats = false) { + if (print_stats) { + LOG(INFO) << "Test case# num_active_requests: " << num_active_requests + << " num_threads: " << num_threads + << " min_threads: " << min_threads_per_request; + } + std::vector<std::uint_fast32_t> start(num_active_requests); + std::vector<std::uint_fast32_t> end(num_active_requests); + + ComputeInterOpSchedulingRanges(num_active_requests, num_threads, + min_threads_per_request, &start, &end); + string range_str = ""; + for (int i = 0; i < num_active_requests; ++i) { + if (i > 0) range_str += " "; + range_str += strings::StrCat("[", start[i], ", ", end[i], ")"); + + ASSERT_GE(start[i], 0) << range_str; + ASSERT_LE(end[i], num_threads) << range_str; + if (i > 0) { + // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i) + ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str; + // No missing threads. + ASSERT_GE(end[i - 1], start[i]) << range_str; + } + // Each interval is at least of size 'min_threads_per_request'. + ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str; + // Verify that assigned (quantized) threads is not overly estimated + // from real demand, when the demand is high (>= + // min_threads_per_request). + float entry_weight = num_active_requests - i; + float total_weight = 0.5f * num_active_requests * (num_active_requests + 1); + float thread_demand = (entry_weight * num_threads) / total_weight; + if (thread_demand > min_threads_per_request) { + // We expect some over-estimation of threads due to quantization, + // but we hope it's not more than 1 extra thread. + ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0) + << "Ranges: " << range_str << " thread_demand: " << thread_demand + << " i: " << i; + } + } + ASSERT_EQ(end[num_active_requests - 1], num_threads); + ASSERT_EQ(start[0], 0); + if (print_stats) { + LOG(INFO) << "Assigned ranges: " << range_str; + } +} + +TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) { + const int kMinThreadsPerRequestBound = 12; + const int kMaxActiveRequests = 128; + const int kMaxThreads = 128; + + for (int min_threads_per_request = 1; + min_threads_per_request <= kMinThreadsPerRequestBound; + ++min_threads_per_request) { + for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests; + ++num_active_requests) { + for (int num_threads = min_threads_per_request; + num_threads <= kMaxThreads; ++num_threads) { + VerifyFunction(num_active_requests, num_threads, + min_threads_per_request); + } + } + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc index 3df677675e..1dea6da911 100644 --- a/tensorflow/core/framework/tensor.cc +++ b/tensorflow/core/framework/tensor.cc @@ -813,7 +813,7 @@ Tensor Tensor::Slice(int64 start, int64 limit) const { } Tensor Tensor::SubSlice(int64 index) const { - CHECK_GE(dims(), 2); // Crash ok. + CHECK_GE(dims(), 1); // Crash ok. CHECK_LE(0, index); // Crash ok. int64 dim0_size = shape_.dim_size(0); CHECK_LE(index, dim0_size); // Crash ok. diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 8a0c70fef2..d0f9eb56e2 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -219,7 +219,7 @@ class Tensor { /// must check the returned tensor's alignment before calling certain /// methods that have alignment requirement (e.g., `flat()`, `tensor()`). /// - /// REQUIRES: `dims()` >= 2 + /// REQUIRES: `dims()` >= 1 /// REQUIRES: `0 <= dim0_start < dim_size(0)` Tensor SubSlice(int64 index) const; diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc index 0bfa53e6c5..c596604143 100644 --- a/tensorflow/core/framework/tensor_test.cc +++ b/tensorflow/core/framework/tensor_test.cc @@ -1246,6 +1246,9 @@ TEST(Tensor, SubSlice_Basic) { EXPECT_EQ(&tx(5, j, k), &ty(j, k)); } } + Tensor z = y.SubSlice(3).SubSlice(31); + auto tz = z.unaligned_flat<float>(); + EXPECT_EQ(*tz.data(), 5.0); } { // Test unaligned access via a SubSlice. diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc index 1630ab7a15..7a4a0096fa 100644 --- a/tensorflow/core/graph/graph.cc +++ b/tensorflow/core/graph/graph.cc @@ -192,6 +192,11 @@ void Node::ClearAttr(const string& name) { (*props_->node_def.mutable_attr()).erase(name); } +void Node::set_name(string name) { + MaybeCopyOnWrite(); + props_->node_def.set_name(std::move(name)); +} + void Node::set_requested_device(const string& device) { MaybeCopyOnWrite(); props_->node_def.set_device(device); @@ -643,7 +648,7 @@ Status Graph::IsValidNode(const Node* node) const { Status Graph::IsValidOutputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); - if (idx >= node->num_outputs()) { + if (idx >= node->num_outputs() || idx < 0) { return errors::OutOfRange("Node '", node->name(), "' (type: '", node->op_def().name(), "', num of outputs: ", node->num_outputs(), @@ -654,7 +659,7 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const { Status Graph::IsValidInputTensor(const Node* node, int idx) const { TF_RETURN_IF_ERROR(IsValidNode(node)); - if (idx >= node->num_inputs()) { + if (idx >= node->num_inputs() || idx < 0) { return errors::OutOfRange("Node '", node->name(), "' (type: '", node->op_def().name(), "', num of inputs: ", node->num_inputs(), diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h index 52e9f23a76..2944951f82 100644 --- a/tensorflow/core/graph/graph.h +++ b/tensorflow/core/graph/graph.h @@ -72,6 +72,7 @@ class Node { int id() const { return id_; } int cost_id() const { return cost_id_; } const string& name() const; + void set_name(string name); const string& type_string() const; // def() provides the NodeDef the user supplied, but the specifics @@ -590,12 +591,12 @@ class Graph { // Returns OK if `node` is non-null and belongs to this graph Status IsValidNode(const Node* node) const; - // Returns OK if IsValidNode(`node`) and `idx` is less than - // node->num_outputs() + // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not + // accept control outputs. Status IsValidOutputTensor(const Node* node, int idx) const; - // Returns OK if IsValidNode(`node`) and `idx` is less than - // node->num_inputs() + // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept + // control inputs. Status IsValidInputTensor(const Node* node, int idx) const; // Create and return a new WhileContext owned by this graph. This is called diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc index f5b0105862..7394b1cddf 100644 --- a/tensorflow/core/graph/mkl_layout_pass.cc +++ b/tensorflow/core/graph/mkl_layout_pass.cc @@ -38,6 +38,7 @@ limitations under the License. #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/util/tensor_format.h" +#include "tensorflow/core/util/util.h" #include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_layout_pass.h" @@ -977,7 +978,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_; // nodes. Do not change the ordering of the Mkl passes. const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); +#endif // ENABLE_MKL ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node @@ -2448,6 +2451,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { csinfo_.tanh = "Tanh"; csinfo_.tanh_grad = "TanhGrad"; csinfo_.reshape = "Reshape"; + csinfo_.slice = "Slice"; csinfo_.softmax = "Softmax"; csinfo_.split = "Split"; // Element-wise ops. Ensure you also add any new ops to IsOpElementWise @@ -2555,6 +2559,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass { rinfo_.push_back({csinfo_.reshape, mkl_op_registry::GetMklOpName(csinfo_.reshape), CopyAttrsReshape, AlwaysRewrite}); + rinfo_.push_back({csinfo_.slice, + mkl_op_registry::GetMklOpName(csinfo_.slice), + CopyAttrsSlice, AlwaysRewrite}); rinfo_.push_back({csinfo_.softmax, mkl_op_registry::GetMklOpName(csinfo_.softmax), CopyAttrsDataType, AlwaysRewrite}); @@ -2674,6 +2681,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { string tanh; string tanh_grad; string reshape; + string slice; string softmax; string split; string squared_difference; @@ -3132,6 +3140,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass { static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb); + static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb); static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb); // Generate a graph node in graph 'g' representing a dummy Mkl tensor node, @@ -3150,7 +3159,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_; // nodes. Do not change the ordering of the Mkl passes. const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass); +#endif // ENABLE_MKL ////////////////////////////////////////////////////////////////////////// // Helper functions for creating new node @@ -3735,6 +3746,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node, nb->Attr("Tshape", Tshape); } +void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node, + NodeBuilder* nb) { + DataType T; + DataType Index; + + // Get all attributes from old node. + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T)); + TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index)); + // Add attributes to new node. + nb->Attr("T", T); + nb->Attr("Index", Index); +} + void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb) { DataType T; @@ -4488,6 +4512,10 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } + if (DisableMKL()) { + VLOG(2) << "TF-MKL: Disabling MKL"; + return Status::OK(); + } auto process_graph = [&](std::unique_ptr<Graph>* g) { // Get the ownership of a graph diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc index e8bac847e5..77640e287c 100644 --- a/tensorflow/core/graph/mkl_layout_pass_test.cc +++ b/tensorflow/core/graph/mkl_layout_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" @@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) { "B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1"); } +TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) { + InitGraph( + "node { name: 'A' op: 'Input'}" + "node { name: 'B' op: 'Int32Input'}" + "node { name: 'C' op: 'Int32Input'}" + "node { name: 'D' op: 'Slice'" + " attr { key: 'T' value { type: DT_FLOAT } }" + " attr { key: 'Index' value { type: DT_INT32 } }" + " input: ['A', 'B', 'C'] }" + "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }" + " input: ['A', 'D'] }"); + EXPECT_EQ(DoMklLayoutOptimizationPass(), + "A(Input);B(Int32Input);C(Int32Input);" + "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/" + "_2(Const);E(Zeta)|A->D;A->E;" + "A:control->DMT/_0:control;A:control->DMT/" + "_1:control;A:control->DMT/_2:control;" + "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5"); +} + ///////////////////////////////////////////////////////////////////// // Post-rewrite fixup pass test @@ -3586,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000); } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc index b67a321fc1..6804ab84ce 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/util.h" #include "tensorflow/core/graph/mkl_graph_util.h" #include "tensorflow/core/graph/mkl_tfconversion_pass.h" @@ -133,7 +134,9 @@ class MklToTfConversionPass : public GraphOptimizationPass { // complete picture of inputs and outputs of the nodes in the graphs. const OptimizationPassRegistry::Grouping kMklTfConvPassGroup = OptimizationPassRegistry::POST_PARTITIONING; +#ifdef ENABLE_MKL REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass); +#endif // ENABLE_MKL Status MklToTfConversionPass::InsertConversionNodeOnEdge( std::unique_ptr<Graph>* g, Edge* e) { @@ -422,6 +425,10 @@ Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) { if (options.graph == nullptr && options.partition_graphs == nullptr) { return Status::OK(); } + if (DisableMKL()) { + VLOG(2) << "TF-MKL: Disabling MKL"; + return Status::OK(); + } auto process_graph = [&](std::unique_ptr<Graph>* g) { // Get the ownership of graph diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc index ebcb6de551..319437a801 100644 --- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc +++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_graph_util.h" @@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000); } // namespace } // namespace tensorflow -#endif /* INTEL_MKL */ +#endif // INTEL_MKL && ENABLE_MKL diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc index a446e0d136..d92874909f 100644 --- a/tensorflow/core/graph/node_builder.cc +++ b/tensorflow/core/graph/node_builder.cc @@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) { return *this; } +NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) { + assigned_device_ = string(device); + return *this; +} + Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { // In case of error, set *created_node to nullptr. if (created_node != nullptr) *created_node = nullptr; @@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const { Node* node = graph->AddNode(node_def, &status); if (!status.ok()) return status; + node->set_assigned_device_name(assigned_device_); + for (size_t i = 0; i < inputs_.size(); ++i) { if (inputs_[i].node != nullptr) { // Skip back edges. graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i); diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h index 4727ee7b56..d576985a23 100644 --- a/tensorflow/core/graph/node_builder.h +++ b/tensorflow/core/graph/node_builder.h @@ -100,6 +100,9 @@ class NodeBuilder { // "assigned device" in the Node). NodeBuilder& Device(StringPiece device_spec); + // Sets the device name in the "assigned device" field in tensorflow::Node. + NodeBuilder& AssignedDevice(StringPiece device); + // Set the value of an attr. attr_name must match the name of one of // attrs defined by the Op, and value must have the corresponding type // (see SetAttrValue() in ../framework/attr_value_util.h for legal @@ -141,6 +144,7 @@ class NodeBuilder { std::vector<NodeOut> inputs_; std::vector<Node*> control_inputs_; std::vector<string> errors_; + string assigned_device_; }; // IMPLEMENTATION ------------------------------------------------------------- diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index f716cd72c9..28fd7565cc 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -74,6 +74,10 @@ class GraphProperties { // shape information. void ClearInputProperties(const string& node_name); void ClearOutputProperties(const string& node_name); + // Returns true if we have *any* properties. + bool has_properties() const { + return input_properties_.size() > 0 || output_properties_.size() > 0; + } private: // Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 362092a6cf..db10f586bc 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {}); Output g = ops::Shape(s.WithOpName("g"), c); Output h = ops::Fill(s.WithOpName("h"), g, zero); + Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1}); + Output j = ops::Sum(s.WithOpName("j"), a, zero_idx); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) { ASSERT_EQ(2, shape_f.dim_size()); EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size()); EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size()); + + const auto shape_j = properties.GetOutputProperties("j").at(0).shape(); + ASSERT_EQ(1, shape_j.dim_size()); + EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size()); } TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) { diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt index c94ee2f227..0ec95dd684 100644 --- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt +++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt @@ -88,6 +88,13 @@ library { } } } + attr { + key: "output_shapes" + value { + list { + } + } + } } ret { key: "while" diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc index 2619a9a8f3..de0a63fc4e 100644 --- a/tensorflow/core/grappler/graph_view.cc +++ b/tensorflow/core/grappler/graph_view.cc @@ -20,23 +20,25 @@ limitations under the License. namespace tensorflow { namespace grappler { -int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { - for (int output_arg_id = 0; output_arg_id < op.output_arg_size(); - ++output_arg_id) { +namespace { +int OpPortIdToArgId(const NodeDef& node, + const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, + int port_id) { + for (int arg_id = 0; arg_id < args.size(); ++arg_id) { if (port_id < 0) { return -1; } else if (port_id == 0) { - return output_arg_id; + return arg_id; } - // Default is 1 port per output arg. + // Default is 1 port per arg. int n = 1; - const auto& output_arg = op.output_arg(output_arg_id); - if (!output_arg.number_attr().empty()) { - n = node.attr().at(output_arg.number_attr()).i(); - } else if (!output_arg.type_list_attr().empty()) { - n = node.attr().at(output_arg.type_list_attr()).list().type_size(); + const auto& arg = args.Get(arg_id); + if (!arg.number_attr().empty()) { + n = node.attr().at(arg.number_attr()).i(); + } else if (!arg.type_list_attr().empty()) { + n = node.attr().at(arg.type_list_attr()).list().type_size(); } if (n < 0) { @@ -44,13 +46,22 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { DCHECK_GE(n, 0); return -1; } else if (port_id < n) { - return output_arg_id; + return arg_id; } port_id -= n; } return -1; } +} // end namespace + +int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + return OpPortIdToArgId(node, op.output_arg(), port_id); +} + +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) { + return OpPortIdToArgId(node, op.input_arg(), port_id); +} GraphView::GraphView(GraphDef* graph) : graph_(graph) { for (int i = 0; i < graph_->node_size(); i++) { @@ -72,7 +83,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) { void GraphView::AddFanouts(NodeDef* node) { for (int i = 0; i < node->input_size(); ++i) { OutputPort fanin; - string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); + const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id); fanin.node = nodes_[fanin_name]; InputPort input; diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h index ec946ca3b5..09c36a1368 100644 --- a/tensorflow/core/grappler/graph_view.h +++ b/tensorflow/core/grappler/graph_view.h @@ -26,7 +26,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Map a node/op's output port_id to arg_id. +// Map a node/op's input/output port_id to arg_id. // // The port_id refers to the n-th tensor of the node, while the arg_id refers to // the n-th arg of the op. These two can be different if an op's arg is a list @@ -34,6 +34,7 @@ namespace grappler { // // We return -1 for any invalid port_id (i.e., no corresponding arg_id). int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); +int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id); // A utility class to simplify the traversal of a GraphDef. class GraphView { diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc index 3d7d2faf7c..f90e2c8cfc 100644 --- a/tensorflow/core/grappler/graph_view_test.cc +++ b/tensorflow/core/grappler/graph_view_test.cc @@ -26,7 +26,7 @@ namespace { class GraphViewTest : public ::testing::Test {}; -TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { +TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10}); ops::ShapeN b(s.WithOpName("b"), {a, a, a}); @@ -45,9 +45,16 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { EXPECT_TRUE( OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); - EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0)); - EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1)); + // Const has 0 inputs, 1 output. + EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0)); + EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0)); + EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1)); + // ShapeN has N=3 inputs and outputs. + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1)); EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2)); @@ -55,7 +62,7 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) { EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4)); } -TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { +TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) { for (int num_splits : {1, 2}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10}); @@ -70,6 +77,13 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) { EXPECT_TRUE( OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok()); + // We have 4 inputs. + EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0)); + EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1)); + EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2)); + EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3)); + EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4)); + for (int port_id = 0; port_id <= num_splits * 3; ++port_id) { int arg_id = -1; if (port_id < num_splits * 3) { diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc index bbc0fedd22..2c490f3966 100644 --- a/tensorflow/core/grappler/grappler_item.cc +++ b/tensorflow/core/grappler/grappler_item.cc @@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) { restore_op = other.restore_op; save_restore_loc_tensor = other.save_restore_loc_tensor; queue_runners = other.queue_runners; + allowed_optimizations = other.allowed_optimizations; graph.Swap(graph_def); } diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h index 939e5fa046..a0748abfe6 100644 --- a/tensorflow/core/grappler/grappler_item.h +++ b/tensorflow/core/grappler/grappler_item.h @@ -77,6 +77,15 @@ struct GrapplerItem { // Return a set of node names that must be preserved. This includes feed and // fetch nodes, keep_ops, init_ops. std::unordered_set<string> NodesToPreserve() const; + + // Restrict types of optimizations that are allowed for this GrapplerItem. + struct AllowedOptimizations { + // Is it allowed to add nodes to the graph that do not have registered + // gradient function. + bool non_differentiable_rewrites = true; + }; + + AllowedOptimizations allowed_optimizations; }; // Return the transitive fanin of a set of terminal nodes. diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc index 029515ad3c..369046666d 100644 --- a/tensorflow/core/grappler/grappler_item_builder.cc +++ b/tensorflow/core/grappler/grappler_item_builder.cc @@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef( const string feed_name = NodeName(feed_node); new_item->feed.emplace_back(feed_name, Tensor()); } + for (const auto& fetch_node : cfg.fetch_nodes) { + new_item->fetch.emplace_back(NodeName(fetch_node)); + } - // Attempt to detect the fetch node(s). - if (meta_graph.collection_def().count("train_op") > 0) { + // Attempt to detect the fetch node(s) if they were not set explicitly. + if (new_item->fetch.empty() && + meta_graph.collection_def().count("train_op") > 0) { const CollectionDef& nodes = meta_graph.collection_def().at("train_op"); if (nodes.has_node_list()) { for (const auto& node : nodes.node_list().value()) { diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h index aafd2fdcda..1698587f8c 100644 --- a/tensorflow/core/grappler/grappler_item_builder.h +++ b/tensorflow/core/grappler/grappler_item_builder.h @@ -49,6 +49,8 @@ struct ItemConfig { bool prune_graph = false; // Override feed nodes list. std::set<string> feed_nodes; + // Override fetch nodes list. + std::set<string> fetch_nodes; }; // Factory method for creating a GrapplerItem from a MetaGraphDef. diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc index 4b90bf3038..d00981f174 100644 --- a/tensorflow/core/grappler/grappler_item_builder_test.cc +++ b/tensorflow/core/grappler/grappler_item_builder_test.cc @@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) { EXPECT_EQ(item2->feed[0].second.NumElements(), 1); } +TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + auto x = ops::Const(s.WithOpName("x"), 0); + auto y = ops::Const(s.WithOpName("y"), 1); + auto z = ops::Add(s.WithOpName("z"), x, y); + + MetaGraphDef meta_graph; + TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def())); + + ItemConfig config; + config.feed_nodes.insert("x"); + config.fetch_nodes.insert("z"); + + std::unique_ptr<GrapplerItem> item = + GrapplerItemFromMetaGraphDef("0", meta_graph, config); + ASSERT_TRUE(item != nullptr); + + EXPECT_EQ(item->feed.size(), 1); + EXPECT_EQ(item->fetch.size(), 1); + EXPECT_EQ(item->feed[0].first, "x"); + EXPECT_EQ(item->fetch[0], "z"); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 3521669b63..cbf5c8e038 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -13,14 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include <unordered_set> - +#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/types.h" -#include "tensorflow/core/grappler/op_types.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/logging.h" @@ -102,6 +101,22 @@ bool IsConjugateTranspose(const NodeDef& node) { return node.op() == "ConjugateTranspose"; } +bool IsControlFlow(const NodeDef& node) { + // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative + // string comparison. + static const gtl::FlatSet<string>* const kControFlowOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ + "ControlTrigger", + "Enter", + "Exit", + "LoopCond", + "Merge", + "NextIteration", + "Switch", + })); + return kControFlowOps->count(node.op()) > 0; +} + bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; } bool IsConv2DBackpropFilter(const NodeDef& node) { @@ -140,26 +155,26 @@ bool IsDiv(const NodeDef& node) { return node.op() == "Div"; } // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, // e.g. inv. bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { - static const std::unordered_set<string>* monotonic_non_decreasing_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1", "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint", "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh", })); - static const std::unordered_set<string>* monotonic_non_increasing_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Inv", "Reciprocal", "Erfc", "Rsqrt", "Neg", })); - if (monotonic_non_decreasing_ops->count(node.op()) > 0) { + if (kMonotonicNonDecreasingOps->count(node.op()) > 0) { if (is_non_decreasing) { *is_non_decreasing = true; } return true; - } else if (monotonic_non_increasing_ops->count(node.op()) > 0) { + } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) { if (is_non_decreasing) { *is_non_decreasing = false; } @@ -425,8 +440,44 @@ bool IsSwitch(const NodeDef& node) { return op == "Switch" || op == "RefSwitch"; } +bool IsSymbolicGradient(const NodeDef& node) { + return node.op() == "SymbolicGradient"; +} + bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; } +bool IsTensorArray(const NodeDef& node) { + static const gtl::FlatSet<string>* const kTensorArrayOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ + "TensorArray", + "TensorArrayV2", + "TensorArrayV3", + "TensorArrayGrad", + "TensorArrayGradV2", + "TensorArrayGradV3", + "TensorArrayGradWithShape", + "TensorArrayWrite", + "TensorArrayWriteV2", + "TensorArrayWriteV3", + "TensorArrayRead", + "TensorArrayReadV2", + "TensorArrayReadV3", + "TensorArrayConcat", + "TensorArrayConcatV2", + "TensorArrayConcatV3", + "TensorArraySplit", + "TensorArraySplitV2", + "TensorArraySplitV3", + "TensorArraySize", + "TensorArraySizeV2", + "TensorArraySizeV3", + "TensorArrayClose", + "TensorArrayCloseV2", + "TensorArrayCloseV3", + })); + return kTensorArrayOps->count(node.op()) > 0; +} + bool IsTile(const NodeDef& node) { return node.op() == "Tile"; } bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; } @@ -538,30 +589,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate) OPDEF_PROPERTY_HELPER(Commutative, commutative) bool IsInvolution(const NodeDef& node) { - static const std::unordered_set<string>* involution_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ - "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"})); - return involution_ops->count(node.op()) > 0; + static const gtl::FlatSet<string>* const kInvolutionOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert", + "Neg", "LogicalNot"})); + return kInvolutionOps->count(node.op()) > 0; } bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } - static const std::unordered_set<string>* - value_and_order_and_shape_preserving_ops = - CHECK_NOTNULL((new const std::unordered_set<string>{ - "CheckNumerics", - "DebugGradientIdentity", - "DeepCopy" - "Enter", - "Exit", - "PreventGradient", - "Print", - "Snapshot", - "StopGradient", - })); - return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 || + static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps = + CHECK_NOTNULL((new const gtl::FlatSet<string>{ + "CheckNumerics", + "DebugGradientIdentity", + "DeepCopy" + "Enter", + "Exit", + "PreventGradient", + "Print", + "Snapshot", + "StopGradient", + })); + return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 || IsIdentity(node); } @@ -569,31 +619,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } - static const std::unordered_set<string>* value_and_order_preserving_ops = - CHECK_NOTNULL((new const std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps = + CHECK_NOTNULL((new const gtl::FlatSet<string>{ "ExpandDims", "Reshape", "Squeeze", })); - return value_and_order_preserving_ops->count(node.op()) > 0 || + return kValueAndOrderPreservingOps->count(node.op()) > 0 || IsValueAndOrderAndShapePreserving(node); } bool IsValuePreserving(const NodeDef& node) { - static const std::unordered_set<string>* value_preserving_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kValuePreservingOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "InvertPermutation", "Reverse", "Roll", "Transpose", })); return IsValueAndOrderPreserving(node) || - value_preserving_ops->count(node.op()) > 0; + kValuePreservingOps->count(node.op()) > 0; } bool IsUnaryElementWise(const NodeDef& node) { - static const std::unordered_set<string>* element_wise_ops = - CHECK_NOTNULL((new std::unordered_set<string>{ + static const gtl::FlatSet<string>* const kElementWiseOps = + CHECK_NOTNULL((new gtl::FlatSet<string>{ "Abs", "Acos", "Acosh", @@ -642,7 +692,7 @@ bool IsUnaryElementWise(const NodeDef& node) { "Tan" "Tanh", })); - return element_wise_ops->count(node.op()) > 0 || + return kElementWiseOps->count(node.op()) > 0 || IsValueAndOrderAndShapePreserving(node); } diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 25ab6b65ac..d4e0159e81 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node); bool IsConcat(const NodeDef& node); bool IsConcatOffset(const NodeDef& node); bool IsConstant(const NodeDef& node); +bool IsControlFlow(const NodeDef& node); bool IsConv2D(const NodeDef& node); bool IsConv2DBackpropFilter(const NodeDef& node); bool IsConv2DBackpropInput(const NodeDef& node); @@ -149,7 +150,9 @@ bool IsStridedSliceGrad(const NodeDef& node); bool IsSub(const NodeDef& node); bool IsSum(const NodeDef& node); bool IsSwitch(const NodeDef& node); +bool IsSymbolicGradient(const NodeDef& node); bool IsTanhGrad(const NodeDef& node); +bool IsTensorArray(const NodeDef& node); bool IsTile(const NodeDef& node); bool IsTranspose(const NodeDef& node); bool IsTruncateDiv(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD index 960d1addb3..c708f84948 100644 --- a/tensorflow/core/grappler/optimizers/BUILD +++ b/tensorflow/core/grappler/optimizers/BUILD @@ -525,6 +525,7 @@ cc_library( "//tensorflow/core:core_cpu_base", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/utils:colocation", @@ -541,6 +542,7 @@ tf_cuda_cc_test( ":custom_graph_optimizer_registry", ":meta_optimizer", "//tensorflow/cc:cc_ops", + "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", "//tensorflow/core:test", diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 75ed12635e..7d5014ee0a 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> { for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) { for (int i = 0; i < output->input_size(); ++i) { auto input = output->input(i); - string name = ParseNodeName(input, &position); + StringPiece name = ParseNodeNameAsStringPiece(input, &position); if (name == node.name() && /*control input*/ position < 0) { return true; } @@ -1568,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { for (NodeDef* output : outputs) { if (IsControlInput(output->input(0))) continue; int port; - const string node_name = ParseNodeName(output->input(0), &port); + const StringPiece node_name = + ParseNodeNameAsStringPiece(output->input(0), &port); if (node_name == node.name()) { tails->insert(ChainLink(output, port)); } else { @@ -1618,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } else { for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) { int port; - const string node_name = ParseNodeName(new_tail->input(0), &port); + const StringPiece node_name = + ParseNodeNameAsStringPiece(new_tail->input(0), &port); if (node_name != tail->name()) { return Status::OK(); } @@ -2929,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const { for (const auto& input : node.input()) { int pos; - string node_name = ParseNodeName(input, &pos); - h = Hash64CombineUnordered(Hash64(node_name), h); + const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos); + h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h); h = Hash64CombineUnordered(std::hash<int>()(pos), h); } for (const auto& attr : node.attr()) { @@ -3247,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, optimized_graph_ = &optimized_item.graph; node_map_.reset(new NodeMap(optimized_graph_)); + // Disable restricted graph rewrites. + options_.unary_ops_composition &= + item.allowed_optimizations.non_differentiable_rewrites; + if (options_.dedup_computations) { DedupComputations(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ca5d3a6dfd..3d0d95bba7 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices( // We can't do anything if we don't know the rank of the input. return Status::OK(); } - const int rank = input_prop.shape().dim_size(); - if (rank == 0) { + const int input_rank = input_prop.shape().dim_size(); + if (input_rank < 1) { // Unexpected graph, don't try to change it. return Status::OK(); } + const OpInfo::TensorProperties& reduction_indices_prop = input_props[1]; + DataType dtype = reduction_indices_prop.dtype(); + if (dtype != DT_INT32 && dtype != DT_INT64) { + return Status::OK(); + } + PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape()); + const int num_reduction_indices = reduction_indices_shape.num_elements(); + const std::vector<OpInfo::TensorProperties>& output_props = properties.GetOutputProperties(node->name()); if (output_props.size() != 1) { return Status::OK(); } - const bool keep_dims = - node->attr().count("keep_dims") && node->attr().at("keep_dims").b(); const OpInfo::TensorProperties& output_prop = output_props[0]; - PartialTensorShape output_shape(output_prop.shape()); - if (output_shape.num_elements() != 1) { - bool full_reduction = false; + const int output_rank = + output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size(); + + bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank; + if (!full_reduction) { + // A full reduction will generate a tensor of one of the shapes + // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of + // elements in the output of the reduction, we may deduce it from reshape + // nodes following it. for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) { - if (!IsReshape(*fanout) && !keep_dims) { - // Depending on how it's setup, a full reduction will generate a tensor - // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we - // rely on the existence of a reshape node following the reduction to - // ensure that the fanout is fed a scalar of the right shape. + full_reduction = false; + if (!IsReshape(*fanout)) { return Status::OK(); } const std::vector<OpInfo::TensorProperties>& reshape_props = @@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices( } } - const OpInfo::TensorProperties& reduction_prop = input_props[1]; - DataType dtype = reduction_prop.dtype(); - if (dtype != DT_INT32 && dtype != DT_INT64) { - return Status::OK(); - } - // We know it's a full reduction. We can generate the set of indices to - // reduce. + // We know it's a full reduction. We can generate the full set of indices to + // reduce as a constant node. string const_name = OptimizedNodeName(*node, "-reduction_indices"); if (node_map_->GetNode(const_name)) { return Status::OK(); } NodeDef* reduction_indices = graph_->add_node(); - Tensor value(dtype, TensorShape({rank})); - for (int i = 0; i < rank; ++i) { + Tensor value(dtype, TensorShape({input_rank})); + for (int i = 0; i < input_rank; ++i) { if (dtype == DT_INT32) { value.vec<int32>()(i) = i; } else { @@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices( } TF_RETURN_IF_ERROR( CreateNodeDef(const_name, TensorValue(&value), reduction_indices)); + reduction_indices->set_device(node->device()); string ctrl_dep = AddControlDependency(node->input(1), graph_, node_map_.get()); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b09360a2c2..fab01edfed 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) { } TEST_F(ConstantFoldingTest, MaterializeReductionIndices) { - tensorflow::Scope s = tensorflow::Scope::NewRootScope(); - Output input = - ops::Placeholder(s.WithOpName("input"), DT_FLOAT, - ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); - Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32); - Output sum = ops::Sum(s.WithOpName("sum"), input, indices); - Output size = ops::Const(s.WithOpName("size"), 1, {1}); - Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + for (bool use_reshape : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape(PartialTensorShape({-1, -1}))); + // If use_reshape is false, we need to now the number of indices to apply + // the rewrite. + Output indices = ops::Placeholder( + s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + if (use_reshape) { + Output size = ops::Const(s.WithOpName("size"), 1, {1}); + Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size); + } - GrapplerItem item; - TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch.push_back("reshape"); + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back(use_reshape ? "reshape" : "sum"); - auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); - Tensor indices_t(DT_INT32, TensorShape({2})); - indices_t.flat<int>()(0) = 0; - indices_t.flat<int>()(1) = 1; - auto tensors_expected = EvaluateNodes( - item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors_expected.size()); + auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4})); + Tensor indices_t(DT_INT32, TensorShape({2})); + indices_t.flat<int>()(0) = 0; + indices_t.flat<int>()(1) = 1; + auto tensors_expected = EvaluateNodes( + item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors_expected.size()); - ConstantFolding optimizer(nullptr /* cpu_device */); - GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - // Run a second time to make sure the optimization is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + // Run a second time to make sure the optimization is idempotent. + item.graph.Swap(&output); + status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); - int found = 0; - for (const auto& node : output.node()) { - if (node.name() == "ConstantFolding/sum-reduction_indices") { - ++found; - EXPECT_EQ("Const", node.op()); - EXPECT_EQ("^indices", node.input(0)); - EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape()) - .num_elements()); - } else if (node.name() == "sum") { - ++found; - EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); - } else if (node.name() == "indices") { - ++found; + int found = 0; + for (const auto& node : output.node()) { + if (node.name() == "ConstantFolding/sum-reduction_indices") { + ++found; + EXPECT_EQ("Const", node.op()); + EXPECT_EQ("^indices", node.input(0)); + EXPECT_EQ(2, + TensorShape(node.attr().at("value").tensor().tensor_shape()) + .num_elements()); + } else if (node.name() == "sum") { + ++found; + EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1)); + } else if (node.name() == "indices") { + ++found; + } } + EXPECT_EQ(3, found); + + auto tensors = EvaluateNodes(output, item.fetch, + {{"input", input_t}, {"indices", indices_t}}); + EXPECT_EQ(1, tensors.size()); + test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); } - EXPECT_EQ(3, found); +} - auto tensors = EvaluateNodes(output, item.fetch, - {{"input", input_t}, {"indices", indices_t}}); - EXPECT_EQ(1, tensors.size()); - test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5); +TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) { + for (bool input_rank_known : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output input = + (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT, + ops::Placeholder::Shape( + PartialTensorShape({-1, -1}))) + : ops::Placeholder(s.WithOpName("input"), DT_FLOAT)); + Output indices = + ops::Placeholder(s.WithOpName("indices"), DT_INT32, + ops::Placeholder::Shape( + PartialTensorShape({input_rank_known ? 1 : 2}))); + Output sum = ops::Sum(s.WithOpName("sum"), input, indices); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + item.fetch.push_back("sum"); + + // Use aggressive mode to force the shape inference to propagate placeholder + // shapes. + ConstantFolding optimizer(RewriterConfig::AGGRESSIVE, + nullptr /* cpu_device */); + GraphDef output; + Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + + CompareGraphs(item.graph, output); + } } TEST_F(ConstantFoldingTest, LargeConstant) { diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD index cf305cebe1..ee7c14e3ab 100644 --- a/tensorflow/core/grappler/optimizers/data/BUILD +++ b/tensorflow/core/grappler/optimizers/data/BUILD @@ -22,6 +22,7 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -31,6 +32,7 @@ tf_cc_test( visibility = ["//visibility:public"], deps = [ ":filter_fusion", + ":graph_test_utils", ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:test", @@ -87,11 +89,12 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + ":graph_utils", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -121,11 +124,12 @@ cc_library( ], visibility = ["//visibility:public"], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/grappler:mutable_graph_view", - "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler:utils", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -135,6 +139,7 @@ tf_cc_test( visibility = ["//visibility:public"], deps = [ ":graph_utils", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -146,6 +151,62 @@ tf_cc_test( ) cc_library( + name = "graph_test_utils", + testonly = 1, + srcs = ["graph_test_utils.cc"], + hdrs = [ + "graph_test_utils.h", + ], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:testlib", + ] + tf_protos_all(), +) + +cc_library( + name = "hoist_random_uniform", + srcs = ["hoist_random_uniform.cc"], + hdrs = [ + "hoist_random_uniform.h", + ], + visibility = ["//visibility:public"], + deps = [ + ":function_utils", + ":graph_utils", + "//tensorflow/core:lib", + "//tensorflow/core/grappler:mutable_graph_view", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:op_types", + "//tensorflow/core/grappler:utils", + "//tensorflow/core/grappler/clusters:cluster", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", + "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", + ] + tf_protos_all(), +) + +tf_cc_test( + name = "hoist_random_uniform_test", + srcs = ["hoist_random_uniform_test.cc"], + visibility = ["//visibility:public"], + deps = [ + ":graph_test_utils", + ":graph_utils", + ":hoist_random_uniform", + "//tensorflow/core:framework", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + "//tensorflow/core/grappler:grappler_item", + ] + tf_protos_all(), +) + +cc_library( name = "latency_all_edges", srcs = ["latency_all_edges.cc"], hdrs = [ @@ -256,7 +317,7 @@ cc_library( "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", - "//tensorflow/core:ptr_util", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -265,6 +326,7 @@ tf_cc_test( srcs = ["map_and_filter_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_and_filter_fusion", "//tensorflow/core:framework", @@ -294,6 +356,7 @@ cc_library( "//tensorflow/core/grappler/utils:topological_sort", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + "//tensorflow/core:lib_internal", ] + tf_protos_all(), ) @@ -302,6 +365,7 @@ tf_cc_test( srcs = ["map_fusion_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_fusion", "//tensorflow/core:framework", @@ -339,6 +403,7 @@ tf_cc_test( srcs = ["map_parallelization_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_test_utils", ":graph_utils", ":map_parallelization", "//tensorflow/core:framework", @@ -422,6 +487,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":filter_fusion", + ":hoist_random_uniform", ":latency_all_edges", ":map_and_batch_fusion", ":map_and_filter_fusion", @@ -458,7 +524,9 @@ cc_library( deps = [ ":function_utils", ":graph_utils", + "//tensorflow/cc:ops", "@com_google_absl//absl/strings", + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -474,6 +542,7 @@ tf_cc_test( srcs = ["vectorization_utils_test.cc"], visibility = ["//visibility:public"], deps = [ + ":graph_utils", ":function_utils", ":vectorization_utils", "//tensorflow/core:framework", @@ -483,7 +552,10 @@ tf_cc_test( "//tensorflow/core/grappler:grappler_item", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", + # For ops we need registered + "//tensorflow/core/kernels/data:dataset_ops", "//tensorflow/core/kernels:cast_op", + "//tensorflow/core/kernels:logging_ops", "//tensorflow/tools/graph_transforms:transform_utils", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc index c71aa6e804..1ad495bbad 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc @@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node, fused_node.set_op("FilterDataset"); fused_node.add_input(first_filter_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = first_filter_node.attr().at("predicate"); *attr.mutable_func()->mutable_name() = fused_function.signature().name(); (*fused_node.mutable_attr())["predicate"] = std::move(attr); - copy_attribute("Targuments", first_filter_node, &fused_node); + graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, second_filter_node, &fused_node); + graph_utils::CopyAttribute(key, second_filter_node, &fused_node); return fused_node; } @@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // functions, or make sure that optimization passes run after filter // fusion. TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate)); - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. nodes_to_delete.insert(first_filter_node->name()); nodes_to_delete.insert(second_filter_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc index 12b1924efd..c8becc5cc0 100644 --- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,14 +28,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "FilterDataset", {string(input_node_name)}, - {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeFilterNode; TEST(FilterFusionTest, FuseTwoFilterIntoOne) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc index e95ea1a4c1..311df15bc2 100644 --- a/tensorflow/core/grappler/optimizers/data/function_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc @@ -14,31 +14,16 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/strings/scanner.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace grappler { namespace function_utils { -namespace { - -template <typename Predicate, typename Collection> -std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate, - const Collection& collection) { - std::vector<int> indices = {}; - unsigned idx = 0; - for (auto&& element : collection) { - if (predicate(element)) { - indices.push_back(idx); - } - idx++; - } - return indices; -} - -} // namespace FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name, const string& output, int position) @@ -152,32 +137,27 @@ bool ContainsFunctionOutputWithName(StringPiece name, } int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, function.signature().input_arg()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const OpDef_ArgDef& arg) { return arg.name() == name; }, function.signature().output_arg()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, function.node_def()); - return indices.empty() ? -1 : indices.front(); } int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return graph_utils::GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, function.node_def()); - - return indices.empty() ? -1 : indices.front(); } void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function, diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc new file mode 100644 index 0000000000..b2eec7220e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc @@ -0,0 +1,49 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "MapDataset", {string(input_node_name)}, + {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<DataType>{}}}); +} + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name) { + return test::function::NDef( + name, "FilterDataset", {string(input_node_name)}, + {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))}, + {"Targuments", {}}, + {"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<TensorShape>{}}}); +} + +} // end namespace graph_tests_utils +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h new file mode 100644 index 0000000000..ca0fde997d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h @@ -0,0 +1,36 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ + +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" + +namespace tensorflow { +namespace grappler { +namespace graph_tests_utils { + +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name = "XTimesTwo"); + +NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name = "IsZero"); + +} // end namespace graph_tests_utils +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_ diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc index 2dd9ee822e..b863a25dc5 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -201,25 +202,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) { int FindGraphFunctionWithName(StringPiece name, const FunctionDefLibrary& library) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const FunctionDef& function) { return function.signature().name() == name; }, library.function()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&name](const NodeDef& node) { return node.name() == name; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) { - std::vector<int> indices = GetElementIndicesWithPredicate( + return GetFirstElementIndexWithPredicate( [&op](const NodeDef& node) { return node.op() == op; }, graph.node()); - return indices.empty() ? -1 : indices.front(); } std::vector<int> FindAllGraphNodesWithOp(const string& op, @@ -260,6 +258,41 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, } function->mutable_signature()->set_name(std::move(name)); } + +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node) { + (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name); +} + +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node) { + CopyAttribute(attribute_name, first, to_node); + (*to_node->mutable_attr()) + .at(attribute_name) + .mutable_list() + ->MergeFrom(second.attr().at(attribute_name).list()); +} + +Status EnsureNodeNamesUnique(Graph* g) { + // Modeled after Scope::Impl::GetUniqueName + std::unordered_map<string, int> name_map; + + for (auto node : g->op_nodes()) { + const string& prefix = node->name(); + if (auto entry = gtl::FindOrNull(name_map, prefix)) { + string unique_name; + do { + unique_name = strings::StrCat(prefix, "_", ++(*entry)); + } while (name_map.find(unique_name) != name_map.end()); + name_map.insert({unique_name, 0}); + node->set_name(std::move(unique_name)); + } else { + name_map.insert({node->name(), 0}); + } + } + + return Status::OK(); +} } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h index b117482db2..d130fee204 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils.h +++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/lib/core/errors.h" @@ -31,6 +32,21 @@ namespace tensorflow { namespace grappler { namespace graph_utils { +// Returns the index of the first element in collection that fulfills predicate. +// If no such element exists, returns -1. +template <typename Predicate, typename Collection> +int GetFirstElementIndexWithPredicate(const Predicate& predicate, + const Collection& collection) { + unsigned idx = 0; + for (auto&& element : collection) { + if (predicate(element)) { + return idx; + } + idx++; + } + return -1; +} + // Adds a node to the graph. NodeDef* AddNode(StringPiece name, StringPiece op, const std::vector<string>& inputs, @@ -101,11 +117,29 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op, // is unique across the graph. void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node); -// Sets the node name using the `prefix` name as a prefix while guaranteeing the -// name is unique across the graph. +// Sets the function name using the `prefix` name as a prefix while guaranteeing +// the name is unique across the function library. void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library, FunctionDef* function); +// Copies attribute having name `attribute_name` from node `from` to node +// `to_node`. +void CopyAttribute(const string& attribute_name, const NodeDef& from, + NodeDef* to_node); + +// Concatenates list attribute having name `attribute_name` from `first` and +// `second` node, setting it to `to_node`. +void ConcatAttributeList(const string& attribute_name, const NodeDef& first, + const NodeDef& second, NodeDef* to_node); + +// Checks that all nodes in the graphs have unique names, and sets their names +// to be unique if they are not already. This is necessary as Graph does not +// have the provisions to deduplicate names, and name deduplication elsewhere +// in tensorflow happens in other layers (for example, in the Scope class of the +// C++ API). Note that the nodes in the graph are identified by their id, +// and renaming nodes does not mutate any edges. +Status EnsureNodeNamesUnique(Graph* g); + } // end namespace graph_utils } // end namespace grappler } // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc index 6877c207c4..4ab6d71532 100644 --- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -24,6 +25,18 @@ namespace grappler { namespace graph_utils { namespace { +TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) { + std::vector<int> vec({1, 2, 3, 4, 5, 6}); + auto result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 3 == 0; }, vec); + + EXPECT_EQ(result, 2); + + result = GetFirstElementIndexWithPredicate( + [](int elem) { return elem % 7 == 0; }, vec); + EXPECT_EQ(result, -1); +} + TEST(GraphUtilsTest, AddScalarConstNodeBool) { GraphDef graph_def; MutableGraphView graph(&graph_def); @@ -217,6 +230,33 @@ TEST(GraphUtilsTest, GetInputNode) { EXPECT_EQ(GetInputNode(*node1, graph), nullptr); } +TEST(GraphUtilsTest, EnsureNodeNamesUnique) { + Graph g(OpRegistry::Global()); + + Node *const_0, *const_1, *const_2; + + // Arbitrary const + Tensor tensor(DT_INT32, {}); + tensor.scalar<int32>()() = 5; + + for (auto node : {&const_0, &const_1}) { + TF_EXPECT_OK(NodeBuilder("Const", "Const") + .Attr("value", tensor) + .Attr("dtype", DT_INT32) + .Finalize(&g, node)); + } + // Make sure generated name doesn't clash with existing name either + TF_EXPECT_OK(NodeBuilder("Const_1", "Const") + .Attr("value", tensor) + .Attr("dtype", DT_INT32) + .Finalize(&g, &const_2)); + + TF_EXPECT_OK(EnsureNodeNamesUnique(&g)); + EXPECT_NE(const_0->name(), const_1->name()); + EXPECT_NE(const_1->name(), const_2->name()); + EXPECT_NE(const_0->name(), const_2->name()); +} + } // namespace } // namespace graph_utils } // namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc new file mode 100644 index 0000000000..ce0b2db039 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc @@ -0,0 +1,289 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h" + +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/grappler/clusters/cluster.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/mutable_graph_view.h" +#include "tensorflow/core/grappler/op_types.h" +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h" +#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/grappler/utils.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/protobuf.h" + +namespace tensorflow { +namespace grappler { +namespace { + +NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node, + const FunctionDef& stateless_function, + MutableGraphView* graph) { + NodeDef stateless_map; + graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(), + &stateless_map); + + stateless_map.set_op("MapDataset"); + stateless_map.add_input(zip_node.name()); + // Add placeholders. + for (int i = 1; i < map_node.input_size(); i++) + stateless_map.add_input(map_node.input(i)); + + auto attr = map_node.attr().at("f"); + *attr.mutable_func()->mutable_name() = stateless_function.signature().name(); + *attr.mutable_func()->mutable_attr() = stateless_function.attr(); + (*stateless_map.mutable_attr())["f"] = std::move(attr); + + graph_utils::CopyAttribute("Targuments", map_node, &stateless_map); + for (auto key : {"output_shapes", "output_types"}) + graph_utils::CopyAttribute(key, map_node, &stateless_map); + + if (const auto* attr = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) + (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr; + + return stateless_map; +} + +NodeDef MakeRandomDataset(const NodeDef& random_uniform_node, + MutableGraphView* graph) { + NodeDef random_dataset; + random_dataset.set_op("RandomDataset"); + graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(), + &random_dataset); + + const auto* seed = graph_utils::AddScalarConstNode<int64>( + random_uniform_node.attr().at("seed").i(), graph); + const auto* seed2 = graph_utils::AddScalarConstNode<int64>( + random_uniform_node.attr().at("seed2").i(), graph); + + random_dataset.add_input(seed->name()); + random_dataset.add_input(seed2->name()); + + (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape(); + (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type( + DT_INT64); + + return random_dataset; +} + +NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) { + NodeDef batch_dataset; + batch_dataset.set_op("BatchDatasetV2"); + graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(), + &batch_dataset); + const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph); + const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph); + batch_dataset.add_input(random_dataset.name()); + batch_dataset.add_input(batch_size->name()); + batch_dataset.add_input(drop_reminder->name()); + + (*batch_dataset.mutable_attr())["output_shapes"] + .mutable_list() + ->add_shape() + ->mutable_dim() + ->Add() + ->set_size(-1); + (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type( + DT_INT64); + + return batch_dataset; +} + +NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node, + MutableGraphView* graph) { + NodeDef zip_node; + graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(), + &zip_node); + + zip_node.set_op("ZipDataset"); + zip_node.add_input(first_node.name()); + zip_node.add_input(second_node.name()); + + for (auto key : {"output_shapes", "output_types"}) + graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node); + + (*zip_node.mutable_attr())["N"].set_i(2); + + return zip_node; +} + +// We need to insert our argument before the placeholders, which are the last +// arguments. +OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) { + int new_argument_idx = signature->input_arg_size() - num_placeholders; + signature->add_input_arg(); + for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) { + signature->mutable_input_arg()->SwapElements(i - 1, i); + } + auto* seed_arg = signature->mutable_input_arg(new_argument_idx); + seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx)); + seed_arg->set_type(DT_INT64); + + return seed_arg; +} + +// Make function that uses `StatelessRandomUniform` instead of `RandomUniform` +// to make it less statefull. The function can still be stateful, but in when +// other stateful ops are e.g. `Assert`, then it will be parallelizable. +const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function, + bool is_stateful, + int num_placeholders, + FunctionDefLibrary* library) { + FunctionDef* stateless_function = library->add_function(); + *stateless_function = map_function; + if (is_stateful) + stateless_function->mutable_signature()->set_is_stateful(is_stateful); + graph_utils::SetUniqueGraphFunctionName("stateless_function", library, + stateless_function); + + auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(), + num_placeholders); + + auto* const random_uniform = stateless_function->mutable_node_def( + function_utils::FindFunctionNodeWithOp("RandomUniform", + *stateless_function)); + + // Replace RandomUniform node with StatelessRandomUniform. + random_uniform->set_op("StatelessRandomUniform"); + random_uniform->add_input(seed_arg->name()); + (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64); + random_uniform->mutable_attr()->erase("seed"); + random_uniform->mutable_attr()->erase("seed2"); + + return stateless_function; +} +// This function returns true if function is stateful and has single +// RandomUniform op and no other stateful ops except Assert. +// `is_stateful_after_hoisting` is set to true if RandomUniform is the only +// stateful op and hoisting can be performed. +bool CanHoistRandomUniform(const FunctionDef& map_function, + const FunctionLibraryDefinition& library, + bool* is_stateful_after_hoisting, + const NodeDef** random_uniform_op) { + if (!map_function.signature().is_stateful()) return false; + *is_stateful_after_hoisting = true; + + bool have_other_stateful_ops = false; + + for (const auto& node : map_function.node_def()) { + const OpDef* op_def; + TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def)); + // Skip stateless nodes and assert, as it does not actually have a state. + if (!op_def->is_stateful()) continue; + + if (op_def->name() == "Assert") { + have_other_stateful_ops = true; + continue; + } + + // TODO(prazek): For now we only handle RandomUniform, we should handle + // RandomUniformInt as well. + if (op_def->name() != "RandomUniform") return false; + + // TODO(prazek): For now we can only hoist single RandomUniform. + if (*random_uniform_op != nullptr) return false; + + *random_uniform_op = &node; + } + + if (!have_other_stateful_ops) *is_stateful_after_hoisting = false; + + // Have we found single RandomUniform? + return *random_uniform_op != nullptr; +} + +int NumberOfPlaceholders(const NodeDef& map_node) { + // First input of MapDataset is the argument to the function. Rest of the + // inputs are placeholders. + return map_node.input_size() - 1; +} + +} // namespace + +Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) { + *output = item.graph; + + MutableGraphView graph(output); + std::set<string> nodes_to_delete; + FunctionLibraryDefinition function_library(OpRegistry::Global(), + item.graph.library()); + + auto get_map_node = [](const NodeDef& node) -> const NodeDef* { + // TODO(prazek): we could also handle ParallelMapDataset and + // MapAndBatchDataset. + if (node.op() == "MapDataset") return &node; + return nullptr; + }; + + for (const NodeDef& node : item.graph.node()) { + const NodeDef* map_node = get_map_node(node); + if (!map_node) continue; + + const auto& fun = map_node->attr().at("f"); + const FunctionDef* func = function_library.Find(fun.func().name()); + + const NodeDef* random_uniform_op = nullptr; + bool is_stateful_after_hoisting = true; + if (!CanHoistRandomUniform(*func, function_library, + &is_stateful_after_hoisting, &random_uniform_op)) + continue; + const auto* random_seed_dataset = + graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph)); + + const auto* batch_dataset = + graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph)); + + const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph); + + const auto* zip_node = + graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph)); + + const auto* stateless_func = MakeLessStatefulFunction( + *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node), + output->mutable_library()); + + const auto* stateless_map = graph.AddNode( + MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph)); + + graph.ReplaceInput(*map_node, *stateless_map); + + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. + nodes_to_delete.insert(map_node->name()); + } + + graph.DeleteNodes(nodes_to_delete); + return Status::OK(); +} + +void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, + double result) { + // no-op +} + +REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform"); + +} // end namespace grappler +} // end namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h new file mode 100644 index 0000000000..d1bcf6782d --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h @@ -0,0 +1,55 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ + +#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h" + +namespace tensorflow { +namespace grappler { + +// This optimization hoists instances of `random_uniform` out of a function +// with the aim of making it stateless. It creates a new function that takes a +// random seed as an extra argument and uses `stateless_random_uniform` instead +// of `random_uniform` to make it stateless. +// It also creates RandomDataset(seed).batch(2), which is zipped with old input +// to the map. The batching in RandomDataset is because we need 2 seeds for +// `stateless_random_uniform`. +// TODO(prazek): for now only `RandomUniform` is handled, but we could handle +// `RandomUniformInt` similarly. +class HoistRandomUniform : public CustomGraphOptimizer { + public: + HoistRandomUniform() = default; + ~HoistRandomUniform() override = default; + + string name() const override { return "hoist_random_uniform"; }; + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* output) override; + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimize_output, double result) override; +}; + +} // end namespace grappler +} // end namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_ diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc new file mode 100644 index 0000000000..455459e3f6 --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc @@ -0,0 +1,84 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h" + +#include "tensorflow/core/framework/function_testlib.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace grappler { +namespace { + +TEST(HoistRandomUniform, SimpleHoisting) { + using test::function::NDef; + GrapplerItem item; + item.graph = test::function::GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}), + NDef("range", "RangeDataset", {"start", "stop", "step"}, + {{"output_shapes", gtl::ArraySlice<TensorShape>{}}, + {"output_types", gtl::ArraySlice<DataType>{}}}), + graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"), + NDef("cache", "CacheDataset", {"map1", "filename"}, {})}, + // FunctionLib + { + test::function::RandomUniform(), + }); + + HoistRandomUniform optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output)); + const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output); + const int zip_dataset_id = + graph_utils::FindGraphNodeWithOp("ZipDataset", output); + const int random_dataset_id = + graph_utils::FindGraphNodeWithOp("RandomDataset", output); + const int batch_random_id = + graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output); + ASSERT_NE(random_dataset_id, -1); + ASSERT_NE(zip_dataset_id, -1); + ASSERT_NE(new_map_id, -1); + ASSERT_NE(batch_random_id, -1); + + const auto& new_map = output.node(new_map_id); + const auto& zip = output.node(zip_dataset_id); + const auto& random = output.node(random_dataset_id); + const auto& batch = output.node(batch_random_id); + + ASSERT_EQ(new_map.input_size(), 1); + EXPECT_EQ(new_map.input(0), zip.name()); + + ASSERT_EQ(zip.input_size(), 2); + EXPECT_EQ(zip.input(0), "range"); + EXPECT_EQ(zip.input(1), batch.name()); + + ASSERT_EQ(batch.input_size(), 3); + EXPECT_EQ(batch.input(0), random.name()); +} + +} // namespace +} // namespace grappler +} // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc index 63945b8b9e..e66766eb23 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc @@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node, // Set `f` and `Targuments` attributes. for (auto key : {"f", "Targuments"}) { - (*new_node.mutable_attr())[key] = map_node.attr().at(key); + graph_utils::CopyAttribute(key, map_node, &new_node); } + // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node.mutable_attr())[key] = batch_node.attr().at(key); + graph_utils::CopyAttribute(key, batch_node, &new_node); } return new_node; } diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc index f1844a141c..c4868eacbb 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node, fused_node.set_op("MapDataset"); fused_node.add_input(map_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = map_node.attr().at("f"); attr.mutable_func()->set_name(fused_function.signature().name()); (*fused_node.mutable_attr())["f"] = std::move(attr); - copy_attribute("Targuments", map_node, &fused_node); + graph_utils::CopyAttribute("Targuments", map_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, map_node, &fused_node); + graph_utils::CopyAttribute(key, map_node, &fused_node); + + if (const auto* attr = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism")) + (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr; // Add the predicate output attributes. (*fused_node.mutable_attr())["output_types"] diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc index f029a093fa..6e6da37d7c 100644 --- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -27,24 +28,8 @@ limitations under the License. namespace tensorflow { namespace grappler { namespace { - -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} - -NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "FilterDataset", {string(input_node_name)}, - {{"predicate", FunctionDefHelper::FunctionRef("IsZero")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeFilterNode; +using graph_tests_utils::MakeMapNode; TEST(MapAndFilterFusionTest, FuseMapAndFilter) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc index a78ecb09f7..bd943342e8 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/topological_sort.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/protobuf.h" namespace tensorflow { @@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node, NodeDef fused_node; graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(), &fused_node); - fused_node.set_op("MapDataset"); fused_node.add_input(parent_map_node.input(0)); - auto copy_attribute = [](const string& attribute_name, const NodeDef& from, - NodeDef* to) { - (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name); - }; - auto attr = parent_map_node.attr().at("f"); *attr.mutable_func()->mutable_name() = fused_function.signature().name(); (*fused_node.mutable_attr())["f"] = std::move(attr); - copy_attribute("Targuments", parent_map_node, &fused_node); - + graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node); for (auto key : {"output_shapes", "output_types"}) - copy_attribute(key, map_node, &fused_node); + graph_utils::CopyAttribute(key, map_node, &fused_node); + auto value_or_false = [](const AttrValue* attr) { + if (!attr) return false; + return attr->b(); + }; + + const auto* first_parallelism = + gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism"); + const auto* second_parallelism = + gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"); + // Some graphs cannot execute with use_inter_op_parallelism=False, so we need + // to set it to true if one of the ops have it set to true. + if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) { + (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true); + } return fused_node; } @@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item, // fusion. TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function)); - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. + // TODO(b/116285210): we could also remove map functions from library if + // they are not used anymore. nodes_to_delete.insert(parent_map_node->name()); nodes_to_delete.insert(map_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc index b25dfbd0b8..8889f9dddd 100644 --- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" @@ -28,14 +29,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} +using graph_tests_utils::MakeMapNode; TEST(MapFusionTest, FuseTwoMapNodesIntoOne) { using test::function::NDef; diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc index 305325e434..782c9f48b7 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc @@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item, auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph)); graph.ReplaceInput(*map_node, *parallel_map); - - // TODO(prazek): we could also remove map functions from library if they - // are not used anymore. nodes_to_delete.insert(map_node->name()); } diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc index b2a5d9b6af..9fdfe8af30 100644 --- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc @@ -19,8 +19,8 @@ limitations under the License. #include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/grappler/grappler_item.h" +#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" - #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" @@ -28,16 +28,7 @@ namespace tensorflow { namespace grappler { namespace { -NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, - StringPiece function_name) { - return test::function::NDef( - name, "MapDataset", {string(input_node_name)}, - {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, - {"Targuments", {}}, - {"output_shapes", {}}, - {"output_types", {}}}); -} - +using graph_tests_utils::MakeMapNode; const char stateless_fun_name[] = "XTimesTwo"; const char stateful_fun_name[] = "RandomUniform"; diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc index 7a2f1910da..a9254ed58b 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc @@ -35,10 +35,6 @@ namespace tensorflow { namespace grappler { namespace { -void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) { - (*to->mutable_attr())[attr_name] = from.attr().at(attr_name); -} - // Returns a FunctionDef containing a MapDefun op that wraps the original // function. FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, @@ -48,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, // Function inputs and outputs are the same as original, just // with different shapes. *vectorized_func->mutable_signature() = orig_func.signature(); - graph_utils::SetUniqueGraphFunctionName("vectorized_function", library, + graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library, vectorized_func); // Add MapDefun node @@ -61,7 +57,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, for (const string& k : {"f", "output_types", "output_shapes"}) { // Function, output types and (unbatched) shapes are the same as the // original map node. - CopyAttribute(k, map_node, map_defun_node); + graph_utils::CopyAttribute(k, map_node, map_defun_node); } // Get types of input arguments from original map function @@ -71,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node, map_defun_node->add_input(input.name()); } (*map_defun_node->mutable_attr())["Targuments"] = t_args; + AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node); // Set return values to match output names string output_prefix = strings::StrCat(map_defun_node->name(), ":output:"); @@ -90,21 +87,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node, // efficient vectorization with VectorizeMapDefun. FunctionDef* vectorized_func = CreateMapDefunWrapper(map_node, orig_func, library); - NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0); - DCHECK_EQ(map_defun_node->op(), "MapDefun"); - - // Create a copy of the original function so that we can mutate it, and - // attach that to the map defun node. - FunctionDef* map_defun_fn = library->add_function(); - *map_defun_fn = orig_func; - graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library, - map_defun_fn); - (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name( - map_defun_fn->signature().name()); - - vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn, - map_defun_node); - return vectorized_func; + const NodeDef& map_defun_node = vectorized_func->node_def(0); + DCHECK_EQ(map_defun_node.op(), "MapDefun"); + + // TODO(b/116285210): Unreferenced functions should get cleaned up later + FunctionDef* result; + Status s = vectorization_utils::VectorizeMapDefun( + *vectorized_func, map_defun_node, library, &result); + + if (!s.ok()) { + LOG(ERROR) << "VectorizeMapDefun failed: " << s; + return vectorized_func; + } + return result; } bool IsOutputShapesFullyDefined(const NodeDef& node) { @@ -195,13 +190,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node, } // Set attrs - CopyAttribute("Targuments", old_map_node, &map_node); + graph_utils::CopyAttribute("Targuments", old_map_node, &map_node); auto& func_attr = (*map_node.mutable_attr())["f"]; func_attr.mutable_func()->set_name(vectorized_func.signature().name()); for (auto key : {"output_shapes", "output_types"}) { - CopyAttribute(key, old_batch_node, &map_node); + graph_utils::CopyAttribute(key, old_batch_node, &map_node); } + + (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true); + return map_node; } diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc index ed1bd6bc97..f4faf41549 100644 --- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc +++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc @@ -30,72 +30,51 @@ namespace { using test::function::GDef; using test::function::NDef; -void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims, - TensorShapeProto* t) { - for (size_t i = 0; i < dims.size(); ++i) { - auto* d = t->add_dim(); - d->set_size(dims[i]); - } -} - -AttrValue MakeShapeListAttr( - const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) { - AttrValue shapes_attr; - for (size_t i = 0; i < shapes.size(); ++i) { - MakeTensorShapeProtoHelper(shapes[i], - shapes_attr.mutable_list()->add_shape()); - } - - return shapes_attr; -} - -NodeDef MakeMapNodeHelper( - StringPiece name, StringPiece input_node_name, StringPiece function_name, - StringPiece map_op_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { +NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name, + StringPiece function_name, StringPiece map_op_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { return test::function::NDef( name, map_op_name, {string(input_node_name)}, {{"f", FunctionDefHelper::FunctionRef(string(function_name))}, {"Targuments", {}}, - {"output_shapes", MakeShapeListAttr(output_shapes)}, + {"output_shapes", output_shapes}, {"output_types", output_types}}); } -NodeDef MakeMapNode( - StringPiece name, StringPiece input_node_name, StringPiece function_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { +NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name, + StringPiece function_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset", output_shapes, output_types); } -NodeDef MakeBatchNode( - StringPiece name, StringPiece input_node_name, - StringPiece input_batch_size_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { - return NDef(name, "BatchDataset", - {string(input_node_name), string(input_batch_size_name)}, - {{"output_types", output_types}, - {"output_shapes", MakeShapeListAttr(output_shapes)}}); +NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name, + StringPiece input_batch_size_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { + return NDef( + name, "BatchDataset", + {string(input_node_name), string(input_batch_size_name)}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); } -NodeDef MakeBatchV2Node( - StringPiece name, StringPiece input_node_name, - StringPiece input_batch_size_name, StringPiece input_drop_remainder_name, - const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes, - const gtl::ArraySlice<DataType>& output_types) { - return NDef(name, "BatchDatasetV2", - {string(input_node_name), string(input_batch_size_name), - string(input_drop_remainder_name)}, - {{"output_types", output_types}, - {"output_shapes", MakeShapeListAttr(output_shapes)}}); +NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name, + StringPiece input_batch_size_name, + StringPiece input_drop_remainder_name, + gtl::ArraySlice<PartialTensorShape> output_shapes, + gtl::ArraySlice<DataType> output_types) { + return NDef( + name, "BatchDatasetV2", + {string(input_node_name), string(input_batch_size_name), + string(input_drop_remainder_name)}, + {{"output_types", output_types}, {"output_shapes", output_shapes}}); } -NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) { +NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) { return NDef(name, "RangeDataset", inputs, - {{"output_shapes", MakeShapeListAttr({{}})}, + {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}, {"output_types", gtl::ArraySlice<DataType>({DT_INT64})}}); } @@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { item.graph = GDef( {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), NDef("input", "InputDataset", {}, - {{"output_shapes", MakeShapeListAttr({{}})}}), + {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}), MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}), MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})}, // FunctionLib @@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) { TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); } +TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) { + GrapplerItem item; + item.graph = GDef( + {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}), + NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}), + NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}), + MakeRangeNode("range", {"start", "stop", "step"}), + MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}), + MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})}, + // FunctionLib + {FunctionDefHelper::Create( + "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {}, + {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}}, + {{"res", "o:z"}, {"res2", "o:z"}})}); + MapVectorization optimizer; + GraphDef output; + TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output)); + + EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(), + 1); + EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(), + 1); + const NodeDef& map_node = + output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output)); + const NodeDef& batch_node = + output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output)); + EXPECT_EQ(map_node.input(0), batch_node.name()); + EXPECT_EQ(batch_node.input(0), "range"); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc index cb0ff670e8..99c4afa634 100644 --- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc +++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc @@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster, // Set `output_types` and `output_shapes` attributes. for (auto key : {"output_shapes", "output_types"}) { - (*new_node.mutable_attr())[key] = repeat_node.attr().at(key); + graph_utils::CopyAttribute(key, repeat_node, &new_node); } return new_node; }; diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD index 1462cb234d..985d6c6c3a 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD +++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD @@ -9,13 +9,24 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all") VECTORIZER_DEPS = [ ":vectorizer_registry", - "//tensorflow/core/grappler/optimizers/data:function_utils", + "//tensorflow/core/grappler/optimizers/data:graph_utils", ] + tf_protos_all() cc_library( + name = "wrapped_tensor", + hdrs = ["wrapped_tensor.h"], + deps = [ + "//tensorflow/core:core_cpu", + "//tensorflow/core:lib", + ], +) + +cc_library( name = "vectorizer", hdrs = ["vectorizer.h"], deps = [ + ":wrapped_tensor", + "//tensorflow/core:core_cpu", "//tensorflow/core:lib", ] + tf_protos_all(), ) diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc index c1739737a0..f445157531 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc @@ -14,41 +14,38 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class CastVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { - if (inputs.size() != 1) { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { + Status s; + if (node.num_inputs() != 1) { return errors::Internal("Cast op should only have one input."); } - // Add new Cast node - NodeDef* new_cast_node = outer_scope->add_node_def(); - *new_cast_node = node; - new_cast_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", node.name()), outer_scope, - new_cast_node); - new_cast_node->set_input(0, inputs[0]); + // Add new Cast node with the same op and attrs as the original node + auto new_cast_node = outer_scope->AddNode(node.def(), &s); + TF_RETURN_IF_ERROR(s); - // Add the output mapping to conversion map - (*conversion_map)[strings::StrCat(node.name(), ":y:0")] = - strings::StrCat(new_cast_node->name(), ":y:0"); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node, + 0); + // Add output mappings + outputs->push_back({new_cast_node, 0, true}); return Status::OK(); } }; REGISTER_VECTORIZER("Cast", CastVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc index 776d3179c5..f1ba741821 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc @@ -14,40 +14,38 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { +namespace { class UnpackVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { - if (inputs.size() != 1) { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { + Status s; + if (node.num_inputs() != 1 || inputs.size() != 1) { return errors::Internal("Unpack op should only have one input."); } - // Add new Unpack node - NodeDef* new_unpack_node = outer_scope->add_node_def(); - *new_unpack_node = node; - new_unpack_node->clear_name(); - function_utils::SetUniqueFunctionNodeName( - strings::StrCat("vectorized/", node.name()), outer_scope, - new_unpack_node); + // Add new Unpack node with the same op and attrs as the original node + auto new_unpack_node = outer_scope->AddNode(node.def(), &s); + TF_RETURN_IF_ERROR(s); // Increment "axis" attr by 1: - (*new_unpack_node->mutable_attr())["axis"].set_i( - node.attr().at("axis").i() + 1); - new_unpack_node->set_input(0, inputs[0]); + int new_axis = node.def().attr().at("axis").i() + 1; + new_unpack_node->AddAttr("axis", new_axis); - // Add the output mappings to conversion map - int num = new_unpack_node->attr().at("num").i(); + outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, + new_unpack_node, 0); + + // Add the output mappings + int num = node.def().attr().at("num").i(); for (int i = 0; i < num; ++i) { - (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] = - strings::StrCat(new_unpack_node->name(), ":output:", i); + outputs->push_back({new_unpack_node, i, true}); } return Status::OK(); @@ -56,6 +54,6 @@ class UnpackVectorizer : public Vectorizer { REGISTER_VECTORIZER("Unpack", UnpackVectorizer); -} // namespace vectorization_utils +} // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h index d341dbba7d..8d4676aae0 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h @@ -17,13 +17,13 @@ limitations under the License. #define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ #include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/lib/gtl/array_slice.h" namespace tensorflow { namespace grappler { -namespace vectorization_utils { // Interface for vectorization of TensorFlow operations. See `CastVectorizer` // for an example. @@ -31,19 +31,19 @@ class Vectorizer { public: virtual ~Vectorizer() {} - // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope` + // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope` // that produce the same vector output(s) as executing `node`'s op - // on elements of the vector inputs, and adding mappings to `conversion_map` - // from old output tensor names to new (vectorized) output tensor names. - // The new node(s) collectively have the same number of inputs and outputs as - // the node being converted, and use the tensor names in `inputs` as their - // inputs. - virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) = 0; + // on elements of `inputs`. The new Node(s) collectively have the + // same number of input and output ports as the node being converted. + // Adds edges between the newly created nodes and nodes in `inputs`, and adds + // mappings to the new nodes' output ports to `outputs`, where the i'th + // value in `outputs` corresponds to the i'th output port of the node + // to be converted. + virtual Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) = 0; }; -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow #endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc index a6551e36ac..e1cf77a7d5 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc @@ -19,7 +19,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { VectorizerRegistry* VectorizerRegistry::Global() { static VectorizerRegistry* registry = new VectorizerRegistry; @@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type, vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>( op_type, std::move(vectorizer))); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h index 16159d47ca..ad54c74933 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h @@ -23,7 +23,6 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { // A global VectorizerRegistry is used to hold all the vectorizers. class VectorizerRegistry { @@ -59,16 +58,12 @@ class VectorizerRegistration { #define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) -#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ - static ::tensorflow::grappler::vectorization_utils:: \ - vectorizer_registration::VectorizerRegistration \ - vectorizer_registration_##ctr( \ - op_type, \ - ::std::unique_ptr< \ - ::tensorflow::grappler::vectorization_utils::Vectorizer>( \ - new vectorizer())) +#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \ + static ::tensorflow::grappler::vectorizer_registration:: \ + VectorizerRegistration vectorizer_registration_##ctr( \ + op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \ + new vectorizer())) -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc index 86e303564b..054aeb9a8f 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc @@ -20,13 +20,12 @@ limitations under the License. namespace tensorflow { namespace grappler { -namespace vectorization_utils { class TestVectorizer : public Vectorizer { public: - Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs, - FunctionDef* outer_scope, - std::map<string, string>* conversion_map) override { + Status Vectorize(const Node& node, Graph* outer_scope, + std::vector<WrappedTensor>&& inputs, + std::vector<WrappedTensor>* outputs) override { return Status::OK(); } }; @@ -39,12 +38,14 @@ TEST(TestVectorizer, TestTestVectorizer) { auto vectorizer = VectorizerRegistry::Global()->Get("test_op"); EXPECT_NE(vectorizer, nullptr); - FunctionDef function; - NodeDef node; - std::map<string, string> conversion_map; - EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok()); + Graph g(OpRegistry::Global()); + NodeDef node_def; + Status s; + Node* node = g.AddNode(node_def, &s); + std::vector<WrappedTensor> inputs, outputs; + EXPECT_TRUE( + vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok()); } -} // namespace vectorization_utils } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h new file mode 100644 index 0000000000..4439b4ab4e --- /dev/null +++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h @@ -0,0 +1,44 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ +#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ + +#include "tensorflow/core/graph/graph.h" + +namespace tensorflow { +namespace grappler { + +// Represents a tensor that has been vectorized. +struct WrappedTensor { + Node* const node; + const int output_index; + + // Whether the tensor is stacked, i.e. represents the results of applying + // the operation on all slices of the input, where each row i of the + // tensor corresponds to the op's output on slice i of the input. False + // if the tensor is not stacked, i.e. represents the result of the op on + // a single slice of the input, where the result does not vary between + // slices. + bool stacked; + + WrappedTensor(Node* node, int output_index, bool stacked) + : node(node), output_index(output_index), stacked(stacked) {} +}; + +} // namespace grappler +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_ diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc index cb56b65985..ba857ab5d9 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc @@ -17,274 +17,588 @@ limitations under the License. #include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h" #include "absl/strings/str_join.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/graph_to_functiondef.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op_def.pb.h" #include "tensorflow/core/framework/types.h" +#include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/grappler/mutable_graph_view.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" #include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/lib/gtl/map_util.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { namespace vectorization_utils { -using function_utils::FunctionDefTensorDesc; - namespace { -void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node, - const string& output_retval, const DataType t) { - // Set to unknown shape - TensorShapeProto tensor_shape_proto; - PartialTensorShape().AsProto(&tensor_shape_proto); +// Describes a tensor with its operation Node and output position +typedef std::pair<Node*, int> TensorDesc; - function_utils::AddFunctionOutputWithUniqueName( - "vectorized_out", output_retval, map_defun_fn, t); +const char* const kRetValOp = "_Retval"; - *(*map_defun_node->mutable_attr())["output_shapes"] - .mutable_list() - ->add_shape() = tensor_shape_proto; - (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t); +void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src, + Graph* graph) { + // NOTE: We need two for loops here because we can't mutate the set of output + // edges as we iterate over them. + std::vector<const Edge*> edges_to_replace; + for (auto edge : old_src.first->out_edges()) { + if (edge->src_output() == old_src.second) { + edges_to_replace.push_back(edge); + } + } + for (auto edge : edges_to_replace) { + graph->AddEdge(new_src.first, new_src.second, edge->dst(), + edge->dst_input()); + graph->RemoveEdge(edge); + } } -void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node, int output_position) { - DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size()) - << "Trying to remove output that doesn't exist. Output number: " - << output_position; +Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node, + const TensorDesc& output) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DataType type = output.first->output_type(output.second); + int index = map_defun_fn->ret_nodes.size(); - int num_later_outputs = - map_defun_fn->signature().output_arg_size() - output_position - 1; + NodeDef ret_node_def; + ret_node_def.set_name("map_out"); + ret_node_def.set_op(kRetValOp); + AddNodeAttr("T", type, &ret_node_def); + AddNodeAttr("index", index, &ret_node_def); - // Remove from map_defun_fn's ret dict and output args - map_defun_fn->mutable_ret()->erase( - map_defun_fn->signature().output_arg(output_position).name()); - map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange( - output_position, 1); + Status s; + Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s); + TF_RETURN_IF_ERROR(s); - // Renumber outputs that come after - for (int i = 0; i < num_later_outputs; ++i) { - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i + 1), - strings::StrCat(map_defun_node->name(), - ":output:", output_position + i), - outer_scope); - } - map_defun_node->mutable_attr() - ->at("output_shapes") - .mutable_list() - ->mutable_shape() - ->DeleteSubrange(output_position, 1); - map_defun_node->mutable_attr() - ->at("output_types") - .mutable_list() - ->mutable_type() - ->ExtractSubrange(output_position, 1, nullptr); + map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0); + map_defun_fn->ret_nodes.push_back(ret_node); + map_defun_fn->ret_types.push_back(type); + + return s; } -int FindOutputToConvert(const FunctionDef& function, - const std::set<string>& unconvertible, - FunctionDefTensorDesc* f) { - for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) { - const string& ret_key = function.signature().output_arg(i).name(); - *f = FunctionDefTensorDesc(function.ret().at(ret_key)); +void RemoveMapDefunOutput(int output_position, Graph* outer_scope, + FunctionBody* map_defun_fn, Node* map_defun_node) { + // Note that we don't update MapDefun attrs as we go, only when we are done + DCHECK_LT(output_position, map_defun_fn->ret_nodes.size()) + << "Trying to remove output that doesn't exist. Output number: " + << output_position; - if (unconvertible.find(f->node_name) == unconvertible.end()) { - return i; - } + int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1; + + // Modify map_defun_fn's signature and remove the output node from its graph + map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]); + map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() + + output_position); + map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() + + output_position); + + // Renumber the nodes and edges that come after + for (int i = 0; i < num_later_outputs; ++i) { + ReplaceEdgeSources({map_defun_node, output_position + i + 1}, + {map_defun_node, output_position + i}, outer_scope); + // Each ret node has an "index" attr that has to be updated + map_defun_fn->ret_nodes[output_position + i]->AddAttr("index", + output_position + i); } - return -1; } // Helper class that vectorizes the body of a MapDefun node, adding new // operations to the graph that collectively compute the same value as what // running the MapDefun function on slices of the input would produce. -// Each instance of the class encapsulates all the data necessary to vectorize a -// MapDefun op in place. +// This class transforms the input FunctionDefs into their corresponding +// Graph objects and works on the graphs directly, then converts them back +// to FunctionDefs when GetResult is called. class Vectorization { public: - Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) - : outer_scope_(outer_scope), - map_defun_fn_(map_defun_fn), - map_defun_node_(map_defun_node) {} + explicit Vectorization(FunctionDefLibrary* lib) + : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {} - // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in - // the outer_scope_, until there are no convertible outputs remaining. - // This method is idempotent. - void Vectorize(); + // Adds the vectorized function and new map_defun_fn to lib, and points + // vectorized_function to the former. Returns an error status if + // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere + // along the way. + Status Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDef** result); private: - // Vectorizes the map defun function's output at output_position - Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc); - // Given a descriptor of the original output tensor, gets a string - // corresponding to the converted output tensor. - Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc, - string* converted); - Status AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc); + // Converts FunctionDefs to Graphs and adds mappings from + // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_. + Status Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node); + + // Converts Graphs back to FunctionDefs and adds them to `lib_`. + Status GetResult(FunctionDef** vectorized_function); + + // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in + // `outer_scope_`, until there are no convertible outputs remaining. + void VectorizeHelper(); + + // Vectorizes map_defun_fn's output at output_position. + Status ConvertOutput(int output_position); // Adds mappings from node's outputs tensors to converted output tensors, // creating the necessary new node(s). Generally, the steps to convert an op // are: - // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_, - // and modify map_defun_node_ attrs accordingly - // 2) Create new node(s) in outer_scope_ that act on batched input tensors. + // 1) Create new node(s) in `outer_scope_` that act on batched input tensors. // These operations collectively compute the same value as what running // the original operation on slices of the input tensors would produce. // For example, a Cast op in MapDefun translates to a Cast op in - // outer_scope_, since the vectorized version of Cast is itself. - // 3) Set inputs of new node(s) to the corresponding converted inputs (that - // are now outputs of map_defun_node_) - // 4) For each output of the old node, add the mapping of output strings to - // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0") - Status AddConversionMappingFromOp(const NodeDef& node, - const FunctionDefTensorDesc& output_desc); - - // Maps a tensor name to the name of the corresponding vectorized tensor. For - // example, "Cast:y:0" -> "Vectorize/Cast:y:0" - std::map<string, string> conversion_map_; - // Unconvertible node names - std::set<string> unconvertible_; - - FunctionDef* outer_scope_; - FunctionDef* map_defun_fn_; - NodeDef* map_defun_node_; + // `outer_scope_`, since the vectorized version of Cast is itself. + // 2) Promote the inputs of the op inputs to outputs of the + // `map_defun_node_` and `map_defun_fn_`. + // 3) Add edges between the promoted inputs (that are now outputs of + // `map_defun_node`) and the inputs ports of the new node(s). + // 4) For each output of the old node, add the mapping of output tensors to + // the conversion map. + Status AddConversionMapping(Node* op_node); + + // Given a tensor t in `unstacked`, stacks it by doing the equivalent of + // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of + // inputs to `map_defun_node_`. This stacked tensor will be compatible with + // the expected output shape of `map_defun_node_`. + // This is equivalent to the _stack function in python Pfor. + Status StackTensor(WrappedTensor* unstacked, TensorDesc* result); + + // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by + // doing a depth-first search from the ret nodes. Lifts nodes that are + // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly + // and add mappings to `conversion_map_`. + Status AddUnstackedNodeMappings(); + + // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor + // is unstacked. + bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status); + + // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input + // nodes to `conversion_map_`. + Status AddArgNodeMappings(); + + // Maps a tensor to the corresponding WrappedTensor. For example, + // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true) + std::map<TensorDesc, WrappedTensor> conversion_map_; + + // Unconvertible ret nodes + std::set<Node*> unconvertible_; + + FunctionDefLibrary* lib_; // Not owned + FunctionLibraryDefinition lib_def_; + // Note that FunctionBody has a pointer to a Graph object that corresponds + // to the function's subgraph, with additional kArgOp and kRetValOp nodes + // that denote that function arguments and return values. These nodes have the + // attrs "T" for the type, and "index" for the argument / retval index + // respectively. FunctionBody also keeps track of arg/ret_nodes and + // arg/ret_types, that should be ordered according to argument/output indices. + std::unique_ptr<Graph> outer_scope_; + std::unique_ptr<FunctionBody> map_defun_fn_; + Node* map_defun_node_ = nullptr; // Owned by `outer_scope` + + // Caches the loop_len_node_ needed for tiling unstacked output. This + // corresponds to a vector with one element. + Node* loop_len_node_ = nullptr; // Owned by `outer_scope` + Status status_; }; -Status Vectorization::AddConversionMappingFromOp( - const NodeDef& node, const FunctionDefTensorDesc& output_desc) { - for (const string& input_name : node.input()) { - if (IsControlInput(input_name)) { +Status Vectorization::AddConversionMapping(Node* op_node) { + for (auto edge : op_node->in_edges()) { + if (edge->IsControlEdge()) { return errors::InvalidArgument( "Vectorizing outputs with control inputs is currently not " "supported."); } } - // TODO(rachelim): Have some mechanism for registering converters and some - // uniform, simpler way to represent them. + auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string()); + if (vectorizer == nullptr) { + return errors::Unimplemented("No vectorizer registered for op: ", + op_node->type_string()); + } + std::vector<WrappedTensor> inputs, outputs; + inputs.reserve(op_node->num_inputs()); + outputs.reserve(op_node->num_outputs()); + + std::vector<const Edge*> input_edges; + TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges)); + + // The inputs for the node to be converted may already have been converted + // themselves. For those that are not, we promote them to MapDefun outputs. + for (size_t i = 0; i < op_node->num_inputs(); ++i) { + auto edge = input_edges[i]; + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + inputs.push_back(*found); + } else { + // TODO(rachelim): Handle the case where unconverted inputs are unstacked. + // We assume that all unconverted inputs will be stacked, since we + // converted all unstacked nodes in `Initialize`. However, it's actually + // possible that yet-unconverted nodes may produce unstacked outputs after + // they are vectorized. (For example, see the "Shape" converter in + // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects + // an unstacked input but receives a stacked one, vectorizer->Vectorize + // will return an error. + TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_, + {edge->src(), edge->src_output()})); + int output_index = map_defun_fn_->ret_nodes.size() - 1; + inputs.push_back({map_defun_node_, output_index, true}); + } + } - DataTypeVector types; - const OpDef* op_def = nullptr; - TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def)); - TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types)); + TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(), + std::move(inputs), &outputs)); - std::vector<string> promoted_inputs; - promoted_inputs.reserve(node.input_size()); - for (int i = 0; i < node.input_size(); ++i) { - promoted_inputs.push_back(strings::StrCat( - map_defun_node_->name(), - ":output:", map_defun_fn_->signature().output_arg_size() + i)); + if (op_node->num_outputs() != outputs.size()) { + return errors::Internal( + "Number of vectorizer outputs does not match. Expected: ", + op_node->num_outputs(), " Actual: ", outputs.size()); } - auto vectorizer = VectorizerRegistry::Global()->Get(node.op()); - if (vectorizer == nullptr) { - return errors::Unimplemented("No vectorizer registered for op: ", - node.op()); + // Add output mappings. + for (size_t i = 0; i < op_node->num_outputs(); ++i) { + conversion_map_.insert({{op_node, i}, outputs[i]}); } - TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_, - &conversion_map_)); + return Status::OK(); +} + +Status Vectorization::ConvertOutput(int output_position) { + // ret_edge->src() is the actual op that generated the retval, and + // ret_edge->dst() is the retval node whose op is "_Retval" + const Edge* ret_edge; + TF_RETURN_IF_ERROR( + map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge)); + + TensorDesc output({ret_edge->src(), ret_edge->src_output()}); + TensorDesc converted_output; + + // It's possible the output already has a mapping, if it comes from a node + // that has already been converted. + auto found = gtl::FindOrNull(conversion_map_, output); + if (!found) { + TF_RETURN_IF_ERROR(AddConversionMapping(output.first)); + found = &conversion_map_.at(output); + } - // If we get here, the conversion was successful, so we promote the inputs - // of the ops to MapDefun outputs. - for (int i = 0; i < types.size(); ++i) { - AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]); + if (found->stacked) { + converted_output = {found->node, found->output_index}; + } else { + // Some outputs may be unstacked if they don't derive from arg nodes + // (for example, if a function returns a constant). For these, we + // have to add extra nodes to tile it in the 0th dimension. + TF_RETURN_IF_ERROR(StackTensor(found, &converted_output)); } + ReplaceEdgeSources({map_defun_node_, output_position}, converted_output, + outer_scope_.get()); + RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(), + map_defun_node_); + return Status::OK(); } -Status Vectorization::AddConversionMappingFromInput( - const FunctionDefTensorDesc& output_desc) { - int input_index = function_utils::FindFunctionInputWithName( - output_desc.node_name, *map_defun_fn_); - if (input_index == -1) { - return errors::Internal("Cannot convert non-existent input."); +Status Vectorization::Vectorize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, + FunctionDef** result) { + TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node)); + VectorizeHelper(); + return GetResult(result); +} + +void Vectorization::VectorizeHelper() { + while (true) { + int output_position = graph_utils::GetFirstElementIndexWithPredicate( + [this](Node* n) { + return this->unconvertible_.find(n) == this->unconvertible_.end(); + }, + map_defun_fn_->ret_nodes); + + // No outputs left to convert + if (output_position == -1) break; + + Status s = ConvertOutput(output_position); + if (!s.ok()) { + Node* output_node = map_defun_fn_->ret_nodes.at(output_position); + VLOG(2) << "Could not convert the output at node: " + << output_node->DebugString() << "\nError: " << s; + unconvertible_.insert(output_node); + } + } + + // If we've converted all the outputs of the MapDefun function, we no longer + // need the MapDefun node and can delete it. + if (map_defun_fn_->ret_nodes.empty()) { + outer_scope_->RemoveNode(map_defun_node_); + } else { + // Update MapDefun node attrs accordingly + DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size()); + map_defun_node_->AddAttr( + "output_shapes", + std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size())); + map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types); + } +} + +Status Vectorization::Initialize(const FunctionDef& outer_scope, + const NodeDef& map_defun_node) { + // Convert outer_scope and map_defun_fn to FunctionBodys so we can + // work on Graphs directly. + const FunctionDef* map_defun_fn = + lib_def_.Find(map_defun_node.attr().at("f").func().name()); + + if (map_defun_fn == nullptr) { + return errors::NotFound("Could not find function with name ", + map_defun_node.attr().at("f").func().name(), + " in function library."); + } + + auto get_func_sig = [this](const string& op, const OpDef** sig) { + return this->lib_def_.LookUpOpDef(op, sig); + }; + + FunctionBody* outer_fn; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_, + get_func_sig, &outer_fn)); + // We don't need outer_fn, just the graph + outer_scope_.reset(outer_fn->graph); + outer_fn->graph = nullptr; + delete outer_fn; + + FunctionBody* tmp; + TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_, + get_func_sig, &tmp)); + map_defun_fn_.reset(tmp); + + // Find the MapDefun node in outer_scope_ + int node_id = graph_utils::GetFirstElementIndexWithPredicate( + [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); }, + outer_scope_->nodes()); + if (node_id == -1) { + return errors::NotFound("Could not find node with name ", + map_defun_node.name(), " in outer_scope."); } + map_defun_node_ = outer_scope_->FindNodeId(node_id); + + TF_RETURN_IF_ERROR(AddArgNodeMappings()); + + TF_RETURN_IF_ERROR(AddUnstackedNodeMappings()); + loop_len_node_ = nullptr; - conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index); return Status::OK(); } -Status Vectorization::ConvertOutputHelper( - const FunctionDefTensorDesc& output_desc, string* converted) { - // It's possible the output already has a mapping, if it comes from a node - // that has already been converted. - if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) { - *converted = *found; - return Status::OK(); +// TODO(rachelim): It might be profitable to use the C++ API for this instead of +// NodeBuilder +Status Vectorization::StackTensor(WrappedTensor* unstacked, + TensorDesc* result) { + // Note that all these nodes are necessary as the size of the batch may not be + // constant. + if (unstacked->stacked) { + return errors::Internal("Can only stack unstacked tensor."); } - int index = function_utils::FindFunctionNodeWithName(output_desc.node_name, - *map_defun_fn_); - if (index == -1) { // The output comes from an input - TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc)); - } else { - TF_RETURN_IF_ERROR(AddConversionMappingFromOp( - map_defun_fn_->node_def(index), output_desc)); + Graph* g = outer_scope_.get(); + auto node_builder = [](StringPiece op) { + return NodeBuilder(strings::StrCat("vectorized/stack/", op), op); + }; + + auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph, + Node** result) { + TF_RETURN_IF_ERROR(val.status); + return node_builder("Const") + .Attr("value", val.tensor) + .Attr("dtype", val.tensor.dtype()) + .Finalize(graph, result); + }; + + // If loop_len_node_ hasn't been created yet, add the node and cache it. + if (loop_len_node_ == nullptr) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node)); + + Node* shape_node; + TF_RETURN_IF_ERROR( + node_builder("Shape").Input(input_node).Finalize(g, &shape_node)); + + Node* const_vec_0; + TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0)); + Node* const_vec_1; + TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1)); + + Node* strided_slice_node; + TF_RETURN_IF_ERROR(node_builder("StridedSlice") + .Input(shape_node) // input + .Input(const_vec_0) // begin + .Input(const_vec_1) // end + .Input(const_vec_1) // strides + .Finalize(g, &strided_slice_node)); + + // Produces a vector of length 1 + TF_RETURN_IF_ERROR(node_builder("Reshape") + .Input(strided_slice_node) // tensor + .Input(const_vec_1) // shape + .Finalize(g, &loop_len_node_)); } - *converted = conversion_map_.at(output_desc.full_str); + + Node* ones_shape; + TF_RETURN_IF_ERROR(node_builder("Shape") + .Input(unstacked->node) // input + .Finalize(g, &ones_shape)); + + Node* ones; + TF_RETURN_IF_ERROR( + node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones)); + + Node* const_0; + TF_RETURN_IF_ERROR(make_const(0, g, &const_0)); + + Node* multiples; + TF_RETURN_IF_ERROR(node_builder("Concat") + .Input(const_0) // concat_dim + .Input({{loop_len_node_, 0}, {ones, 0}}) // values + .Finalize(g, &multiples)); + + Node* expand_dims; + TF_RETURN_IF_ERROR(node_builder("ExpandDims") + .Input(unstacked->node) // input + .Input(const_0) // dim + .Finalize(g, &expand_dims)); + + TF_RETURN_IF_ERROR(node_builder("Tile") + .Input(expand_dims) // input + .Input(multiples) // multiples + .Finalize(g, &result->first)); + result->second = 0; return Status::OK(); } -Status Vectorization::ConvertOutput(int output_position, - const FunctionDefTensorDesc& output_desc) { - string converted_output_name; - TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name)); +Status Vectorization::AddArgNodeMappings() { + for (auto arg_node : map_defun_fn_->arg_nodes) { + Node* input_node; + TF_RETURN_IF_ERROR(map_defun_node_->input_node( + arg_node->attrs().Find("index")->i(), &input_node)); - // Remove the old output and make everything that referenced it point - // to the new string - function_utils::ReplaceReferences( - strings::StrCat(map_defun_node_->name(), ":output:", output_position), - converted_output_name, outer_scope_); - RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_, - output_position); + conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}}); + // Control inputs + conversion_map_.insert({{arg_node, Graph::kControlSlot}, + {input_node, Graph::kControlSlot, true}}); + } return Status::OK(); } -void Vectorization::Vectorize() { - while (true) { - FunctionDefTensorDesc desc; - int output_position = - FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc); - if (output_position == -1) break; +bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, + Status* status) { + if (auto found = gtl::FindOrNull(conversion_map_, tensor)) { + return !found->stacked; + } - if (!ConvertOutput(output_position, desc).ok()) { - unconvertible_.insert(desc.node_name); + if (tensor.first->op_def().is_stateful()) { + // We don't lift stateful nodes directly out of the MapDefun, since they may + // have to be executed N times. + return false; + } + + bool is_unstacked = true; + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + // A node is unstacked if all of its inputs are unstacked + is_unstacked &= AddUnstackedNodeMappingsHelper( + {edge->src(), edge->src_output()}, status); + } + + if (!is_unstacked) { + return false; + } + + // If the node is unstacked, we copy it into outer_scope_ and + // add it to the map. Note that we don't clean up the nodes that are copied + // in map_defun_fn_, and rely on them being pruned out later. + Node* node = outer_scope_->AddNode(tensor.first->def(), status); + if (!status->ok()) return true; + + // Add input edges to nodes that should already have been lifted. + for (auto edge : tensor.first->in_edges()) { + // Ignore Source nodes. Note that these are also ignored in the + // GraphToFunctionDef conversion. + if (edge->src()->IsSource()) continue; + + if (auto found = gtl::FindOrNull(conversion_map_, + {edge->src(), edge->src_output()})) { + outer_scope_->AddEdge(found->node, found->output_index, node, + edge->dst_input()); + } else { + status->Update(errors::Internal( + "Could not find input conversion even though we did depth first " + "conversion.")); } } - // If we've converted all the outputs of the MapDefun function, we no longer - // need the MapDefun node and can delete it. - if (map_defun_fn_->signature().output_arg_size() == 0) { - outer_scope_->mutable_node_def()->DeleteSubrange( - function_utils::FindFunctionNodeWithName(map_defun_node_->name(), - *outer_scope_), - 1); + // Add output mappings + for (int i = 0; i < tensor.first->num_outputs(); ++i) { + conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)}); + } + conversion_map_.insert({{tensor.first, Graph::kControlSlot}, + WrappedTensor(node, Graph::kControlSlot, false)}); + + return true; +} + +Status Vectorization::AddUnstackedNodeMappings() { + SetVector<Node*> unstacked_nodes; + Status s; + for (const auto& ret_node : map_defun_fn_->ret_nodes) { + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge)); + AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s); + TF_RETURN_IF_ERROR(s); } + return Status::OK(); +} - if (!unconvertible_.empty()) { - VLOG(2) << "The following nodes could not be converted: [" - << absl::StrJoin(unconvertible_, ", ") << "]."; +Status Vectorization::GetResult(FunctionDef** vectorized_function) { + TF_RETURN_IF_ERROR(status_); + TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get())); + TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph)); + + if (!map_defun_fn_->ret_nodes.empty()) { + FunctionDef* map_defun_fn = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn)); + + AttrValue func_attr; + func_attr.mutable_func()->set_name(map_defun_fn->signature().name()); + map_defun_node_->AddAttr("f", func_attr); } + + *vectorized_function = lib_->add_function(); + graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_, + *vectorized_function); + TF_RETURN_IF_ERROR(GraphToFunctionDef( + *outer_scope_, (*vectorized_function)->signature().name(), + *vectorized_function)); + return Status::OK(); } + } // namespace -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node) { - Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize(); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result) { + *result = nullptr; + return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result); } } // end namespace vectorization_utils diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h index bb405faa77..bd7d390900 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h @@ -24,22 +24,28 @@ namespace tensorflow { namespace grappler { namespace vectorization_utils { -// Given a function, `map_defun_fn`, that is mapped across some input vector -// elements via a MapDefun operation, `VectorizeMapDefun` attempts to -// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the -// `outer_scope`; that is, replacing `map_defun_fn` operations with new -// `outer_scope` operations that produce the same vector output(s) as executing -// the `map_defun_fn` operations on elements of vector input(s) would. If all -// `map_defun_fn` operations are successfully lifted, `map_defun_node` is -// eliminated from `outer_scope` altogether. However, if some operations cannot -// be lifted, and this vectorization only succeeds partially, `map_defun_node` -// remains to be used for operations that were not lifted. +// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`) +// that maps a function in lib across some input vector elements, +// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope` +// by "lifting" operations from the MapDefun function to the new function +// (`result`); that is, replacing operations in the MapDefun function with +// operations that produce the same vector output(s) as executing the original +// operations on elements of vector input(s) would. If all operations in the +// MapDefun function are successfully lifted, `result` has no MapDefun node +// altogether. However, if some operations cannot be lifted, and this +// vectorization only succeeds partially, a MapDefun node remains in `result` to +// be used for operations that were not lifted, and the modified MapDefun +// function is added to `lib`. The newly vectorized function `result` is also +// added to `lib`. +// +// Returns Status::OK() if the vectorization is completely or partially +// successful. Otherwise, returns an error, and sets `result` to nullptr. // // Example: // If the input to the `VectorizeMapDefun` function is a MapDefun // whose `map_defun_fn` performs the Cast operation, the vectorization will // eliminate the MapDefun. This is because the Cast operation supports -// any tensor shape and can thus be lifted to the `outer_scope`. +// any tensor shape and can thus be lifted to `result`. // // Before: // @@ -68,7 +74,7 @@ namespace vectorization_utils { // // After: // -// outer_scope +------+ +// result +------+ // +---------------+ Arg0 +---------+ // | +---+--+ | // | | | @@ -80,8 +86,9 @@ namespace vectorization_utils { // +---------------+ Ret0 +---------+ // +------+ // -void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn, - NodeDef* map_defun_node); +Status VectorizeMapDefun(const FunctionDef& outer_scope, + const NodeDef& map_defun_node, FunctionDefLibrary* lib, + FunctionDef** result); } // end namespace vectorization_utils } // end namespace grappler diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc index e129fa9237..a6020e36bb 100644 --- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc +++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/grappler/optimizers/data/function_utils.h" +#include "tensorflow/core/grappler/optimizers/data/graph_utils.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/tools/graph_transforms/transform_utils.h" @@ -54,12 +55,18 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs, func.set_name(function_name); NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn); graph_transforms::SetNodeAttr("Targuments", t_arguments, node); + graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node); graph_transforms::SetNodeAttr("output_types", output_types, node); graph_transforms::SetNodeAttr("output_shapes", output_shapes, node); graph_transforms::SetNodeAttr("f", func, node); return node; } +string GetRetval(const FunctionDef& function_def, int index) { + return function_def.ret().at( + function_def.signature().output_arg(index).name()); +} + // TODO(rachelim): Use FunctionDefHelper::Create instead FunctionDef CreateFunction( StringPiece name, const std::vector<std::pair<string, DataType>>& inputs, @@ -85,7 +92,6 @@ FunctionDef CreateFunction( return func; } -TEST(FunctionDefInputDescTest, ConstructedCorrectly) {} // Before: // @@ -133,10 +139,17 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - EXPECT_EQ(outer.ret().at("mapdefun"), "ret0"); - EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1"); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized); + LOG(ERROR) << s; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_EQ(GetRetval(*vectorized, 0), "ret0"); + EXPECT_EQ(GetRetval(*vectorized, 1), "ret1"); } // Before: @@ -149,12 +162,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // | +-----------+ Arg0 +---+ Arg1 +----+ | // | | +---+--+ +---+--+ | | // | | | | | | -// | | +------+ | +---v--+ | | -// | | |Const | | | Op0 | | | -// | | +---v--+ | +---+--+ | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | // | | | | | | | // | | | +---v--+ +---v--+ | | -// | | +---| XOp1 | | XOp2 | | | +// | | +---| XOp1 | | Cast | | | // | | +---+--+ +---+--+ | | // | | | | | | // | | MapDefun +---v--+ +---v--+ | | @@ -165,23 +178,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) { // +---------------+ Ret0 +---+ Ret1 +--------+ // +------+ +------+ // -// where XOp1 and XOp2 are not convertible. +// where XOp1 is not convertible. // // After: // -// No change because the ops are not convertible. +// +// +------+ +------+ +// +---------------+ Arg0 +---+ Arg1 +--------+ +// | +---+--+ +---+--+ | +// | | | | +// | +---v--+ | | +// | +-----------+ Arg0 +-+ | | +// | | +---+--+ | | | +// | | | | | | +// | | +------+ | | | | +// | | |Const | | | | | +// | | +---v--+ | | | | +// | | | | | | | +// | | | +---v--+ | +---v--+ | +// | | +---| XOp1 | | | Cast | | +// | | +---+--+ | +---+--+ | +// | | | | | | +// | | MapDefun +---v--+ | | | +// | +-----------+ Ret0 +-+ | | +// | +---+--+ | | +// | | | | +// | +---v--+ +---v--+ | +// +---------------+ Ret0 +---+ Ret1 +--------+ +// +------+ +------+ // TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}}, {{"ret0", DT_INT32}, {"ret1", DT_INT32}}, - {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}}); + {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}}); + // TODO(rachelim): If we ever write a converter for MatMul, we have to + // change this test. NodeDef* x_op1 = - function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner); + function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner); CHECK_NOTNULL(x_op1); + graph_transforms::SetNodeAttr("T", DT_INT32, x_op1); - NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner); - CHECK_NOTNULL(x_op2); + NodeDef* cast_node = + AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner); + CHECK_NOTNULL(cast_node); FunctionDef outer = CreateFunction( "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}}, @@ -193,12 +233,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); - // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); - EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto map_defun_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized)); + // The Cast node should be converted just fine. + EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0"); + + // The inner function should only have one retval. + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1); } // Before: @@ -257,14 +307,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -330,16 +385,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) { {{}, {}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(cast_node.name(), ":y:0")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -411,21 +471,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), "x"); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 1); + EXPECT_EQ(vectorized->node_def_size(), 1); } // Before: @@ -486,7 +551,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {"ret1", "MyUnstack:output:1"}, {"ret2", "MyUnstack:output:2"}}); NodeDef* cast_op = - AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner); + AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner); CHECK_NOTNULL(cast_op); NodeDef* unstack_op = AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner); @@ -505,25 +570,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) { {{1}, {1}, {1}}, inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - VectorizeMapDefun(&outer, &inner, map_defun); - EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer)); - const NodeDef& cast_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer)); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + const NodeDef& cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); EXPECT_EQ(cast_node.input(0), "x"); - const NodeDef& unpack_node = - outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer)); + const NodeDef& unpack_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Unpack", *vectorized)); EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0")); EXPECT_EQ(unpack_node.attr().at("axis").i(), 1); EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32); EXPECT_EQ(unpack_node.attr().at("num").i(), 3); - EXPECT_EQ(outer.ret().at("mapdefun"), + EXPECT_EQ(GetRetval(*vectorized, 0), strings::StrCat(unpack_node.name(), ":output:0")); - EXPECT_EQ(outer.ret().at("mapdefun_0"), + EXPECT_EQ(GetRetval(*vectorized, 1), strings::StrCat(unpack_node.name(), ":output:1")); - EXPECT_EQ(outer.ret().at("mapdefun_1"), + EXPECT_EQ(GetRetval(*vectorized, 2), strings::StrCat(unpack_node.name(), ":output:2")); - EXPECT_EQ(outer.node_def_size(), 2); + EXPECT_EQ(vectorized->node_def_size(), 2); } // Before: @@ -561,9 +631,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { FunctionDef inner = CreateFunction("inner_function", {{"arg0", DT_INT32}}, {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}}); - // The attrs aren't relevant - NodeDef* print_op = - function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner); + NodeDef* print_op = function_utils::AddNode( + "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner); + graph_transforms::SetNodeAttr("T", DT_INT32, print_op); + graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}), + print_op); CHECK_NOTNULL(print_op); NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64, false, &inner); @@ -578,11 +650,278 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) { inner.signature().name(), &outer); CHECK_NOTNULL(map_defun); - FunctionDef outer_copy(outer); - FunctionDef inner_copy(inner); - VectorizeMapDefun(&outer, &inner, map_defun); + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); // They should be unchanged - EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer)); + // We check this somewhat manually as the names of nodes may have changed + EXPECT_EQ(vectorized->node_def_size(), 1); + const NodeDef& map_defun_node = vectorized->node_def(0); + EXPECT_EQ(map_defun_node.op(), "MapDefun"); + FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib); + const FunctionDef* map_defun_fn = + lib_def.Find(map_defun_node.attr().at("f").func().name()); + + const NodeDef& print_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn)); + const NodeDef& cast_node = map_defun_fn->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn)); + string control_input = strings::StrCat("^", print_node.name()); + EXPECT_TRUE(cast_node.input(0) == control_input || + cast_node.input(1) == control_input); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeConst) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Const:output:0"}}); + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized)); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | | | +// | | +------+ | | +// | | |Const | | | +// | | +---+--+ | | +// | | | | | +// | | +---v--+ | | +// | | | Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | +------+ | +// | |Const | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | | Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | |Stack*| | +// | +---+--+ | +// | | | +// | | | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Const", *vectorized)); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node.name()); +} + +// Before: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +---+--+ | +// | | | +// | +---v--+ | +// | +-----------+ Arg0 +-----+ | +// | | +------+ | | +// | | | | +// | | +------+ +------+ | | +// | | |Const | |Const | | | +// | | +---+--+ +---+--+ | | +// | | : +---v--+ | | +// | | ::::::> Cast | | | +// | | +---+--+ | | +// | | | | | +// | | MapDefun +---v--+ | | +// | +-----------+ Ret0 +-----+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// +// +// After: +// +// +// +------+ +// +---------------+ Arg0 +---------+ +// | +------+ | +// | | +// | | +// | +------+ | +// | +------+ |Const | | +// | |Const | +---+--+ | +// | +---+--+ | | +// | : +---v--+ | +// | ::::::> Cast | | +// | +---+--+ | +// | | | +// | +---v--+ | +// | +Stack*+ | +// | +---+--+ | +// | | | +// | +---v--+ | +// +---------------+ Ret0 +---------+ +// +------+ +// *Not actually a Stack node, but does the equivalent. +// +TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) { + FunctionDef inner = FunctionDefHelper::Create( + "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */}, + {/* nodes */ FunctionDefHelper::Const("Const", 2), + FunctionDefHelper::Const("ConstDep", 3)}, + {{"ret0", "Cast:y:0"}}); + AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64, + false, &inner); + + FunctionDef outer = FunctionDefHelper::Create( + "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"}, + {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}}); + + NodeDef* map_defun = + AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}}, + inner.signature().name(), &outer); + + FunctionDefLibrary lib; + *lib.add_function() = outer; + *lib.add_function() = inner; + + FunctionDef* vectorized; + EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok()); + + auto find_const = [vectorized](int val) -> const NodeDef* { + for (const auto& n : vectorized->node_def()) { + if (n.attr().at("value").tensor().int_val(0) == val) { + return &n; + } + } + return nullptr; + }; + + EXPECT_TRUE( + !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized)); + auto const_node = find_const(2); + auto const_dep_node = find_const(3); + auto cast_node = vectorized->node_def( + function_utils::FindFunctionNodeWithOp("Cast", *vectorized)); + EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')), + const_node->name()); + EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name())); } // TODO(rachelim): More test cases when we get around to implementing them: diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc index c59645e5f2..3f33b16ba8 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils/functions.h" #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/util/ptr_util.h" namespace tensorflow { @@ -106,7 +107,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( MK_OPT("scoped_allocator", new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); - MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization())); + MK_OPT("pin_to_host", + new PinToHostOptimizer(cfg_.pin_to_host_optimization())); return std::unique_ptr<GraphOptimizer>(); } @@ -115,6 +117,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer( Status MetaOptimizer::InitializeOptimizers( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { + if (cfg_.disable_meta_optimizer()) { + return Status::OK(); + } if (!cfg_.disable_model_pruning()) { optimizers->push_back(MakeUnique<ModelPruner>()); } @@ -172,11 +177,12 @@ Status MetaOptimizer::InitializeOptimizers( optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>( cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts())); } - return InitializeCustomGraphOptimizers(optimizers); + return InitializeCustomGraphOptimizers(std::set<string>(), optimizers); } Status MetaOptimizer::InitializeOptimizersByName( std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { + std::set<string> initialized_custom_optimizers; for (const string& optimizer_name : cfg_.optimizers()) { auto optimizer = MakeNewOptimizer(optimizer_name); if (optimizer) { @@ -190,18 +196,26 @@ Status MetaOptimizer::InitializeOptimizersByName( if (custom_optimizer) { VLOG(2) << "Registered custom graph optimizer: " << optimizer_name; - TF_RETURN_IF_ERROR(custom_optimizer->Init()); + TF_RETURN_IF_ERROR(custom_optimizer->Init( + GetCustomGraphOptimizerConfig(optimizer_name))); optimizers->push_back(std::move(custom_optimizer)); + initialized_custom_optimizers.insert(optimizer_name); } else { VLOG(2) << "Can't register an optimizer by name: " << optimizer_name; } } - return InitializeCustomGraphOptimizers(optimizers); + return InitializeCustomGraphOptimizers(initialized_custom_optimizers, + optimizers); } Status MetaOptimizer::InitializeCustomGraphOptimizers( + const std::set<string>& pre_initialized_optimizers, std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const { for (const auto& optimizer_config : cfg_.custom_optimizers()) { + if (pre_initialized_optimizers.find(optimizer_config.name()) != + pre_initialized_optimizers.end()) { + continue; + } // Initialize the ExperimentalImplementationSelector here instead of // CustomizeOptimizer registry, due the static link issue in TensorRT for // double registry. @@ -237,6 +251,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers( return Status::OK(); } +const RewriterConfig::CustomGraphOptimizer* +MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const { + for (const auto& config : cfg_.custom_optimizers()) { + if (config.name() == name) { + return &config; + } + } + return nullptr; +} + Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item, GraphDef* optimized_graph) { int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes @@ -391,6 +415,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, FunctionLibraryDefinition flib(OpRegistry::Global(), optimized_graph->library()); + // Find functions for which we might need to compute a gradient at runtime. + gtl::FlatSet<string> differentiable_functions; + for (const NodeDef& node : optimized_graph->node()) { + if (IsSymbolicGradient(node)) { + const auto* f_attr = gtl::FindOrNull(node.attr(), "f"); + if (f_attr) differentiable_functions.insert(f_attr->func().name()); + } + } + // Optimize each function only once. std::unordered_set<string> optimized_funcs; bool optimize_function_library = true; @@ -406,6 +439,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // Skip parametrized functions (function type or body is defined only at // function call time by caller node attributes). + // They should be specialized to their instantiation type parameters by + // the function optimizer, before we can optimize function body. if (IsParametrized(func)) continue; VLOG(3) << "Optimize function: function=" << func_name; @@ -420,6 +455,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem( func, flib, item.graph.versions().producer(), &func_item)); + // If we need to compute the gradient of optimized function at runtime, we + // can't perform non-differentiable rewrites. + if (differentiable_functions.find(func_name) != + differentiable_functions.end()) { + func_item.allowed_optimizations.non_differentiable_rewrites = false; + } + // Optimize function body graph. GraphDef optimized_func_graph; TF_RETURN_IF_ERROR( @@ -470,6 +512,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item, } bool MetaOptimizerEnabled(const RewriterConfig& cfg) { + if (cfg.disable_meta_optimizer()) { + return false; + } return !cfg.disable_model_pruning() || cfg.layout_optimizer() != RewriterConfig::OFF || cfg.function_optimization() != RewriterConfig::OFF || diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h index 831c5e37c0..99a0a33ffa 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer.h +++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h @@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer { std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; // Initialize active optimizers from RewriterConfig.custom_optimizers. Status InitializeCustomGraphOptimizers( + const std::set<string>& pre_initialized_optimizers, std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const; + // Returns the config for a custom graph optimizer. Null if none was found. + const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig( + const string& name) const; // Run optimization pass over a single GrapplerItem. Meta optimizer might run // multiple such passes: 1) for the main graph 2) for the function library diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc index e74e0f7501..3f3f43382f 100644 --- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/grappler/utils.h" #include "tensorflow/core/grappler/utils/grappler_test.h" #include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -71,6 +72,59 @@ class TestGraphOptimizer : public TestOptimizer { REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer); +class TestOptimizerWithParams : public TestOptimizer { + public: + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + CHECK(config != nullptr); + return Status::OK(); + } +}; + +REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams); + +// Record various properties of the GrapplerItems passed for optimization. +class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer { + public: + static void SetAllowedOptimizations( + gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + allowed_optimizations) { + allowed_optimizations_ = allowed_optimizations; + } + static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; } + + GrapplerItemPropertiesAccumulator() {} + string name() const override { + return "grappler_item_properties_accumulator"; + } + + Status Init( + const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override { + return Status::OK(); + } + + Status Optimize(Cluster* cluster, const GrapplerItem& item, + GraphDef* optimized_graph) override { + *optimized_graph = item.graph; + if (allowed_optimizations_) { + allowed_optimizations_->insert({item.id, item.allowed_optimizations}); + } + return Status::OK(); + } + + void Feedback(Cluster* cluster, const GrapplerItem& item, + const GraphDef& optimized_graph, double result) override {} + + private: + static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + allowed_optimizations_; +}; + +gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>* + GrapplerItemPropertiesAccumulator::allowed_optimizations_; + +REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator); + class MetaOptimizerTest : public GrapplerTest {}; TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { @@ -90,6 +144,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) { EXPECT_TRUE(TestOptimizer::IsOptimized()); } +TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) { + TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); + GrapplerItem item; + CHECK(fake_input.NextItem(&item)); + + TestOptimizer::SetOptimized(false); + RewriterConfig rewriter_config; + rewriter_config.add_optimizers("TestOptimizerWithParams"); + auto* custom_config = rewriter_config.add_custom_optimizers(); + custom_config->set_name("TestOptimizerWithParams"); + (*custom_config->mutable_parameter_map())["foo"] = AttrValue(); + + MetaOptimizer optimizer(nullptr, rewriter_config); + GraphDef output; + const Status status = optimizer.Optimize(nullptr, item, &output); + TF_EXPECT_OK(status); + EXPECT_TRUE(TestOptimizer::IsOptimized()); +} + TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) { TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"}); GrapplerItem item; @@ -305,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) { test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]); } +TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) { + using test::function::NDef; + using FDH = FunctionDefHelper; + + // We will record what type of optimizations meta optimizer allows for each + // GrapplerItem (main graph and graphs for each function). + gtl::FlatMap<string, GrapplerItem::AllowedOptimizations> + allowed_optimizations; + GrapplerItemPropertiesAccumulator::SetAllowedOptimizations( + &allowed_optimizations); + + // Just record properties of optimized Grappler items. + RewriterConfig rewriter_config; + rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO); + rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator"); + rewriter_config.set_min_graph_nodes(-1); + + MetaOptimizer optimizer(nullptr, rewriter_config); + + // Define simple function library with two identical mul functions. + FunctionDef mul_func_1 = FunctionDefHelper::Create( + "MyMul1", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + FunctionDef mul_func_2 = FunctionDefHelper::Create( + "MyMul2", {"x:float", "y:float"}, {"z:float"}, {}, + {{{"mul"}, "Mul", {"x", "y"}, {}}}, + /* Mapping between function returns and function node outputs. */ + {{"z", "mul:z:0"}}); + + // Tensorflow graph: + // + // x0 = tf.Placeholder(tf.float); + // x1 = tf.Placeholder(tf.float); + // dy = tf.Placeholder(tf.float); + // + // mul_1 = MyMul1(x0, x1); + // mul_2 = MyMul2(x0, x1); + // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2) + GrapplerItem item; + item.id = "main"; + item.graph = test::function::GDef( + {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice), + // Calls into function library + NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice), + NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice), + // Symbolic gradient of a MyMul2 + NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"}, + {{"f", FDH::FunctionRef("MyMul2", {})}, + {"Tin", DataTypeSlice{DT_FLOAT}}, + {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}}, + kDevice)}, + // FunctionLib + {mul_func_1, mul_func_2}); + item.fetch = {"mul_1", "mul_2", "dx"}; + + GraphDef output; + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + // Our custom optimizer must be called for the main graph and for the two + // functions. + ASSERT_EQ(allowed_optimizations.size(), 3); + + auto allowed_optimizations_main = + gtl::FindOrNull(allowed_optimizations, "main"); + ASSERT_NE(allowed_optimizations_main, nullptr); + EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites); + + auto allowed_optimizations_my_mul_1 = + gtl::FindOrNull(allowed_optimizations, "MyMul1"); + ASSERT_NE(allowed_optimizations_my_mul_1, nullptr); + EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites); + + auto allowed_optimizations_my_mul_2 = + gtl::FindOrNull(allowed_optimizations, "MyMul2"); + ASSERT_NE(allowed_optimizations_my_mul_2, nullptr); + EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites); +} + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 2190d38937..29a3b2b74c 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -25,23 +25,67 @@ limitations under the License. #include "tensorflow/core/grappler/utils/topological_sort.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/strings/str_util.h" namespace tensorflow { namespace grappler { + namespace internal { +namespace { // TODO(williamchan): Change this constant to be something smarter, maybe // dynamically determined. constexpr int64 kTensorMaxSize = 64; -// Find KernelDef for `node`. -Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { - // Try find KernelDef for node.device, else GPU or CPU. - for (const DeviceType& device : - {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) { - Status s = FindKernelDef(device, node, kdef, nullptr); +struct OpDevicePortHasher { + std::size_t operator()(const std::tuple<string, string, int>& x) const { + uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x))); + + return Hash64Combine(code, hash<int>()(std::get<2>(x))); + } +}; +using OpDevicePortOnHostMap = + gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>; + +// All the nodes that should be blacklisted and not swapped. +bool IsBlacklisted(const NodeDef& node) { + return + // Collective ops should not be swapped. + IsCollective(node) || + // ControlFlow ops should not be swapped. + IsControlFlow(node) || + // NoOp ops should not be swapped (due to group dependencies). + IsNoOp(node); +} + +// Check if Tensor is integer and small size. +bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { + // Check type to be int32 or int64. + if (prop.dtype() != DataType::DT_INT32 && + prop.dtype() != DataType::DT_INT64) { + return false; + } + + // Check size known and small. + const int64 size = NumCoefficients(prop.shape()); + if (size < 0 || size > kTensorMaxSize) { + return false; + } + + return true; +} + +// Find KernelDef for `node`, greedily return first found from `devices`. +Status TryFindKernelDef(const std::vector<DeviceType>& devices, + const NodeDef& node, const KernelDef** kdef) { + for (const DeviceType& device : devices) { + const KernelDef* kernel = nullptr; + Status s = FindKernelDef(device, node, &kernel, nullptr); if (s.ok()) { + if (kdef) { + *kdef = kernel; + } return Status::OK(); } } @@ -49,96 +93,239 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) { return errors::NotFound("Could not find KernelDef for op: ", node.op()); } -// Check if all node's inputs are pinned to CPU memory. -bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) { - // Loop through all the inputs excluding the controlling nodes. - for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) { - // Check if (the fanin) op's device is on CPU. - if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) { - continue; - } - - // Check if (the fanin) op's output port is pinned to HostMemory. - const OpDef* fanin_odef = nullptr; - Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef); - if (!s.ok()) { - LOG(INFO) << "Could not find OpDef for : " << fanin.node->op(); - return false; - } +// Checks if a node's output port is host friendly. +// Roughly this means checking if the output port is on Host memory. +Status IsNodeOutputPortHostFriendly( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { + *is_candidate = false; - const int output_arg_id = - OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id); - if (output_arg_id < 0) { - LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n" - << node.DebugString() << "\n" - << fanin.node->DebugString() << "\n" - << fanin_odef->DebugString(); - return false; - } + // Make sure we are not a blacklisted op. + if (IsBlacklisted(node)) { + return Status::OK(); + } - const KernelDef* fanin_kdef = nullptr; - s = TryFindKernelDef(*fanin.node, &fanin_kdef); - if (!s.ok()) { - LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op(); - return false; - } + // Check to make sure we have the right properties (i.e., statically shaped). + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + const auto& output_properties = properties->GetOutputProperties(node.name()); + if (port_id >= output_properties.size()) { + LOG(WARNING) << "port_id=" << port_id + << " but output_properties.size()=" << output_properties.size() + << "\n" + << node.DebugString(); + return Status::OK(); + } + if (!IsTensorIntegerAndSmall(output_properties[port_id])) { + return Status::OK(); + } - bool fanin_pinned = false; - for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) { - if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) { - fanin_pinned = true; - break; + // These nodes may be optimized away downstream (even if pinned to Host), we + // should (recusively) check their source. + if (IsIdentity(node)) { + for (const auto& fanin : graph.GetFanins(node, false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } + *is_candidate = true; + return Status::OK(); + } + + // Check if op's device is on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Check `op_device_outport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + const std::tuple<string, string, int> cache_key(node.op(), node.device(), + port_id); + auto it = op_device_outport_pinned_to_host_cache->find(cache_key); + if (it != op_device_outport_pinned_to_host_cache->end()) { + *is_candidate = it->second; + return Status::OK(); + } + + // Check if op's output port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); + return Status::OK(); + } + + // Map the port_id to output_arg_id. + const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id); + if (output_arg_id < 0) { + LOG(WARNING) << "Invalid port: " << port_id << "!\n" + << node.DebugString() << "\n" + << op->DebugString(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); + return Status::OK(); + } - if (!fanin_pinned) { - return false; + // Find the kernel. + const KernelDef* kernel = nullptr; + s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, + &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_outport_pinned_to_host_cache->emplace(cache_key, false); + return Status::OK(); + } + + // Check if the output_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->output_arg(output_arg_id).name() == host_memory_arg) { + *is_candidate = true; + break; } } - return true; + op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate); + + return Status::OK(); } -bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) { - // Check if Tensor is integer and small size. +// Checks if a node's input port is Host friendly. +// Roughly this means checking if the input port is on Host memory. +bool IsNodeInputPortHostFriendly( + const NodeDef& node, int port_id, + OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) { + // If node is on Host, assume its inputs are Host friendly. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + return true; + } - // Check type to be int32 or int64. - if (prop.dtype() != DataType::DT_INT32 && - prop.dtype() != DataType::DT_INT64) { - return false; + // Check `op_device_inport_pinned_to_host_cache` for our + // {op, device, port_id} combo to see if the arg is pinned on Host. + std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id); + auto it = op_device_inport_pinned_to_host_cache->find(cache_key); + if (it != op_device_inport_pinned_to_host_cache->end()) { + return it->second; } - // Check size known and small. - const int64 size = NumCoefficients(prop.shape()); - if (size < 0 || size > kTensorMaxSize) { + // Check if op's input port is pinned to HostMemory. + const OpDef* op = nullptr; + Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op); + if (!s.ok()) { + LOG(WARNING) << "Could not find OpDef for : " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); + return false; + } + const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id); + + // Find the kernel. + const KernelDef* kernel = nullptr; + s = internal::TryFindKernelDef( + {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel); + if (!s.ok()) { + LOG(INFO) << "Could not find KernelDef for: " << node.op(); + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); return false; } - return true; + // Check if the input_arg is pinned to Host. + for (const string& host_memory_arg : kernel->host_memory_arg()) { + if (op->input_arg(input_arg_id).name() == host_memory_arg) { + op_device_inport_pinned_to_host_cache->emplace(cache_key, true); + return true; + } + } + + op_device_inport_pinned_to_host_cache->emplace(cache_key, false); + + return false; } -bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties, - const NodeDef& node) { - for (const auto& prop : properties.GetInputProperties(node.name())) { +// Checks if a node is a candidate to pin to Host. +// The rough algorithm is as follows: +// 1] Check if node is blacklisted. +// 2] Check if node can run on Host. +// 3] Check all input/outputs are Host "friendly" (atm, friendly means small, +// ints, and pinned to Host). +Status IsNodeHostCandidate( + const GraphView& graph, GraphProperties* properties, const NodeDef& node, + OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache, + bool* is_candidate) { + *is_candidate = false; + + // Skip these node types. + if (IsBlacklisted(node)) { + return Status::OK(); + } + + // Check if node already on CPU. + if (str_util::StrContains(node.device(), DEVICE_CPU)) { + *is_candidate = true; + return Status::OK(); + } + + // Check the node can be run on CPU. + Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr); + if (!s.ok()) { + return Status::OK(); + } + + // Check all outputs are Host friendly. + if (!properties->has_properties()) { + // This is an expensive call, call it lazily. + TF_RETURN_IF_ERROR(properties->InferStatically( + /*assume_valid_feeds=*/false)); + } + for (const auto& prop : properties->GetOutputProperties(node.name())) { if (!IsTensorIntegerAndSmall(prop)) { - return false; + return Status::OK(); } } - for (const auto& prop : properties.GetOutputProperties(node.name())) { - if (!IsTensorIntegerAndSmall(prop)) { - return false; + // Check all inputs are Host friendly. + for (const GraphView::OutputPort& fanin : + graph.GetFanins(node, /*include_controlling_nodes=*/false)) { + bool fanin_candidate = false; + TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly( + graph, properties, *fanin.node, fanin.port_id, + op_device_outport_pinned_to_host_cache, &fanin_candidate)); + if (!fanin_candidate) { + return Status::OK(); } } - return true; + + *is_candidate = true; + return Status::OK(); } -string TryFindHostDevice(const gtl::FlatSet<string>& devices, - bool has_device_cpu, const string& device) { +bool IsTPUGraphDef(const GraphDef& def) { + for (const auto& node : def.node()) { + if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || + node.op() == "TPUPartitionedCall") { + return true; + } + } + return false; +} +} // end namespace + +// Tries to swap `device` to a Host device from `devices`. Returns true iff +// there was a swap. +bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, string* device) { // Force this node onto the CPU. - if (device.empty() && has_device_cpu) { - return "/device:CPU:0"; - } else if (str_util::StrContains(device, DEVICE_GPU)) { + if (device->empty() && has_device_cpu) { + *device = "/device:CPU:0"; + return true; + } else if (str_util::StrContains(*device, DEVICE_GPU)) { // Sometimes the cluster can have: // devices = {"/device:CPU:0", "/device:XLA_GPU:0"} // and we need to handle them properly. @@ -146,30 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices, {std::pair<string, string>("GPU", "CPU:0"), std::pair<string, string>("/device", "/device:CPU:0")}) { const string device_host = - strings::StrCat(device.substr(0, device.rfind(device_match.first)), + strings::StrCat(device->substr(0, device->rfind(device_match.first)), device_match.second); if (devices.find(device_host) != devices.end()) { - return device_host; + *device = device_host; + return true; } } } - // We couldn't find an appropriate Host device, return original device. - return device; -} - -bool IsTPUGraphDef(const GraphDef& def) { - for (const auto& node : def.node()) { - if (node.op() == "TPUCompile" || node.op() == "TPUExecute" || - node.op() == "TPUPartitionedCall") { - return true; - } - } + // We couldn't find an appropriate Host device, return false. return false; } -// All the nodes that should be blacklisted and not swapped. -bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); } } // end namespace internal Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, @@ -182,7 +358,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } GraphProperties properties(item); - bool has_properties = false; GraphView graph(optimized_graph); gtl::FlatSet<string> devices; @@ -202,45 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, // All the Const nodes, and their original devices in topological order. std::vector<std::pair<NodeDef*, string>> const_nodes; - for (auto& node : *optimized_graph->mutable_node()) { - // Check if node already on CPU. - if (str_util::StrContains(node.device(), DEVICE_CPU)) { - continue; - } - - // Skip these node types. - if (internal::IsBlacklisted(node)) { - continue; - } + // Cache to map {op, device, port} -> bool on whether it is pinned to host. + internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache; + internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache; - // Check the node can be run on CPU. - Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr); - if (!s.ok()) { - continue; - } - - // Check all input's are pinned to CPU. - if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) { - continue; - } - - if (!has_properties) { - // This is an expensive call, call it lazily. - TF_RETURN_IF_ERROR(properties.InferStatically(false)); - has_properties = true; - } - - // Check all inputs and outputs are integers and small. - if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) { + for (auto& node : *optimized_graph->mutable_node()) { + bool is_candidate = false; + TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate( + graph, &properties, node, &op_device_outport_pinned_to_host_cache, + &is_candidate)); + if (!is_candidate) { continue; } - if (IsConstant(node)) { - const_nodes.emplace_back(&node, node.device()); + const string original_device = node.device(); + const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu, + node.mutable_device()); + // Keep track of all Const nodes that we swapped. + if (swapped && IsConstant(node)) { + const_nodes.emplace_back(&node, original_device); } - // Try and swap the device to Host. - node.set_device( - internal::TryFindHostDevice(devices, has_device_cpu, node.device())); } // Traverse all `const_nodes`, and map them back to GPU greedily. @@ -248,10 +404,13 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, NodeDef* node = it.first; const string& device = it.second; - // Check all the consumers of this node, if any of them are on the original - // device, swap this node back onto the original device. + // Check all the consumers of this node, if any of them are not on CPU, swap + // this node back onto the original device. for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) { - if (fanout.node->device() == device) { + // The consumer is not Host friendly, swap it back to the original device. + if (!internal::IsNodeInputPortHostFriendly( + *fanout.node, fanout.port_id, + &op_device_inport_pinned_to_host_cache)) { node->set_device(device); break; } diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h index d557a03463..bed4a9ef95 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h @@ -26,8 +26,8 @@ namespace tensorflow { namespace grappler { namespace internal { // Try and find an appropriate Host device in `devices` given `device`. -string TryFindHostDevice(const gtl::FlatSet<string>& devices, - bool has_device_cpu, const string& device); +bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices, + bool has_device_cpu, string* device); } // end namespace internal // Optimize TensorFlow ops that should be swapped into the CPU to avoid diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc index 173cb3fe3c..9bb030b220 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc @@ -28,30 +28,60 @@ namespace { class PinToHostOptimizerTest : public GrapplerTest {}; -TEST_F(PinToHostOptimizerTest, TryFindHostDevice) { +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) { gtl::FlatSet<string> devices = {}; - EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC")); - - devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"), - "/device:CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"), - "/device:CPU:0"); - - devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_CPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_CPU:0"); - - devices = {"/device:XLA_GPU:0"}; - EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), ""); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"), - "/device:XLA_GPU:0"); - EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"), - "/device:XLA_GPU:*"); + + string device = "ABC"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "ABC"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device)); + EXPECT_EQ(device, "/device:CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_CPU:0"); +} + +TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) { + gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"}; + + string device = ""; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_TRUE(device.empty()); + + device = "/device:XLA_GPU:0"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:0"); + + device = "/device:XLA_GPU:*"; + EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device)); + EXPECT_EQ(device, "/device:XLA_GPU:*"); } TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) { @@ -160,6 +190,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) { EXPECT_EQ(found, 3); } +TEST_F(PinToHostOptimizerTest, Identity) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped. + // `b` should be placed onto Host since `c` pins the input to Host memory. + Output a = + ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64}); + Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2}); + Output c = + ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b); + Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c); + Output e = ops::Multiply(s.WithOpName("e"), d, d); + + GrapplerItem item; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + GraphDef output; + PinToHostOptimizer optimizer(RewriterConfig::ON); + TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output)); + + int found = 0; + for (const NodeDef& node : output.node()) { + if (node.name() == "a" || node.name() == "c") { + EXPECT_EQ(node.device(), "/device:GPU:0"); + } else if (node.name() == "b") { + // If CUDA, then there is a GPU kernel registration that is pinned to Host + // memory. Consequently, `b` will be mapped to Host correct if there is + // a GPU kernel registered. +#if GOOGLE_CUDA + EXPECT_EQ(node.device(), "/device:CPU:0"); +#else + EXPECT_TRUE(node.device().empty()); +#endif + } else if (node.name() == "d") { + EXPECT_EQ(node.device(), "/device:CPU:0"); + } else if (node.name() == "e") { + EXPECT_TRUE(node.device().empty()); + } + ++found; + } + EXPECT_EQ(found, 5); +} + TEST_F(PinToHostOptimizerTest, PortIdToArgId) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3}); diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 008a289cfd..9ada8b7ff9 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) { Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, GraphDef* optimized_graph) { GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + bool inferred_properties = false; GraphView graph(const_cast<GraphDef*>(&item.graph)); // During inference, most of the inputs to FusedBatchNorm are constant, and we // can therefore replace the op with a much cheaper set of primitives. + optimized_graph->mutable_node()->Reserve(item.graph.node_size()); for (const NodeDef& node : item.graph.node()) { if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") { bool optimizable = (node.attr().count("T") == 0 || @@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, !node.attr().at("is_training").b()); if (optimizable) { int const_inputs = 0; + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& props = properties.GetInputProperties(node.name()); for (const auto& prop : props) { if (prop.has_value()) { diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 4542d17ccc..6ccb1cd783 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, *optimized_graph = item.graph; GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + bool inferred_properties = false; GraphView graph(optimized_graph); // The product of all the dimensions in a tensor shape can be expressed more @@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } const GraphView::OutputPort reduce_indices = graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1)); + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& prop = properties.GetOutputProperties(reduce_indices.node->name()); if (prop.size() < reduce_indices.port_id) { @@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, if (!IsSize(*input1.node) || !IsSize(*input2.node)) { continue; } + if (!inferred_properties) { + // Infer properties lazily in case they are not needed. + TF_RETURN_IF_ERROR(properties.InferStatically(false)); + inferred_properties = true; + } const auto& prop1 = properties.GetInputProperties(input1.node->name()); const auto& prop2 = properties.GetInputProperties(input2.node->name()); if (prop1.size() != 1 || prop2.size() != 1) { diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc index db6e4e6852..5867d01324 100644 --- a/tensorflow/core/grappler/utils.cc +++ b/tensorflow/core/grappler/utils.cc @@ -156,45 +156,6 @@ bool IsControlInput(const string& name) { return !name.empty() && name[0] == '^'; } -string NodeName(const string& name) { - int position; - return ParseNodeName(name, &position); -} - -int NodePosition(const string& name) { - int position; - ParseNodeNameAsStringPiece(name, &position); - return position; -} - -int NodePositionIfSameNode(const string& input_name, const string& node_name) { - const bool is_ctrl = input_name[0] == '^'; - auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); - auto node_it = node_name.begin(); - if (node_name.empty() || - std::distance(input_it, input_name.end()) < node_name.size()) { - return -2; - } - while (node_it != node_name.end()) { - if (*input_it++ != *node_it++) { - return -2; - } - } - if (input_it == input_name.end()) { - return is_ctrl ? -1 : 0; - } else if (*input_it++ == ':') { - StringPiece remaining(&(*input_it), - std::distance(input_it, input_name.end())); - int position; - if (!strings::safe_strto32(remaining, &position)) { - return -2; - } - return is_ctrl ? -1 : position; - } else { - return -2; - } -} - string AddPrefixToNodeName(const string& name, const string& prefix, const string& delimiter) { if (!name.empty()) { diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h index 296ee1678e..95126d470c 100644 --- a/tensorflow/core/grappler/utils.h +++ b/tensorflow/core/grappler/utils.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/core/threadpool.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" -#include "tensorflow/core/lib/strings/scanner.h" namespace tensorflow { namespace grappler { @@ -102,40 +101,92 @@ bool IsControlInput(const string& name); // True iff 'name1' and 'name2' refer to the same input. bool IsSameInput(const string& name1, const string& name2); +// Returns the trailing position number (or zero if no number is present) if +// NodeName(input_name) is equal to node_name. Returns -1 for control inputs. +// Returns -2 if NodeName(input_name) is not equal to node_name. +// Note: This function is used very heavily, and this hand-optimized +// version is 3-4x faster than the version using Scanner, which it replaced. +// This is worth the reduction in readability. +inline int NodePositionIfSameNode(const string& input_name, + const string& node_name) { + if (input_name.empty()) return -2; + const bool is_ctrl = input_name[0] == '^'; + auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin(); + auto node_it = node_name.begin(); + if (node_name.empty() || + std::distance(input_it, input_name.end()) < node_name.size()) { + return -2; + } + while (node_it != node_name.end()) { + if (*input_it++ != *node_it++) { + return -2; + } + } + if (input_it == input_name.end()) { + return is_ctrl ? -1 : 0; + } else if (*input_it++ == ':') { + StringPiece remaining(&(*input_it), + std::distance(input_it, input_name.end())); + int position; + if (!strings::safe_strto32(remaining, &position)) { + return -2; + } + return is_ctrl ? -1 : position; + } else { + return -2; + } +} + // Return the node name corresponding to 'name' if name is valid, or the empty // string otherwise. -string NodeName(const string& name); +inline StringPiece NodeNameAsStringPiece(const string& name) { + static const string empty; + if (name.empty()) return StringPiece(empty); + const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin(); + auto end_it = begin_it; + while (end_it != name.end() && *end_it != ':') { + ++end_it; + } + if (end_it != name.end() && *end_it != ':') { + return StringPiece(empty); + } + return StringPiece(&(*begin_it), std::distance(begin_it, end_it)); +} -// Get the trailing position number ":{digits}" (if any) of a node name. -// Returns -1 for control inputs. -int NodePosition(const string& name); +// Return the node name corresponding to 'name' if name is valid, or the empty +// string otherwise. +inline string NodeName(const string& name) { + return string(NodeNameAsStringPiece(name)); +} +// Returns the node name and position in a single call. inline StringPiece ParseNodeNameAsStringPiece(const string& name, int* position) { - // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any) - // to get a node name. - strings::Scanner scan(name); - scan.ZeroOrOneLiteral("^") - .RestartCapture() - .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE) - .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE); - StringPiece capture; - StringPiece remaining; - if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) { + static const string empty; + if (name.empty()) { *position = 0; - static const string empty; return StringPiece(empty); - } else { - if (name[0] == '^') { - *position = -1; - } else if (remaining.empty()) { - *position = 0; - } else { - // Skip the first ':' character. - CHECK(strings::safe_strto32(remaining.substr(1), position)); + } + const bool is_ctrl = name[0] == '^'; + const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin(); + *position = is_ctrl ? -1 : 0; + auto end_it = begin_it; + while (end_it != name.end() && *end_it != ':') { + ++end_it; + } + const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it)); + if (end_it != name.end()) { + if (*end_it != ':') { + return StringPiece(empty); + } else if (!is_ctrl) { + ++end_it; + StringPiece remaining(&(*end_it), std::distance(end_it, name.end())); + if (!strings::safe_strto32(remaining, position)) { + return StringPiece(empty); + } } - return capture; } + return node_name; } // Returns the node name and position in a single call. @@ -143,10 +194,11 @@ inline string ParseNodeName(const string& name, int* position) { return string(ParseNodeNameAsStringPiece(name, position)); } -// Returns NodePosition(input_name) if NodeName(input_name) == node_name. -// Otherwise returns -2; -// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0. -int NodePositionIfSameNode(const string& input_name, const string& node_name); +inline int NodePosition(const string& name) { + int position; + ParseNodeNameAsStringPiece(name, &position); + return position; +} // Add a prefix to a node name with a custom delimiter. string AddPrefixToNodeName(const string& name, const string& prefix, diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc index a428aea7f5..6861fb423c 100644 --- a/tensorflow/core/grappler/utils/functions.cc +++ b/tensorflow/core/grappler/utils/functions.cc @@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration, tensorflow::NameRangeMap outputs_range_map; TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode( node, registration.op_def, nullptr, &outputs_range_map)); - connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map); + connectivity->RegisterFunctionBodyOutputs(node.name(), + std::move(outputs_range_map)); return Status::OK(); } @@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders( } // namespace void GrapplerFunctionConnectivity::RegisterInputArgExpansion( - const InputArgExpansion& input_arg_expansion) { - const auto& input_name = input_arg_expansion.input_name; + InputArgExpansion input_arg_expansion) { + string input_name = input_arg_expansion.input_name; const auto& placeholders = input_arg_expansion.placeholders; - input_arg_expansions_.emplace(input_name, input_arg_expansion); + for (int i = 0; i < placeholders.size(); ++i) { const string& placeholder = input_arg_expansion.placeholders[i]; - input_arg_placeholders_.emplace( - placeholder, InputArgPlaceholder{input_name, /*position=*/i}); + input_arg_placeholders_.insert( + {placeholder, InputArgPlaceholder{input_name, /*position=*/i}}); } + input_arg_expansions_.insert( + {std::move(input_name), std::move(input_arg_expansion)}); } void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs( - const string& node_name, const tensorflow::NameRangeMap& outputs) { - function_body_outputs_[node_name] = outputs; + const string& node_name, tensorflow::NameRangeMap&& outputs) { + function_body_outputs_[node_name] = std::move(outputs); } Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( @@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( const auto& output_range = output->second; if (position == -1) { + graph_def_inputs->reserve(graph_def_inputs->size() + + output_range.second - output_range.first); // If position is not defined expand node output range for (int i = output_range.first; i < output_range.second; ++i) { - i == 0 ? graph_def_inputs->push_back(node_name) - : graph_def_inputs->push_back( - strings::StrCat(node_name, ":", i)); + graph_def_inputs->push_back( + i == 0 ? node_name : strings::StrCat(node_name, ":", i)); } } else { if (position > (output_range.second - output_range.first)) { @@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput( " position: ", position, " (out of range)"); } int pos = output_range.first + position; - pos == 0 ? graph_def_inputs->push_back(node_name) - : graph_def_inputs->push_back( - strings::StrCat(node_name, ":", pos)); + graph_def_inputs->push_back( + pos == 0 ? node_name : strings::StrCat(node_name, ":", pos)); } return Status::OK(); @@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs( } function_body_node->clear_input(); - for (const string& expanded_input : expanded_inputs) - function_body_node->add_input(expanded_input); + for (string& expanded_input : expanded_inputs) + function_body_node->add_input(std::move(expanded_input)); return Status::OK(); } @@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem( // Fill the feed nodes with input placeholders. for (const InputArgExpansion& input_arg : input_arg_expansions_) { for (const string& placeholder : input_arg.placeholders) { - feed.emplace_back(placeholder, Tensor()); + feed.push_back({placeholder, Tensor()}); input_arg_placeholders_.insert(placeholder); } } @@ -460,7 +463,7 @@ Status InstantiationBodyParameters( auto it = func_instantiation_attr.find(placeholder); if (it != func_instantiation_attr.end()) { - body_parameters->emplace(placeholder, it->second); + body_parameters->insert({placeholder, it->second}); } else { return errors::InvalidArgument("Can't resolve placeholder: ", placeholder); @@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, // GraphDef input format (name[:position]) GrapplerFunctionConnectivity connectivity; - std::vector<InputArgExpansion> inputs; - std::vector<OutputArgExpansion> outputs; - std::vector<string> keep_nodes; - // Function body shares the library with the graph that instantiated it. GraphDef function_body; *function_body.mutable_library() = flib.ToProto(); @@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, } } + std::vector<InputArgExpansion> inputs; + inputs.reserve(signature.input_arg_size()); + // For each input argument create a placeholder in function body. for (const OpDef::ArgDef& input : signature.input_arg()) { if (!input.type_list_attr().empty() || !input.number_attr().empty()) { @@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, /*is_ref*/ input.is_ref(), /*placeholders=*/{input.name()}}; connectivity.RegisterInputArgExpansion(input_expansion); - inputs.push_back(input_expansion); + inputs.push_back(std::move(input_expansion)); } + std::vector<string> keep_nodes; // Add all function nodes to the function body for (const NodeDef& func_def_node : func.node_def()) { NodeDef* new_node = function_body.add_node(); @@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node)); } + std::vector<OutputArgExpansion> outputs; + outputs.reserve(signature.output_arg_size()); // Add function outputs for (const OpDef::ArgDef& out : signature.output_arg()) { std::vector<string> output_tensors; @@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func, OutputArgExpansion output{/*output_name=*/out.name(), /*data_type=*/output_data_type, /*is_ref=*/out.is_ref(), - /*output_tensors=*/output_tensors}; - outputs.push_back(output); + /*output_tensors=*/std::move(output_tensors)}; + outputs.push_back(std::move(output)); } bool is_stateful = signature.is_stateful(); diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h index 733caf325f..ef944ced09 100644 --- a/tensorflow/core/grappler/utils/functions.h +++ b/tensorflow/core/grappler/utils/functions.h @@ -18,6 +18,7 @@ limitations under the License. #include <memory> #include <string> +#include <unordered_map> #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/function.pb.h" @@ -70,9 +71,9 @@ struct OutputArgExpansion { // and fold it back when doing backward conversion. class GrapplerFunctionConnectivity { public: - void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion); + void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion); void RegisterFunctionBodyOutputs(const string& node_name, - const tensorflow::NameRangeMap& outputs); + tensorflow::NameRangeMap&& outputs); // Expand input encoded in FunctionDef format (name[:output][:position]) into // multiple inputs in GraphDef format (name[:position]). diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc index 6b787a6910..9b6c1f690b 100644 --- a/tensorflow/core/grappler/utils_test.cc +++ b/tensorflow/core/grappler/utils_test.cc @@ -371,6 +371,25 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl); BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0); BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end); +#define BM_ParseNodeNameAsStringPiece(I, NAME) \ + static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \ + string input = I; \ + for (int i = 0; i < iters; ++i) { \ + int position; \ + const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \ + CHECK_GE(position, -1); \ + CHECK(!name.empty()); \ + } \ + } \ + BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME) + +BM_ParseNodeNameAsStringPiece("foo", foo); +BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz); +BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl); +BM_ParseNodeNameAsStringPiece("foo:123", foo123); +BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123); +BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl); + } // namespace } // namespace grappler } // namespace tensorflow diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 1a3db2c7cd..3a920f26f3 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -1197,8 +1197,10 @@ tf_cc_test( tf_cc_test( name = "example_parsing_ops_test", - size = "large", + size = "medium", srcs = ["example_parsing_ops_test.cc"], + shard_count = 4, + tags = ["optonly"], deps = [ ":example_parsing_ops", ":ops_testutil", @@ -2028,8 +2030,8 @@ tf_kernel_library( ":variable_ops", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:resource_variable_ops_op_lib", - "//third_party/eigen3", ], ) @@ -4049,11 +4051,6 @@ cc_library( ) SPARSE_DEPS = [ - ":bounds_check", - ":cwise_op", - ":fill_functor", - ":scatter_functor", - "//third_party/eigen3", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:sparse_ops_op_lib", @@ -4086,7 +4083,9 @@ tf_kernel_library( tf_kernel_library( name = "sparse_cross_op", prefix = "sparse_cross_op", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + "//third_party/eigen3", + ], ) tf_kernel_library( @@ -4098,13 +4097,19 @@ tf_kernel_library( tf_kernel_library( name = "sparse_dense_binary_op_shared", prefix = "sparse_dense_binary_op_shared", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + ":cwise_op", + "//third_party/eigen3", + ], ) tf_kernel_library( name = "sparse_sparse_binary_op_shared", prefix = "sparse_sparse_binary_op_shared", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + ":cwise_op", + "//third_party/eigen3", + ], ) tf_kernel_library( @@ -4136,7 +4141,9 @@ tf_kernel_library( tf_kernel_library( name = "sparse_softmax", prefix = "sparse_softmax", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + "//third_party/eigen3", + ], ) tf_kernel_library( @@ -4148,25 +4155,37 @@ tf_kernel_library( tf_kernel_library( name = "sparse_tensor_dense_add_op", prefix = "sparse_tensor_dense_add_op", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + ":scatter_functor", + "//third_party/eigen3", + ], ) tf_kernel_library( name = "sparse_tensor_dense_matmul_op", prefix = "sparse_tensor_dense_matmul_op", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + ":bounds_check", + ":fill_functor", + "//third_party/eigen3", + ], ) tf_kernel_library( name = "sparse_to_dense_op", prefix = "sparse_to_dense_op", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + "//third_party/eigen3", + ], ) tf_kernel_library( name = "sparse_xent_op", prefix = "sparse_xent_op", - deps = SPARSE_DEPS, + deps = SPARSE_DEPS + [ + ":bounds_check", + "//third_party/eigen3", + ], ) tf_kernel_library( @@ -4431,6 +4450,7 @@ cc_library( ":string_strip_op", ":string_to_hash_bucket_op", ":substr_op", + ":unicode_script_op", ], ) @@ -4438,7 +4458,12 @@ cc_library( name = "string_util", srcs = ["string_util.cc"], hdrs = ["string_util.h"], - deps = ["//tensorflow/core:lib"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@icu//:common", + ], ) STRING_DEPS = [ @@ -5254,6 +5279,8 @@ filegroup( "cwise_op_squared_difference.cc", "cwise_op_sub.cc", "cwise_op_tanh.cc", + "cwise_op_xlogy.cc", + "cwise_op_xdivy.cc", "data_format_ops.cc", "decode_wav_op.cc", "deep_conv2d.cc", @@ -5469,6 +5496,7 @@ filegroup( "batch_kernels.*", "regex_full_match_op.cc", "regex_replace_op.cc", + "unicode_script_op.cc", # Ops that are inherently incompatible with Android (e.g. tied to x86 platform). "mkl_*", "xsmm_*", @@ -6414,6 +6442,12 @@ tf_mkl_kernel_library( ) tf_mkl_kernel_library( + name = "mkl_slice_op", + prefix = "mkl_slice_op", + deps = ARRAY_DEPS + mkl_deps(), +) + +tf_mkl_kernel_library( name = "mkl_identity_op", prefix = "mkl_identity_op", deps = ARRAY_DEPS + mkl_deps(), @@ -6557,6 +6591,16 @@ tf_kernel_library( ], ) +tf_kernel_library( + name = "unicode_script_op", + srcs = ["unicode_script_op.cc"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:string_ops_op_lib", + "@icu//:common", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc index 54c45bfe63..f48bd0c318 100644 --- a/tensorflow/core/kernels/batch_matmul_op_complex.cc +++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc @@ -17,14 +17,18 @@ limitations under the License. namespace tensorflow { -#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) +// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc +// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL). +// Anything else (the complement) should register the TF ones. +// (MKL-DNN doesn't implement these kernels either.) +#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL) TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU); -#endif +#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL #if GOOGLE_CUDA TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU); TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU); -#endif +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc index 584b507c70..25ae795d8e 100644 --- a/tensorflow/core/kernels/batch_matmul_op_real.cc +++ b/tensorflow/core/kernels/batch_matmul_op_real.cc @@ -21,10 +21,15 @@ limitations under the License. namespace tensorflow { -#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) +// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc +// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL). +// Anything else (the complement) should register the TF ones. +// (MKL-DNN doesn't implement these kernels either.) +#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL) TF_CALL_float(REGISTER_BATCH_MATMUL_CPU); TF_CALL_double(REGISTER_BATCH_MATMUL_CPU); -#endif +#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL + TF_CALL_half(REGISTER_BATCH_MATMUL_CPU); TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU); diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD index 792eb74e31..0d53240330 100644 --- a/tensorflow/core/kernels/batching_util/BUILD +++ b/tensorflow/core/kernels/batching_util/BUILD @@ -1,7 +1,7 @@ # Description: Utilities. package( - default_visibility = ["//tensorflow:internal"], + default_visibility = ["//visibility:public"], ) licenses(["notice"]) # Apache 2.0 @@ -12,7 +12,6 @@ cc_library( name = "periodic_function_dynamic", srcs = ["periodic_function.cc"], hdrs = ["periodic_function.h"], - visibility = ["//visibility:public"], deps = [ "//tensorflow/core:framework_headers_lib", "//tensorflow/core:protos_all_cc", @@ -21,7 +20,6 @@ cc_library( cc_library( name = "periodic_function", - visibility = ["//visibility:public"], deps = [ ":periodic_function_dynamic", "//tensorflow/core:lib", @@ -190,7 +188,6 @@ cc_library( testonly = 1, srcs = ["fake_clock_env.cc"], hdrs = ["fake_clock_env.h"], - visibility = ["//visibility:public"], deps = [ "//tensorflow/core:lib", "//tensorflow/core:tensorflow", diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc index e0da91125b..82e2913b64 100644 --- a/tensorflow/core/kernels/collective_ops.cc +++ b/tensorflow/core/kernels/collective_ops.cc @@ -143,6 +143,7 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel { c->forward_input_or_allocate_output( {0}, 0, c->input(0).shape(), &output), done); + col_params_.instance.shape = c->input(0).shape(); } if (!CanProceedWithCompute(c, col_exec, done)) return; auto actual_done = [c, col_exec, done](const Status& s) { @@ -171,7 +172,7 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel { OP_REQUIRES_OK( c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); - OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape)); col_params_.is_source = true; col_params_.instance.impl_details.subdiv_offsets = {0}; @@ -195,13 +196,14 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel { if (c->mutable_output(0) == nullptr) { // Allocate the output tensor, trying to reuse the input. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC( - c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), - done); + OP_REQUIRES_OK_ASYNC(c, + c->forward_input_or_allocate_output( + {0}, 0, col_params_.instance.shape, &output), + done); } if (!CanProceedWithCompute(c, col_exec, done)) return; OP_REQUIRES_ASYNC( - c, shape_.IsSameSize(c->input(0).shape()), + c, col_params_.instance.shape.IsSameSize(c->input(0).shape()), errors::Internal("Declared shape of op ", col_params_.name, " does not match shape of input"), done); @@ -214,8 +216,6 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel { } private: - TensorShape shape_; - TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel); }; @@ -234,7 +234,7 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { OP_REQUIRES_OK( c, c->GetAttr("instance_key", &col_params_.instance.instance_key)); OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type)); - OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape)); col_params_.is_source = false; col_params_.instance.impl_details.subdiv_offsets = {0}; @@ -258,7 +258,8 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { if (c->mutable_output(0) == nullptr) { // No input, so must allocate output. Tensor* output = nullptr; - OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done); + OP_REQUIRES_OK_ASYNC( + c, c->allocate_output(0, col_params_.instance.shape, &output), done); } if (!CanProceedWithCompute(c, col_exec, done)) return; @@ -270,8 +271,6 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel { } private: - TensorShape shape_; - TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel); }; diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc index 717a9f40a9..78856c4a99 100644 --- a/tensorflow/core/kernels/conv_ops.cc +++ b/tensorflow/core/kernels/conv_ops.cc @@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> { }; #endif +#define TF_REQUIRES(EXP, STATUS) \ + do { \ + if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \ + } while (false) + +Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params) { + TF_RETURN_IF_ERROR(context->GetAttr("dilations", ¶ms->dilations)); + TF_RETURN_IF_ERROR(context->GetAttr("strides", ¶ms->strides)); + TF_RETURN_IF_ERROR(context->GetAttr("padding", ¶ms->padding)); + string data_format_string; + TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string)); + TF_REQUIRES(FormatFromString(data_format_string, ¶ms->data_format), + errors::InvalidArgument("Invalid data format")); + + const auto& strides = params->strides; + const auto& dilations = params->dilations; + const auto& data_format = params->data_format; + + TF_REQUIRES(dilations.size() == 4, + errors::InvalidArgument("Sliding window dilations field must " + "specify 4 dimensions")); + TF_REQUIRES(strides.size() == 4, + errors::InvalidArgument("Sliding window strides field must " + "specify 4 dimensions")); + const int64 stride_n = GetTensorDim(strides, data_format, 'N'); + const int64 stride_c = GetTensorDim(strides, data_format, 'C'); + const int64 stride_h = GetTensorDim(strides, data_format, 'H'); + const int64 stride_w = GetTensorDim(strides, data_format, 'W'); + TF_REQUIRES( + stride_n == 1 && stride_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + TF_REQUIRES(stride_h > 0 && stride_w > 0, + errors::InvalidArgument( + "Row and column strides should be larger than 0.")); + + const int64 dilation_n = GetTensorDim(dilations, data_format, 'N'); + const int64 dilation_c = GetTensorDim(dilations, data_format, 'C'); + const int64 dilation_h = GetTensorDim(dilations, data_format, 'H'); + const int64 dilation_w = GetTensorDim(dilations, data_format, 'W'); + TF_REQUIRES( + dilation_n == 1 && dilation_c == 1, + errors::InvalidArgument("Current implementation does not yet support " + "dilations in the batch and depth dimensions.")); + TF_REQUIRES( + dilation_h > 0 && dilation_w > 0, + errors::InvalidArgument("Dilated rates should be larger than 0.")); + + return Status::OK(); +} + +Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions) { + // Check that 2D convolution input and filter have exactly 4 dimensions. + TF_REQUIRES(input.dims() == 4, + errors::InvalidArgument("input must be 4-dimensional", + input.shape().DebugString())); + TF_REQUIRES(filter.dims() == 4, + errors::InvalidArgument("filter must be 4-dimensional: ", + filter.shape().DebugString())); + for (int i = 0; i < 3; i++) { + TF_REQUIRES( + FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), + errors::InvalidArgument("filter too large")); + } + + // The last dimension for input is in_depth. Check that it is the same as the + // filter's in_depth or it is evenly divisible by filter's in_depth. + const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C'); + const int64 patch_depth_raw = filter.dim_size(2); + TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input depth too large")); + TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Patch depth too large")); + const int in_depth = static_cast<int>(in_depth_raw); + const int patch_depth = static_cast<int>(patch_depth_raw); + TF_REQUIRES(in_depth % patch_depth == 0, + errors::InvalidArgument( + "input depth must be evenly divisible by filter depth: ", + in_depth, " vs ", patch_depth)); + + // The last dimension for filter is out_depth. + const int out_depth = static_cast<int>(filter.dim_size(3)); + + // The second dimension for input is rows/height. + // The first dimension for filter is rows/height. + const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H'); + TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input rows too large")); + const int input_rows = static_cast<int>(input_rows_raw); + const int filter_rows = static_cast<int>(filter.dim_size(0)); + + // The third dimension for input is columns/width. + // The second dimension for filter is columns/width. + const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W'); + TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("Input cols too large")); + const int input_cols = static_cast<int>(input_cols_raw); + const int filter_cols = static_cast<int>(filter.dim_size(1)); + + // The first dimension for input is batch. + const int64 batch_raw = GetTensorDim(input, params.data_format, 'N'); + TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), + errors::InvalidArgument("batch is too large")); + const int batch = static_cast<int>(batch_raw); + + // Take the stride and dilation from the second and third dimensions only (we + // do not support striding or dilation on the batch or depth dimension). + const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H'); + const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W'); + const int dilation_rows = + GetTensorDim(params.dilations, params.data_format, 'H'); + const int dilation_cols = + GetTensorDim(params.dilations, params.data_format, 'W'); + + // Compute windowed output sizes for rows and columns. + int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input_rows, filter_rows, dilation_rows, stride_rows, params.padding, + &out_rows, &pad_rows)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2( + input_cols, filter_cols, dilation_cols, stride_cols, params.padding, + &out_cols, &pad_cols)); + + dimensions->batch = batch; + dimensions->input_rows = input_rows; + dimensions->input_cols = input_cols; + dimensions->in_depth = in_depth; + dimensions->filter_rows = filter_rows; + dimensions->filter_cols = filter_cols; + dimensions->patch_depth = patch_depth; + dimensions->out_depth = out_depth; + dimensions->stride_rows = stride_rows; + dimensions->stride_cols = stride_cols; + dimensions->dilation_rows = dilation_rows; + dimensions->dilation_cols = dilation_cols; + dimensions->out_rows = out_rows; + dimensions->out_cols = out_cols; + dimensions->pad_rows = pad_rows; + dimensions->pad_cols = pad_cols; + + return Status::OK(); +} + +#undef TF_REQUIRES + template <typename Device, typename T> class Conv2DOp : public BinaryOp<T> { public: explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) { - OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_)); - OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_)); - string data_format; - OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); - OP_REQUIRES(context, FormatFromString(data_format, &data_format_), - errors::InvalidArgument("Invalid data format")); + OP_REQUIRES_OK(context, InitConv2DParameters(context, ¶ms_)); + OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_)); use_cudnn_ &= CanUseCudnn(); cudnn_use_autotune_ = CudnnUseAutotune(); - OP_REQUIRES(context, dilations_.size() == 4, - errors::InvalidArgument("Sliding window dilations field must " - "specify 4 dimensions")); - OP_REQUIRES(context, strides_.size() == 4, - errors::InvalidArgument("Sliding window strides field must " - "specify 4 dimensions")); - const int64 stride_n = GetTensorDim(strides_, data_format_, 'N'); - const int64 stride_c = GetTensorDim(strides_, data_format_, 'C'); - const int64 stride_h = GetTensorDim(strides_, data_format_, 'H'); - const int64 stride_w = GetTensorDim(strides_, data_format_, 'W'); - OP_REQUIRES( - context, stride_n == 1 && stride_c == 1, - errors::InvalidArgument("Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES(context, stride_h > 0 && stride_w > 0, - errors::InvalidArgument( - "Row and column strides should be larger than 0.")); - - const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N'); - const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C'); - const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H'); - const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W'); - OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1, - errors::InvalidArgument( - "Current implementation does not yet support " - "dilations in the batch and depth dimensions.")); - OP_REQUIRES( - context, dilation_h > 0 && dilation_w > 0, - errors::InvalidArgument("Dilated rates should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); } void Compute(OpKernelContext* context) override { // Input tensor is of the following dimensions: // [ batch, in_rows, in_cols, in_depth ] - const Tensor& input = context->input(0); // Input filter is of the following dimensions: // [ filter_rows, filter_cols, in_depth, out_depth] const Tensor& filter = context->input(1); - // For 2D convolution, there should be 4 dimensions. - OP_REQUIRES(context, input.dims() == 4, - errors::InvalidArgument("input must be 4-dimensional", - input.shape().DebugString())); - OP_REQUIRES(context, filter.dims() == 4, - errors::InvalidArgument("filter must be 4-dimensional: ", - filter.shape().DebugString())); - - for (int i = 0; i < 3; i++) { - OP_REQUIRES( - context, - FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()), - errors::InvalidArgument("filter too large")); - } + Conv2DDimensions dimensions; + OP_REQUIRES_OK(context, + ComputeConv2DDimension(params_, input, filter, &dimensions)); - // The last dimension for input is in_depth. It must be the same as the - // filter's in_depth or be evenly divisible by filter's in_depth. - const int64 in_depth = GetTensorDim(input, data_format_, 'C'); - const int64 patch_depth = filter.dim_size(2); - OP_REQUIRES(context, in_depth % patch_depth == 0, - errors::InvalidArgument( - "input depth must be evenly divisible by filter depth: ", - in_depth, " vs ", patch_depth)); - - // The last dimension for filter is out_depth. - const int out_depth = static_cast<int>(filter.dim_size(3)); - - // The second dimension for input is rows/height. - // The first dimension for filter is rows/height. - const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H'); - OP_REQUIRES( - context, - FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input rows too large")); - const int input_rows = static_cast<int>(input_rows_raw); - const int filter_rows = static_cast<int>(filter.dim_size(0)); - - // The third dimension for input is columns/width. - // The second dimension for filter is columns/width. - const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W'); - OP_REQUIRES( - context, - FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("Input cols too large")); - const int input_cols = static_cast<int>(input_cols_raw); - const int filter_cols = static_cast<int>(filter.dim_size(1)); - - // The first dimension for input is batch. - const int64 batch_raw = GetTensorDim(input, data_format_, 'N'); - OP_REQUIRES(context, - FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()), - errors::InvalidArgument("batch is too large")); - const int batch = static_cast<int>(batch_raw); - - // For now we take the stride and dilation from the second and third - // dimensions only (we do not support striding or dilation on the batch or - // depth dimension). - const int stride_rows = GetTensorDim(strides_, data_format_, 'H'); - const int stride_cols = GetTensorDim(strides_, data_format_, 'W'); - - const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H'); - const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W'); - - int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0; - OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( - input_rows, filter_rows, dilation_rows, - stride_rows, padding_, &out_rows, &pad_rows)); - OP_REQUIRES_OK(context, GetWindowedOutputSizeV2( - input_cols, filter_cols, dilation_cols, - stride_cols, padding_, &out_cols, &pad_cols)); - TensorShape out_shape = - ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth); + TensorShape out_shape = ShapeFromFormat( + params_.data_format, dimensions.batch, dimensions.out_rows, + dimensions.out_cols, dimensions.out_depth); // Output tensor is of the following dimensions: // [ in_batch, out_rows, out_cols, out_depth ] Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output)); - VLOG(2) << "Conv2D: in_depth = " << in_depth - << ", patch_depth = " << patch_depth - << ", input_cols = " << input_cols - << ", filter_cols = " << filter_cols - << ", input_rows = " << input_rows - << ", filter_rows = " << filter_rows - << ", stride_rows = " << stride_rows - << ", stride_cols = " << stride_cols - << ", dilation_rows = " << dilation_rows - << ", dilation_cols = " << dilation_cols - << ", out_depth = " << out_depth; + VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth + << ", patch_depth = " << dimensions.patch_depth + << ", input_cols = " << dimensions.input_cols + << ", filter_cols = " << dimensions.filter_cols + << ", input_rows = " << dimensions.input_rows + << ", filter_rows = " << dimensions.filter_rows + << ", stride_rows = " << dimensions.stride_rows + << ", stride_cols = " << dimensions.stride_cols + << ", dilation_rows = " << dimensions.dilation_rows + << ", dilation_cols = " << dimensions.dilation_cols + << ", out_depth = " << dimensions.out_depth; // If there is nothing to compute, return. if (out_shape.num_elements() == 0) { @@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> { #ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS if (LaunchXsmmConvOp<Device, T>::Run( - context, input, filter, batch, input_rows, input_cols, in_depth, - filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, - output, data_format_)) { + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols, + dimensions.out_rows, dimensions.out_cols, dimensions.out_depth, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, output, + params_.data_format)) { return; } #endif if (LaunchDeepConvOp<Device, T>::Run( - context, input, filter, batch, input_rows, input_cols, in_depth, - filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols, - out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols, - output, data_format_)) { + context, input, filter, dimensions.batch, dimensions.input_rows, + dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows, + dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols, + dimensions.out_rows, dimensions.out_cols, dimensions.out_depth, + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, output, + params_.data_format)) { return; } launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter, - dilation_rows, dilation_cols, stride_rows, stride_cols, padding_, - output, data_format_); + dimensions.dilation_rows, dimensions.dilation_cols, + dimensions.stride_rows, dimensions.stride_cols, params_.padding, + output, params_.data_format); } private: - std::vector<int32> dilations_; - std::vector<int32> strides_; + Conv2DParameters params_; bool use_cudnn_; - Padding padding_; - TensorFormat data_format_; - LaunchConv2DOp<Device, T> launcher_; bool cudnn_use_autotune_; + LaunchConv2DOp<Device, T> launcher_; + TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp); }; diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h index adf4601b43..7ec878e0b2 100644 --- a/tensorflow/core/kernels/conv_ops.h +++ b/tensorflow/core/kernels/conv_ops.h @@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase { string DebugString() { return "Im2ColBufferResource"; } }; +// Convolution parameters specified by Op attributes. +struct Conv2DParameters { + std::vector<int32> dilations; + std::vector<int32> strides; + Padding padding; + TensorFormat data_format; +}; + +// Convolution dimensions inferred from parameters, input and filter tensors. +struct Conv2DDimensions { + int batch; + int input_rows; + int input_cols; + int in_depth; + + int filter_rows; + int filter_cols; + int patch_depth; + int out_depth; + + int stride_rows; + int stride_cols; + + int dilation_rows; + int dilation_cols; + + int64 out_rows; + int64 out_cols; + int64 pad_rows; + int64 pad_cols; +}; + +// Initializes and validates Conv2D parameters configured by OpKernel +// attributes. +Status InitConv2DParameters(const OpKernelConstruction* context, + Conv2DParameters* params); + +// Computes and validates convolutions dimensions from Conv2D parameters. If +// parameters are valid, dimensions will be updated with derived convolution +// dimensions, otherwise error will be returned. +Status ComputeConv2DDimension(const Conv2DParameters& params, + const Tensor& input, const Tensor& filter, + Conv2DDimensions* dimensions); + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_ diff --git a/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc new file mode 100644 index 0000000000..e4b21a66c6 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc new file mode 100644 index 0000000000..1e1b5a426e --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc @@ -0,0 +1,26 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h" + +namespace tensorflow { +namespace functor { +DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128); +} // namespace functor +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc new file mode 100644 index 0000000000..6a6aec5e86 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_xdivy.cc @@ -0,0 +1,38 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double, + complex64, complex128); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \ + BinaryOp<SYCLDevice, functor::xdivy<TYPE>>); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double, + complex64, complex128); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc new file mode 100644 index 0000000000..e71a9109b2 --- /dev/null +++ b/tensorflow/core/kernels/cwise_op_xlogy.cc @@ -0,0 +1,41 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/cwise_ops_common.h" + +namespace tensorflow { +REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double, + complex64, complex128); + +#if TENSORFLOW_USE_SYCL +#define REGISTER_SYCL_KERNEL(TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \ + BinaryOp<SYCLDevice, functor::xlogy<TYPE>>); +REGISTER_SYCL_KERNEL(Eigen::half); +REGISTER_SYCL_KERNEL(float); +REGISTER_SYCL_KERNEL(double); +REGISTER_SYCL_KERNEL(complex64); +REGISTER_SYCL_KERNEL(complex128); +#undef REGISTER_SYCL_KERNEL + +#endif // TENSORFLOW_USE_SYCL + +#if GOOGLE_CUDA +REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double, + complex64, complex128); +#endif // GOOGLE_CUDA + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h index 22eb66e979..66ba827a90 100644 --- a/tensorflow/core/kernels/cwise_ops.h +++ b/tensorflow/core/kernels/cwise_ops.h @@ -471,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> { enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; }; +// TODO(srvasude): Add packet versions of this operation. +template <typename Scalar> +struct xlogy_op { + EIGEN_EMPTY_STRUCT_CTOR(xlogy_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x * numext::log(y); + } +}; + +template <typename Scalar> +struct functor_traits<xlogy_op<Scalar>> { + enum { + Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost, + PacketAccess = false + }; +}; + +template <typename Scalar> +// TODO(srvasude): Add packet versions of this operation. +struct xdivy_op { + EIGEN_EMPTY_STRUCT_CTOR(xdivy_op) + EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar + operator()(const Scalar& x, const Scalar& y) const { + if (x == Scalar(0.)) { + return Scalar(0.); + } + return x / y; + } +}; + +template <typename Scalar> +struct functor_traits<xdivy_op<Scalar>> { + enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false }; +}; + } // end namespace internal } // end namespace Eigen @@ -830,6 +869,12 @@ struct squared_difference Eigen::internal::scalar_difference_op<T>>> {}; template <typename T> +struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {}; + +template <typename T> +struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {}; + +template <typename T> struct less : base<T, Eigen::internal::less<T>, bool> {}; template <typename T> diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc index 980edffceb..8ad3b4d1fc 100644 --- a/tensorflow/core/kernels/cwise_ops_common.cc +++ b/tensorflow/core/kernels/cwise_ops_common.cc @@ -20,9 +20,9 @@ namespace tensorflow { BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out, DataType in) : OpKernel(ctx) { -#ifndef INTEL_MKL +#if !defined(INTEL_MKL) || !defined(ENABLE_MKL) OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out})); -#endif +#endif // !INTEL_MKL || !ENABLE_MKL } void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) { diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 87efdff789..37c1c54786 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -45,6 +45,16 @@ cc_library( ], ) +tf_cc_test( + name = "dataset_utils_test", + srcs = ["dataset_utils_test.cc"], + deps = [ + ":dataset_utils", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "captured_function", srcs = ["captured_function.cc"], @@ -205,6 +215,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -232,6 +243,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -245,6 +257,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", "//tensorflow/core:framework", @@ -285,6 +298,7 @@ tf_kernel_library( deps = [ ":captured_function", ":dataset", + ":dataset_utils", ":parallel_map_iterator", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:dataset_ops_op_lib", @@ -458,6 +472,7 @@ tf_kernel_library( srcs = ["stats_aggregator_dataset_op.cc"], deps = [ ":dataset", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib_internal", ], @@ -765,6 +780,7 @@ tf_kernel_library( ":window_dataset_op", ":writer_ops", ":zip_dataset_op", + "//tensorflow/core/kernels/data/experimental:dataset_kernels", ], ) diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc index a04f150e71..9607e9444c 100644 --- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc +++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc @@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel { static PartialTensorShape MostSpecificCompatibleShape( const PartialTensorShape& ts1, const PartialTensorShape& ts2) { - PartialTensorShape output_tensorshape; if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) - return output_tensorshape; + return PartialTensorShape(); + PartialTensorShape output_tensorshape({}); auto dims1 = ts1.dim_sizes(); auto dims2 = ts2.dim_sizes(); for (int d = 0; d < ts1.dims(); d++) { if (dims1[d] == dims2[d]) - output_tensorshape.Concatenate(dims1[d]); + output_tensorshape.AddDim(dims1[d]); else - output_tensorshape.Concatenate(-1); + output_tensorshape.AddDim(-1); } return output_tensorshape; } diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc index e10833f525..a40f7f2146 100644 --- a/tensorflow/core/kernels/data/dataset_utils.cc +++ b/tensorflow/core/kernels/data/dataset_utils.cc @@ -15,10 +15,57 @@ limitations under the License. #include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/gtl/cleanup.h" namespace tensorflow { namespace data { +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector<int>* indices) { + FunctionLibraryRuntime::Handle fn_handle; + TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate( + func.name(), AttrSlice(&func.attr()), &fn_handle)); + auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() { + Status s = ctx->function_library()->ReleaseHandle(fn_handle); + if (!s.ok()) { + LOG(WARNING) << "Failed to release handle: " << s.error_message(); + } + }); + + const FunctionBody* fn_body = + ctx->function_library()->GetFunctionBody(fn_handle); + indices->resize(fn_body->ret_nodes.size()); + for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) { + Node* ret_node = fn_body->ret_nodes[i]; + Node* ret_input_node; + TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node)); + if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) { + TF_RETURN_IF_ERROR( + GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i]))); + } else { + indices->clear(); + break; + } + } + return Status::OK(); +} + +std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) { + std::map<int, int> last_use; + for (size_t i = 0; i < indices.size(); ++i) { + last_use[indices[i]] = i; + } + std::vector<bool> can_move; + can_move.resize(indices.size()); + for (size_t i = 0; i < indices.size(); ++i) { + can_move[i] = last_use[indices[i]] == i; + } + return can_move; +} + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h index 6ec1350cd4..d777062293 100644 --- a/tensorflow/core/kernels/data/dataset_utils.h +++ b/tensorflow/core/kernels/data/dataset_utils.h @@ -22,6 +22,26 @@ limitations under the License. namespace tensorflow { namespace data { +// This method is used to determine whether we can short-circuit the evaluation +// of the user-defined function `func`. Short-circuting is possible if every +// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) = +// (y,x)`, or `f(x) = (x,x)`). +// +// If short-circuiting is possible, the method stores the mapping from output +// indices to input indices in `indices`. Otherwise, `indices` will be empty. +// +// Returns non-ok status if analysis of the function fails. +// +// TODO(jsimsa): Extend this to support constants as well. +Status ComputeShortCircuitIndices(OpKernelContext* ctx, + const NameAttrList& func, + std::vector<int>* indices); + +// Given a vector that maps output indices to input indices, return a vector +// that identifies for which output indices can we move the input (assuming +// output indices are processed left to right). +std::vector<bool> ComputeMoveVector(const std::vector<int>& indices); + Status MakeIteratorFromInputElement( IteratorContext* ctx, const std::vector<Tensor>& input_element, int64 thread_index, CapturedFunction* captured_func, StringPiece prefix, diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc new file mode 100644 index 0000000000..43295b8ebb --- /dev/null +++ b/tensorflow/core/kernels/data/dataset_utils_test.cc @@ -0,0 +1,46 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/dataset_utils.h" + +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace data { +namespace { + +TEST(DatasetUtils, ComputeMoveVector) { + struct TestCase { + std::vector<int> indices; + std::vector<bool> expected; + }; + + TestCase test_cases[] = { + TestCase{{}, {}}, + TestCase{{1}, {true}}, + TestCase{{1, 1}, {false, true}}, + TestCase{{1, 2}, {true, true}}, + TestCase{{1, 1, 2}, {false, true, true}}, + TestCase{{1, 2, 2}, {true, false, true}}, + }; + + for (auto& test_case : test_cases) { + EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices)); + } +} + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD new file mode 100644 index 0000000000..43406db3ed --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/BUILD @@ -0,0 +1,139 @@ +# Description: +# Contains experimental kernels for datasets and iterators. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load( + "//tensorflow:tensorflow.bzl", + "tf_kernel_library", +) + +cc_library( + name = "indexed_dataset_headers", + hdrs = ["indexed_dataset.h"], + deps = [ + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "indexed_dataset", + srcs = [ + "identity_indexed_dataset.cc", + "indexed_dataset.cc", + ], + deps = [ + ":indexed_dataset_headers", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "prefetching_kernels", + srcs = ["prefetching_kernels.cc"], + deps = [ + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "directed_interleave_dataset_op", + srcs = ["directed_interleave_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "csv_dataset_op", + srcs = ["csv_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_kernel_library( + name = "ignore_errors_dataset_op", + srcs = ["ignore_errors_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "lmdb_dataset_op", + srcs = ["lmdb_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + "@lmdb", + ], +) + +tf_kernel_library( + name = "threadpool_dataset_op", + srcs = ["threadpool_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "unique_dataset_op", + srcs = ["unique_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "assert_next_dataset_op", + srcs = ["assert_next_dataset_op.cc"], + deps = [ + "//tensorflow/core:experimental_dataset_ops_op_lib", + "//tensorflow/core:framework", + "//third_party/eigen3", + ], +) + +tf_kernel_library( + name = "dataset_kernels", + deps = [ + ":assert_next_dataset_op", + ":csv_dataset_op", + ":directed_interleave_dataset_op", + ":ignore_errors_dataset_op", + ":indexed_dataset", + ":lmdb_dataset_op", + ":prefetching_kernels", + ":threadpool_dataset_op", + ":unique_dataset_op", + ], +) diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc new file mode 100644 index 0000000000..3511cca0f5 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc @@ -0,0 +1,156 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <map> + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. +class AssertNextDatasetOp : public UnaryDatasetOpKernel { + public: + explicit AssertNextDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + protected: + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + std::vector<string> transformations; + OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations", + &transformations)); + *output = + new Dataset(ctx, input, transformations, output_types_, output_shapes_); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const std::vector<string>& transformations, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes) + : DatasetBase(DatasetContext(ctx)), + input_(input), + transformations_(transformations), + output_types_(output_types), + output_shapes_(output_shapes) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Assert")})); + } + + const DataTypeVector& output_dtypes() const override { + return output_types_; + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "AssertNextDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* transformations_node = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, transformations_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + std::vector<string> tokens = + str_util::Split(prefix(), ':', str_util::SkipEmpty()); + if (dataset()->transformations_.size() > tokens.size() - 2) { + return errors::InvalidArgument( + "Asserted next ", dataset()->transformations_.size(), + " transformations but encountered only ", tokens.size() - 2, "."); + } + int n = tokens.size(); + for (size_t i = 0; i < dataset()->transformations_.size(); ++i) { + if (dataset()->transformations_[i] != tokens[n - 2 - i]) { + return errors::InvalidArgument( + "Asserted ", dataset()->transformations_[i], + " transformation at offset ", i, " but encountered ", + tokens[n - 2 - i], " transformation instead."); + } + } + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + return input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + std::unique_ptr<IteratorBase> input_impl_; + }; + + const DatasetBase* input_; + const std::vector<string> transformations_; + const DataTypeVector output_types_; + const std::vector<PartialTensorShape> output_shapes_; + }; + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU), + AssertNextDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc new file mode 100644 index 0000000000..7451ca4cb1 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc @@ -0,0 +1,860 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/parsing_ops.cc. +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/io/inputstream_interface.h" +#include "tensorflow/core/lib/io/random_inputstream.h" +#include "tensorflow/core/lib/io/zlib_compression_options.h" +#include "tensorflow/core/lib/io/zlib_inputstream.h" + +namespace tensorflow { +namespace data { +namespace { + +class CSVDatasetOp : public DatasetOpKernel { + public: + explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + string compression_type; + OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type", + &compression_type)); + + OpInputList record_defaults_list; + OP_REQUIRES_OK(ctx, + ctx->input_list("record_defaults", &record_defaults_list)); + for (int i = 0; i < record_defaults_list.size(); ++i) { + OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1, + errors::InvalidArgument( + "Each record default should be at most rank 1")); + OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2, + errors::InvalidArgument( + "There should only be 1 default per field but field ", i, + " has ", record_defaults_list[i].NumElements())); + } + + const Tensor* select_cols_tensor; + OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor)); + OP_REQUIRES(ctx, select_cols_tensor->dims() == 1, + errors::InvalidArgument("`select_cols` must be a vector.")); + + int64 buffer_size; + OP_REQUIRES_OK( + ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size)); + OP_REQUIRES(ctx, buffer_size > 0, + errors::InvalidArgument("buffer_size should be positive")); + + string delim; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<string>(ctx, "field_delim", &delim)); + OP_REQUIRES(ctx, delim.size() == 1, + errors::InvalidArgument("field_delim should be only 1 char")); + + bool header; + OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header)); + + bool use_quote_delim; + OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim", + &use_quote_delim)); + string na_value; + OP_REQUIRES_OK(ctx, + ParseScalarArgument<string>(ctx, "na_value", &na_value)); + + std::vector<Tensor> record_defaults; + record_defaults.reserve(record_defaults_list.size()); + for (const Tensor& t : record_defaults_list) { + record_defaults.push_back(t); + } + + std::vector<string> filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat<string>()(i)); + } + + io::ZlibCompressionOptions zlib_compression_options = + io::ZlibCompressionOptions::DEFAULT(); + if (compression_type == "ZLIB") { + zlib_compression_options = io::ZlibCompressionOptions::DEFAULT(); + } else if (compression_type == "GZIP") { + zlib_compression_options = io::ZlibCompressionOptions::GZIP(); + } else { + OP_REQUIRES(ctx, compression_type.empty(), + errors::InvalidArgument( + "Unsupported compression_type: ", compression_type, ".")); + } + zlib_compression_options.input_buffer_size = buffer_size; + + std::vector<int64> select_cols; + select_cols.reserve(select_cols_tensor->NumElements()); + for (int i = 0; i < select_cols_tensor->NumElements(); ++i) { + select_cols.push_back(select_cols_tensor->flat<int64>()(i)); + } + OP_REQUIRES( + ctx, output_types_.size() == select_cols.size() || select_cols.empty(), + errors::InvalidArgument("select_cols should match output size")); + for (int i = 1; i < select_cols.size(); i++) { + OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i], + errors::InvalidArgument( + "select_cols should be strictly increasing indices")); + } + OP_REQUIRES( + ctx, select_cols.empty() || select_cols.front() >= 0, + errors::InvalidArgument("select_cols should be non-negative indices")); + + *output = new Dataset(ctx, std::move(filenames), header, + std::move(compression_type), zlib_compression_options, + output_types_, output_shapes_, + std::move(record_defaults), std::move(select_cols), + use_quote_delim, delim[0], std::move(na_value)); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header, + string compression_type, io::ZlibCompressionOptions options, + const DataTypeVector& output_types, + const std::vector<PartialTensorShape>& output_shapes, + std::vector<Tensor> record_defaults, std::vector<int64> select_cols, + bool use_quote_delim, char delim, string na_value) + : DatasetBase(DatasetContext(ctx)), + filenames_(std::move(filenames)), + header_(header), + out_type_(output_types), + output_shapes_(output_shapes), + record_defaults_(std::move(record_defaults)), + select_cols_(std::move(select_cols)), + use_quote_delim_(use_quote_delim), + delim_(delim), + na_value_(std::move(na_value)), + use_compression_(!compression_type.empty()), + compression_type_(std::move(compression_type)), + options_(options) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::CSV")})); + } + + const DataTypeVector& output_dtypes() const override { return out_type_; } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { return "CSVDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + Node* compression_type = nullptr; + Node* buffer_size = nullptr; + Node* header = nullptr; + Node* delim = nullptr; + Node* use_quote_delim = nullptr; + Node* na_value = nullptr; + Node* select_cols = nullptr; + + std::vector<Node*> record_defaults; + record_defaults.reserve(record_defaults_.size()); + for (const Tensor& t : record_defaults_) { + Node* node; + TF_RETURN_IF_ERROR(b->AddTensor(t, &node)); + record_defaults.emplace_back(node); + } + + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type)); + TF_RETURN_IF_ERROR( + b->AddScalar(options_.input_buffer_size, &buffer_size)); + TF_RETURN_IF_ERROR(b->AddScalar(header_, &header)); + + string delim_string(1, delim_); + TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim)); + TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim)); + TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value)); + TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols)); + + TF_RETURN_IF_ERROR(b->AddDataset( + this, + {std::make_pair(0, filenames), std::make_pair(1, compression_type), + std::make_pair(2, buffer_size), std::make_pair(3, header), + std::make_pair(4, delim), std::make_pair(5, use_quote_delim), + std::make_pair(6, na_value), + std::make_pair(7, select_cols)}, // Single tensor inputs + {std::make_pair(8, record_defaults)}, // Tensor list inputs + {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool select_all = dataset()->select_cols_.empty(); + do { + // We are currently processing a file, so try to read the next record + if (input_stream_) { + Status s = ReadRecord(ctx, out_tensors, select_all, + dataset()->select_cols_); + if (s.ok()) { + // Validate output + if (out_tensors->size() != dataset()->out_type_.size()) { + return errors::InvalidArgument( + "Expect ", dataset()->out_type_.size(), " fields but have ", + out_tensors->size(), " in record"); + } + + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { + // Not at the end of file, return OK or non-EOF errors to caller. + *end_of_sequence = false; + return s; + } + // We have reached the end of the current file, so maybe + // move on to next file. + ResetStreamsLocked(); + ++current_file_index_; + } + // Iteration ends when there are no more files to process. + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"), + current_file_index_)); + // `input_stream_` is empty if + // 1. GetNext has not been called even once. + // 2. All files have been read and the iterator has been exhausted. + if (input_stream_ && num_buffer_reads_ > 0) { + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_)); + // If num_buffer_reads_ == 0, the buffer hasn't been filled even once. + TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"), + num_buffer_reads_)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + ResetStreamsLocked(); + int64 current_file_index; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"), + ¤t_file_index)); + current_file_index_ = size_t(current_file_index); + // The keys "pos" and "num_buffer_reads" are written only if + // the iterator was saved with an open, partially read file. + if (reader->Contains(full_name("pos"))) { + int64 pos, num_buffer_reads; + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos)); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"), + &num_buffer_reads)); + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + + num_buffer_reads_ = size_t(num_buffer_reads - 1); + + // Restores the most recently held buffer + Status s = input_stream_->SkipNBytes( + num_buffer_reads_ * dataset()->options_.input_buffer_size); + if (!s.ok() && !errors::IsOutOfRange(s)) { + // We might get out of range error here if the size of the file + // is not an exact multiple of the buffer size, and the last buffer + // read is < buffer_size. This is valid and we do not surface the + // error. + return s; + } + + Status s2 = FillBuffer(&buffer_); + if (!s2.ok() && !errors::IsOutOfRange(s2)) { + return s2; + } + pos_ = size_t(pos); + } + return Status::OK(); + } + + private: + // Reads an entire CSV row from the input stream, either from the + // existing buffer or by filling the buffer as needed. Converts extracted + // fields to output tensors as we go. + // + // When this function is called, pos_ should be the index of the first + // character of the record in buffer_, or past the end of the buffer. + // Note: ctx and out_tensors are only used in this function + // when fields are included in the record. + Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors, + bool select_all, const std::vector<int64>& selected) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // At the end of the file, this will return errors::OutOfRange + TF_RETURN_IF_ERROR(FillBuffer(&buffer_)); + pos_ = 0; + } + + // The first character may be \n if this is the continuation of a + // \r\n linebreak between this and the previous record. If so, skip it. + + bool end_of_record = false; // Keep track of when we find \n, \r or EOF + size_t num_parsed = 0; + size_t num_selected_parsed = 0; + + Status result; + + while (!end_of_record) { // Read till we reach \n, \r or EOF + bool include = + select_all || (num_selected_parsed < selected.size() && + selected[num_selected_parsed] == num_parsed); + + // Don't fail fast, so that the next call to GetNext may still return + // a valid record + result.Update( + ParseOneField(ctx, out_tensors, &end_of_record, include)); + + num_parsed++; + if (include) num_selected_parsed++; + } + + return result; + } + + // Parses one field from position pos_ in the buffer. Fields are + // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of + // the next field. + Status ParseOneField(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + // If we get here, this means the previous field's end coincided + // with the end of the buffer. We can fill the buffer without abandon. + Status s = FillBuffer(&buffer_); + + if (errors::IsOutOfRange(s)) { + // Reached EOF, and last field is empty + *end_of_record = true; + if (include) { + return FieldToOutput(ctx, StringPiece(), out_tensors); + } else { + return Status::OK(); + } + } else if (!s.ok()) { + return s; // Surface other errors back to caller + } + + pos_ = 0; + } + + if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') { + return ParseQuotedField(ctx, out_tensors, end_of_record, include); + } + + return ParseUnquotedField(ctx, out_tensors, end_of_record, include); + } + + // For keeping track of relevant parts of a field from a previous buffer + struct Piece { + size_t start; + size_t len; + string buffer; + + Piece(string buffer, size_t start, size_t len) + : start(start), len(len), buffer(std::move(buffer)) {} + }; + + // Given that pos_ exceeds the buffer, saves the relevant part of the + // current buffer (if necessary), fills the buffer, and resets indices to + // 0. + Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces, + size_t* start, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + string temp_buffer; + + buffer_.swap(temp_buffer); + if (include && pos_ > *start) { + earlier_pieces->push_back( + Piece(std::move(temp_buffer), *start, pos_ - *start)); + } + pos_ = 0; + *start = 0; + return FillBuffer(&buffer_); + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseQuotedField(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector<Piece> earlier_pieces; + size_t start = pos_; + pos_++; // Starting quotation mark + + Status parse_result; + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + return errors::InvalidArgument( + "Reached end of file without closing quoted field in " + "record"); + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + if (ch == '"') { + // When we encounter a quote, we look ahead to the next character to + // decide what to do + pos_++; + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + if (errors::IsOutOfRange(s)) { + // This was the last field. We are done + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(), out_tensors, earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; + } + } + + char next = buffer_[pos_]; + pos_++; + if (next == dataset()->delim_) { + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + return parse_result; + + } else if (next == '\n' || next == '\r') { + *end_of_record = true; + parse_result.Update(QuotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - 1 - start), + out_tensors, earlier_pieces, include)); + if (next == '\r') SkipNewLineIfNecessary(); + return parse_result; + } else if (next != '"') { + // Take note of the error, but keep going to end of field. + include = false; // So we don't get funky errors when trying to + // unescape the quotes. + parse_result.Update(errors::InvalidArgument( + "Quote inside a string has to be escaped by another quote")); + } + + } else { + pos_++; + } + } + } + + // Converts quoted field to an output tensor, removing the starting + // and ending quotes from it and unescaping double quotations if + // necessary. + Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector<Tensor>* out_tensors, + const std::vector<Piece>& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + if (field.find('\"', 1) == field.size() - 1) { + // `field` contains no escaped quotation marks. + // Exclude framing quotation marks + field.remove_prefix(1); + field.remove_suffix(1); + return FieldToOutput(ctx, field, out_tensors); + } + } + string field_complete; + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + field_complete.reserve(str_len); + + // This bool flips every time we see a quote, so that we skip the second + // quote of every pair of adjacent quotes in the field. We need to track + // this across iterations of the for loop because adjacent double quotes + // may be in different buffers. Initialize to true because we also skip + // the opening quotation mark of the quoted field. + bool skip_next_quote = true; + for (const Piece& p : earlier_pieces) { + AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len), + &field_complete, &skip_next_quote); + } + AppendUnescapedPiece(field, &field_complete, &skip_next_quote); + StringPiece result = StringPiece(field_complete); + result.remove_suffix(1); // Skip final quote + + return FieldToOutput(ctx, result, out_tensors); + } + + void AppendUnescapedPiece(StringPiece piece, string* field_complete, + bool* skip_next_quote) { + size_t from = 0; + size_t found = piece.find('\"', from); + while (found != string::npos) { + if (!*skip_next_quote) { + // This is the first quote in a pair of adjacent double quotes + field_complete->append(piece.data() + from, found + 1 - from); + } + *skip_next_quote = !*skip_next_quote; + from = found + 1; + found = piece.find('\"', from); + } + // Include the chunk after the last quotation mark in the string + if (from < piece.size()) { + field_complete->append(piece.data() + from, piece.size() - from); + } + } + + // Parses unquoted field from position pos_ in the buffer. Continually + // reads from buffer until end of field is reached (delim, CRLF, or EOF). + // Advances pos_ to keep track of our position in the buffer as we go, + // stopping at the first character of the next field. + Status ParseUnquotedField(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_record, bool include) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + std::vector<Piece> earlier_pieces; + size_t start = pos_; + Status parse_result; + + while (true) { // Each iter reads 1 char, filling buffer if necessary + if (pos_ >= buffer_.size()) { + Status s = SaveAndFillBuffer(&earlier_pieces, &start, include); + // Handle errors + if (errors::IsOutOfRange(s)) { + // Whatever we have is the last field of the last record + *end_of_record = true; + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + return parse_result; + } else if (!s.ok()) { + return s; // Surface all other errors to caller + } + } + + char ch = buffer_[pos_]; + + if (ch == dataset()->delim_) { + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + pos_++; + return parse_result; + } + if (ch == '\n' || ch == '\r') { + // need special case to skip over first \n of record if the line + // breaks are \r\n + parse_result.Update(UnquotedFieldToOutput( + ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors, + earlier_pieces, include)); + *end_of_record = true; + pos_++; + if (ch == '\r') SkipNewLineIfNecessary(); + return parse_result; + } + if (dataset()->use_quote_delim_ && ch == '"') { + // Take note of the error, but keep going to end of field. + parse_result.Update(errors::InvalidArgument( + "Unquoted fields cannot have quotes inside")); + } + // Otherwise, go to next character + pos_++; + } + } + + Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + result->clear(); + ++num_buffer_reads_; + Status s = input_stream_->ReadNBytes( + dataset()->options_.input_buffer_size, result); + + if (errors::IsOutOfRange(s) && !result->empty()) { + // Ignore OutOfRange error when ReadNBytes read < N bytes. + return Status::OK(); + } + return s; + } + + // Given a field, converts it to the right output tensor type + Status FieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector<Tensor>* out_tensors) { + size_t output_idx = out_tensors->size(); + if (output_idx >= dataset()->out_type_.size()) { + // We can get here if we're selecting all columns, but the number of + // fields exceeds the number of defaults provided + return errors::InvalidArgument("Expect ", dataset()->out_type_.size(), + " fields but have more in record"); + } + const DataType& dtype = dataset()->out_type_[output_idx]; + Tensor component(ctx->allocator({}), dtype, {}); + if ((field.empty() || field == dataset()->na_value_) && + dataset()->record_defaults_[output_idx].NumElements() != 1) { + // If the field is empty or NA value, and default is not given, + // report error. + return errors::InvalidArgument("Field ", output_idx, + " is required but missing in record!"); + } + + switch (dtype) { + // For each case, if the field is empty, we use the default. + // Otherwise, we convert it to the right type. + case DT_INT32: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<int32>()() = + dataset()->record_defaults_[output_idx].flat<int32>()(0); + } else { + int32 value; + if (!strings::safe_strto32(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int32: ", field); + } + component.scalar<int32>()() = value; + } + break; + } + case DT_INT64: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<int64>()() = + dataset()->record_defaults_[output_idx].flat<int64>()(0); + } else { + int64 value; + if (!strings::safe_strto64(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid int64: ", field); + } + component.scalar<int64>()() = value; + } + break; + } + case DT_FLOAT: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<float>()() = + dataset()->record_defaults_[output_idx].flat<float>()(0); + } else { + float value; + if (!strings::safe_strtof(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid float: ", field); + } + component.scalar<float>()() = value; + } + break; + } + case DT_DOUBLE: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<double>()() = + dataset()->record_defaults_[output_idx].flat<double>()(0); + } else { + double value; + if (!strings::safe_strtod(field, &value)) { + return errors::InvalidArgument( + "Field ", output_idx, + " in record is not a valid double: ", field); + } + component.scalar<double>()() = value; + } + break; + } + case DT_STRING: { + if (field.empty() || field == dataset()->na_value_) { + component.scalar<string>()() = + dataset()->record_defaults_[output_idx].flat<string>()(0); + } else { + component.scalar<string>()() = string(field); + } + break; + } + default: + return errors::InvalidArgument("csv: data type ", dtype, + " not supported in field ", + output_idx); + } + out_tensors->push_back(std::move(component)); + return Status::OK(); + } + + // Records can be delimited by "\r\n" line breaks. When we encounter a + // '\r', we have to check the next character to see if it is part of the + // linebreak, and ignore it if so. + void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (pos_ >= buffer_.size()) { + Status s = FillBuffer(&buffer_); + pos_ = 0; + // If we failed to fill buffer, it doesn't matter because we're done + // with the record + if (!s.ok()) return; + } + if (buffer_[pos_] == '\n') { + pos_++; + } + } + + // Given a string field, and its index in the output, + // converts it to a Tensor of the right type and adds it to the + // out_tensors vector. + Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field, + std::vector<Tensor>* out_tensors, + const std::vector<Piece>& earlier_pieces, + bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (!include) return Status::OK(); + + if (earlier_pieces.empty()) { + return FieldToOutput(ctx, field, out_tensors); + } + + size_t str_len = field.size(); + for (const Piece& p : earlier_pieces) { + str_len += p.len; + } + string field_complete; + field_complete.reserve(str_len); + + for (const Piece& p : earlier_pieces) { + field_complete.append(p.buffer, p.start, p.len); + } + + field_complete.append(field.data(), field.size()); + return FieldToOutput(ctx, field_complete, out_tensors); + } + + // Sets up reader streams to read from the file at `current_file_index_`. + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + + // Actually move on to next file. + TF_RETURN_IF_ERROR(env->NewRandomAccessFile( + dataset()->filenames_[current_file_index_], &file_)); + random_access_input_stream_ = + std::make_shared<io::RandomAccessInputStream>(file_.get(), false); + + if (dataset()->use_compression_) { + input_stream_ = std::make_shared<io::ZlibInputStream>( + random_access_input_stream_.get(), + dataset()->options_.input_buffer_size, + dataset()->options_.input_buffer_size, dataset()->options_); + } else { + input_stream_ = random_access_input_stream_; + } + buffer_.clear(); + pos_ = 0; + num_buffer_reads_ = 0; + if (dataset()->header_) { + // Read one line, but don't include it. Pass nullptrs as dummy + // pointers to objects that shouldn't be invoked anyway + // We need to process this as a record here instead of just finding + // the first newline because it might contain quoted fields with + // newlines in the header as well + std::vector<int64> empty; + Status s = ReadRecord(nullptr, nullptr, false, empty); + if (!s.ok()) { + return errors::InvalidArgument("Can't read header of file"); + } + } + return Status::OK(); + } + + // Resets all reader streams. + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + input_stream_.reset(); + file_.reset(); + } + + mutex mu_; + string buffer_ GUARDED_BY(mu_); // Maintain our own buffer + size_t pos_ GUARDED_BY( + mu_); // Index into the buffer must be maintained between iters + size_t num_buffer_reads_ GUARDED_BY(mu_); + std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_ + GUARDED_BY(mu_); + std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_); + size_t current_file_index_ GUARDED_BY(mu_) = 0; + std::unique_ptr<RandomAccessFile> file_ + GUARDED_BY(mu_); // must outlive input_stream_ + }; // class Iterator + + const std::vector<string> filenames_; + const bool header_; + const DataTypeVector out_type_; + const std::vector<PartialTensorShape> output_shapes_; + const std::vector<Tensor> record_defaults_; + const std::vector<int64> select_cols_; + const bool use_quote_delim_; + const char delim_; + const string na_value_; + const bool use_compression_; + const string compression_type_; + const io::ZlibCompressionOptions options_; + }; // class Dataset + + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; +}; // class CSVDatasetOp + +// Register the kernel implementation for CSVDataset. +REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU), + CSVDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc new file mode 100644 index 0000000000..c47a9099c4 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc @@ -0,0 +1,281 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class DirectedInterleaveDatasetOp : public DatasetOpKernel { + public: + explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + DatasetBase* selector_input; + OP_REQUIRES_OK(ctx, + GetDatasetFromVariantTensor(ctx->input(0), &selector_input)); + + OP_REQUIRES( + ctx, + selector_input->output_dtypes().size() == 1 && + selector_input->output_dtypes()[0] == DT_INT64 && + selector_input->output_shapes().size() == 1 && + selector_input->output_shapes()[0].IsCompatibleWith( + PartialTensorShape({})), + errors::InvalidArgument( + "The selector input must be a dataset of scalar int64 elements.")); + + std::vector<DatasetBase*> data_inputs; + for (size_t i = 1; i < ctx->num_inputs(); ++i) { + DatasetBase* input; + OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input)); + data_inputs.push_back(input); + + OP_REQUIRES( + ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(), + errors::InvalidArgument( + "All inputs must have the same output_dtypes. First input " + "has types ", + DataTypeVectorString(data_inputs[0]->output_dtypes()), + ", and input ", i - 1, " has types ", + DataTypeVectorString(input->output_dtypes()))); + } + *output = new Dataset(ctx, selector_input, std::move(data_inputs)); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* selector_input, + std::vector<DatasetBase*> data_inputs) + : DatasetBase(DatasetContext(ctx)), + selector_input_(selector_input), + data_inputs_(std::move(data_inputs)) { + selector_input_->Ref(); + + output_shapes_ = data_inputs_[0]->output_shapes(); + data_inputs_[0]->Ref(); + for (size_t i = 1; i < data_inputs_.size(); ++i) { + const DatasetBase* data_input = data_inputs_[i]; + data_input->Ref(); + for (size_t j = 0; j < output_shapes_.size(); ++j) { + output_shapes_[j] = MostSpecificCompatibleShape( + output_shapes_[j], data_input->output_shapes()[j]); + } + } + } + + ~Dataset() override { + selector_input_->Unref(); + for (DatasetBase* data_input : data_inputs_) { + data_input->Unref(); + } + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::DirectedInterleave")})); + } + + const DataTypeVector& output_dtypes() const override { + return data_inputs_[0]->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return strings::StrCat("DirectedInterleaveDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* selector_input_node; + TF_RETURN_IF_ERROR( + b->AddInputDataset(ctx, selector_input_, &selector_input_node)); + std::vector<Node*> data_input_nodes(data_inputs_.size()); + for (size_t i = 0; i < data_inputs_.size(); ++i) { + TF_RETURN_IF_ERROR( + b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i])); + } + TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}}, + {{1, data_input_nodes}}, {}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params), + num_active_inputs_(params.dataset->data_inputs_.size()) {} + + Status Initialize(IteratorContext* ctx) override { + mutex_lock l(mu_); + TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator( + ctx, strings::StrCat(prefix(), ".selector"), + &selector_input_impl_)); + data_input_impls_.resize(dataset()->data_inputs_.size()); + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const DatasetBase* data_input = dataset()->data_inputs_[i]; + TF_RETURN_IF_ERROR(data_input->MakeIterator( + ctx, strings::StrCat(prefix(), "[", i, "]"), + &data_input_impls_[i])); + } + return Status::OK(); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (!selector_input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + + while (true) { + std::vector<Tensor> selector_result; + *end_of_sequence = false; + TF_RETURN_IF_ERROR(selector_input_impl_->GetNext( + ctx, &selector_result, end_of_sequence)); + if (*end_of_sequence) { + selector_input_impl_.reset(); + for (auto& data_input_impl : data_input_impls_) { + data_input_impl.reset(); + } + return Status::OK(); + } + + int64 selected_input = selector_result[0].scalar<int64>()(); + if (selected_input < 0 || selected_input > data_input_impls_.size()) { + return errors::InvalidArgument( + "Selector index out of range: ", selected_input, + " >= ", data_input_impls_.size()); + } + + if (data_input_impls_[selected_input]) { + bool end_of_selected_input = false; + TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext( + ctx, out_tensors, &end_of_selected_input)); + + if (!end_of_selected_input) { + return Status::OK(); + } + + data_input_impls_[selected_input].reset(); + --num_active_inputs_; + + if (num_active_inputs_ == 0) { + selector_input_impl_.reset(); + *end_of_sequence = true; + return Status::OK(); + } + } + + LOG(WARNING) << "DirectedInterleave selected an exhausted input: " + << selected_input; + } + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (selector_input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("selector_input_impl_empty"), "")); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + const auto& data_input_impl = data_input_impls_[i]; + if (data_input_impl) { + TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl)); + } else { + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name(strings::StrCat("data_input_impl_empty[", i, "]")), + "")); + } + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("selector_input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_)); + } else { + selector_input_impl_.reset(); + } + for (size_t i = 0; i < data_input_impls_.size(); ++i) { + if (!reader->Contains(full_name( + strings::StrCat("data_input_impl_empty[", i, "]")))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i])); + } else { + data_input_impls_[i].reset(); + } + } + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_); + std::vector<std::unique_ptr<IteratorBase>> data_input_impls_ + GUARDED_BY(mu_); + int64 num_active_inputs_ GUARDED_BY(mu_); + }; + + static PartialTensorShape MostSpecificCompatibleShape( + const PartialTensorShape& ts1, const PartialTensorShape& ts2) { + PartialTensorShape output_tensorshape; + if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank()) + return output_tensorshape; + auto dims1 = ts1.dim_sizes(); + auto dims2 = ts2.dim_sizes(); + for (int d = 0; d < ts1.dims(); d++) { + if (dims1[d] == dims2[d]) + output_tensorshape.Concatenate(dims1[d]); + else + output_tensorshape.Concatenate(-1); + } + return output_tensorshape; + } + + const DatasetBase* const selector_input_; + const std::vector<DatasetBase*> data_inputs_; + std::vector<PartialTensorShape> output_shapes_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU), + DirectedInterleaveDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc new file mode 100644 index 0000000000..2141f118ca --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc @@ -0,0 +1,156 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { +namespace data { +namespace { + +class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel { + public: + using IndexedDatasetOpKernel::IndexedDatasetOpKernel; + + void MakeIndexedDataset(OpKernelContext* ctx, + IndexedDataset** output) override { + uint64 size = -1; + OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size)); + OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0")); + *output = new Dataset(ctx, size); + } + + class Dataset : public IndexedDataset { + public: + Dataset(OpKernelContext* ctx, uint64 size) + : IndexedDataset(DatasetContext(ctx)), size_(size) {} + + Status MaterializeDataset( + std::shared_ptr<MaterializedIndexedDataset>* materialized) override { + materialized->reset(new Materialized(this)); + return Status::OK(); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}}); + return *shapes; + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>(new Iterator( + {this, strings::StrCat(prefix, "::IdentityIndexedDataset")})); + } + + string DebugString() const override { + return "IdentityIndexedDataset::Dataset"; + } + + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** node) const override { + return errors::Unimplemented( + "identity_indexed_dataset.AsGraphDefInternal"); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + if (cur_ < dataset()->size_) { + Tensor result_tensor(ctx->allocator({}), DT_UINT64, {}); + result_tensor.scalar<uint64>()() = cur_++; + out_tensors->emplace_back(std::move(result_tensor)); + *end_of_sequence = false; + return Status::OK(); + } + *end_of_sequence = true; + return Status::OK(); + } + + private: + mutex mu_; + uint64 cur_ GUARDED_BY(mu_); + }; + + class Materialized : public MaterializedIndexedDataset { + public: + explicit Materialized(Dataset* dataset) : dataset_(dataset) { + dataset->Ref(); + } + + ~Materialized() override { + // TODO(saeta): Pull this into MaterializedIndexedDataset + dataset_->Unref(); + } + + const DataTypeVector& output_dtypes() const override { + return dataset_->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return dataset_->output_shapes(); + } + + Status Get(IteratorContext&& ctx, uint64 index, + std::vector<Tensor>* out_tensors) const override { + LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index + << ")"; + if (index >= dataset_->size_) { + // Note: use InvalidArgument instead of OutOfRange error because many + // things consider OutOfRange to be a "clean termination" error. + return errors::InvalidArgument( + "Index ", index, + " is out of range for this dataset. (Size is: ", dataset_->size_, + ".)"); + } + Tensor result_tensor(ctx.allocator({}), DT_UINT64, {}); + result_tensor.scalar<uint64>()() = index; + out_tensors->emplace_back(std::move(result_tensor)); + return Status::OK(); + } + + Status Size(uint64* size) const override { + *size = dataset_->size_; + return Status::OK(); + } + + private: + const Dataset* const dataset_; // Not owned. + }; + + const uint64 size_; + std::shared_ptr<Materialized> materialized_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU), + IdentityIndexedDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc new file mode 100644 index 0000000000..b34377c642 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc @@ -0,0 +1,141 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel { + public: + explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public DatasetBase { + public: + explicit Dataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return "IgnoreErrorsDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + { + tf_shared_lock l(mu_); + if (!input_impl_) { + *end_of_sequence = true; + return Status::OK(); + } + Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + while (!s.ok()) { + out_tensors->clear(); + s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence); + } + } + if (*end_of_sequence) { + mutex_lock l(mu_); + input_impl_.reset(); + } + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impl_) + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + else + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impls_empty"), "")); + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (reader->Contains(full_name("input_impls_empty"))) + input_impl_.reset(); + else + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + return Status::OK(); + } + + private: + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + }; +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU), + IgnoreErrorsDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc new file mode 100644 index 0000000000..75ea462f40 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc @@ -0,0 +1,375 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h" + +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/gtl/cleanup.h" + +namespace tensorflow { +namespace data { +namespace { + +Status VerifyTypesMatch(const DataTypeVector& expected, + const DataTypeVector& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " types but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (expected[i] != received[i]) { + return errors::InvalidArgument("Data type mismatch at component ", i, + ": expected ", DataTypeString(expected[i]), + " but got ", DataTypeString(received[i]), + "."); + } + } + return Status::OK(); +} + +Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected, + const std::vector<PartialTensorShape>& received) { + if (expected.size() != received.size()) { + return errors::InvalidArgument( + "Number of components does not match: expected ", expected.size(), + " shapes but got ", received.size(), "."); + } + for (size_t i = 0; i < expected.size(); ++i) { + if (!expected[i].IsCompatibleWith(received[i])) { + return errors::InvalidArgument("Incompatible shapes at component ", i, + ": expected ", expected[i].DebugString(), + " but got ", received[i].DebugString(), + "."); + } + } + + return Status::OK(); +} + +class MaterializedDatasetResource : public ResourceBase { + public: + MaterializedDatasetResource( + const DataTypeVector& output_dtypes, + const std::vector<PartialTensorShape>& output_shapes) + : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {} + + string DebugString() override { + return "Materialized IndexedDataset resource"; + } + + Status Get(IteratorContext&& ctx, uint64 index, + std::vector<Tensor>* out_tensors) { + std::shared_ptr<MaterializedIndexedDataset> captured(materialized_); + if (captured) { + return captured->Get(std::move(ctx), index, out_tensors); + } else { + return errors::FailedPrecondition( + "Get() failed because the MaterializedIndexedDataset has not been " + "initialized. Ensure that you have run the materialization operation " + "for this MaterializedIndexedDataset before retrieving elements."); + } + } + + // TODO(saeta): Implement Save and Restore + + const DataTypeVector& output_dtypes() const { return output_dtypes_; } + const std::vector<PartialTensorShape>& output_shapes() const { + return output_shapes_; + } + + Status set_materialized_dataset( + const std::shared_ptr<MaterializedIndexedDataset>& dataset) { + if (dataset) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, dataset->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, dataset->output_shapes())); + } + materialized_ = dataset; + return Status::OK(); + } + + private: + std::shared_ptr<MaterializedIndexedDataset> materialized_; + const DataTypeVector output_dtypes_; + const std::vector<PartialTensorShape> output_shapes_; +}; + +// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT +// tensor. Objects of the wrapper class own a reference on an instance of an +// `IndexedTensor` and the wrapper's copy constructor and desctructor take care +// of managing the reference count. +// +// NOTE: This is not a feature-complete implementation of the DT_VARIANT +// specification. In particular, we cannot currently serialize an arbitrary +// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not +// implemented. +// +// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just +// use `tensorflow::DatasetVariantWrapper`. +class IndexedDatasetVariantWrapper { + public: + IndexedDatasetVariantWrapper() : dataset_(nullptr) {} + + // Transfers ownership of `dataset` to `*this`. + explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset) + : dataset_(dataset) {} + + IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other) + : dataset_(other.dataset_) { + if (dataset_) dataset_->Ref(); + } + + ~IndexedDatasetVariantWrapper() { + if (dataset_) dataset_->Unref(); + } + + IndexedDataset* get() const { return dataset_; } + + string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; } + string DebugString() const { + if (dataset_) { + return dataset_->DebugString(); + } else { + return "<Uninitialized IndexedDatasetVariantWrapper>"; + } + } + + void Encode(VariantTensorData* data) const { + LOG(ERROR) << "The Encode() method is not implemented for " + "IndexedDatasetVariantWrapper objects."; + } + + bool Decode(const VariantTensorData& data) { + LOG(ERROR) << "The Decode() method is not implemented for " + "IndexedDatasetVariantWrapper objects."; + return false; + } + + private: + IndexedDataset* const dataset_; // Owns one reference. +}; + +} // namespace + +Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, + IndexedDataset** out_dataset) { + if (!(tensor.dtype() == DT_VARIANT || + TensorShapeUtils::IsScalar(tensor.shape()))) { + return errors::InvalidArgument( + "IndexedDataset tensor must be a scalar of dtype DT_VARIANT."); + } + const Variant& variant = tensor.scalar<Variant>()(); + const IndexedDatasetVariantWrapper* wrapper = + variant.get<IndexedDatasetVariantWrapper>(); + if (wrapper == nullptr) { + return errors::InvalidArgument("Tensor must be an IndexedDataset object."); + } + *out_dataset = wrapper->get(); + if (*out_dataset == nullptr) { + return errors::Internal("Read uninitialized IndexedDataset variant."); + } + return Status::OK(); +} + +Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, + Tensor* tensor) { + if (!(tensor->dtype() == DT_VARIANT || + TensorShapeUtils::IsScalar(tensor->shape()))) { + return errors::InvalidArgument( + "Dataset tensor must be a scalar of dtype DT_VARIANT."); + } + tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset); + return Status::OK(); +} + +void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) { + IndexedDataset* dataset = nullptr; + MakeIndexedDataset(ctx, &dataset); + + if (ctx->status().ok()) { + OP_REQUIRES(ctx, dataset != nullptr, + errors::Internal("MakeIndexedDataset did not correctly " + "construct the IndexedDataset")); + Tensor* output = nullptr; + OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output)); + OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output)); + } +} + +namespace { + +class MaterializedHandleOp : public OpKernel { + public: + explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + } + + ~MaterializedHandleOp() override { + if (resource_ != nullptr) { + resource_->Unref(); + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->template Delete<MaterializedDatasetResource>( + cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h + } + } + } + } + + void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) { + { + mutex_lock l(mu_); + if (resource_ == nullptr) { + ResourceMgr* mgr = context->resource_manager(); + OP_REQUIRES_OK(context, cinfo_.Init(mgr, def())); + + MaterializedDatasetResource* resource; + OP_REQUIRES_OK(context, + mgr->LookupOrCreate<MaterializedDatasetResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this](MaterializedDatasetResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new MaterializedDatasetResource( + output_dtypes_, output_shapes_); + return Status::OK(); + })); + Status s = VerifyResource(resource); + if (TF_PREDICT_FALSE(!s.ok())) { + resource->Unref(); + context->SetStatus(s); + return; + } + + resource_ = resource; + } + } + OP_REQUIRES_OK(context, MakeResourceHandleToOutput( + context, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<MaterializedDatasetResource>())); + } + + private: + // During the first Compute(), resource is either created or looked up using + // shared_name. In the latter case, the resource found should be verified if + // it is compatible with this op's configuration. The verification may fail in + // cases such as two graphs asking queues of the same shared name to have + // inconsistent capacities. + Status VerifyResource(MaterializedDatasetResource* resource) { + TF_RETURN_IF_ERROR( + VerifyTypesMatch(output_dtypes_, resource->output_dtypes())); + TF_RETURN_IF_ERROR( + VerifyShapesCompatible(output_shapes_, resource->output_shapes())); + return Status::OK(); + } + + mutex mu_; + ContainerInfo cinfo_; // Written once under mu_ then constant afterwards. + MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr; + DataTypeVector output_dtypes_; + std::vector<PartialTensorShape> output_shapes_; +}; + +// TODO(saeta): Make async. +class MaterializeDatasetOp : public OpKernel { + public: + explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + IndexedDataset* dataset; + OP_REQUIRES_OK(ctx, + GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset)); + + MaterializedDatasetResource* materialized_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &materialized_resource)); + core::ScopedUnref unref(materialized_resource); + std::shared_ptr<MaterializedIndexedDataset> materialized; + OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized)); + OP_REQUIRES_OK( + ctx, materialized_resource->set_materialized_dataset(materialized)); + } +}; + +// TODO(saeta): Make async +class IndexedDatasetGet : public OpKernel { + public: + explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {} + + void Compute(OpKernelContext* ctx) override { + MaterializedDatasetResource* materialized_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), + &materialized_resource)); + auto cleanup = gtl::MakeCleanup([materialized_resource] { + materialized_resource->Unref(); // Note: can't use core::ScopedUnref. + }); + + const Tensor* index_t; + OP_REQUIRES_OK(ctx, ctx->input("index", &index_t)); + // TODO(saeta): Support batch reads (indexes should be non-scalar!) + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()), + errors::InvalidArgument("index must be a scalar")); + const uint64 index = index_t->scalar<uint64>()(); + + std::vector<Tensor> out_tensors; + Status s = + materialized_resource->Get(IteratorContext(ctx), index, &out_tensors); + + // Note: Unref materialized_resource to avoid destruction races. (Important + // in a [future] async op implementation.) + cleanup.release()(); + + if (!s.ok()) { + ctx->SetStatus(s); + } else { + auto expected_shapes = materialized_resource->output_shapes(); + auto expected_types = materialized_resource->output_dtypes(); + for (size_t i = 0; i < out_tensors.size(); ++i) { + OP_REQUIRES( + ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()), + errors::Internal( + "Materialized dataset output at index ", i, + " is incompatible with the expected shape. (Expected: ", + expected_shapes[i], ", got: ", out_tensors[i].shape(), ")")); + OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i], + errors::Internal("Materialized dataset output at index ", i, + " was not the expected dtype. (Expected: ", + expected_types[i], + ", got: ", out_tensors[i].dtype(), ")")); + ctx->set_output(i, out_tensors[i]); + } + } + } +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU), + MaterializedHandleOp); +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU), + MaterializeDatasetOp); +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU), + IndexedDatasetGet); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h new file mode 100644 index 0000000000..27a8360cbc --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h @@ -0,0 +1,119 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +namespace data { + +// TODO(saeta): Urgh, this is ugly. +class MaterializedIndexedDataset { + public: + virtual ~MaterializedIndexedDataset() = default; + + // Retrieve the element at a given index. The output tensors are stored in + // out_tensors. + // + // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is + // returned. + // + // Get is thread-safe. + virtual Status Get(IteratorContext&& ctx, uint64 index, + std::vector<Tensor>* out_tensors) const = 0; + + // Size determines the number of elements in this IndexedDataset. + // + // Size is thread-safe. + virtual Status Size(uint64* size) const = 0; + + // Returns a vector of DataType values, representing the respective + // element types of each tuple component in the outputs of this dataset. + virtual const DataTypeVector& output_dtypes() const = 0; + + // Returns a vector of tensor shapes, representing the respective + // (and possibly partially defined) shapes of each tuple component + // in the outputs of this dataset. + virtual const std::vector<PartialTensorShape>& output_shapes() const = 0; +}; + +// IndexedDataset represents a dataset that supports random access in addition +// to iterator-based sequential access. +// +// Note: IndexedDatasets are HIGHLY experimental at this time. Expect +// significant (backwards incompatible) changes! +class IndexedDataset : public DatasetBase { + public: + IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {} + + // Materialize (if necessary) the dataset, and return a pointer. + // TODO(saeta): Add in `IteratorContext* ctx` when materializing. + virtual Status MaterializeDataset( + std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0; +}; + +// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the +// rest of the TensorFlow runtime. +// +// Most IndexedDataset's will be private members of classes inheriting from this +// class. +class IndexedDatasetOpKernel : public OpKernel { + public: + IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {} + void Compute(OpKernelContext* ctx) final; + + protected: + // Subclasses should implement this method. It will be called during Compute + // execution. + virtual void MakeIndexedDataset(OpKernelContext* ctx, + IndexedDataset** output) = 0; + + template <typename T> + Status ParseScalarArgument(OpKernelContext* ctx, + const StringPiece& argument_name, T* output) { + const Tensor* argument_t; + TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t)); + if (!TensorShapeUtils::IsScalar(argument_t->shape())) { + return errors::InvalidArgument(argument_name, " must be a scalar"); + } + *output = argument_t->scalar<T>()(); + return Status::OK(); + } +}; + +// Validates and extracts an `IndexedDataset` object from `tensor`. +// +// `tensor` must have been written by a call to +// `StoreIndexedDatasetInVariantTensor` +// +// The retrieved pointer isa borrowed reference to the dataset, which is owned +// by the tensor. The consumer must either acquire its own reference to the +// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not +// destroyed or mutated while the retrieved pointer is in use. +Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor, + IndexedDataset** out_dataset); + +// Stores an `IndexedDataset` object in `tensor.` +// +// The ownership of `dataset` is transferred to `tensor`. +Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset, + Tensor* tensor); + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_ diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc new file mode 100644 index 0000000000..8a88d32f0c --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc @@ -0,0 +1,218 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include <sys/stat.h> + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/platform/file_system.h" + +#include "lmdb.h" // NOLINT(build/include) + +namespace tensorflow { +namespace data { +namespace { + +class LMDBDatasetOp : public DatasetOpKernel { + public: + using DatasetOpKernel::DatasetOpKernel; + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override { + const Tensor* filenames_tensor; + OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor)); + OP_REQUIRES( + ctx, filenames_tensor->dims() <= 1, + errors::InvalidArgument("`filenames` must be a scalar or a vector.")); + + std::vector<string> filenames; + filenames.reserve(filenames_tensor->NumElements()); + for (int i = 0; i < filenames_tensor->NumElements(); ++i) { + filenames.push_back(filenames_tensor->flat<string>()(i)); + } + + *output = new Dataset(ctx, filenames); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const std::vector<string>& filenames) + : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {} + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::LMDB")})); + } + + const DataTypeVector& output_dtypes() const override { + static DataTypeVector* dtypes = + new DataTypeVector({DT_STRING, DT_STRING}); + return *dtypes; + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + static std::vector<PartialTensorShape>* shapes = + new std::vector<PartialTensorShape>({{}, {}}); + return *shapes; + } + + string DebugString() const override { return "LMDBDatasetOp::Dataset"; } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filenames = nullptr; + TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + do { + if (mdb_cursor_) { + Tensor key_tensor(ctx->allocator({}), DT_STRING, {}); + key_tensor.scalar<string>()() = string( + static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size); + out_tensors->emplace_back(std::move(key_tensor)); + + Tensor value_tensor(ctx->allocator({}), DT_STRING, {}); + value_tensor.scalar<string>()() = + string(static_cast<const char*>(mdb_value_.mv_data), + mdb_value_.mv_size); + out_tensors->emplace_back(std::move(value_tensor)); + + int val; + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + ++current_file_index_; + } + *end_of_sequence = false; + return Status::OK(); + } + if (current_file_index_ == dataset()->filenames_.size()) { + *end_of_sequence = true; + return Status::OK(); + } + + TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env())); + } while (true); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented( + "Checkpointing is currently not supported for LMDBDataset."); + } + + private: + Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (current_file_index_ >= dataset()->filenames_.size()) { + return errors::InvalidArgument( + "current_file_index_:", current_file_index_, + " >= filenames_.size():", dataset()->filenames_.size()); + } + const string& filename = dataset()->filenames_[current_file_index_]; + + int val = mdb_env_create(&mdb_env_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK; + + struct stat source_stat; + if (stat(filename.c_str(), &source_stat) == 0 && + (source_stat.st_mode & S_IFREG)) { + flags |= MDB_NOSUBDIR; + } + val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_); + if (val != MDB_SUCCESS) { + return errors::InvalidArgument(mdb_strerror(val)); + } + val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST); + if (val != MDB_SUCCESS && val != MDB_NOTFOUND) { + return errors::InvalidArgument(mdb_strerror(val)); + } + if (val == MDB_NOTFOUND) { + ResetStreamsLocked(); + } + return Status::OK(); + } + void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (mdb_env_ != nullptr) { + if (mdb_cursor_) { + mdb_cursor_close(mdb_cursor_); + mdb_cursor_ = nullptr; + } + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + mdb_env_close(mdb_env_); + mdb_txn_ = nullptr; + mdb_dbi_ = 0; + mdb_env_ = nullptr; + } + } + mutex mu_; + size_t current_file_index_ GUARDED_BY(mu_) = 0; + MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr; + MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr; + MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0; + MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr; + + MDB_val mdb_key_ GUARDED_BY(mu_); + MDB_val mdb_value_ GUARDED_BY(mu_); + }; + + const std::vector<string> filenames_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU), + LMDBDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc new file mode 100644 index 0000000000..2c6179d9f5 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc @@ -0,0 +1,482 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include <deque> + +#include "tensorflow/core/common_runtime/process_function_library_runtime.h" +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { +namespace data { +namespace { + +struct BufferElement { + // The producer sets `status` if getting the input element fails. + Status status; + // The buffered data element. + std::vector<Tensor> value; +}; + +using FunctionBufferCallback = std::function<void(const BufferElement&)>; + +class FunctionBufferingResource : public ResourceBase { + public: + FunctionBufferingResource(FunctionLibraryRuntime* lib, + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr, + const NameAttrList& func, int64 buffer_size, + const string& source_device, + const string& target_device, + const std::vector<Tensor>& func_args, + const DataTypeVector& output_types) + : lib_(lib), + pflr_(std::move(pflr)), + func_(func), + buffer_size_(buffer_size), + source_device_(source_device), + target_device_(target_device), + func_args_(func_args), + output_types_(output_types), + handle_(kInvalidHandle), + is_buffering_(false), + end_of_sequence_(false), + cancelled_(false) {} + + ~FunctionBufferingResource() override { + Cancel(); + } + + string DebugString() override { + return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_, + "; target_device: ", target_device_); + } + + // Instantiates the function the first time it's called. After that it caches + // the handle. + Status Instantiate() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + // Re-use existing handle if it's been set, effectively caching it. + if (handle_ != kInvalidHandle) { + return Status::OK(); + } + AttrValueMap attr_values = func_.attr(); + FunctionLibraryRuntime::InstantiateOptions opts; + opts.target = target_device_; + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts, + &handle_); + } + + // Returns true if we've got to the end of the sequence and exhausted the + // buffer. + bool Finished() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return end_of_sequence_ && buffer_.empty(); + } + + // Cancels any buffering / prefetching going on. + void Cancel() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + cancelled_ = true; + while (is_buffering_) { + cond_var_.wait(l); + } + } + + // Cancels all pending operations and then clears out the state. + void Reset() LOCKS_EXCLUDED(mu_) { + Cancel(); + mutex_lock l(mu_); + buffer_.clear(); + requests_.clear(); + is_buffering_ = false; + end_of_sequence_ = false; + cancelled_ = false; + } + + // If the buffer has anything, runs `callback` on the first element in the + // buffer, else schedules the `callback` to be called. Requires `args` and + // `lib` in case more function calls need to be scheduled. + void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) { + bool start_buffering = false; + bool produced_output = false; + BufferElement buffer_element; + { + mutex_lock l(mu_); + if (!is_buffering_ && !end_of_sequence_) { + start_buffering = true; + } + if (!buffer_.empty()) { + produced_output = true; + std::swap(buffer_element, buffer_.front()); + buffer_.pop_front(); + } else { + produced_output = false; + requests_.push_back(std::move(callback)); + } + } + if (produced_output) { + callback(buffer_element); + } + if (start_buffering) { + FillBuffer(); + } + } + + private: + void FillBuffer() LOCKS_EXCLUDED(mu_) { + FunctionLibraryRuntime::Handle handle; + std::vector<FunctionBufferCallback> cancellation_callbacks; + std::vector<BufferElement> cancellation_buffer_elements; + bool cancelled = false; + { + mutex_lock l(mu_); + handle = handle_; + if (cancelled_) { + cancelled = true; + // Run through and fulfill all pending requests, if possible. + while (!requests_.empty()) { + if (!buffer_.empty()) { + cancellation_buffer_elements.push_back(std::move(buffer_.front())); + buffer_.pop_front(); + cancellation_callbacks.push_back(std::move(requests_.front())); + requests_.pop_front(); + } else { + LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: " + << requests_.size() << " requests"; + break; + } + } + is_buffering_ = false; + } else { + is_buffering_ = true; + } + } + if (cancelled) { + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_buffer_elements[i]); + } + cond_var_.notify_all(); + return; + } + FunctionLibraryRuntime::Options opts; + // Copied from CapturedFunction::generate_step_id(); + opts.step_id = -std::abs(static_cast<int64>(random::New64())); + opts.source_device = source_device_; + AllocatorAttributes arg_alloc_attr; + arg_alloc_attr.set_on_host(true); + opts.args_alloc_attrs.push_back(arg_alloc_attr); + for (const auto& dtype : output_types_) { + AllocatorAttributes ret_alloc_attrs; + if (DataTypeAlwaysOnHost(dtype)) { + ret_alloc_attrs.set_on_host(true); + } + opts.rets_alloc_attrs.push_back(ret_alloc_attrs); + } + if (opts.source_device != target_device_) { + opts.remote_execution = true; + } + opts.create_rendezvous = true; + auto* rets = new std::vector<Tensor>; + lib_->Run(opts, handle, func_args_, rets, + [this, rets](const Status& status) { + FunctionBufferCallback callback = nullptr; + BufferElement buffer_front; + bool restart_buffering = false; + { + mutex_lock l(mu_); + BufferElement buffer_element; + buffer_element.status = status; + if (status.ok()) { + buffer_element.value.swap(*rets); + } else { + end_of_sequence_ = true; + is_buffering_ = false; + } + buffer_.push_back(std::move(buffer_element)); + if (!requests_.empty()) { + buffer_front = std::move(buffer_.front()); + buffer_.pop_front(); + callback = std::move(requests_.front()); + requests_.pop_front(); + } + if (buffer_.size() < buffer_size_ && !end_of_sequence_) { + restart_buffering = true; + } else { + // When the buffer is full, we don't want to call + // FillBuffer() unless we're in cancellation phase in which + // case FillBuffer() will do the final cleanup post + // cancellation. + if (cancelled_) { + restart_buffering = true; + } + is_buffering_ = false; + } + } + if (callback != nullptr) { + callback(buffer_front); + } + if (restart_buffering) { + FillBuffer(); + } + }); + } + + mutex mu_; + FunctionLibraryRuntime* lib_; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_; + NameAttrList func_; + const int64 buffer_size_; + const string source_device_; + const string target_device_; + const std::vector<Tensor> func_args_; + const DataTypeVector output_types_; + FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); + std::deque<BufferElement> buffer_ GUARDED_BY(mu_); + std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_); + bool is_buffering_ GUARDED_BY(mu_); + bool end_of_sequence_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_); + condition_variable cond_var_; +}; + +class FunctionBufferResourceHandleOp : public OpKernel { + public: + explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx), flib_def_(nullptr) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + } + + ~FunctionBufferResourceHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<FunctionBufferingResource>(cinfo_.container(), + cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* string_arg; + OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg)); + std::vector<Tensor> func_args; + func_args.push_back(*string_arg); + + const string& source_device = ctx->device()->name(); + + // Obtain and canonicalize target_device. + const Tensor* target_arg; + OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); + string target_device; + OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName( + target_arg->scalar<string>()(), source_device, + &target_device)); + + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, + errors::Internal("No function library is provided.")); + + mutex_lock l(mu_); + if (!initialized_) { + OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def())); + FunctionLibraryRuntime* clone_lib; + std::unique_ptr<ProcessFunctionLibraryRuntime> pflr; + OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib)); + // Create the resource. + FunctionBufferingResource* buffer; + OP_REQUIRES_OK( + ctx, + ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>( + cinfo_.container(), cinfo_.name(), &buffer, + [clone_lib, &pflr, &source_device, &target_device, func_args, + this](FunctionBufferingResource** ptr) { + *ptr = new FunctionBufferingResource( + clone_lib, std::move(pflr), func_, buffer_size_, + source_device, target_device, func_args, output_types_); + return Status::OK(); + })); + core::ScopedUnref s(buffer); + OP_REQUIRES_OK(ctx, buffer->Instantiate()); + initialized_ = true; + } + + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<FunctionBufferingResource>())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + std::unique_ptr<FunctionLibraryDefinition> flib_def_; + NameAttrList func_; + int64 buffer_size_; + string container_; + string name_; + DataTypeVector output_types_; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource") + .Device(DEVICE_CPU) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource") + .Device(DEVICE_GPU) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource") + .Device(DEVICE_SYCL) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +#endif // TENSORFLOW_USE_SYCL + +// Prefetches and fills up a buffer by calling a function that provides the +// elements to buffer. +class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { + public: + explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx) {} + + ~FunctionBufferingResourceGetNextOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ResourceHandle handle; + OP_REQUIRES_OK_ASYNC( + ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer), + done); + + if (buffer->Finished()) { + buffer->Unref(); + ctx->SetStatus(errors::OutOfRange("end_of_sequence")); + done(); + return; + } + + FunctionBufferCallback callback = + [ctx, buffer, done](const BufferElement& buffer_element) { + Status s = buffer_element.status; + if (!s.ok()) { + ctx->SetStatus(s); + buffer->Unref(); + done(); + return; + } + for (size_t i = 0; i < buffer_element.value.size(); ++i) { + ctx->set_output(i, buffer_element.value[i]); + } + buffer->Unref(); + done(); + }; + buffer->MaybeGet(std::move(callback)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +#endif // TENSORFLOW_USE_SYCL + +// Resets the FunctionBufferingResource, cancelling all pending requests and +// clearing out the buffer. +class FunctionBufferingResourceResetOp : public OpKernel { + public: + explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx) + : OpKernel(ctx) {} + + ~FunctionBufferingResourceResetOp() override {} + + void Compute(OpKernelContext* ctx) override { + ResourceHandle handle; + OP_REQUIRES_OK(ctx, + HandleFromInput(ctx, "function_buffer_resource", &handle)); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK( + ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer)); + core::ScopedUnref s(buffer); + + buffer->Reset(); + } +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceResetOp); +#endif // TENSORFLOW_USE_SYCL + +class IteratorGetDeviceOp : public OpKernel { + public: + using OpKernel::OpKernel; + + void Compute(OpKernelContext* ctx) override { + // NOTE(mrry): We do not currently Validate that the handle + // corresponds to a real IteratorResource, because that symbol is + // not exposed from the framework library. + Tensor* device_name_t; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, TensorShape({}), &device_name_t)); + // NOTE(mrry): Since the operation's input is a resource, we must be + // colocated with it, and so we can simply return the current device's + // name without looking at the input. + device_name_t->scalar<string>()() = ctx->device()->name(); + } +}; + +REGISTER_KERNEL_BUILDER( + Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU), + IteratorGetDeviceOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc new file mode 100644 index 0000000000..8d561ca0e3 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc @@ -0,0 +1,220 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/util/work_sharder.h" + +namespace tensorflow { +namespace data { +namespace { + +class ThreadPoolResource : public ResourceBase { + public: + ThreadPoolResource(Env* env, const ThreadOptions& thread_options, + const string& name, int num_threads, bool low_latency_hint, + int max_intra_op_parallelism) + : thread_pool_(env, thread_options, name, num_threads, low_latency_hint), + max_intra_op_parallelism_(max_intra_op_parallelism) {} + + // Schedules fn() for execution in the pool of threads. + void Schedule(std::function<void()> fn) { + if (max_intra_op_parallelism_ < 0) { + thread_pool_.Schedule(std::move(fn)); + } else { + thread_pool_.Schedule(std::bind( + [this](std::function<void()> bound_fn) { + // TODO(mrry): Consider moving this thread-local configuration to + // the threads themselves. + ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_); + bound_fn(); + }, + std::move(fn))); + } + } + + string DebugString() override { return "ThreadPoolResource"; } + + private: + thread::ThreadPool thread_pool_; + const int max_intra_op_parallelism_; +}; + +// Creates a handle to a ThreadPool resource. Note that we don't use +// ResourceOpKernel here because the ThreadPoolResource constructor requires +// access to `OpKernelContext::env()`, which isn't provided by +// `ResourceOpKernel<T>::CreateResource()`. +class ThreadPoolHandleOp : public OpKernel { + public: + explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism", + &max_intra_op_parallelism_)); + OP_REQUIRES( + ctx, num_threads_ > 0, + errors::InvalidArgument("`num_threads` must be greater than zero.")); + } + + // The resource is deleted from the resource manager only when it is private + // to kernel. Ideally the resource should be deleted when it is no longer held + // by anyone, but it would break backward compatibility. + ~ThreadPoolHandleOp() override { + if (cinfo_.resource_is_private_to_kernel()) { + if (!cinfo_.resource_manager() + ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name()) + .ok()) { + // Do nothing; the resource can have been deleted by session resets. + } + } + } + + void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + if (!initialized_) { + ResourceMgr* mgr = ctx->resource_manager(); + OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def())); + ThreadPoolResource* resource; + OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>( + cinfo_.container(), cinfo_.name(), &resource, + [this, ctx](ThreadPoolResource** ret) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + *ret = new ThreadPoolResource( + ctx->env(), {}, display_name_, + num_threads_, max_intra_op_parallelism_, + false /* low_latency_hint */); + return Status::OK(); + })); + initialized_ = true; + } + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo_.container(), cinfo_.name(), + MakeTypeIndex<ThreadPoolResource>())); + } + + private: + mutex mu_; + ContainerInfo cinfo_ GUARDED_BY(mu_); + bool initialized_ GUARDED_BY(mu_) = false; + string display_name_; + int num_threads_; + int max_intra_op_parallelism_; +}; + +class ThreadPoolDatasetOp : public UnaryDatasetOpKernel { + public: + explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + ThreadPoolResource* threadpool_resource; + OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), + &threadpool_resource)); + core::ScopedUnref unref_iterator(threadpool_resource); + + *output = new Dataset(ctx, input, threadpool_resource); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input, + ThreadPoolResource* threadpool) + : DatasetBase(DatasetContext(ctx)), + input_(input), + threadpool_(threadpool) { + input_->Ref(); + threadpool_->Ref(); + } + + ~Dataset() override { + input_->Unref(); + threadpool_->Unref(); + } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::ThreadPool")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return "ThreadPoolDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + return errors::Unimplemented("%s does not support serialization", + DebugString()); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + ThreadPoolResource* pool = dataset()->threadpool_; + IteratorContext::Params params; + params.env = ctx->env(); + params.runner = [pool](std::function<void()> c) { + pool->Schedule(std::move(c)); + }; + params.stats_aggregator = ctx->stats_aggregator(); + params.lib = ctx->lib(); + params.function_library = ctx->function_library(); + params.allocator_getter = ctx->allocator_getter(); + IteratorContext threadpool_ctx(params); + return input_impl_->GetNext(&threadpool_ctx, out_tensors, + end_of_sequence); + } + + private: + std::unique_ptr<IteratorBase> input_impl_; + }; + + const DatasetBase* const input_; + ThreadPoolResource* const threadpool_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU), + ThreadPoolHandleOp); +REGISTER_KERNEL_BUILDER( + Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU), + ThreadPoolDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc new file mode 100644 index 0000000000..cd612e0eb2 --- /dev/null +++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc @@ -0,0 +1,224 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/hash/hash.h" + +namespace tensorflow { +namespace data { +namespace { + +// See documentation in ../ops/dataset_ops.cc for a high-level +// description of the following op. + +class UniqueDatasetOp : public UnaryDatasetOpKernel { + public: + explicit UniqueDatasetOp(OpKernelConstruction* ctx) + : UnaryDatasetOpKernel(ctx) {} + + void MakeDataset(OpKernelContext* ctx, DatasetBase* input, + DatasetBase** output) override { + OP_REQUIRES(ctx, input->output_dtypes().size() == 1, + errors::InvalidArgument("UniqueDataset only supports " + "inputs with a single component.")); + + DataType input_dtype = input->output_dtypes()[0]; + OP_REQUIRES(ctx, + input_dtype == DT_INT32 || input_dtype == DT_INT64 || + input_dtype == DT_STRING, + errors::InvalidArgument( + "UniqueDataset only supports inputs with a single " + "`tf.int32`, `tf.int64`, or `tf.string` component.")); + + *output = new Dataset(ctx, input); + } + + private: + class Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const DatasetBase* input) + : DatasetBase(DatasetContext(ctx)), input_(input) { + input_->Ref(); + } + + ~Dataset() override { input_->Unref(); } + + std::unique_ptr<IteratorBase> MakeIteratorInternal( + const string& prefix) const override { + return std::unique_ptr<IteratorBase>( + new Iterator({this, strings::StrCat(prefix, "::Unique")})); + } + + const DataTypeVector& output_dtypes() const override { + return input_->output_dtypes(); + } + + const std::vector<PartialTensorShape>& output_shapes() const override { + return input_->output_shapes(); + } + + string DebugString() const override { + return strings::StrCat("UniqueDatasetOp::Dataset"); + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output)); + return Status::OK(); + } + + private: + class Iterator : public DatasetIterator<Dataset> { + public: + explicit Iterator(const typename Iterator::Params& params) + : DatasetIterator<Dataset>(params) {} + + Status Initialize(IteratorContext* ctx) override { + return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_); + } + + Status GetNextInternal(IteratorContext* ctx, + std::vector<Tensor>* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + bool saw_new_value; + do { + saw_new_value = false; + out_tensors->clear(); + TF_RETURN_IF_ERROR( + input_impl_->GetNext(ctx, out_tensors, end_of_sequence)); + if (*end_of_sequence) { + break; + } + DCHECK_EQ(1, out_tensors->size()); + saw_new_value = unique_elements_.insert((*out_tensors)[0]).second; + } while (!saw_new_value); + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + mutex_lock l(mu_); + if (input_impl_) { + TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); + } else { + TF_RETURN_IF_ERROR( + writer->WriteScalar(full_name("input_impl_empty"), "")); + } + TF_RETURN_IF_ERROR(writer->WriteScalar( + full_name("unique_elements_size"), unique_elements_.size())); + size_t i = 0; + for (const Tensor& t : unique_elements_) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("unique_elements[", i++, "]")), t)); + } + return Status::OK(); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + mutex_lock l(mu_); + if (!reader->Contains(full_name("input_impl_empty"))) { + TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); + } else { + input_impl_.reset(); + } + int64 num_unique_elements; + unique_elements_.clear(); + TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"), + &num_unique_elements)); + for (int64 i = 0; i < num_unique_elements; ++i) { + Tensor unique_element; + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("unique_elements[", i, "]")), + &unique_element)); + auto insert_result = unique_elements_.insert(unique_element); + if (!insert_result.second) { + return errors::InvalidArgument( + "Checkpoint contained two unique elements with the same " + "value."); + } + } + return Status::OK(); + } + + private: + struct TensorHash { + size_t operator()(const Tensor& t) const { + if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) { + return Hash64(t.tensor_data().data(), t.tensor_data().size()); + } else { + DCHECK_EQ(DT_STRING, t.dtype()); + auto flat_t = t.flat<string>(); + uint64 hash = 0; + for (int64 i = 0; i < t.NumElements(); ++i) { + hash = Hash64Combine(hash, Hash64(flat_t(i))); + } + return static_cast<size_t>(hash); + } + } + }; + + struct TensorKeyEqual { + bool operator()(const Tensor& lhs, const Tensor& rhs) const { + if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) { + return false; + } + switch (lhs.dtype()) { +#define HANDLE_TYPE(T) \ + case T: \ + do { \ + auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \ + auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \ + for (int64 i = 0; i < lhs.NumElements(); ++i) { \ + if (lhs_flat(i) != rhs_flat(i)) { \ + return false; \ + } \ + } \ + return true; \ + } while (0) + + HANDLE_TYPE(DT_INT32); + HANDLE_TYPE(DT_INT64); + HANDLE_TYPE(DT_STRING); + default: + DCHECK(false) << "UniqueDataset unhandled data type: " + << DataTypeString(lhs.dtype()); + return false; + } + } + }; + + mutex mu_; + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_ + GUARDED_BY(mu_); + }; + + const DatasetBase* const input_; + }; +}; + +REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU), + UniqueDatasetOp); + +} // namespace +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc index 00884314a9..be7d182a1f 100644 --- a/tensorflow/core/kernels/data/filter_dataset_op.cc +++ b/tensorflow/core/kernels/data/filter_dataset_op.cc @@ -18,9 +18,11 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -31,67 +33,84 @@ namespace { class FilterDatasetOp : public UnaryDatasetOpKernel { public: + using FilterIteratorPredicate = + std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>; + explicit FilterDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_)); } void MakeDataset(OpKernelContext* ctx, DatasetBase* input, DatasetBase** output) override { - FunctionLibraryRuntime::Handle pred_handle; - OP_REQUIRES_OK(ctx, - ctx->function_library()->Instantiate( - func_.name(), AttrSlice(&func_.attr()), &pred_handle)); - auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() { - OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle)); - }); - - const FunctionBody* pred_body = - ctx->function_library()->GetFunctionBody(pred_handle); - OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1, - errors::InvalidArgument( - "predicate function must have a single return value.")); - Node* ret_node = pred_body->ret_nodes[0]; - Node* ret_input_node; - OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node)); - std::unique_ptr<CapturedFunction> captured_func; OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - if (ret_input_node->def().op() == "_Arg") { - int32 index = -1; - OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index)); - *output = new FilterTensorDataset(ctx, input, func_, - std::move(captured_func), index); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + OP_REQUIRES(ctx, indices.size() <= 1, + errors::InvalidArgument( + "predicate function has more than one return value.")); + + FilterIteratorPredicate filter_pred; + if (indices.empty()) { + CapturedFunction* raw_captured_func = captured_func.get(); + filter_pred = [raw_captured_func](IteratorContext* ctx, + const std::vector<Tensor>& args, + bool* out_matched) { + std::vector<Tensor> result; + TF_RETURN_IF_ERROR( + raw_captured_func->RunWithBorrowedArgs(ctx, args, &result)); + + if (result.size() != 1 || result[0].dtype() != DT_BOOL || + result[0].NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = result[0].scalar<bool>()(); + return Status::OK(); + }; } else { - *output = new FilterFunctionDataset(ctx, input, func_, - std::move(captured_func)); + filter_pred = [indices](IteratorContext* ctx, + const std::vector<Tensor>& args, + bool* out_matched) { + const Tensor& predicate = args[indices[0]]; + if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { + return errors::InvalidArgument( + "Filter predicate `f` must return a scalar bool."); + } + *out_matched = predicate.scalar<bool>()(); + return Status::OK(); + }; } + + *output = new Dataset(ctx, input, func_, std::move(captured_func), + std::move(filter_pred)); } private: - const int graph_def_version_; - - class FilterDatasetBase : public DatasetBase { + class Dataset : public DatasetBase { public: - FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr<CapturedFunction> captured_func) + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, + std::unique_ptr<CapturedFunction> captured_func, + FilterIteratorPredicate filter_pred) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + filter_pred_(std::move(filter_pred)) { input_->Ref(); } - ~FilterDatasetBase() override { input_->Unref(); } + ~Dataset() override { input_->Unref(); } std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Filter")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::Filter")}, + filter_pred_); } const DataTypeVector& output_dtypes() const override { @@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - virtual Status EvaluatePredicate(IteratorContext* ctx, - const std::vector<Tensor>& element, - bool* out_matched) const = 0; - private: - class Iterator : public DatasetIterator<FilterDatasetBase> { + class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) - : DatasetIterator<FilterDatasetBase>(params), + explicit Iterator(const Params& params, + FilterIteratorPredicate filter_pred) + : DatasetIterator<Dataset>(params), filtered_elements_(0), - dropped_elements_(0) { + dropped_elements_(0), + filter_pred_(std::move(filter_pred)) { std::vector<string> components = str_util::Split(params.prefix, "::", str_util::SkipEmpty()); prefix_end_ = components.back(); @@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - TF_RETURN_IF_ERROR( - dataset()->EvaluatePredicate(ctx, *out_tensors, &matched)); + TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched)); if (!matched) { // Clear the output tensor list since it didn't match. out_tensors->clear(); @@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); int64 filtered_elements_ GUARDED_BY(mu_); int64 dropped_elements_ GUARDED_BY(mu_); + const FilterIteratorPredicate filter_pred_; string prefix_end_; }; const DatasetBase* const input_; const NameAttrList func_; - - protected: const std::unique_ptr<CapturedFunction> captured_func_; - }; - - class FilterFunctionDataset : public FilterDatasetBase { - public: - using FilterDatasetBase::FilterDatasetBase; - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector<Tensor>& element, - bool* out_matched) const override { - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - std::vector<Tensor> result; - TF_RETURN_IF_ERROR( - captured_func_->RunWithBorrowedArgs(ctx, element, &result)); - - if (result.size() != 1 || result[0].dtype() != DT_BOOL || - result[0].NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = result[0].scalar<bool>()(); - return Status::OK(); - } - }; - - class FilterTensorDataset : public FilterDatasetBase { - public: - FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input, - const NameAttrList& func, - std::unique_ptr<CapturedFunction> captured_func, - int32 index) - : FilterDatasetBase(ctx, input, func, std::move(captured_func)), - index_(index) {} - - protected: - Status EvaluatePredicate(IteratorContext* ctx, - const std::vector<Tensor>& element, - bool* out_matched) const override { - const Tensor& predicate = element[index_]; - if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) { - return errors::InvalidArgument( - "Filter predicate `f` must return a scalar bool."); - } - *out_matched = predicate.scalar<bool>()(); - return Status::OK(); - } - - private: - const int32 index_; + const FilterIteratorPredicate filter_pred_; }; private: diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc index 71a36314a0..b4367d5a11 100644 --- a/tensorflow/core/kernels/data/generator_dataset_op.cc +++ b/tensorflow/core/kernels/data/generator_dataset_op.cc @@ -86,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx)); TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx)); TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx)); - TF_RETURN_IF_ERROR( - dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); return Status::OK(); } @@ -96,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { bool* end_of_sequence) override { mutex_lock l(mu_); + if (!initialized_) { + TF_RETURN_IF_ERROR( + dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_)); + initialized_ = true; + } + if (finalized_) { *end_of_sequence = true; return Status::OK(); @@ -123,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase { private: mutex mu_; + bool initialized_ GUARDED_BY(mu_) = false; bool finalized_ GUARDED_BY(mu_) = false; std::vector<Tensor> state_ GUARDED_BY(mu_); }; diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc index d6ee42a7c6..e7244ee208 100644 --- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc @@ -30,8 +30,7 @@ namespace { class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { public: explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); @@ -421,7 +420,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList key_func_; diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc index 8b417bb1c2..14aefe5d54 100644 --- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc +++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc @@ -31,8 +31,7 @@ namespace { class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { public: explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_)); @@ -507,7 +506,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; NameAttrList key_func_; diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc index c0bc507ec0..7a833668ac 100644 --- a/tensorflow/core/kernels/data/iterator_ops.cc +++ b/tensorflow/core/kernels/data/iterator_ops.cc @@ -659,6 +659,115 @@ class ToSingleElementOp : public AsyncOpKernel { BackgroundWorker background_worker_; }; +class ReduceDatasetOp : public AsyncOpKernel { + public: + explicit ReduceDatasetOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx), + background_worker_( + ctx->env(), + strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism", + &use_inter_op_parallelism_)); + } + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + // The call to `iterator->GetNext()` may block and depend on an + // inter-op thread pool thread, so we issue the call from the + // owned thread pool. + background_worker_.Schedule([this, ctx, done]() { + DatasetBase* dataset; + OP_REQUIRES_OK_ASYNC( + ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done); + OpInputList inputs; + OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs), + done); + std::vector<Tensor> state(inputs.begin(), inputs.end()); + + std::unique_ptr<CapturedFunction> captured_func; + OP_REQUIRES_OK_ASYNC( + ctx, + CapturedFunction::Create(reduce_func_, ctx, "other_arguments", + use_inter_op_parallelism_, &captured_func), + done); + + IteratorContext iter_ctx(ctx); + OP_REQUIRES_OK_ASYNC(ctx, captured_func->Instantiate(&iter_ctx), done); + + std::unique_ptr<IteratorBase> iterator; + OP_REQUIRES_OK_ASYNC( + ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator), + done); + + // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to + // avoid destruction races. + IteratorBase* raw_iterator = iterator.release(); + auto cleanup = gtl::MakeCleanup([raw_iterator, done] { + delete raw_iterator; + done(); + }); + + // Iterate through the input dataset. + Status status; + while (true) { + std::vector<Tensor> next_input_element; + bool end_of_input; + status = raw_iterator->GetNext(&iter_ctx, &next_input_element, + &end_of_input); + if (!status.ok() || end_of_input) { + break; + } + + // Run the reduce function to update the current state. + std::vector<Tensor> args; + args.reserve(state.size() + next_input_element.size()); + std::copy(state.begin(), state.end(), std::back_inserter(args)); + std::copy(next_input_element.begin(), next_input_element.end(), + std::back_inserter(args)); + + std::vector<Tensor> reduce_func_output; + status = + captured_func->Run(&iter_ctx, std::move(args), &reduce_func_output); + if (!status.ok()) { + break; + } + std::swap(reduce_func_output, state); + } + + if (!status.ok()) { + ctx->SetStatus(status); + return; + } + for (int i = 0; i < state.size(); ++i) { + OP_REQUIRES_ASYNC( + ctx, state[i].dtype() == output_types_[i], + errors::InvalidArgument( + "The result does not match the expected type for component ", i, + ". Expected: ", DataTypeString(output_types_[i]), + ". Actual: ", DataTypeString(state[i].dtype()), "."), + done); + OP_REQUIRES_ASYNC( + ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()), + errors::InvalidArgument( + "The result does not match the expected shape for component ", + i, ". Expected: ", output_shapes_[i].DebugString(), + ". Actual: ", state[i].shape().DebugString(), "."), + done); + ctx->set_output(i, state[i]); + } + }); + } + + private: + NameAttrList reduce_func_; + DataTypeVector output_types_; + std::vector<PartialTensorShape> output_shapes_; + bool use_inter_op_parallelism_; + BackgroundWorker background_worker_; +}; + class OneShotIteratorOp : public AsyncOpKernel { public: explicit OneShotIteratorOp(OpKernelConstruction* ctx) @@ -1146,6 +1255,8 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU), AnonymousIteratorHandleOp); REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU), ToSingleElementOp); +REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU), + ReduceDatasetOp); REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU), OneShotIteratorOp); REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU), diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc index 2bbf4af664..f45a239793 100644 --- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/inplace_ops_functor.h" #include "tensorflow/core/lib/core/blocking_counter.h" #include "tensorflow/core/lib/gtl/cleanup.h" @@ -29,6 +30,7 @@ limitations under the License. #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cpu_info.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -37,8 +39,14 @@ namespace { // See documentation in ../ops/dataset_ops.cc for a high-level // description of the following op. +// TODO(b/116852688): Make coordination between the performance model and this +// transformation more robust. class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { public: + using MapAndBatchIteratorFunction = + std::function<void(IteratorContext*, const string&, std::vector<Tensor>, + std::shared_ptr<std::vector<Tensor>>, StatusCallback)>; + explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx), op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) { @@ -89,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments", &captured_func)); - *output = new Dataset(ctx, input, batch_size, num_parallel_calls, - drop_remainder, output_types_, output_shapes_, func_, - std::move(captured_func), &ctx->eigen_cpu_device()); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapAndBatchIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::shared_ptr<std::vector<Tensor>> out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(), + std::move(done), prefix); + }; + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::shared_ptr<std::vector<Tensor>> out_tensors, + StatusCallback done) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + + *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls, + drop_remainder, output_types_, output_shapes_, + std::move(captured_func), &ctx->eigen_cpu_device(), + std::move(map_func)); } private: class Dataset : public DatasetBase { public: - Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size, + Dataset(OpKernelContext* ctx, const DatasetBase* input, + const NameAttrList& func, int64 batch_size, int64 num_parallel_calls, bool drop_remainder, const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, - const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, - const Eigen::ThreadPoolDevice* device) + const Eigen::ThreadPoolDevice* device, + MapAndBatchIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), + func_(func), batch_size_(batch_size), num_parallel_calls_(num_parallel_calls), drop_remainder_(drop_remainder), output_types_(output_types), output_shapes_(output_shapes), - map_fn_(func), captured_func_(std::move(captured_func)), - device_(device) { + device_(device), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -121,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")}, + map_func_); } const DataTypeVector& output_dtypes() const override { @@ -141,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name())); + TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name())); Node* input_graph_node = nullptr; TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); Node* batch_size_node; @@ -163,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { other_arguments_types.emplace_back(t.dtype()); } AttrValue f; - b->BuildAttrValue(map_fn_, &f); + b->BuildAttrValue(func_, &f); AttrValue other_arguments_types_attr; b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr); @@ -183,31 +234,35 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) + explicit Iterator(const Params& params, + MapAndBatchIteratorFunction map_func) : DatasetIterator<Dataset>(params), - num_parallel_calls_(params.dataset->num_parallel_calls_) {} + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, cond_var_)), + map_func_(std::move(map_func)) {} ~Iterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); AddConstantParameter(ctx, "batch_size", dataset()->batch_size_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; - AddTunableParameter(ctx, "parallelism", - &num_parallel_calls_ /* value */, 1 /* min */, - port::NumSchedulableCPUs() /* max */, &cond_var_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + port::NumSchedulableCPUs()); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } TF_RETURN_IF_ERROR( dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -219,27 +274,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { bool* end_of_sequence) override { std::shared_ptr<BatchResult> result; { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (batch_results_.empty() || batch_results_.front()->num_calls > 0) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } std::swap(result, batch_results_.front()); batch_results_.pop_front(); - cond_var_.notify_all(); + cond_var_->notify_all(); } return ProcessResult(ctx, result, out_tensors, end_of_sequence); } protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -255,7 +310,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); TF_RETURN_IF_ERROR( reader->ReadScalar(full_name("call_counter"), &call_counter_)); @@ -293,55 +348,17 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { int64 num_calls; // access guarded by owner's mutex }; - void Callback(const std::shared_ptr<IteratorContext>& ctx, - const std::shared_ptr<BatchResult>& result, - const std::shared_ptr<std::vector<Tensor>>& return_values, - int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) { - result->UpdateStatus(status); - if (status.ok()) { - EnsureOutputAllocated(ctx, result, return_values); - for (size_t i = 0; i < return_values->size(); ++i) { - const Tensor& tensor = return_values->at(i); - Tensor* batch = &(result->output)[i]; - if (tensor.NumElements() != - (batch->NumElements() / batch->dim_size(0))) { - TensorShape batch_shape = batch->shape(); - batch_shape.RemoveDim(0); - result->UpdateStatus(errors::InvalidArgument( - "Cannot add tensor to the batch: number of elements does not " - "match. Shapes are: [tensor]: ", - tensor.shape().DebugString(), - ", [batch]: ", batch_shape.DebugString())); - break; - } - // TODO(mrry): Add a version of DoParallelConcat that allows us to - // move `tensor` where possible, to speed up string tensor batching. - Status copy_status = ::tensorflow::functor::DoParallelConcat( - *dataset()->device_, tensor, offset, batch); - if (!copy_status.ok()) { - result->UpdateStatus(copy_status); - break; - } - } - { - mutex_lock l(result->mu); - result->num_elements++; - } - } - CallCompleted(result); - } - void CallCompleted(const std::shared_ptr<BatchResult>& result) - LOCKS_EXCLUDED(mu_) { - mutex_lock l(mu_); + LOCKS_EXCLUDED(*mu_) { + mutex_lock l(*mu_); num_calls_--; result->num_calls--; - cond_var_.notify_all(); + cond_var_->notify_all(); } void CallFunction(std::shared_ptr<IteratorContext> ctx, const std::shared_ptr<BatchResult>& result, - int64 offset) LOCKS_EXCLUDED(mu_) { + int64 offset) LOCKS_EXCLUDED(*mu_) { // Get the next input element. std::vector<Tensor> input_element; bool end_of_input; @@ -359,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { return; } - // Call `captured_func_(input_element)`, using `Callback` to store the - // result in `result`. - (*ctx->runner())(std::bind( - [this, result, offset](std::shared_ptr<IteratorContext> ctx, - std::vector<Tensor> input_element) { - std::shared_ptr<std::vector<Tensor>> return_values( - new std::vector<Tensor>()); - dataset()->captured_func_->RunAsync( - ctx.get(), std::move(input_element), return_values.get(), - [this, ctx, result, return_values, offset](Status status) { - Callback(ctx, result, return_values, offset, status); - }, - prefix()); - }, - ctx, std::move(input_element))); + std::shared_ptr<std::vector<Tensor>> return_values = + std::make_shared<std::vector<Tensor>>(); + auto done = [this, ctx, result, return_values, offset](Status status) { + result->UpdateStatus(status); + if (status.ok()) { + EnsureOutputAllocated(ctx, result, return_values); + for (size_t i = 0; i < return_values->size(); ++i) { + const Tensor& tensor = return_values->at(i); + Tensor* batch = &(result->output)[i]; + if (tensor.NumElements() != + (batch->NumElements() / batch->dim_size(0))) { + TensorShape batch_shape = batch->shape(); + batch_shape.RemoveDim(0); + result->UpdateStatus(errors::InvalidArgument( + "Cannot add tensor to the batch: number of elements does " + "not match. Shapes are: [tensor]: ", + tensor.shape().DebugString(), + ", [batch]: ", batch_shape.DebugString())); + break; + } + // TODO(mrry): Add a version of DoParallelConcat that allows us to + // move `tensor` where possible, to speed up string tensor + // batching. + Status copy_status = ::tensorflow::functor::DoParallelConcat( + *dataset()->device_, tensor, offset, batch); + if (!copy_status.ok()) { + result->UpdateStatus(copy_status); + break; + } + } + { + mutex_lock l(result->mu); + result->num_elements++; + } + } + CallCompleted(result); + }; + + // Apply the map function on `input_element`, storing the result in + // `return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + std::move(return_values), std::move(done)); } Status CopyPartialBatch(Tensor* output, const Tensor& value, @@ -398,9 +442,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + auto ctx_copy = std::make_shared<IteratorContext>(*ctx); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&Iterator::RunnerThread, this, ctx_copy))); @@ -474,14 +518,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls; RecordStart(ctx.get()); auto stop_cleanup = gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); }); - new_calls.reserve(num_parallel_calls_); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { - int64 num_parallel_calls = num_parallel_calls_; + new_calls.reserve(num_parallel_calls_->value); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_->value; int64 max_batch_results = (num_parallel_calls + dataset()->batch_size_ - 1) / dataset()->batch_size_; @@ -492,10 +536,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { }; while (true) { { - mutex_lock l(mu_); + mutex_lock l(*mu_); while (!cancelled_ && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } @@ -505,8 +549,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { while (!busy()) { if (call_counter_ % dataset()->batch_size_ == 0) { - batch_results_.emplace_back( - new BatchResult(dataset()->batch_size_)); + batch_results_.push_back( + std::make_shared<BatchResult>(dataset()->batch_size_)); } int64 offset = call_counter_++ % dataset()->batch_size_; new_calls.emplace_back(batch_results_.back(), offset); @@ -522,8 +566,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader, - size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) { - batch_results_.emplace_back(new BatchResult(dataset()->batch_size_)); + size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { + batch_results_.push_back( + std::make_shared<BatchResult>(dataset()->batch_size_)); std::shared_ptr<BatchResult> result = batch_results_.back(); string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -567,7 +612,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status ReadStatus(IteratorStateReader* reader, const string& prefix, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name(strings::StrCat(prefix, "_code")), &code_int)); @@ -585,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status WriteBatchResult(IteratorStateWriter* writer, size_t index) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { std::shared_ptr<BatchResult> result = batch_results_[index]; string prefix = strings::StrCat("batch_results_", index); mutex_lock l(result->mu); @@ -626,7 +671,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { } Status WriteStatus(IteratorStateWriter* writer, const string& prefix, - const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")), static_cast<int64>(status.code()))); @@ -640,24 +685,26 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { // Used for coordination between the main thread, the runner thread, and // the callback threads. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread, the runner thread, and // the callback threads. In particular, the runner thread should only - // schedule new calls when the number of in-flight calls is less than the - // user specified level of parallelism and there are slots available in - // the `batch_results_` buffer. - condition_variable cond_var_; + // schedule new calls when the number of in-flight calls is less than + // `num_parallel_calls_->value` and there are slots available in the + // `batch_results_` buffer. + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; + const MapAndBatchIteratorFunction map_func_; + // Counts the number of outstanding calls for this batch. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; // Counts the total number of calls. - int64 call_counter_ GUARDED_BY(mu_) = 0; + int64 call_counter_ GUARDED_BY(*mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; // Buffer for storing the (intermediate) batch results. - std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; + std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + bool cancelled_ GUARDED_BY(*mu_) = false; }; const DatasetBase* const input_; @@ -667,9 +714,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel { const bool drop_remainder_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; - const NameAttrList map_fn_; const std::unique_ptr<CapturedFunction> captured_func_; const Eigen::ThreadPoolDevice* device_; // not owned + const MapAndBatchIteratorFunction map_func_; }; const int op_version_; diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc index f112e1dc43..6b6ffabf4f 100644 --- a/tensorflow/core/kernels/data/map_dataset_op.cc +++ b/tensorflow/core/kernels/data/map_dataset_op.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { @@ -28,6 +30,9 @@ namespace { class MapDatasetOp : public UnaryDatasetOpKernel { public: + using MapIteratorFunction = std::function<Status( + IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>; + explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + MapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors) { + return raw_captured_func->Run(ctx, std::move(args), out_tensors); + }; + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, std::vector<Tensor> args, + std::vector<Tensor>* out_tensors) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + return Status::OK(); + }; + } + *output = new Dataset(ctx, input, func_, std::move(captured_func), - output_types_, output_shapes_); + output_types_, output_shapes_, std::move(map_func)); } private: @@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const NameAttrList& func, std::unique_ptr<CapturedFunction> captured_func, const DataTypeVector& output_types, - const std::vector<PartialTensorShape>& output_shapes) + const std::vector<PartialTensorShape>& output_shapes, + MapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), captured_func_(std::move(captured_func)), output_types_(output_types), - output_shapes_(output_shapes) { + output_shapes_(output_shapes), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - return std::unique_ptr<IteratorBase>( - new Iterator({this, strings::StrCat(prefix, "::Map")})); + return MakeUnique<Iterator>( + Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_); } const DataTypeVector& output_dtypes() const override { @@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: class Iterator : public DatasetIterator<Dataset> { public: - explicit Iterator(const Params& params) - : DatasetIterator<Dataset>(params) {} + explicit Iterator(const Params& params, MapIteratorFunction map_func) + : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {} Status Initialize(IteratorContext* ctx) override { TF_RETURN_IF_ERROR( @@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { return Status::OK(); } - // TODO(mrry): Avoid blocking a threadpool thread. We will need to - // stack-rip the iterators and use async kernels. - Status s = - dataset()->captured_func_->Run(ctx, std::move(args), out_tensors); + Status s = map_func_(ctx, args, out_tensors); if (errors::IsOutOfRange(s)) { // `f` may deliberately raise `errors::OutOfRange` to indicate // that we should terminate the iteration early. @@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { private: std::unique_ptr<IteratorBase> input_impl_; + const MapIteratorFunction map_func_; }; const DatasetBase* const input_; @@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel { const std::unique_ptr<CapturedFunction> captured_func_; const DataTypeVector output_types_; const std::vector<PartialTensorShape> output_shapes_; + const MapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc index 6657f2b2b3..705b0393de 100644 --- a/tensorflow/core/kernels/data/map_defun_op.cc +++ b/tensorflow/core/kernels/data/map_defun_op.cc @@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel { ~MapDefunOp() override {} - Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) { - // Validates inputs and gets the size of their leading dimension. - *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { - return errors::InvalidArgument( - "All inputs must have rank at least 1. Input ", i, - " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != *batch_size) { - return errors::InvalidArgument( - "All inputs must have the same dimension 0. Input ", i, - " has leading dimension ", ctx->input(i).dim_size(0), - ", while all previous inputs have leading dimension ", batch_size); - } - } - return Status::OK(); - } - void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { ComputeOptions* compute_opts = nullptr; @@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel { // all calls to the function are complete. This struct also encapsulates // all the components that need to be passed to each MapFunctionCallFrame. - const std::vector<Tensor> args; + OpInputList args; const std::vector<TensorShape> arg_shapes; + OpInputList captured_inputs; const int64 batch_size; // Output of a compute call @@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel { // Create a copy of output_shapes because every `Compute` may expect a // different output shape. - ComputeOptions(std::vector<Tensor> args, + ComputeOptions(OpInputList args, OpInputList captured_inputs, std::vector<TensorShape> arg_shapes, int64 batch_size, const std::vector<PartialTensorShape>& output_shapes_attr) - : args(std::move(args)), + : args(args), arg_shapes(std::move(arg_shapes)), + captured_inputs(captured_inputs), batch_size(batch_size), output_shapes(output_shapes_attr) {} }; // Get inputs to Compute and check that they are valid. Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) { - int64 batch_size = - ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1; + OpInputList arguments; + TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments)); + OpInputList captured_inputs; + TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs)); + + int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1; - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - if (ctx->input(i).dims() == 0) { + for (size_t i = 0; i < arguments.size(); ++i) { + if (arguments[i].dims() == 0) { return errors::InvalidArgument( "All inputs must have rank at least 1. Input ", i, " has a rank of 0."); - } else if (ctx->input(i).dim_size(0) != batch_size) { + } else if (arguments[i].dim_size(0) != batch_size) { return errors::InvalidArgument( "All inputs must have the same dimension 0. Input ", i, " has leading dimension ", ctx->input(i).dim_size(0), @@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel { } } - std::vector<Tensor> args; std::vector<TensorShape> arg_shapes; - args.reserve(ctx->num_inputs()); - arg_shapes.reserve(ctx->num_inputs()); + arg_shapes.reserve(arguments.size()); - for (size_t i = 0; i < ctx->num_inputs(); ++i) { - args.push_back(ctx->input(i)); - arg_shapes.push_back(ctx->input(i).shape()); + for (size_t i = 0; i < arguments.size(); ++i) { + arg_shapes.push_back(arguments[i].shape()); arg_shapes.at(i).RemoveDim(0); } - *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes), - batch_size, output_shapes_); + *compute_opts = + new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes), + batch_size, output_shapes_); return Status::OK(); } @@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel { } Status GetArg(int index, Tensor* val) const override { - if (index < 0 || index >= compute_opts_->args.size()) { + if (index < 0 || index >= compute_opts_->args.size() + + compute_opts_->captured_inputs.size()) { return errors::InvalidArgument( "Mismatch in number of function inputs."); } + + if (index >= compute_opts_->args.size()) { + // The function is calling for a captured input + *val = + compute_opts_->captured_inputs[index - compute_opts_->args.size()]; + return Status::OK(); + } + bool result = - val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1), + val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1), compute_opts_->arg_shapes.at(index)); if (!result) { return errors::Internal("GetArg failed."); @@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel { // Ensure alignment *val = tensor::DeepCopy(*val); } - return Status::OK(); } diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc index 5f143967d9..d909b9e9d3 100644 --- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc +++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc @@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase { void Reset() LOCKS_EXCLUDED(mu_) { { mutex_lock l(mu_); - if (background_thread_finished_) { - return; - } - - cancelled_ = true; - // Wake up the background thread. - for (int i = 0; i < size_; ++i) { - buffer_[i].cond_var.notify_all(); - } + if (!background_thread_finished_) { + cancelled_ = true; + // Wake up the background thread. + for (int i = 0; i < size_; ++i) { + buffer_[i].cond_var.notify_all(); + } - // Make sure background thread has finished first. - while (!background_thread_finished_) { - shutdown_cond_var_.wait(l); + // Make sure background thread has finished first. + while (!background_thread_finished_) { + shutdown_cond_var_.wait(l); + } } } RunPendingCallbacks(); @@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase { buffer_[shard_num].cond_var.notify_all(); } } else { - if (background_thread_finished_) { + if (end_of_iterator_) { produced_output = true; elem.end_of_sequence = true; } else { @@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase { while (!buffer_[i].callbacks.empty()) { if (buffer_[i].data.empty()) { HostBufferElement elem; - elem.status = - errors::Cancelled("Cancelled and buffer not filled."); + if (end_of_iterator_) { + elem.end_of_sequence = true; + } else { + elem.status = + errors::Cancelled("Cancelled and buffer not filled."); + } cancellation_elements.push_back(std::move(elem)); } else { cancellation_elements.push_back( @@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase { { mutex_lock l(mu_); background_thread_finished_ = true; + end_of_iterator_ = true; shutdown_cond_var_.notify_all(); } RunPendingCallbacks(); @@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase { std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_); bool background_thread_finished_ GUARDED_BY(mu_) = false; bool background_thread_started_ GUARDED_BY(mu_) = false; + bool end_of_iterator_ GUARDED_BY(mu_) = false; bool cancelled_ GUARDED_BY(mu_) = false; condition_variable shutdown_cond_var_ GUARDED_BY(mu_); diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc index d5b725eac9..1cb7caa738 100644 --- a/tensorflow/core/kernels/data/optimize_dataset_op.cc +++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc @@ -154,12 +154,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { : DatasetIterator<Dataset>(params) {} Status Initialize(IteratorContext* ctx) override { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + IteratorContext::Params params = ctx->params(); params.lib = dataset()->lib_; - params.allocator_getter = ctx->allocator_getter(); return dataset()->optimized_input_->MakeIterator( IteratorContext(params), prefix(), &input_impl_); } @@ -167,14 +163,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel { Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors, bool* end_of_sequence) override { - IteratorContext::Params params; - params.env = ctx->env(); - params.runner = *(ctx->runner()); - params.stats_aggregator_getter = ctx->stats_aggregator_getter(); + IteratorContext::Params params = ctx->params(); params.lib = dataset()->lib_; - params.allocator_getter = ctx->allocator_getter(); - IteratorContext iter_ctx(params); - return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence); + return input_impl_->GetNext(IteratorContext(params), out_tensors, + end_of_sequence); } protected: diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 2e6e0465f7..6b6b3d6ab9 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU), // The above design choices were made with automated optimizations in mind, // isolating the degree of parallelism as the single tunable knob of this // implementation. +// +// TODO(b/116852688): Make coordination between the performance model and this +// transformation more robust. class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { public: explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx) @@ -1214,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { public: explicit Iterator(const Params& params) : DatasetIterator<Dataset>(params), - num_parallel_calls_(params.dataset->num_parallel_calls_), + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + params.dataset->num_parallel_calls_, mu_, cond_var_)), args_list_(params.dataset->cycle_length_), current_elements_(params.dataset->cycle_length_), element_in_use_(params.dataset->cycle_length_, false), @@ -1224,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { false /* low_latency_hint */)) {} ~Iterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; - AddTunableParameter(ctx, "parallelism", - &num_parallel_calls_ /* value */, 1 /* min */, - dataset()->cycle_length_ /* max */, &cond_var_); + mutex_lock l(*mu_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + dataset()->cycle_length_); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_); TF_RETURN_IF_ERROR( @@ -1256,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { std::shared_ptr<InvocationResult> result; do { { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty() && (!end_of_input_ || num_open_ > 0)) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } if (!invocation_results_.empty()) { @@ -1271,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { *end_of_sequence = true; return Status::OK(); } - cond_var_.notify_all(); + cond_var_->notify_all(); } RecordStop(ctx); result->notification.WaitForNotification(); @@ -1287,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); @@ -1328,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64 invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar( @@ -1381,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { }; void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx)); runner_thread_.reset(ctx->env()->StartThread( @@ -1398,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void FetchOutputs( const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index, const std::vector<std::shared_ptr<InvocationResult>>& results) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); bool end_of_input = false; @@ -1421,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { if (end_of_input) { current_elements_[cycle_index].reset(); } - mutex_lock l(mu_); + mutex_lock l(*mu_); element_in_use_[cycle_index] = false; num_calls_--; if (end_of_input) { args_list_[cycle_index].clear(); num_open_--; } - cond_var_.notify_all(); + cond_var_->notify_all(); } // Method responsible for 1) creating iterators out of input elements, 2) @@ -1439,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { return element_in_use_[cycle_index_] || - num_calls_ >= num_parallel_calls_ || + num_calls_ >= num_parallel_calls_->value || invocation_results_.size() >= dataset()->cycle_length_ * dataset()->block_length_; }; while (true) { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait until this thread is cancelled, the end of input has been // reached, or the cycle element at the `cycle_index_` position is // not in use and there is space in the `invocation_results_` queue. while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } @@ -1506,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_; } - cond_var_.notify_all(); + cond_var_->notify_all(); } } Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, const Status& status) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR(writer->WriteScalar( CodeKey(index), static_cast<int64>(status.code()))); if (!status.ok()) { @@ -1523,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); error::Code code = static_cast<error::Code>(code_int); @@ -1550,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { } Status WriteCurrentElements(IteratorStateWriter* writer) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (current_elements_[idx]) { TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx])); @@ -1569,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { Status ReadCurrentElements(IteratorContext* ctx, IteratorStateReader* reader) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { for (int idx = 0; idx < current_elements_.size(); idx++) { if (reader->Contains( full_name(strings::StrCat("args_size[", idx, "]")))) { @@ -1597,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // Used for coordination between the main thread, the runner thread, and // the worker threads. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread, the runner thread, and // the worker threads. In particular, the runner thread should only @@ -1605,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel { // user specified level of parallelism, there are slots available in the // `invocation_results_` buffer, the current cycle element is not in use, // and there are elements left to be fetched. - condition_variable cond_var_; + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Iterator for input elements. - std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_); + std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_); // Identifies current cycle element. int64 cycle_index_ = 0; // Arguments for creating an iterator for cycle elements. - std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_); + std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_); // Iterators for the current cycle elements. Concurrent access is // protected by `element_in_use_`. std::vector<std::unique_ptr<IteratorBase>> current_elements_; // Identifies cycle elements that are in use by worker threads. - std::vector<bool> element_in_use_ GUARDED_BY(mu_); + std::vector<bool> element_in_use_ GUARDED_BY(*mu_); // Buffer for storing the invocation results. std::deque<std::shared_ptr<InvocationResult>> invocation_results_ - GUARDED_BY(mu_); + GUARDED_BY(*mu_); // Identifies whether end of input has been reached. - bool end_of_input_ GUARDED_BY(mu_) = false; + bool end_of_input_ GUARDED_BY(*mu_) = false; // Identifies the number of open iterators. - int64 num_open_ GUARDED_BY(mu_) = 0; + int64 num_open_ GUARDED_BY(*mu_) = 0; // Identifies the number of outstanding calls. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<thread::ThreadPool> thread_pool_; - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); // Identifies whether background activity should be cancelled. - bool cancelled_ GUARDED_BY(mu_) = false; + bool cancelled_ GUARDED_BY(*mu_) = false; }; const DatasetBase* const input_; diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc index 6abe6c8338..3a14924fba 100644 --- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/captured_function.h" #include "tensorflow/core/kernels/data/dataset.h" +#include "tensorflow/core/kernels/data/dataset_utils.h" #include "tensorflow/core/kernels/data/parallel_map_iterator.h" #include "tensorflow/core/lib/core/error_codes.pb.h" #include "tensorflow/core/lib/random/random.h" @@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { use_inter_op_parallelism_, &captured_func)); + std::vector<int> indices; + OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices)); + + ParallelMapIteratorFunction map_func; + CapturedFunction* raw_captured_func = captured_func.get(); + if (indices.empty()) { + map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors, + StatusCallback done) { + raw_captured_func->RunAsync(ctx, std::move(args), out_tensors, + std::move(done), prefix); + }; + if (!use_inter_op_parallelism_) { + map_func = [map_func](IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, + std::vector<Tensor>* out_tensors, + StatusCallback done) { + (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args), + out_tensors, std::move(done))); + }; + } + } else { + std::vector<bool> can_move = ComputeMoveVector(indices); + map_func = [raw_captured_func, indices, can_move]( + IteratorContext* ctx, const string& prefix, + std::vector<Tensor> args, std::vector<Tensor>* out_tensors, + StatusCallback done) { + const std::vector<Tensor>& captured_inputs = + raw_captured_func->captured_inputs(); + size_t num_args = args.size(); + for (size_t i = 0; i < indices.size(); ++i) { + if (indices[i] < num_args) { + if (can_move[i]) { + out_tensors->push_back(std::move(args[indices[i]])); + } else { + out_tensors->push_back(args[indices[i]]); + } + } else { + out_tensors->push_back(captured_inputs[indices[i] - num_args]); + } + } + done(Status::OK()); + }; + } + *output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_, output_shapes_, use_inter_op_parallelism_, - std::move(captured_func)); + std::move(captured_func), std::move(map_func)); } private: @@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const DataTypeVector& output_types, const std::vector<PartialTensorShape>& output_shapes, bool use_inter_op_parallelism, - std::unique_ptr<CapturedFunction> captured_func) + std::unique_ptr<CapturedFunction> captured_func, + ParallelMapIteratorFunction map_func) : DatasetBase(DatasetContext(ctx)), input_(input), func_(func), @@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { output_types_(output_types), output_shapes_(output_shapes), use_inter_op_parallelism_(use_inter_op_parallelism), - captured_func_(std::move(captured_func)) { + captured_func_(std::move(captured_func)), + map_func_(std::move(map_func)) { input_->Ref(); } @@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { return captured_func_->Instantiate(ctx); }; - const string& new_prefix = strings::StrCat(prefix, "::ParallelMap"); - ParallelMapIteratorFunction map_func = - [this, new_prefix](IteratorContext* ctx, - std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - captured_func_->RunAsync(ctx, std::move(input_element), result, - std::move(done), new_prefix); - }; - if (!use_inter_op_parallelism_) { - map_func = [map_func]( - IteratorContext* ctx, std::vector<Tensor> input_element, - std::vector<Tensor>* result, StatusCallback done) { - (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element), - result, std::move(done))); - }; - } - - return NewParallelMapIterator({this, new_prefix}, input_, - std::move(init_func), std::move(map_func), - num_parallel_calls_); + return NewParallelMapIterator( + {this, strings::StrCat(prefix, "::ParallelMap")}, input_, + std::move(init_func), map_func_, num_parallel_calls_); } const DataTypeVector& output_dtypes() const override { @@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; const bool use_inter_op_parallelism_; const std::unique_ptr<CapturedFunction> captured_func_; + const ParallelMapIteratorFunction map_func_; }; DataTypeVector output_types_; diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc index ee20249bfe..ebf41925c9 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.cc +++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc @@ -22,11 +22,14 @@ limitations under the License. #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/util/ptr_util.h" namespace tensorflow { namespace data { namespace { +// TODO(b/116852688): Make coordination between the performance model and this +// transformation more robust. class ParallelMapIterator : public DatasetBaseIterator { public: explicit ParallelMapIterator( @@ -38,30 +41,32 @@ class ParallelMapIterator : public DatasetBaseIterator { input_dataset_(input_dataset), init_func_(std::move(init_func)), map_func_(std::move(map_func)), - num_parallel_calls_(num_parallel_calls) {} + mu_(std::make_shared<mutex>()), + cond_var_(std::make_shared<condition_variable>()), + num_parallel_calls_(std::make_shared<model::SharedState>( + num_parallel_calls, mu_, cond_var_)) {} ~ParallelMapIterator() override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Cancel the runner thread. cancelled_ = true; - cond_var_.notify_all(); + cond_var_->notify_all(); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } } Status Initialize(IteratorContext* ctx) override { - mutex_lock l(mu_); - if (num_parallel_calls_ == kAutoTune) { - num_parallel_calls_ = 1; + mutex_lock l(*mu_); + if (num_parallel_calls_->value == kAutoTune) { + num_parallel_calls_->value = 1; // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and // use it here for the maximum. - AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */, - 1 /* min */, port::NumSchedulableCPUs() /* max */, - &cond_var_); + AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1, + port::NumSchedulableCPUs()); } else { - AddConstantParameter(ctx, "parallelism", num_parallel_calls_); + AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value); } TF_RETURN_IF_ERROR( input_dataset_->MakeIterator(ctx, prefix(), &input_impl_)); @@ -75,16 +80,16 @@ class ParallelMapIterator : public DatasetBaseIterator { bool* end_of_sequence) override { std::shared_ptr<InvocationResult> result; { - mutex_lock l(mu_); + mutex_lock l(*mu_); EnsureRunnerThreadStarted(ctx); while (invocation_results_.empty()) { RecordStop(ctx); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx); } std::swap(result, invocation_results_.front()); invocation_results_.pop_front(); - cond_var_.notify_all(); + cond_var_->notify_all(); } RecordStop(ctx); result->notification.WaitForNotification(); @@ -94,28 +99,27 @@ class ParallelMapIterator : public DatasetBaseIterator { protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); // Wait for all in-flight calls to complete. while (num_calls_ > 0) { - cond_var_.wait(l); + cond_var_->wait(l); } CHECK_EQ(num_calls_, 0); TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"), invocation_results_.size())); for (size_t i = 0; i < invocation_results_.size(); i++) { - std::shared_ptr<InvocationResult> result = invocation_results_[i]; - TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status)); + const auto& result = *(invocation_results_[i]); + TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status)); TF_RETURN_IF_ERROR(writer->WriteScalar( full_name(strings::StrCat("invocation_results[", i, "].size")), - result->return_values.size())); - for (size_t j = 0; j < result->return_values.size(); j++) { - TF_RETURN_IF_ERROR( - writer->WriteTensor(full_name(strings::StrCat( - "invocation_results[", i, "][", j, "]")), - result->return_values[j])); + result.return_values.size())); + for (size_t j = 0; j < result.return_values.size(); j++) { + TF_RETURN_IF_ERROR(writer->WriteTensor( + full_name(strings::StrCat("invocation_results[", i, "][", j, "]")), + result.return_values[j])); } - if (result->end_of_input) { + if (result.end_of_input) { TF_RETURN_IF_ERROR(writer->WriteScalar( full_name( strings::StrCat("invocation_results[", i, "].end_of_input")), @@ -127,15 +131,15 @@ class ParallelMapIterator : public DatasetBaseIterator { Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); + mutex_lock l(*mu_); TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); int64 invocation_results_size; TF_RETURN_IF_ERROR(reader->ReadScalar( full_name("invocation_results.size"), &invocation_results_size)); for (size_t i = 0; i < invocation_results_size; i++) { - std::shared_ptr<InvocationResult> result(new InvocationResult()); - invocation_results_.push_back(result); - TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status)); + invocation_results_.push_back(std::make_shared<InvocationResult>()); + auto& result = *invocation_results_.back(); + TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status)); size_t num_return_values; { int64 size; @@ -151,17 +155,16 @@ class ParallelMapIterator : public DatasetBaseIterator { ": ", size, " is not a valid value of type size_t.")); } } - result->return_values.reserve(num_return_values); + result.return_values.reserve(num_return_values); for (size_t j = 0; j < num_return_values; j++) { - result->return_values.emplace_back(); - TF_RETURN_IF_ERROR( - reader->ReadTensor(full_name(strings::StrCat( - "invocation_results[", i, "][", j, "]")), - &result->return_values.back())); + result.return_values.emplace_back(); + TF_RETURN_IF_ERROR(reader->ReadTensor( + full_name(strings::StrCat("invocation_results[", i, "][", j, "]")), + &result.return_values.back())); } - result->end_of_input = reader->Contains(full_name( + result.end_of_input = reader->Contains(full_name( strings::StrCat("invocation_results[", i, "].end_of_input"))); - result->notification.Notify(); + result.notification.Notify(); } return Status::OK(); } @@ -175,9 +178,9 @@ class ParallelMapIterator : public DatasetBaseIterator { }; void EnsureRunnerThreadStarted(IteratorContext* ctx) - EXCLUSIVE_LOCKS_REQUIRED(mu_) { + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { if (!runner_thread_) { - std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx)); + auto ctx_copy = std::make_shared<IteratorContext>(*ctx); runner_thread_.reset(ctx->env()->StartThread( {}, "runner_thread", std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy))); @@ -185,18 +188,18 @@ class ParallelMapIterator : public DatasetBaseIterator { } void CallCompleted(const std::shared_ptr<InvocationResult>& result) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { { - mutex_lock l(mu_); + mutex_lock l(*mu_); num_calls_--; - cond_var_.notify_all(); + cond_var_->notify_all(); } result->notification.Notify(); } void CallFunction(const std::shared_ptr<IteratorContext>& ctx, const std::shared_ptr<InvocationResult>& result) - LOCKS_EXCLUDED(mu_) { + LOCKS_EXCLUDED(*mu_) { // Get the next input element. std::vector<Tensor> input_element; result->status = @@ -206,15 +209,15 @@ class ParallelMapIterator : public DatasetBaseIterator { return; } - // Call `func_(input_element)`, store the result in `result->return_values`, - // and notify `result->notification` to unblock a consumer. auto done = [this, result](Status status) { result->status.Update(status); CallCompleted(result); }; - map_func_(ctx.get(), std::move(input_element), &result->return_values, - std::move(done)); + // Apply the map function on `input_element`, storing the result in + // `result->return_values`, and invoking `done` when finished. + map_func_(ctx.get(), prefix(), std::move(input_element), + &result->return_values, std::move(done)); } Status ProcessResult(const std::shared_ptr<InvocationResult>& result, @@ -239,29 +242,29 @@ class ParallelMapIterator : public DatasetBaseIterator { RecordStart(ctx.get()); auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); }); std::vector<std::shared_ptr<InvocationResult>> new_calls; - new_calls.reserve(num_parallel_calls_); - auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool { - int64 num_parallel_calls = num_parallel_calls_; + new_calls.reserve(num_parallel_calls_->value); + auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool { + int64 num_parallel_calls = num_parallel_calls_->value; return num_calls_ >= num_parallel_calls || invocation_results_.size() >= num_parallel_calls; }; while (true) { { - mutex_lock l(mu_); + mutex_lock l(*mu_); while (!cancelled_ && busy()) { RecordStop(ctx.get()); - cond_var_.wait(l); + cond_var_->wait(l); RecordStart(ctx.get()); } if (cancelled_) { return; } while (!busy()) { - invocation_results_.emplace_back(new InvocationResult()); + invocation_results_.push_back(std::make_shared<InvocationResult>()); new_calls.push_back(invocation_results_.back()); num_calls_++; } - cond_var_.notify_all(); + cond_var_->notify_all(); } for (const auto& call : new_calls) { CallFunction(ctx, call); @@ -271,7 +274,8 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status WriteStatusLocked(IteratorStateWriter* writer, size_t index, - const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + const Status& status) + EXCLUSIVE_LOCKS_REQUIRED(*mu_) { TF_RETURN_IF_ERROR( writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code()))); if (!status.ok()) { @@ -282,7 +286,7 @@ class ParallelMapIterator : public DatasetBaseIterator { } Status ReadStatusLocked(IteratorStateReader* reader, size_t index, - Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) { int64 code_int; TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int)); error::Code code = static_cast<error::Code>(code_int); @@ -312,23 +316,23 @@ class ParallelMapIterator : public DatasetBaseIterator { const std::function<Status(IteratorContext*)> init_func_; const ParallelMapIteratorFunction map_func_; // Used for coordination between the main thread and the runner thread. - mutex mu_; + const std::shared_ptr<mutex> mu_; // Used for coordination between the main thread and the runner thread. In // particular, the runner thread should only schedule new calls when the // number of in-flight calls is less than the user specified level of // parallelism and there are slots available in the `invocation_results_` // buffer. - condition_variable cond_var_; + const std::shared_ptr<condition_variable> cond_var_; // Identifies the maximum number of parallel calls. - std::atomic<int64> num_parallel_calls_; + const std::shared_ptr<model::SharedState> num_parallel_calls_; // Counts the number of outstanding calls. - int64 num_calls_ GUARDED_BY(mu_) = 0; + int64 num_calls_ GUARDED_BY(*mu_) = 0; std::unique_ptr<IteratorBase> input_impl_; // Buffer for storing the invocation results. std::deque<std::shared_ptr<InvocationResult>> invocation_results_ - GUARDED_BY(mu_); - std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_); - bool cancelled_ GUARDED_BY(mu_) = false; + GUARDED_BY(*mu_); + std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_); + bool cancelled_ GUARDED_BY(*mu_) = false; }; } // namespace @@ -346,9 +350,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator( const DatasetBase* input_dataset, std::function<Status(IteratorContext*)> init_func, ParallelMapIteratorFunction map_func, int32 num_parallel_calls) { - return std::unique_ptr<IteratorBase>( - new ParallelMapIterator(params, input_dataset, std::move(init_func), - std::move(map_func), num_parallel_calls)); + return MakeUnique<ParallelMapIterator>( + params, input_dataset, std::move(init_func), std::move(map_func), + num_parallel_calls); } } // namespace data diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h index dc26c5cf25..813f13c9e4 100644 --- a/tensorflow/core/kernels/data/parallel_map_iterator.h +++ b/tensorflow/core/kernels/data/parallel_map_iterator.h @@ -30,7 +30,7 @@ namespace data { // 3. A `std::vector<Tensor>*` to which the function will write the result. // 4. A `StatusCallback` that should be invoked when the function is complete. using ParallelMapIteratorFunction = - std::function<void(IteratorContext*, std::vector<Tensor>, + std::function<void(IteratorContext*, const string&, std::vector<Tensor>, std::vector<Tensor>*, StatusCallback)>; // Returns a new iterator that applies `map_func` to the elements of diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc index c28c06da62..7de5ea8860 100644 --- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc +++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc @@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { std::unique_ptr<IteratorBase> MakeIteratorInternal( const string& prefix) const override { - auto map_fn = [this](IteratorContext* ctx, + auto map_fn = [this](IteratorContext* ctx, const string& prefix, std::vector<Tensor> input_element, std::vector<Tensor>* result, StatusCallback done) { (*ctx->runner())([this, ctx, input_element, result, done]() { @@ -253,7 +253,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { for (example::PerExampleFeatureStats feature_stats : example_result.feature_stats) { stats_aggregator->AddToHistogram( - strings::StrCat("record_stats", ":features"), + "features", {static_cast<double>(feature_stats.features_count)}); stats_aggregator->IncrementCounter( "features_count", "trainer", feature_stats.features_count); @@ -261,7 +261,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel { "feature_values_count", "trainer", feature_stats.feature_values_count); stats_aggregator->AddToHistogram( - strings::StrCat("record_stats", ":feature-values"), + "feature-values", {static_cast<double>(feature_stats.feature_values_count)}); } } diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc index dbe31f37b8..2a911aa368 100644 --- a/tensorflow/core/kernels/data/scan_dataset_op.cc +++ b/tensorflow/core/kernels/data/scan_dataset_op.cc @@ -32,8 +32,7 @@ namespace { class ScanDatasetOp : public UnaryDatasetOpKernel { public: explicit ScanDatasetOp(OpKernelConstruction* ctx) - : UnaryDatasetOpKernel(ctx), - graph_def_version_(ctx->graph_def_version()) { + : UnaryDatasetOpKernel(ctx) { OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_)); OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_)); @@ -258,7 +257,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel { const std::vector<PartialTensorShape> output_shapes_; }; - const int graph_def_version_; DataTypeVector state_types_; DataTypeVector output_types_; std::vector<PartialTensorShape> output_shapes_; diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc index f5314f7a75..c09a73fff1 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc @@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include <memory> #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/stats_aggregator.h" #include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/graph/graph_def_builder.h" #include "tensorflow/core/kernels/data/dataset.h" #include "tensorflow/core/lib/random/random.h" @@ -22,6 +24,52 @@ namespace tensorflow { namespace data { namespace { +class StatsAggregatorWithTagAndPrefix : public StatsAggregator { + public: + StatsAggregatorWithTagAndPrefix( + std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag, + const string& prefix) + : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {} + + void AddToHistogram(const string& name, + gtl::ArraySlice<double> values) override { + if (!tag_.empty()) { + wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values); + } else { + wrapped_->AddToHistogram(name, values); + } + } + + void AddScalar(const string& name, float value) override { + if (!tag_.empty()) { + wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value); + } else { + wrapped_->AddScalar(name, value); + } + } + + void EncodeToProto(Summary* out_summary) override { + wrapped_->EncodeToProto(out_summary); + } + + void IncrementCounter(const string& name, const string& label, + int64 val) override { + if (!prefix_.empty()) { + wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label, + val); + } else { + wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label, + val); + } + } + + private: + std::shared_ptr<StatsAggregator> wrapped_; + string tag_; + string prefix_; + TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix); +}; + class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { public: explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx) @@ -33,18 +81,28 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &stats_aggregator_resource)); core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource); + string tag; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag)); + string prefix; + OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix)); - *output = new Dataset(ctx, input, stats_aggregator_resource); + *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource, + tag, prefix); } private: class Dataset : public DatasetBase { public: explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, - StatsAggregatorResource* stats_aggregator_resource) + const Tensor& resource_handle, + StatsAggregatorResource* stats_aggregator_resource, + const string& tag, const string& prefix) : DatasetBase(DatasetContext(ctx)), input_(input), - stats_aggregator_resource_(stats_aggregator_resource) { + resource_handle_(resource_handle), + stats_aggregator_resource_(stats_aggregator_resource), + tag_(tag), + prefix_(prefix) { input_->Ref(); stats_aggregator_resource_->Ref(); } @@ -75,8 +133,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { Status AsGraphDefInternal(SerializationContext* ctx, DatasetGraphDefBuilder* b, Node** output) const override { - return errors::Unimplemented("%s does not support serialization", - DebugString()); + Node* input_graph_node = nullptr; + TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node)); + Node* resource_handle_node = nullptr; + TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node)); + Node* tag_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node)); + Node* prefix_node = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node)); + TF_RETURN_IF_ERROR(b->AddDataset( + this, {input_graph_node, resource_handle_node, tag_node, prefix_node}, + output)); + return Status::OK(); } private: @@ -98,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { IteratorContext::Params params; params.env = ctx->env(); params.runner = *(ctx->runner()); - params.stats_aggregator_getter = [stats_aggregator_resource]() { - return stats_aggregator_resource->stats_aggregator(); - }; + params.stats_aggregator = std::shared_ptr<StatsAggregator>( + new StatsAggregatorWithTagAndPrefix( + stats_aggregator_resource->stats_aggregator(), dataset()->tag_, + dataset()->prefix_)); params.lib = ctx->lib(); params.function_library = ctx->function_library(); params.allocator_getter = ctx->allocator_getter(); @@ -111,16 +180,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { protected: Status SaveInternal(IteratorStateWriter* writer) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_)); - return Status::OK(); + return errors::Unimplemented(dataset()->DebugString(), + " does not support checkpointing"); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - mutex_lock l(mu_); - TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_)); - return Status::OK(); + return errors::Unimplemented(dataset()->DebugString(), + " does not support checkpointing"); } private: @@ -129,7 +196,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { }; const DatasetBase* const input_; + const Tensor resource_handle_; StatsAggregatorResource* stats_aggregator_resource_; + string tag_; + string prefix_; }; }; diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc index a7ded67876..2d51467616 100644 --- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc +++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc @@ -82,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator { auto counters_map = get_counters_map(); if (counters_map->find(name) == counters_map->end()) { counters_map->emplace( - name, monitoring::Counter<1>::New( - /*streamz name*/ "/tensorflow/" + name, - /*streamz description*/ - name + " generated or consumed by the component.", - /*streamz label name*/ "component_descriptor")); + name, + monitoring::Counter<1>::New( + /*streamz name*/ name, + /*streamz description*/ + strings::StrCat(name, " generated or consumed by the component."), + /*streamz label name*/ "component_descriptor")); } counters_map->at(name)->GetCell(label)->IncrementBy(val); } diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc index 81c432b938..74908994b4 100644 --- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc @@ -41,11 +41,16 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel { : DatasetBase(DatasetContext(ctx)), input_(input) { input_->Ref(); for (const PartialTensorShape& shape : input->output_shapes()) { - gtl::InlinedVector<int64, 4> partial_dim_sizes; - for (int i = 1; i < shape.dims(); ++i) { - partial_dim_sizes.push_back(shape.dim_size(i)); + if (!shape.unknown_rank()) { + gtl::InlinedVector<int64, 4> partial_dim_sizes; + for (int i = 1; i < shape.dims(); ++i) { + partial_dim_sizes.push_back(shape.dim_size(i)); + } + shapes_.emplace_back(std::move(partial_dim_sizes)); + } else { + // If the input shape is unknown, the output shape will be unknown. + shapes_.emplace_back(); } - shapes_.emplace_back(std::move(partial_dim_sizes)); } } diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc index 42fbf95cd3..28940e0849 100644 --- a/tensorflow/core/kernels/dequantize_op.cc +++ b/tensorflow/core/kernels/dequantize_op.cc @@ -96,8 +96,6 @@ class DequantizeOp : public OpKernel { output); } } else if (mode_ == QUANTIZE_MODE_SCALED) { - // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and - // QuantizeAndDequantizeV3 to match this SCALED mode again. const float scale_factor = std::numeric_limits<T>::min() == 0 ? (max_range / std::numeric_limits<T>::max()) diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc index c90ad2cfeb..ada1235449 100644 --- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc +++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc @@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession { } void FuzzImpl(const uint8_t* data, size_t size) final { + // We need to be sure that we don't request too many elements (i.e., we + // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large + // number of elements, up to the limit of the memory available to the OS. + // However, due to the tracing done in ASAN, after 2^32 bytes of requested + // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's + // try parsing the proto here, check that the size (if valid) is below a + // maximum threshold (using 2^20 for convenience), and then run the + // remainder of the fuzzer testing. Of course, this duplicates some work + // but it's better than repeating the investigation whenever Autofuzz + // detects another similar OOM. + string as_string = string(reinterpret_cast<const char*>(data), size); + TensorProto proto; + if (!ParseProtoUnlimited(&proto, as_string)) { + LOG(WARNING) << "Unable to parse proto of tensor\n"; + return; + } + if (!TensorShape::IsValid(proto.tensor_shape())) { + LOG(WARNING) << "Invalid tensor shape\n"; + return; + } + TensorShape shape(proto.tensor_shape()); + const int64 num_elements = shape.num_elements(); + const int64 max_num_elements = 1 << 20; + if (num_elements > max_num_elements) { + LOG(WARNING) << "Requiring a tensor with too many elements\n"; + return; + } + + // Now we can do the actual fuzz implementation Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); - input_tensor.scalar<string>()() = - string(reinterpret_cast<const char*>(data), size); + input_tensor.scalar<string>()() = as_string; // TODO(b/32704451): Don't just ignore the ::tensorflow::Status object! RunOneInput(input_tensor).IgnoreError(); } diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h index 277ee2be02..1c78de253e 100644 --- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h +++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h @@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator( slice_size, Tindices, Tparams, Tout, &error_loc); -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) // Eigen implementation below is not highly performant. gather_nd_generator // does not seem to be called in parallel, leading to very poor performance. // Additionally, since it uses scalar (Tscratch) to invoke 'generate', it @@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { const Eigen::array<Eigen::DenseIndex, 1> loc{i}; gather_nd_generator(loc); } -#else // INTEL_MKL +#else // INTEL_MKL && ENABLE_MKL Tscratch.device(d) = Tscratch.reshape(reshape_dims) .broadcast(broadcast_dims) .generate(gather_nd_generator) .sum(); -#endif +#endif // INTEL_MKL && ENABLE_MKL // error_loc() returns -1 if there's no out-of-bounds index, // otherwise it returns the location of an OOB index in Tindices. diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc index 79967aab38..4ad390a411 100644 --- a/tensorflow/core/kernels/matmul_op.cc +++ b/tensorflow/core/kernels/matmul_op.cc @@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> { .Label("cublas"), \ MatMulOp<GPUDevice, T, true /* cublas */>) -#if defined(INTEL_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) // MKL does not support half, bfloat16 and int32 types for // matrix-multiplication, so register the kernel to use default Eigen based @@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU_EIGEN); TF_CALL_complex128(REGISTER_CPU_EIGEN); TF_CALL_double(REGISTER_CPU_EIGEN); -#endif +#endif // INTEL_MKL_DNN_ONLY -#else // INTEL MKL +#else // INTEL_MKL && ENABLE_MKL TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_half(REGISTER_CPU); @@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU); TF_CALL_int32(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); -#endif +#endif // INTEL_MKL && ENABLE_MKL #if GOOGLE_CUDA TF_CALL_float(REGISTER_GPU); diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc index 0841395dc3..bc135de11e 100644 --- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc @@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel { Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ BatchMatMulMkl<CPUDevice, TYPE>) +#ifdef ENABLE_MKL TF_CALL_float(REGISTER_BATCH_MATMUL_MKL); TF_CALL_double(REGISTER_BATCH_MATMUL_MKL); TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL); TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL); +#endif // ENABLE_MKL } // end namespace tensorflow #endif diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc index 52157ed5fb..f406ad2ab5 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc @@ -853,7 +853,7 @@ class MklConvCustomBackpropFilterOp // MKL DNN allocates large buffers when a conv gradient filter primtive is // created. So we don't cache conv backward primitives when the env - // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true. + // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true. bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled(); conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get( convBwdFilterDims, do_not_cache); diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc index c38c9cc27c..a501ce2c93 100644 --- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc @@ -713,7 +713,7 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> { TFPaddingToMklDnnPadding(this->padding_)); // We don't cache those primitves if the env variable - // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor + // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor // includes potentialy large buffers. MKL DNN allocates buffers // in the following cases // 1. Legacy CPU without AVX512/AVX2, or diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc index 184e0cb003..b332edad0a 100644 --- a/tensorflow/core/kernels/mkl_conv_ops.cc +++ b/tensorflow/core/kernels/mkl_conv_ops.cc @@ -901,7 +901,7 @@ class MklConvOp : public OpKernel { // In some cases, primitve descriptor includes potentialy large buffers, // we don't cache those primitves if the env variable - // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers + // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true. MKL DNN allocates buffers // in the following cases // 1. Legacy CPU without AVX512/AVX2, or // 2. 1x1 convolution with stride != 1 diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc index 077d62ce32..f4788f4851 100644 --- a/tensorflow/core/kernels/mkl_matmul_op.cc +++ b/tensorflow/core/kernels/mkl_matmul_op.cc @@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel { reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta, reinterpret_cast<MKL_Complex16*>(c), ldc); } -#endif +#endif // !INTEL_MKL_DNN_ONLY }; #define REGISTER_CPU(T) \ @@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel { Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \ MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>); +#ifdef ENABLE_MKL // TODO(inteltf) Consider template specialization when adding/removing // additional types TF_CALL_float(REGISTER_CPU); @@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU); TF_CALL_double(REGISTER_CPU); TF_CALL_complex64(REGISTER_CPU); TF_CALL_complex128(REGISTER_CPU); -#endif +#endif // !INTEL_MKL_DNN_ONLY +#endif // ENABLE_MKL } // namespace tensorflow #endif // INTEL_MKL diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc new file mode 100644 index 0000000000..d63e14adf6 --- /dev/null +++ b/tensorflow/core/kernels/mkl_slice_op.cc @@ -0,0 +1,358 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// See docs in ../ops/array_ops.cc. + +#ifdef INTEL_MKL +#ifndef INTEL_MKL_ML_ONLY + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/kernels/ops_util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/prefetch.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +#include "mkldnn.hpp" +#include "tensorflow/core/util/mkl_util.h" + +using mkldnn::stream; +using mkldnn::view; + +namespace tensorflow { + +namespace { + +gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) { + gtl::InlinedVector<int64, 4> out; + if (tensor.dtype() == DT_INT32) { + for (int64 i = 0; i < tensor.NumElements(); ++i) { + out.push_back(tensor.flat<int32>()(i)); + } + } else if (tensor.dtype() == DT_INT64) { + for (int64 i = 0; i < tensor.NumElements(); ++i) { + out.push_back(tensor.flat<int64>()(i)); + } + } else { + // tensor must be either int32 or int64 + DCHECK(false); + } + return out; +} + +} // namespace + +typedef Eigen::ThreadPoolDevice CPUDevice; + +// A version of SharedValidation (slice_op.h) written for input that is in +// either Mkl layout or Tensorflow layout. +// A shared code to validate input shapes and check for identity, which is not dependent on the type of T. +// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.) +static void ValidateMklInputs(OpKernelContext* context, bool* is_identity, + gtl::InlinedVector<int64, 4>* begin, + gtl::InlinedVector<int64, 4>* size) { + const int kInputTensorIndex = 0; + const int kInputBeginIndex = 1; + const int kInputSizeIndex = 2; + const Tensor& input = MklGetInput(context, kInputTensorIndex); + const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex); + const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex); + + MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape; + GetMklShape(context, kInputTensorIndex, &input_mkl_shape); + GetMklShape(context, kInputBeginIndex, &begin_mkl_shape); + GetMklShape(context, kInputSizeIndex, &size_mkl_shape); + + // Begin and size tensors cannot be in MklDnn layout. + DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false); + DCHECK_EQ(size_mkl_shape.IsMklTensor(), false); + + TensorShape input_tf_shape = input_mkl_shape.IsMklTensor() + ? input_mkl_shape.GetTfShape() + : input.shape(); + const int input_dims = input_tf_shape.dims(); + + OP_REQUIRES( + context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) && + context->op_kernel().IsLegacyVector(size_tensor.shape()) && + begin_tensor.NumElements() == input_dims && + size_tensor.NumElements() == input_dims, + errors::InvalidArgument( + "Expected begin and size arguments to be 1-D tensors of size ", + input_dims, ", but got shapes ", begin_tensor.shape().DebugString(), + " and ", size_tensor.shape().DebugString(), " instead.")); + + *begin = IntTensorToInt64Vec(begin_tensor); + *size = IntTensorToInt64Vec(size_tensor); + for (int i = 0; i < input_dims; ++i) { + if ((*size)[i] == -1) { + // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". + (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i]; + } + } + + *is_identity = true; + for (int i = 0; i < input_dims; ++i) { + int64 b = (*begin)[i]; + int64 s = (*size)[i]; + if (input_tf_shape.dim_size(i) == 0) { + OP_REQUIRES( + context, b == 0 && s == 0, + errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b, + ") and size[", i, "] == 0 ", "(got ", s, + ") when ", "input.dim_size(", i, ") == 0")); + } else { + OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i), + errors::InvalidArgument("Expected begin[", i, "] in [0, ", + input_tf_shape.dim_size(i), + "], but got ", b)); + OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_tf_shape.dim_size(i) - b, + "], but ", "got ", s)); + } + const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i)); + (*is_identity) &= take_all; + } +} + +// A version of SharedSliceCommonCases function written for input tensor +// that may be in MklDnn layout or in Tensorflow layout. +template <typename T> +static void CheckCommonCasesForMklInputs(OpKernelContext* context, + gtl::InlinedVector<int64, 4>* begin, + gtl::InlinedVector<int64, 4>* size, + bool* done) { + bool is_identity = true; + *done = false; + + ValidateMklInputs(context, &is_identity, begin, size); + if (!context->status().ok()) return; + + const Tensor& input = MklGetInput(context, 0); + MklDnnShape input_mkl_shape; + GetMklShape(context, 0, &input_mkl_shape); + + if (is_identity) { + VLOG(1) << "Slice identity"; + context->set_output(0, input); + // Mkl metadata tensor in this case can just be forwarded from input to + // output. + AllocateOutputSetMklShape(context, 0, input_mkl_shape); + *done = true; + } +} + +// MKL-DNN implementation of Slice +template <typename Device, typename T> +class MklDnnSliceOp : public OpKernel { + public: + explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {} + + ~MklDnnSliceOp() {} + + void Compute(OpKernelContext* context) override { + gtl::InlinedVector<int64, 4> begin; + gtl::InlinedVector<int64, 4> size; + bool done = false; + + CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done); + if (!context->status().ok() || done == true) return; + + // Though MKL-DNN supports more than 8 dimension and + // less than 12 dimension tensor. + // But we are mimicking functionality of Eigen Slice op for CPU. + if (begin.size() >= 8) { + OP_REQUIRES( + context, false, + errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions")); + } + + ComputeMklDnnSlice(context, begin, size); + } + + private: + // Slice op implemented using MKL-DNN APIs. + void ComputeMklDnnSlice(OpKernelContext* context, + const gtl::InlinedVector<int64, 4>& begin, + const gtl::InlinedVector<int64, 4>& size) { + try { + // MKL-DNN API usage below is guided by description at: + // https://github.com/01org/mkl-dnn/issues/69 + // + // Relevant part of the description is copied below: + // + // Let's say you want to copy a part of memory into another buffer (and + // probably change the format). Then your steps are: + // + // 1. create memory primitive descriptor in_mem_pd and memory primitive + // in_mem_p for the entire source data. + // 2. create view primitive descriptor in_submem_pd based on in_mem_pd, + // initial offsets, and sub-sizes + // 3. create memory primitive descriptor out_mem_pd and memory primitive + // out_mem_p for the output (the logical sizes should match sub-sizes + // used in step 2, but the format might be arbitrary) + // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd + // and out_mem_pd + // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and + // out_mem_p. + // + // Please notice that there is no view primitive. There is only view + // primitive descriptor. And the reorder uses source memory as input but + // traverses it according to a view in_submem_pd. + + auto cpu_engine = engine(engine::cpu, 0); + MklDnnData<T> src(&cpu_engine); + MklDnnData<T> output(&cpu_engine); + + // Populate offsets and sizes in memory::dims format based on vector. + memory::dims begin_dims = {}; + begin_dims.resize(begin.size()); + for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i]; + memory::dims size_dims = {}; + bool empty = false; + size_dims.resize(size.size()); + for (size_t i = 0; i < size.size(); ++i) { + size_dims[i] = size[i]; + if (size_dims[i] == 0) empty = true; + } + + Tensor* output_tensor = nullptr; + MklDnnShape output_mkl_shape; + + // If no dimension is selected in slice, the result should be empty. + // Just return an empty output tensor, and a dummy Mkl-shape tensor. + if (empty) { // for empty dims + auto shape_to = MklDnnDimsToTFShape(size_dims); + AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to, + output_mkl_shape); + return; + } + + // Step 1 (as per above description) - Create memory for user data. + // We use blocked format here to describe input tensor. + const Tensor& input_tensor = MklGetInput(context, 0); + MklDnnShape input_mkl_shape; + GetMklShape(context, 0, &input_mkl_shape); + + if (input_mkl_shape.IsMklTensor()) { + auto input_mkl_format = input_mkl_shape.GetTfDataFormat(); + auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format); + begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format); + size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format); + auto input_md = input_mkl_shape.GetMklLayout(); + src.SetUsrMem(input_md, &input_tensor); + } else { + // Initialize input dimensions and strides to be used when input is not + // in MklDnn layout. + memory::dims input_dims, input_strides; + input_dims = TFShapeToMklDnnDims(input_tensor.shape()); + input_strides = CalculateTFStrides(input_dims); + // Create input memory descriptor. + auto input_md = + MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides); + src.SetUsrMem(input_md, &input_tensor); + } + + // Step 2 - create view primitive descriptor + auto view_pd = + view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims) + .dst_primitive_desc(); + auto output_strides = CalculateTFStrides(size_dims); + auto output_md = + MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides); + auto output_pd = memory::primitive_desc(output_md, cpu_engine); + + // Step 3 - Create memory for output. If input is in MklDnn layout, then + // output is also in MklDnn layout. Otherwise, output is in Tensorflow + // layout. + AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims, + &output_tensor, &output_mkl_shape); + DCHECK(output_tensor); + DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor()); + output.SetUsrMem(output_md, output_tensor); + + std::vector<primitive> net; + // Step 4 - create reorder primitive desc between view_pd and output_pd. + auto reorder_pd = + reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc()); + // Step 5 - create reorder primitive itself. + net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem())); + // Execute the reorder primitive. + stream(stream::kind::eager).submit(net).wait(); + } catch (mkldnn::error& e) { + string error_msg = "Status: " + std::to_string(e.status) + ", message: " + + string(e.message) + ", in file " + string(__FILE__) + + ":" + std::to_string(__LINE__); + OP_REQUIRES_OK( + context, + errors::Aborted("Operation received an exception:", error_msg)); + } + } + + private: + void AllocateOutputTensor(OpKernelContext* context, + const MklDnnShape& input_mkl_shape, + memory::primitive_desc* output_pd, + const memory::dims& output_dims, + Tensor** output_tensor, + MklDnnShape* output_mkl_shape) { + DCHECK(output_tensor); + DCHECK(output_mkl_shape); + + TensorShape output_tf_shape; + + if (input_mkl_shape.IsMklTensor()) { + // Since input tensor is in Mkl layout, output tensor will be in Mkl + // layout. + + // Allocate shape of Mkl tensor. + output_mkl_shape->SetMklTensor(true); + output_mkl_shape->SetMklLayout(output_pd); + output_mkl_shape->SetElemType(MklDnnType<T>()); + output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims, + input_mkl_shape.GetTfDataFormat()); + + output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1); + } else { + // If input is not in Mkl layout, then output won't be in Mkl layout. + output_mkl_shape->SetMklTensor(false); + output_tf_shape = MklDnnDimsToTFShape(output_dims); + } + + AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape, + *output_mkl_shape); + } +}; + +// MKL-DNN Slice registration +#define REGISTER_MKL_SLICE(type) \ + REGISTER_KERNEL_BUILDER(Name("_MklSlice") \ + .Device(DEVICE_CPU) \ + .TypeConstraint<type>("T") \ + .HostMemory("begin") \ + .HostMemory("size") \ + .Label(mkl_op_registry::kMklOpLabel), \ + MklDnnSliceOp<CPUDevice, type>); + +TF_CALL_float(REGISTER_MKL_SLICE); +#undef REGISTER_MKL_SLICE + +} // namespace tensorflow + +#endif // INTEL_MKL_DNN +#endif // INTEL_MKL diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc index fc1c9003aa..3979e4b53a 100644 --- a/tensorflow/core/kernels/partitioned_function_ops.cc +++ b/tensorflow/core/kernels/partitioned_function_ops.cc @@ -97,7 +97,20 @@ class PartitionedCallOp : public AsyncOpKernel { OP_REQUIRES_ASYNC(ctx, fbody != nullptr, errors::Internal("Could not find handle ", handle), done); + OP_REQUIRES_ASYNC( + ctx, args.size() == fbody->arg_nodes.size(), + errors::InvalidArgument( + "Wrong number of arguments to the op; function expects ", + fbody->arg_nodes.size(), " but PartitionedCall received ", + args.size()), + done); + // We need to pass global op_registry as default_registry when creating + // graph. So that graph optimization passes can lookup all possible ops + // by name. auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def()); + FunctionLibraryDefinition global_flib(OpRegistry::Global(), {}); + TF_CHECK_OK( + graph.get()->AddFunctionLibrary(global_flib.ToProto())); CopyGraph(*fbody->graph, graph.get()); OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done); @@ -250,9 +263,11 @@ class PartitionedCallOp : public AsyncOpKernel { VLOG(3) << "Partitioned function '" << func_.name() << "', yielding " << partitions.size() << " shards."; - const FunctionLibraryDefinition* flib_def = &graph->flib_def(); for (const auto& partition : partitions) { - std::unique_ptr<Graph> subgraph(new Graph(flib_def)); + std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def())); + FunctionLibraryDefinition global_flib(OpRegistry::Global(), {}); + TF_CHECK_OK( + subgraph.get()->AddFunctionLibrary(global_flib.ToProto())); GraphConstructorOptions opts; opts.allow_internal_ops = true; opts.expect_device_spec = true; diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc index 04a53697c0..3810d817ca 100644 --- a/tensorflow/core/kernels/random_op.cc +++ b/tensorflow/core/kernels/random_op.cc @@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel { Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ RandomGammaOp<TYPE>) -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_CPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<CPUDevice, IntType>); TF_CALL_half(REGISTER); @@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT); random::TruncatedNormalDistribution< \ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>); -#define REGISTER_INT(IntType) \ - REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("minval") \ - .HostMemory("maxval") \ - .TypeConstraint<int32>("T") \ - .TypeConstraint<IntType>("Tout"), \ +#define REGISTER_INT(IntType) \ + template struct functor::FillPhiloxRandom< \ + GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \ + REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \ + .Device(DEVICE_GPU) \ + .HostMemory("shape") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<int32>("T") \ + .TypeConstraint<IntType>("Tout"), \ RandomUniformIntOp<GPUDevice, IntType>); TF_CALL_half(REGISTER); diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 26705a8d34..678d675c4a 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -51,7 +51,9 @@ limitations under the License. #define EIGEN_USE_GPU #endif -#include "tensorflow/core/kernels/resource_variable_ops.h" +#include <memory> +#include <vector> + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -60,10 +62,12 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/gather_functor.h" +#include "tensorflow/core/kernels/resource_variable_ops.h" #include "tensorflow/core/kernels/scatter_functor.h" #include "tensorflow/core/kernels/training_op_helpers.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/mem.h" #include "tensorflow/core/platform/mutex.h" #include "tensorflow/core/platform/types.h" @@ -72,6 +76,8 @@ limitations under the License. namespace tensorflow { REGISTER_RESOURCE_HANDLE_KERNEL(Var); +REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU), + ResourceHandlesOp<Var>); ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); @@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, t); } +ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) { + int n; + OP_REQUIRES_OK(c, c->GetAttr("N", &n)); + OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_)); + OP_REQUIRES(c, n == dtypes_.size(), + errors::InvalidArgument( + "Mismatched number of arguments to ReadVariablesOp (", n, + " vs. ", dtypes_.size(), ")")); +} + +void ReadVariablesOp::Compute(OpKernelContext* ctx) { + std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables( + dtypes_.size()); + std::vector<const ResourceHandle*> handles(dtypes_.size()); + for (size_t i = 0; i < dtypes_.size(); ++i) { + handles[i] = &HandleFromInput(ctx, i); + } + const auto status = LookupResources(ctx, handles, &variables); + OP_REQUIRES(ctx, status.ok(), + errors::FailedPrecondition( + "Error while reading resource variable. This could mean that " + "the variable was uninitialized. ", + status.ToString())); + + for (size_t i = 0; i < dtypes_.size(); ++i) { + // We're acquiring a reference to the underlying buffer while + // holding a shared lock to guarantee ordering of reads and + // writes. + tf_shared_lock ml(*variables[i]->mu()); + const Tensor& t = *variables[i]->tensor(); + OP_REQUIRES(ctx, dtypes_[i] == t.dtype(), + errors::InvalidArgument( + "Trying to read variable ", handles[i]->name(), + " from Container: ", handles[i]->container(), + " with wrong dtype. Expected ", DataTypeString(dtypes_[i]), + " got ", DataTypeString(t.dtype()))); + ctx->set_output(i, t); + } +} + REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU), ReadVariableOp); +REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU), + ReadVariablesOp); #if GOOGLE_CUDA REGISTER_KERNEL_BUILDER( Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"), ReadVariableOp); +REGISTER_KERNEL_BUILDER( + Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"), + ReadVariablesOp); #define REGISTER_GPU_KERNELS(type) \ namespace functor { \ @@ -122,11 +173,20 @@ REGISTER_KERNEL_BUILDER( .HostMemory("resource") \ .TypeConstraint<type>("dtype"), \ ResourceHandleOp<Var>) - TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS); TF_CALL_int64(REGISTER_GPU_KERNELS); TF_CALL_variant(REGISTER_GPU_KERNELS); #undef REGISTER_GPU_KERNELS + +REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") + .Device(DEVICE_GPU) + .HostMemory("resources") + .TypeConstraint("dtypes", + {DT_INT64, DT_COMPLEX64, + DT_COMPLEX128, DT_HALF, DT_FLOAT, + DT_DOUBLE, DT_BOOL, DT_VARIANT}), + ResourceHandlesOp<Var>); + #endif // GOOGLE_CUDA template <typename T> @@ -366,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel { // ADD if value's refcount was 1. mutex_lock ml(*variable->mu()); Tensor* var_tensor = variable->tensor(); + OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()), + errors::InvalidArgument("Cannot update variable with shape ", + var_tensor->shape().DebugString(), + " using a Tensor with shape ", + value.shape().DebugString(), + ", shapes must be equal.")); OP_REQUIRES_OK(context, PrepareToUpdateVariable<Device, T>(context, var_tensor)); functor::DenseUpdate<Device, T, Op> update_functor; diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h index 9b60106f13..cffb732c38 100644 --- a/tensorflow/core/kernels/resource_variable_ops.h +++ b/tensorflow/core/kernels/resource_variable_ops.h @@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel { DataType dtype_; }; +class ReadVariablesOp : public OpKernel { + public: + explicit ReadVariablesOp(OpKernelConstruction* c); + void Compute(OpKernelContext* ctx) override; + bool IsExpensive() override { return false; } + + private: + DataTypeVector dtypes_; +}; + class DestroyResourceOp : public OpKernel { public: explicit DestroyResourceOp(OpKernelConstruction* ctx); diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index 77594479cb..a006c69297 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -228,191 +228,6 @@ class SliceOp : public OpKernel { } }; -#ifdef INTEL_MKL -template <typename Device, typename T> -class MklSliceOp : public OpKernel { - public: - explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {} - - void Compute(OpKernelContext* context) override { - TensorShape output_shape; - gtl::InlinedVector<int64, 4> begin; - gtl::InlinedVector<int64, 4> size; - Tensor* result = nullptr; - bool done = false; - SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result, - &done); - if (!context->status().ok() || done == true) return; - - const Tensor& input = context->input(0); - const int input_dims = input.dims(); - - if (output_shape.num_elements() > 0) { - if (std::is_same<Device, CPUDevice>::value && input_dims == 2 && - DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { - auto input = context->input(0).tensor<T, 2>(); - auto output = result->tensor<T, 2>(); - // TODO(agarwal): Consider multi-threading this loop for cases where - // size[0] is very large. - for (int i = 0; i < size[0]; ++i) { - const int64 row = begin[0] + i; - if (i + 1 < size[0]) { - port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0)); - port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1])); - } - memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T)); - } - return; - } -#define HANDLE_DIM(NDIM) \ - if (input_dims == NDIM) { \ - HandleCase<NDIM>(context, begin, size, result); \ - return; \ - } - - HANDLE_DIM(1); - HANDLE_DIM(2); - HANDLE_DIM(3); - HANDLE_DIM(4); - HANDLE_DIM(5); - HANDLE_DIM(6); - HANDLE_DIM(7); - -#undef HANDLE_DIM - - OP_REQUIRES( - context, false, - errors::Unimplemented("SliceOp : Unhandled input dimensions")); - } - } - - private: - // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following - // criteria matches for slice_dim: if indices for slice are 0 in all dims - // except slice_dim and if sizes of all the dimensions of the slice are same - // as the sizes of all the dimensions of the input except slice_dim, then - // returns True. Otherwise, returns False. - bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape, - const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, - int slice_dim) { - for (int dim = 0; dim < 4; dim++) { - if (dim != slice_dim && - (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) { - return false; - } - } - return true; - } - - // Is 'input' tensor being sliced over a single dimension out of 4? - // - // This check is applicable in the context of Slice of a 4-D tensor in - // NHWC or NCHW format over channel dimension. - // - // If indices for slice are 0 in all dims except one dimension and if sizes of - // all dimensions of slice are same as sizes of all dimensions of inputs - // except that dimension, then we are slicing over a single dimension. - // - // Returns True if Slicing over a single dimension, and sets slice_dim - // to the number of the dimension that satisfies criteria. - bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape, - const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, - int* slice_dim) { - for (int dim = 0; dim < 4; dim++) { - if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) { - *slice_dim = dim; - return true; - } - } - return false; - } - - template <int NDIM> - void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin, - const gtl::ArraySlice<int64>& size, Tensor* result) { - int slice_dim = -1; - TensorShape in_shape = context->input(0).shape(); - // Special case for handling 4-D tensor slice when shape of the slice - // differs from the input tensor in only 1 out of 4 dimensions. - // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW - // format over channel dimension. - if (NDIM == 4 && - DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) { - size_t in_strides[4] = { - (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) * - in_shape.dim_size(3), - (size_t)in_shape.dim_size(2) * in_shape.dim_size(3), - (size_t)in_shape.dim_size(3), (size_t)1}; - - size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3], - (size_t)size[2] * size[3], (size_t)size[3], - (size_t)1}; - - T* in_buf = const_cast<T*>( - const_cast<const T*>(context->input(0).flat<T>().data())); - T* op_buf = result->flat<T>().data(); - - if (slice_dim == 1) { - /* data format = NCHW */ - -#pragma omp parallel for - for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { - T* ip = in_buf + (d0 * in_strides[0]); - T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); -#pragma omp parallel for - for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { - T* ip1 = ip + (d1 * in_strides[1]); - T* op1 = op + ((d1 - begin[1]) * out_strides[1]); - // For NCHW, H and W will be contiguous. So we can copy - // both with one memcpy. - memcpy(static_cast<void*>(op1), static_cast<void*>(ip1), - sizeof(T) * in_strides[1]); - } - } - return; - } else if (slice_dim == 3) { - /* data_format = NHWC */ - -#pragma omp parallel for - for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) { - T* ip = in_buf + (d0 * in_strides[0]); - T* op = op_buf + ((d0 - begin[0]) * out_strides[0]); -#pragma omp parallel for - for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) { - T* ip1 = ip + (d1 * in_strides[1]); - T* op1 = op + ((d1 - begin[1]) * out_strides[1]); -#pragma omp parallel for - for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) { - T* ip2 = ip1 + (d2 * in_strides[2]); - T* ip3 = ip2 + begin[3]; - T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]); - T* op3 = op2; - memcpy(static_cast<void*>(op3), static_cast<void*>(ip3), - sizeof(T) * size[3]); - } - } - } - return; - } - // slice_dim is not 1 or 3, then we fallback to Eigen implementation. - } - - Eigen::DSizes<Eigen::DenseIndex, NDIM> indices; - Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes; - for (int i = 0; i < NDIM; ++i) { - indices[i] = begin[i]; - sizes[i] = size[i]; - } - - functor::Slice<Device, T, NDIM>()( - context->eigen_device<Device>(), result->tensor<T, NDIM>(), - context->input(0).tensor<T, NDIM>(), indices, sizes); - } -}; -#endif - // Forward declarations of the functor specializations for declared in the // sharded source files. namespace functor { @@ -440,7 +255,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N); #undef DECLARE_CPU_SPEC } // namespace functor -#ifndef INTEL_MKL #define REGISTER_SLICE(type) \ REGISTER_KERNEL_BUILDER(Name("Slice") \ .Device(DEVICE_CPU) \ @@ -452,19 +266,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N); TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); #undef REGISTER_SLICE -#else -#define REGISTER_SLICE(type) \ - REGISTER_KERNEL_BUILDER(Name("Slice") \ - .Device(DEVICE_CPU) \ - .TypeConstraint<type>("T") \ - .HostMemory("begin") \ - .HostMemory("size"), \ - MklSliceOp<CPUDevice, type>) - -TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); -TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -#undef REGISTER_SLICE -#endif // INTEL_MKL #if GOOGLE_CUDA // Forward declarations of the functor specializations for GPU. diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc index eab176c7fb..925f5291a6 100644 --- a/tensorflow/core/kernels/stateless_random_ops.cc +++ b/tensorflow/core/kernels/stateless_random_ops.cc @@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase { } }; -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<CPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_CPU) \ - .HostMemory("shape") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - CPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +template <typename Device, typename IntType> +class StatelessRandomUniformIntOp : public StatelessRandomOpBase { + public: + using StatelessRandomOpBase::StatelessRandomOpBase; -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); + void Fill(OpKernelContext* context, random::PhiloxRandom random, + Tensor* output) override { + const Tensor& minval = context->input(2); + const Tensor& maxval = context->input(3); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()), + errors::InvalidArgument("minval must be 0-D, got shape ", + minval.shape().DebugString())); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()), + errors::InvalidArgument("maxval must be 0-D, got shape ", + maxval.shape().DebugString())); + + // Verify that minval < maxval. Note that we'll never reach this point for + // empty output. Zero impossible things are fine. + const auto lo = minval.scalar<IntType>()(); + const auto hi = maxval.scalar<IntType>()(); + OP_REQUIRES( + context, lo < hi, + errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi)); + + // Build distribution + typedef random::UniformDistribution<random::PhiloxRandom, IntType> + Distribution; + Distribution dist(lo, hi); + + auto flat = output->flat<IntType>(); + // Reuse the compute kernels from the stateful random ops + functor::FillPhiloxRandom<Device, Distribution>()( + context, context->eigen_device<Device>(), random, flat.data(), + flat.size(), dist); + } +}; -#undef REGISTER +#define REGISTER(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomUniform") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessRandomNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \ + random::PhiloxRandom, TYPE> >); \ + REGISTER_KERNEL_BUILDER( \ + Name("StatelessTruncatedNormal") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomOp< \ + DEVICE##Device, \ + random::TruncatedNormalDistribution< \ + random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); + +#define REGISTER_INT(DEVICE, TYPE) \ + REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \ + .Device(DEVICE_##DEVICE) \ + .HostMemory("shape") \ + .HostMemory("seed") \ + .HostMemory("minval") \ + .HostMemory("maxval") \ + .TypeConstraint<TYPE>("dtype"), \ + StatelessRandomUniformIntOp<DEVICE##Device, TYPE>); + +#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE) +#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE) +#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE) +#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE) + +TF_CALL_half(REGISTER_CPU); +TF_CALL_bfloat16(REGISTER_CPU); +TF_CALL_float(REGISTER_CPU); +TF_CALL_double(REGISTER_CPU); +TF_CALL_int32(REGISTER_INT_CPU); +TF_CALL_int64(REGISTER_INT_CPU); #if GOOGLE_CUDA -#define REGISTER(TYPE) \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomUniform") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::UniformDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessRandomNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp<GPUDevice, random::NormalDistribution< \ - random::PhiloxRandom, TYPE> >); \ - REGISTER_KERNEL_BUILDER( \ - Name("StatelessTruncatedNormal") \ - .Device(DEVICE_GPU) \ - .HostMemory("shape") \ - .HostMemory("seed") \ - .TypeConstraint<TYPE>("dtype"), \ - StatelessRandomOp< \ - GPUDevice, \ - random::TruncatedNormalDistribution< \ - random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); +TF_CALL_int32(REGISTER_INT_GPU); +TF_CALL_int64(REGISTER_INT_GPU); -TF_CALL_half(REGISTER); -TF_CALL_float(REGISTER); -TF_CALL_double(REGISTER); +#endif // GOOGLE_CUDA #undef REGISTER - -#endif // GOOGLE_CUDA +#undef REGISTER_INT +#undef REGISTER_CPU +#undef REGISTER_GPU +#undef REGISTER_INT_CPU +#undef REGISTER_INT_GPU } // namespace diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index f0575de4d9..3e8a4c5b72 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -149,7 +149,7 @@ class StridedSliceOp : public OpKernel { // NDIM and T if (is_simple_slice && std::is_same<Device, CPUDevice>::value && input_dims == 2 && processing_shape.dims() == 2 && - final_shape.dims() == 2) { + final_shape.dims() == 2 && new_axis_mask == 0) { MemCpyFunctor<T> functor; if (functor.Copy(input, begin, end, result)) { return; diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc index 3a9803a052..92c73220d8 100644 --- a/tensorflow/core/kernels/string_util.cc +++ b/tensorflow/core/kernels/string_util.cc @@ -16,10 +16,6 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" -namespace { -inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } -} // namespace - namespace tensorflow { // Sets unit value based on str. diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h index 390cf57702..d40e93ea33 100644 --- a/tensorflow/core/kernels/string_util.h +++ b/tensorflow/core/kernels/string_util.h @@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 }; // TODO(edloper): Add support for: UTF32_CHAR, etc. enum class CharUnit { BYTE, UTF8_CHAR }; +// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char. +inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; } + // Sets `encoding` based on `str`. Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding); @@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit); // Result may be incorrect if the input string is not valid UTF-8. int32 UTF8StrLen(const string& string); +// Get the next UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset, and +// should never be `null`. The function return true if successful. However, if +// the end of the string is reached before the requested characters, then the +// position will point to the end of string and this function will return false. +template <typename T> +bool ForwardNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t size = in.size(); + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) { + // move forward one utf-8 character + do { + ++*pos; + } while (IsTrailByte(in[*pos]) && *pos < size); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + +// Get the previous UTF8 character position starting at the given position and +// skipping the given number of characters. Position is a byte offset with a +// positive value, relative to the beginning of the string, and should never be +// `null`. The function return true if successful. However, if the beginning of +// the string is reached before the requested character, then the position will +// point to the beginning of the string and this function will return false. +template <typename T> +bool BackNUTF8CharPositions(const StringPiece in, + const T num_utf8_chars_to_shift, T* pos) { + const size_t start = 0; + T utf8_chars_counted = 0; + while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) { + // move back one utf-8 character + do { + --*pos; + } while (IsTrailByte(in[*pos]) && *pos > start); + ++utf8_chars_counted; + } + return utf8_chars_counted == num_utf8_chars_to_shift; +} + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_ diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc index 07f1d6e767..93c427039d 100644 --- a/tensorflow/core/kernels/substr_op.cc +++ b/tensorflow/core/kernels/substr_op.cc @@ -26,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/kernels/bounds_check.h" +#include "tensorflow/core/kernels/string_util.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/types.h" @@ -37,7 +38,11 @@ namespace tensorflow { template <typename T> class SubstrOp : public OpKernel { public: - using OpKernel::OpKernel; + explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { + string unit; + OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); + OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); + } void Compute(OpKernelContext* context) override { // Get inputs @@ -69,11 +74,23 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); for (size_t i = 0; i < input_tensor.NumElements(); ++i) { StringPiece in(input(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } else { @@ -84,11 +101,23 @@ class SubstrOp : public OpKernel { StringPiece in(input(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } } @@ -151,12 +180,24 @@ class SubstrOp : public OpKernel { StringPiece in(input_bcast(i)); const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); - OP_REQUIRES( - context, - FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for string", - "b'", in, "' at index ", i)); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, + FastBoundsCheck(byte_pos, input_bcast(i).size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index ", i)); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i).assign(sub_in.data(), sub_in.size()); } break; @@ -205,12 +246,24 @@ class SubstrOp : public OpKernel { tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); - OP_REQUIRES( - context, FastBoundsCheck(std::abs(pos), in.size() + 1), - errors::InvalidArgument("pos ", pos, " out of range for ", - "string b'", in, "' at index (", i, - ", ", j, ")")); - StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len); + T byte_pos = pos; + T byte_len = len; + switch (unit_) { + case CharUnit::UTF8_CHAR: + OP_REQUIRES( + context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string at index ", i)); + break; + case CharUnit::BYTE: + byte_pos = AdjustedPosIndex(byte_pos, in); + OP_REQUIRES( + context, FastBoundsCheck(byte_pos, in.size() + 1), + errors::InvalidArgument("pos ", pos, " out of range for ", + "string b'", in, "' at index (", + i, ", ", j, ")")); + } + StringPiece sub_in = in.substr(byte_pos, byte_len); output(i, j).assign(sub_in.data(), sub_in.size()); } } @@ -227,12 +280,73 @@ class SubstrOp : public OpKernel { private: // This adjusts the requested position. Note it does not perform any bound // checks. - T AdjustedPosIndex(const T pos_requested, const StringPiece s) { + static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { if (pos_requested < 0) { return s.size() + pos_requested; } return pos_requested; } + + // Return true if successful; otherwise, return false if the `pos` argument + // is out of range in the string. + static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, + T* len) { + if (*pos >= 0) { + return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); + } else { + return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); + } + } + + static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + *char_pos = 0; + // Determine byte position of the substring start. + if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { + return false; + } + // Determine position of the end of the substring. + // The length will be capped at the end of the string, and we ignore whether + // the string had enough characters to handle it or not. + *char_len = *char_pos; + ForwardNUTF8CharPositions(in, len, char_len); + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + // This function expects a negative position relative to the end of the + // string, but will update the character position to a positive number + // relative to the beginning of the string. + static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, + const T len, T* char_pos, + T* char_len) { + // Initially treat the length as position of the end of the substring. + *char_len = in.size(); + // This is the number of character to skip from the end of the string to + // arrive at the position where the substring should end. + T utf8_chars_to_skip = -pos - len; + if (utf8_chars_to_skip < 0) { + utf8_chars_to_skip = 0; + } + // Find the byte position where the substring should end using the computed + // number of characters to skip. + if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { + return false; + } + // Next, determine where the substring should begin. The number of chars to + // skip is the requested position minus the chars we've previously skipped. + *char_pos = *char_len; + if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { + return false; + } + // The length in bytes is the position end of the substring less the start. + *char_len = *char_len - *char_pos; + return true; + } + + CharUnit unit_ = CharUnit::BYTE; }; #define REGISTER_SUBSTR(type) \ diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc index 2e07050260..ea6b1ed500 100644 --- a/tensorflow/core/kernels/substr_op_test.cc +++ b/tensorflow/core/kernels/substr_op_test.cc @@ -42,7 +42,7 @@ limitations under the License. namespace tensorflow { // Test data from the TensorFlow README.md. -const char* lines[] = { +const char* ascii_lines[] = { "**TensorFlow** is an open source software library for numerical " "computation using data flow graphs.", "The graph nodes represent mathematical operations, while the graph edges " @@ -64,17 +64,76 @@ const char* lines[] = { "backwards compatibility guarantee like C++, Go, Java, JavaScript and " "Swift."}; +const char* unicode_lines[] = { + "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6" + "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6" + "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6" + "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82", + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba" + "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c" + "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba" + "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81" + "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae" + "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89" + "\xe3\x80\x82", + "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93" + "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf" + "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2" + "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1" + "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87" + "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a" + "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9" + "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82", + "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88" + "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef" + "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6" + "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5" + "\x8c\x85\xe3\x80\x82", + "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7" + "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5" + "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd" + "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain" + "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c" + "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba" + "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6" + "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6" + "\xe3\x80\x82", + "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82" + "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96" + "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4" + "\xe3\x80\x82", + "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84" + "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2" + "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6" + "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c" + "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82", +}; + +const char* const kByteUnit = "BYTE"; +const char* const kUTF8Unit = "UTF8_CHAR"; + Tensor GetTestTensor(int batch) { - const int sz = TF_ARRAYSIZE(lines); + const int sz = TF_ARRAYSIZE(ascii_lines); + Tensor t(DT_STRING, {batch}); + auto s = t.flat<string>(); + for (int i = 0; i < batch; ++i) { + s(i) = ascii_lines[i % sz]; + } + return t; +} + +Tensor GetTestUTF8Tensor(int batch) { + const int sz = TF_ARRAYSIZE(unicode_lines); Tensor t(DT_STRING, {batch}); auto s = t.flat<string>(); for (int i = 0; i < batch; ++i) { - s(i) = lines[i % sz]; + s(i) = unicode_lines[i % sz]; } return t; } -Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { +Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len, + const char* const unit) { Graph* g = new Graph(OpRegistry::Global()); Tensor position(DT_INT32, TensorShape({})); position.flat<int32>().setConstant(pos); @@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) { .Input(test::graph::Constant(g, input)) .Input(test::graph::Constant(g, position)) .Input(test::graph::Constant(g, length)) + .Attr("unit", unit) .Finalize(g, nullptr /* node */)); return g; } -void BM_Substr(int iters, int batch_size) { +void BM_SubstrByte(int iters, int batch_size) { testing::StopTiming(); testing::ItemsProcessed(static_cast<int64>(iters)); testing::UseRealTime(); Tensor input = GetTestTensor(batch_size); - Graph* g = SetupSubstrGraph(input, 3, 30); + Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} + +void BM_SubstrUTF8(int iters, int batch_size) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast<int64>(iters)); + testing::UseRealTime(); + Tensor input = GetTestUTF8Tensor(batch_size); + Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit); testing::StartTiming(); test::Benchmark("cpu", g).Run(iters); } -BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg( - 256); +BENCHMARK(BM_SubstrByte) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); +BENCHMARK(BM_SubstrUTF8) + ->Arg(1) + ->Arg(8) + ->Arg(16) + ->Arg(32) + ->Arg(64) + ->Arg(128) + ->Arg(256); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc index 83b83fcdb9..4262a5404b 100644 --- a/tensorflow/core/kernels/training_op_helpers.cc +++ b/tensorflow/core/kernels/training_op_helpers.cc @@ -15,14 +15,16 @@ limitations under the License. #include "tensorflow/core/kernels/training_op_helpers.h" +#include "tensorflow/core/util/ptr_util.h" + namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource) { + *maybe_resource = nullptr; if (ctx->input_dtype(input) == DT_RESOURCE) { - Var* var; - if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) { - core::ScopedUnref scoped_unref(var); - return var->mu(); + if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { + return (*maybe_resource)->mu(); } else { ctx->CtxFailureWithWarning( errors::Internal("Invalid variable reference.")); @@ -33,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) { } // MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes -// in address order to mitigate deadlock. Returns a vector of acquired mutexes. -// Safe to pass duplicates - will only lock each distinct mutex once. If -// do_lock is false, returns immediately. Note that this silently doesn't lock -// mutexes for invalid variable references; in all usages this is followed by -// GetInputTensor which will signal a failure. -std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( +// in address order to mitigate deadlock. Returns a structure that, when +// deleted, will release the acquired mutexes. Safe to pass duplicates - will +// only lock each distinct mutex once. If do_lock is false, returns +// immediately. Note that this silently doesn't lock mutexes for invalid +// variable references; in all usages this is followed by GetInputTensor which +// will signal a failure. +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) { bool any_resource = false; for (auto i : input_ids) { @@ -47,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( break; } } - std::vector<mutex_lock> locks; if (!do_lock && !any_resource) { - return locks; + return VariableInputLockHolder({}, {}); } + std::vector<Var*> vars; std::vector<mutex*> mutexes; std::vector<int> acquire_order; for (auto input : input_ids) { - mutex* mutex = GetTrainingVariableMutex(ctx, input); + Var* var; + mutex* mutex = GetTrainingVariableMutex(ctx, input, &var); + if (var) vars.push_back(var); // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { acquire_order.push_back(mutexes.size()); @@ -64,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( std::sort(acquire_order.begin(), acquire_order.end(), [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); + std::unique_ptr<std::vector<mutex_lock>> locks = + MakeUnique<std::vector<mutex_lock>>(); + locks->reserve(acquire_order.size()); + for (auto input : acquire_order) { - mutex* mu = GetTrainingVariableMutex(ctx, input); + Var* var; + mutex* mu = GetTrainingVariableMutex(ctx, input, &var); + core::ScopedUnref scoped_unref(var); if (mu != nullptr) { - locks.emplace_back(*mu); + locks->emplace_back(*mu); } } - return locks; + return VariableInputLockHolder(std::move(vars), std::move(locks)); } void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index 071cb371a7..9f173a80f7 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -23,9 +23,42 @@ limitations under the License. namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input); +// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. +// +// If `input` corresponds to a `DT_RESOURCE`-type variable input, +// `*maybe_resource` will be updated to contain the underlying resource, and the +// caller will be responsible for calling `Unref()` on that resource. +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, + Var** maybe_resource); -std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder( +// Utility structure that releases a sequence of borrowed mutexes when it is +// deleted. +struct VariableInputLockHolder { + public: + VariableInputLockHolder(std::vector<Var*> vars, + std::unique_ptr<std::vector<mutex_lock>> locks) + : vars_(std::move(vars)), locks_(std::move(locks)) {} + + VariableInputLockHolder(VariableInputLockHolder&& other) + : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {} + + ~VariableInputLockHolder() { + // Release the locks before unreffing the Vars, because each lock + // is potentially borrowed from a Var in vars_. + locks_.reset(); + for (Var* var : vars_) { + var->Unref(); + } + } + + private: + std::vector<Var*> vars_; + // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, + // because a `std::vector<mutex_lock>` is not movable on all platforms. + std::unique_ptr<std::vector<mutex_lock>> locks_; +}; + +VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids); void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 9a07ded17d..acf162deec 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -561,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - mutex* mu = GetTrainingVariableMutex(ctx, 0); + Var* resource; + mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource); + core::ScopedUnref scoped_unref(resource); if (use_exclusive_lock_ && mu != nullptr) { mutex_lock l1(*mu); // Don't try to acquire a lock on the second ref as they share the same @@ -710,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - mutex* mu = GetTrainingVariableMutex(ctx, 0); + Var* var; + mutex* mu = GetTrainingVariableMutex(ctx, 0, &var); + core::ScopedUnref scoped_unref(var); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. // diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc index 0f0f65c5a3..48e392c070 100644 --- a/tensorflow/core/kernels/transpose_op.cc +++ b/tensorflow/core/kernels/transpose_op.cc @@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, perm, out); } -#if defined(INTEL_MKL) +#if defined(INTEL_MKL) && defined(ENABLE_MKL) #define REGISTER(T) \ REGISTER_KERNEL_BUILDER(Name("Transpose") \ .Device(DEVICE_CPU) \ @@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx, .TypeConstraint<T>("T") \ .HostMemory("perm"), \ MklConjugateTransposeCpuOp); -TF_CALL_ALL_TYPES(REGISTER); -#undef REGISTER - -#else // INTEL_MKL +#else // INTEL_MKL && ENABLE_MKL #define REGISTER(T) \ REGISTER_KERNEL_BUILDER(Name("Transpose") \ .Device(DEVICE_CPU) \ @@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER); .TypeConstraint<T>("T") \ .HostMemory("perm"), \ ConjugateTransposeCpuOp); +#endif // INTEL_MKL && ENABLE_MKL + TF_CALL_ALL_TYPES(REGISTER) #undef REGISTER -#endif // INTEL_MKL #if GOOGLE_CUDA Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in, diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc new file mode 100644 index 0000000000..085e397eba --- /dev/null +++ b/tensorflow/core/kernels/unicode_script_op.cc @@ -0,0 +1,53 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "unicode/errorcode.h" // TF:icu +#include "unicode/uscript.h" // TF:icu +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class UnicodeScriptOp : public OpKernel { + public: + explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {} + + void Compute(OpKernelContext* context) override { + const Tensor* input_tensor; + OP_REQUIRES_OK(context, context->input("input", &input_tensor)); + const auto& input_flat = input_tensor->flat<int32>(); + + Tensor* output_tensor = nullptr; + OP_REQUIRES_OK(context, + context->allocate_output("output", input_tensor->shape(), + &output_tensor)); + auto output_flat = output_tensor->flat<int32>(); + + icu::ErrorCode status; + for (int i = 0; i < input_flat.size(); i++) { + UScriptCode script_code = uscript_getScript(input_flat(i), status); + if (status.isSuccess()) { + output_flat(i) = script_code; + } else { + output_flat(i) = -1; + status.reset(); + } + } + } +}; + +REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU), + UnicodeScriptOp); + +} // namespace tensorflow diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index 3559baa18e..3bdcfc90b8 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -108,7 +108,7 @@ class UniqueOp : public OpKernel { std::unordered_map<T, TIndex> uniq; uniq.reserve(2 * N); - for (int64 i = 0, j = 0; i < N; ++i) { + for (Eigen::Index i = 0, j = 0; i < N; ++i) { auto it = uniq.insert(std::make_pair(Tin(i), j)); idx_vec(i) = it.first->second; if (it.second) { @@ -131,19 +131,20 @@ class UniqueOp : public OpKernel { // General implementation when unique is run over multiple elements. auto Tin = input.shaped<T, 3>(new_sizes); - auto hash_fn = [&Tin](const int64& key) { + auto hash_fn = [&Tin](const Eigen::Index& key) { size_t h = 0; - for (int64 i = 0; i < Tin.dimension(0); i++) { - for (int64 j = 0; j < Tin.dimension(2); j++) { + for (Eigen::Index i = 0; i < Tin.dimension(0); i++) { + for (Eigen::Index j = 0; j < Tin.dimension(2); j++) { h = Hash64Combine(h, hash<T>{}(Tin(i, key, j))); } } return h; }; - auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) { - for (int64 i = 0; i < Tin.dimension(0); i++) { - for (int64 j = 0; j < Tin.dimension(2); j++) { + auto equal_to_fn = [&Tin](const Eigen::Index& lhs, + const Eigen::Index& rhs) { + for (Eigen::Index i = 0; i < Tin.dimension(0); i++) { + for (Eigen::Index j = 0; j < Tin.dimension(2); j++) { if (Tin(i, lhs, j) != Tin(i, rhs, j)) { return false; } diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 442686c92a..f55562ec99 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -133,6 +133,14 @@ Status TransposeShapeFn(InferenceContext* c) { } else { rank = perm->NumElements(); } + if (!c->RankKnown(input) && rank < 2) { + // A permutation array containing a single element is ambiguous. It could + // indicate either a scalar or a 1-dimensional array, both of which the + // transpose op returns unchanged. + c->set_output(0, input); + return Status::OK(); + } + std::vector<DimensionHandle> dims; dims.resize(rank); TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input)); @@ -1531,37 +1539,6 @@ REGISTER_OP("Size") .Attr("out_type: {int32, int64} = DT_INT32") .SetShapeFn(shape_inference::ScalarShape); -namespace { - -// This SliceHelper processes the output shape of the `slice` -// when the tensor of `sizes` is available. -template <typename T> -Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, - const Tensor* sizes_value, - std::vector<DimensionHandle>* dims) { - auto sizes_vec = sizes_value->vec<T>(); - for (int i = 0; i < sizes_value->NumElements(); ++i) { - DimensionHandle dim = c->Dim(c->input(0), i); - if (sizes_vec(i) != -1) { - auto dim_val = c->Value(dim); - if (sizes_vec(i) < 0) { - return errors::InvalidArgument( - "Out of bounds slicing on dimension ", i, " of length ", dim_val, - ": sizes vector cannot be < -1, but was ", sizes_vec(i)); - } - - dims->emplace_back(c->MakeDim(sizes_vec(i))); - } else { - DimensionHandle result; - TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result)); - dims->emplace_back(result); - } - } - - return Status::OK(); -} -} // namespace - // -------------------------------------------------------------------------- REGISTER_OP("Slice") .Input("input: T") @@ -1570,83 +1547,22 @@ REGISTER_OP("Slice") .Output("output: T") .Attr("T: type") .Attr("Index: {int32,int64}") - .SetShapeFn([](InferenceContext* c) { - ShapeHandle input = c->input(0); - ShapeHandle begin_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape)); - ShapeHandle sizes_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape)); - - // Merge to check compatibility of begin and sizes tensors. - TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape)); + .SetShapeFn(shape_inference::SliceShape); - DimensionHandle ndims = c->Dim(begin_shape, 0); - if (c->ValueKnown(ndims)) { - TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input)); - } - - // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known - // values, even though the `begin` value does not represent a shape. - ShapeHandle begin_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); - - // We check the tensor value here and will only use - // `MakeShapeFromShapeTensor` when `sizes_value` is null. - // The reason is that `sizes`might contain -1, which can't - // be represented (-1 in the ShapeHandle would mean "unknown". - const Tensor* sizes_value = c->input_tensor(2); - - if (sizes_value != nullptr) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, sizes_value->NumElements(), &begin_value)); - std::vector<DimensionHandle> dims; - // If the begin and sizes tensors are available, then - // we can be precise about the shape of the output. - if (sizes_value->dtype() == DT_INT64) { - TF_RETURN_IF_ERROR( - SliceHelper<int64>(c, begin_value, sizes_value, &dims)); - } else { - TF_RETURN_IF_ERROR( - SliceHelper<int32>(c, begin_value, sizes_value, &dims)); - } - - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } else { - // In case `sizes` is not available (`sizes_value` is null), - // we could try to use `MakeShapeFromShapeTensor` here. - // If sizes contain -1, we will simply consider it as `Unknown`. - // This is less than ideal but still an improvement of shape inference. - // The following is an example that returns [None, 1, None] with this - // code path: - // z = tf.zeros((1, 2, 3)) - // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) - // m.get_shape().as_list() - ShapeHandle sizes_value; - TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); - if (c->RankKnown(sizes_value)) { - TF_RETURN_IF_ERROR( - c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); - std::vector<DimensionHandle> dims; - dims.reserve(c->Rank(sizes_value)); - for (int i = 0; i < c->Rank(sizes_value); ++i) { - dims.emplace_back(c->Dim(sizes_value, i)); - } - c->set_output(0, c->MakeShape(dims)); - return Status::OK(); - } - - // We might know the rank of the input. - if (c->RankKnown(input)) { - c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); - return Status::OK(); - } else { - return shape_inference::UnknownShape(c); - } - } - - return Status::OK(); - }); +#ifdef INTEL_MKL +REGISTER_OP("_MklSlice") + .Input("input: T") + .Input("begin: Index") + .Input("size: Index") + .Input("mkl_input: uint8") + .Input("mkl_begin: uint8") + .Input("mkl_size: uint8") + .Output("output: T") + .Output("mkl_output: uint8") + .Attr("T: type") + .Attr("Index: {int32,int64}") + .SetShapeFn(shape_inference::SliceShape); +#endif REGISTER_OP("StridedSlice") .Input("input: T") diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc index 03dab390a7..1c29cd2491 100644 --- a/tensorflow/core/ops/array_ops_test.cc +++ b/tensorflow/core/ops/array_ops_test.cc @@ -975,6 +975,7 @@ TEST(ArrayOpsTest, Transpose_ShapeFn) { INFER_OK(op, "?;[2]", "[?,?]"); INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]"); INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]"); + INFER_OK(op, "?;[0]", "in0"); // Invalid arguments. perm = test::AsTensor<int32>({1, 2}); diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt index 86d4c6b421..0753316724 100644 --- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt @@ -21532,6 +21532,421 @@ op { } } op { + name: "ExperimentalAssertNextDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "transformations" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { + name: "ExperimentalCSVDataset" + input_arg { + name: "filenames" + type: DT_STRING + } + input_arg { + name: "compression_type" + type: DT_STRING + } + input_arg { + name: "buffer_size" + type: DT_INT64 + } + input_arg { + name: "header" + type: DT_BOOL + } + input_arg { + name: "field_delim" + type: DT_STRING + } + input_arg { + name: "use_quote_delim" + type: DT_BOOL + } + input_arg { + name: "na_value" + type: DT_STRING + } + input_arg { + name: "select_cols" + type: DT_INT64 + } + input_arg { + name: "record_defaults" + type_list_attr: "output_types" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalDirectedInterleaveDataset" + input_arg { + name: "selector_input_dataset" + type: DT_VARIANT + } + input_arg { + name: "data_input_datasets" + type: DT_VARIANT + number_attr: "N" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} +op { + name: "ExperimentalFunctionBufferingResource" + input_arg { + name: "string_arg" + type: DT_STRING + } + input_arg { + name: "target_device" + type: DT_STRING + } + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "shared_name" + type: "string" + } + attr { + name: "container" + type: "string" + } + attr { + name: "f" + type: "func" + } + attr { + name: "buffer_size" + type: "int" + } + attr { + name: "output_types" + type: "list(type)" + } + is_stateful: true +} +op { + name: "ExperimentalFunctionBufferingResourceGetNext" + input_arg { + name: "function_buffer_resource" + type: DT_RESOURCE + } + output_arg { + name: "output" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalFunctionBufferingResourceReset" + input_arg { + name: "function_buffer_resource" + type: DT_RESOURCE + } + is_stateful: true +} +op { + name: "ExperimentalIdentityIndexedDataset" + input_arg { + name: "size" + type: DT_UINT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + is_stateful: true +} +op { + name: "ExperimentalIgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { + name: "ExperimentalIndexedDatasetGet" + input_arg { + name: "materialized" + type: DT_RESOURCE + } + input_arg { + name: "index" + type: DT_UINT64 + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalIndexedDatasetMaterialize" + input_arg { + name: "dataset" + type: DT_VARIANT + } + input_arg { + name: "materialized" + type: DT_RESOURCE + } + is_stateful: true +} +op { + name: "ExperimentalIteratorGetDevice" + input_arg { + name: "resource" + type: DT_RESOURCE + } + output_arg { + name: "device" + type: DT_STRING + } + is_stateful: true +} +op { + name: "ExperimentalLMDBDataset" + input_arg { + name: "filenames" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalMaterializedIndexDatasetHandle" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + } + attr { + name: "shared_name" + type: "string" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalThreadPoolDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "thread_pool" + type: DT_RESOURCE + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalThreadPoolHandle" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "num_threads" + type: "int" + } + attr { + name: "max_intra_op_parallelism" + type: "int" + default_value { + i: 1 + } + } + attr { + name: "display_name" + type: "string" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { + name: "ExperimentalUniqueDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "Expm1" input_arg { name: "x" @@ -24105,6 +24520,158 @@ op { } } op { + name: "FusedBatchNorm" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "T" + } + input_arg { + name: "offset" + type_attr: "T" + } + input_arg { + name: "mean" + type_attr: "T" + } + input_arg { + name: "variance" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "batch_mean" + type_attr: "T" + } + output_arg { + name: "batch_variance" + type_attr: "T" + } + output_arg { + name: "reserve_space_1" + type_attr: "T" + } + output_arg { + name: "reserve_space_2" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { + name: "FusedBatchNormGrad" + input_arg { + name: "y_backprop" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "T" + } + input_arg { + name: "reserve_space_1" + type_attr: "T" + } + input_arg { + name: "reserve_space_2" + type_attr: "T" + } + output_arg { + name: "x_backprop" + type_attr: "T" + } + output_arg { + name: "scale_backprop" + type_attr: "T" + } + output_arg { + name: "offset_backprop" + type_attr: "T" + } + output_arg { + name: "reserve_space_3" + type_attr: "T" + } + output_arg { + name: "reserve_space_4" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { name: "FusedBatchNormGrad" input_arg { name: "y_backprop" @@ -24168,6 +24735,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -24345,6 +24918,179 @@ op { } } op { + name: "FusedBatchNormGradV2" + input_arg { + name: "y_backprop" + type_attr: "T" + } + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type: DT_FLOAT + } + input_arg { + name: "reserve_space_1" + type_attr: "U" + } + input_arg { + name: "reserve_space_2" + type_attr: "U" + } + output_arg { + name: "x_backprop" + type_attr: "T" + } + output_arg { + name: "scale_backprop" + type_attr: "U" + } + output_arg { + name: "offset_backprop" + type_attr: "U" + } + output_arg { + name: "reserve_space_3" + type_attr: "U" + } + output_arg { + name: "reserve_space_4" + type_attr: "U" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + } + } + } + attr { + name: "U" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { + name: "FusedBatchNormV2" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "scale" + type_attr: "U" + } + input_arg { + name: "offset" + type_attr: "U" + } + input_arg { + name: "mean" + type_attr: "U" + } + input_arg { + name: "variance" + type_attr: "U" + } + output_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "batch_mean" + type_attr: "U" + } + output_arg { + name: "batch_variance" + type_attr: "U" + } + output_arg { + name: "reserve_space_1" + type_attr: "U" + } + output_arg { + name: "reserve_space_2" + type_attr: "U" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + } + } + } + attr { + name: "U" + type: "type" + allowed_values { + list { + type: DT_FLOAT + } + } + } + attr { + name: "epsilon" + type: "float" + default_value { + f: 0.0001 + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + } + attr { + name: "is_training" + type: "bool" + default_value { + b: true + } + } +} +op { name: "FusedBatchNormV2" input_arg { name: "x" @@ -24392,6 +25138,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT } } @@ -24502,6 +25249,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -26317,6 +27070,52 @@ op { is_stateful: true } op { + name: "If" + input_arg { + name: "cond" + type_attr: "Tcond" + } + input_arg { + name: "input" + type_list_attr: "Tin" + } + output_arg { + name: "output" + type_list_attr: "Tout" + } + attr { + name: "Tcond" + type: "type" + } + attr { + name: "Tin" + type: "list(type)" + has_minimum: true + } + attr { + name: "Tout" + type: "list(type)" + has_minimum: true + } + attr { + name: "then_branch" + type: "func" + } + attr { + name: "else_branch" + type: "func" + } + attr { + name: "output_shapes" + type: "list(shape)" + default_value { + list { + } + } + } + is_stateful: true +} +op { name: "Igamma" input_arg { name: "a" @@ -29768,6 +30567,52 @@ op { } } op { + name: "MapDefun" + input_arg { + name: "arguments" + type_list_attr: "Targuments" + } + input_arg { + name: "captured_inputs" + type_list_attr: "Tcaptured" + } + output_arg { + name: "output" + type_list_attr: "output_types" + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Tcaptured" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "f" + type: "func" + } +} +op { name: "MapIncompleteSize" output_arg { name: "size" @@ -44518,6 +45363,59 @@ op { is_stateful: true } op { + name: "ReduceDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "initial_state" + type_list_attr: "Tstate" + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "f" + type: "func" + } + attr { + name: "Tstate" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "ReduceJoin" input_arg { name: "inputs" @@ -58933,6 +59831,14 @@ op { name: "stats_aggregator" type: DT_RESOURCE } + input_arg { + name: "tag" + type: DT_STRING + } + input_arg { + name: "counter_prefix" + type: DT_STRING + } output_arg { name: "handle" type: DT_VARIANT @@ -69991,6 +70897,62 @@ op { } } op { + name: "StatelessRandomNormal" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessRandomUniform" input_arg { name: "shape" @@ -70088,6 +71050,118 @@ op { } } op { + name: "StatelessRandomUniform" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { + name: "StatelessRandomUniformInt" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + input_arg { + name: "minval" + type_attr: "dtype" + } + input_arg { + name: "maxval" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessTruncatedNormal" input_arg { name: "shape" @@ -70185,6 +71259,62 @@ op { } } op { + name: "StatelessTruncatedNormal" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_HALF + type: DT_BFLOAT16 + type: DT_FLOAT + type: DT_DOUBLE + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessWhile" input_arg { name: "input" @@ -70984,6 +72114,48 @@ op { } } op { + name: "Substr" + input_arg { + name: "input" + type: DT_STRING + } + input_arg { + name: "pos" + type_attr: "T" + } + input_arg { + name: "len" + type_attr: "T" + } + output_arg { + name: "output" + type: DT_STRING + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "unit" + type: "string" + default_value { + s: "BYTE" + } + allowed_values { + list { + s: "BYTE" + s: "UTF8_CHAR" + } + } + } +} +op { name: "Sum" input_arg { name: "input" @@ -74573,6 +75745,17 @@ op { } } op { + name: "UnicodeScript" + input_arg { + name: "input" + type: DT_INT32 + } + output_arg { + name: "output" + type: DT_INT32 + } +} +op { name: "UniformCandidateSampler" input_arg { name: "true_classes" @@ -75981,6 +77164,39 @@ op { is_stateful: true } op { + name: "While" + input_arg { + name: "input" + type_list_attr: "T" + } + output_arg { + name: "output" + type_list_attr: "T" + } + attr { + name: "T" + type: "list(type)" + has_minimum: true + } + attr { + name: "cond" + type: "func" + } + attr { + name: "body" + type: "func" + } + attr { + name: "output_shapes" + type: "list(shape)" + default_value { + list { + } + } + } + is_stateful: true +} +op { name: "WholeFileReader" output_arg { name: "reader_handle" @@ -76283,6 +77499,62 @@ op { is_stateful: true } op { + name: "Xdivy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { + name: "Xlogy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { name: "ZerosLike" input_arg { name: "x" diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc index 1ada623cf5..ec22eee874 100644 --- a/tensorflow/core/ops/dataset_ops.cc +++ b/tensorflow/core/ops/dataset_ops.cc @@ -185,6 +185,8 @@ REGISTER_OP("ParseExampleDataset") REGISTER_OP("SetStatsAggregatorDataset") .Input("input_dataset: variant") .Input("stats_aggregator: resource") + .Input("tag: string") + .Input("counter_prefix: string") .Output("handle: variant") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") @@ -756,6 +758,19 @@ REGISTER_OP("DatasetToSingleElement") .Attr("output_shapes: list(shape) >= 1") .SetShapeFn(IteratorGetNextShapeFn); +REGISTER_OP("ReduceDataset") + .Input("input_dataset: variant") + .Input("initial_state: Tstate") + .Input("other_arguments: Targuments") + .Output("components: output_types") + .Attr("f: func") + .Attr("Tstate: list(type) >= 1") + .Attr("Targuments: list(type) >= 0") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("use_inter_op_parallelism: bool = true") + .SetShapeFn(IteratorGetNextShapeFn); + REGISTER_OP("IteratorToStringHandle") .Input("resource_handle: resource") .Output("string_handle: string") @@ -888,14 +903,18 @@ REGISTER_OP("ModelDataset") REGISTER_OP("MapDefun") .Input("arguments: Targuments") + .Input("captured_inputs: Tcaptured") .Output("output: output_types") .Attr("Targuments: list(type) >= 1") + .Attr("Tcaptured: list(type) >= 0 = []") .Attr("output_types: list(type) >= 1") .Attr("output_shapes: list(shape) >= 1") .Attr("f: func") .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector<PartialTensorShape> output_shapes; TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + DataTypeVector t_args; + TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args)); if (output_shapes.size() != c->num_outputs()) { return errors::InvalidArgument( "`output_shapes` must be the same length as `output_types` (", @@ -903,10 +922,11 @@ REGISTER_OP("MapDefun") } int64 dim_zero = -1; - for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) { + for (size_t i = 0; i < t_args.size(); ++i) { if (c->Rank(c->input(i)) == 0) { return errors::InvalidArgument( - "Inputs must have rank at least 1. Input ", i, " has rank of 0"); + "Arguments must have rank at least 1. Input ", i, + " has rank of 0."); } auto dim_handle = c->Dim(c->input(i), 0); if (c->ValueKnown(dim_handle)) { @@ -914,7 +934,7 @@ REGISTER_OP("MapDefun") dim_zero = c->Value(dim_handle); } else if (c->Value(dim_handle) != dim_zero) { return errors::InvalidArgument( - "Inputs must have the same dimension 0."); + "Arguments must have the same dimension 0."); } } } diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc new file mode 100644 index 0000000000..f6bd5dce26 --- /dev/null +++ b/tensorflow/core/ops/experimental_dataset_ops.cc @@ -0,0 +1,207 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("ExperimentalDirectedInterleaveDataset") + .Input("selector_input_dataset: variant") + .Input("data_input_datasets: N * variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .Attr("N: int >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("ExperimentalCSVDataset") + .Input("filenames: string") + .Input("compression_type: string") + .Input("buffer_size: int64") + .Input("header: bool") + .Input("field_delim: string") + .Input("use_quote_delim: bool") + .Input("na_value: string") + .Input("select_cols: int64") + .Input("record_defaults: output_types") + .Output("handle: variant") + .Attr("output_types: list({float,double,int32,int64,string}) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // `filenames` must be a scalar or a vector. + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused)); + // `compression_type`, `buffer_size`, `header`, `field_delim`, + // `use_quote_delim`, `na_value` must be scalars + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused)); + // `select_cols` must be a vector + TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused)); + // `record_defaults` must be lists of scalars + for (size_t i = 8; i < c->num_inputs(); ++i) { + shape_inference::ShapeHandle v; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v)); + if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) { + return errors::InvalidArgument( + "Shape of a default must be a length-0 or length-1 vector, or a " + "scalar."); + } + } + return shape_inference::ScalarShape(c); + }); + +REGISTER_OP("ExperimentalIgnoreErrorsDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("ExperimentalUniqueDataset") + .Input("input_dataset: variant") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("ExperimentalIteratorGetDevice") + .Input("resource: resource") + .Output("device: string") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("ExperimentalFunctionBufferingResource") + .Input("string_arg: string") + .Input("target_device: string") + .Output("resource: resource") + .Attr("shared_name: string") + .Attr("container: string") + .Attr("f: func") + .Attr("buffer_size: int") + .Attr("output_types: list(type)") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext") + .Input("function_buffer_resource: resource") + .Attr("output_types: list(type)") + .Output("output: output_types") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("ExperimentalFunctionBufferingResourceReset") + .Input("function_buffer_resource: resource") + .SetShapeFn(shape_inference::UnknownShape); + +REGISTER_OP("ExperimentalThreadPoolDataset") + .Input("input_dataset: variant") + .Input("thread_pool: resource") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("ExperimentalThreadPoolHandle") + .Output("handle: resource") + .SetShapeFn(shape_inference::ScalarShape) + .Attr("num_threads: int") + .Attr("max_intra_op_parallelism: int = 1") + .Attr("display_name: string") + .Attr("container: string = ''") + .Attr("shared_name: string = ''"); + +REGISTER_OP("ExperimentalAssertNextDataset") + .Input("input_dataset: variant") + .Input("transformations: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // transformations should be a vector. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); + return shape_inference::ScalarShape(c); + }); + +REGISTER_OP("ExperimentalLMDBDataset") + .Input("filenames: string") + .Output("handle: variant") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked + // stateful to inhibit constant folding. + .SetShapeFn(shape_inference::ScalarShape); + +REGISTER_OP("ExperimentalIdentityIndexedDataset") + .Input("size: uint64") + .Output("handle: variant") + .SetIsStateful() + .SetShapeFn( + shape_inference::ScalarShape); // TODO(saeta): check input shapes. + +/////////////////////////////////////////////////////////////////////////////// +// IndexedDataset Internals +/////////////////////////////////////////////////////////////////////////////// + +// Creates the handle. +REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle") + .Output("handle: resource") + .Attr("container: string") + .Attr("shared_name: string") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(shape_inference::ScalarShape); + +// Actually materialize the materialize handle. +REGISTER_OP("ExperimentalIndexedDatasetMaterialize") + .Input("dataset: variant") + .Input("materialized: resource") + .SetShapeFn(shape_inference::NoOutputs); + +namespace { + +Status GetShapeFn(shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as `output_types` (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); +} + +} // namespace + +REGISTER_OP("ExperimentalIndexedDatasetGet") + .Input("materialized: resource") + .Input("index: uint64") + .Output("components: output_types") + .Attr("output_types: list(type) >= 1") + .Attr("output_shapes: list(shape) >= 1") + .SetShapeFn(GetShapeFn); + +} // namespace tensorflow diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc index bda4a75c5d..22b4b07eff 100644 --- a/tensorflow/core/ops/functional_ops.cc +++ b/tensorflow/core/ops/functional_ops.cc @@ -110,8 +110,27 @@ REGISTER_OP("If") .Attr("Tout: list(type) >= 0") .Attr("then_branch: func") .Attr("else_branch: func") + .Attr("output_shapes: list(shape) = []") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape); + .SetShapeFn([](shape_inference::InferenceContext* c) { + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + // If `output_shapes` attr is set use that as the shapes of the outputs + // else return unknown shapes. + if (output_shapes.empty()) return shape_inference::UnknownShape(c); + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as num outputs (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + return Status::OK(); + }); // TODO(drpng): remove this. REGISTER_OP("_While") @@ -150,10 +169,29 @@ REGISTER_OP("While") .Attr("T: list(type) >= 0") .Attr("cond: func") .Attr("body: func") + .Attr("output_shapes: list(shape) = []") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { - for (int i = 0; i < c->num_outputs(); ++i) { - c->set_output(i, c->input(i)); + std::vector<PartialTensorShape> output_shapes; + TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes)); + // If `output_shapes` attr is set use that as the shapes of the outputs + // else use the input shapes. + if (!output_shapes.empty()) { + if (output_shapes.size() != c->num_outputs()) { + return errors::InvalidArgument( + "`output_shapes` must be the same length as num outputs (", + output_shapes.size(), " vs. ", c->num_outputs()); + } + for (size_t i = 0; i < output_shapes.size(); ++i) { + shape_inference::ShapeHandle output_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape( + output_shapes[i], &output_shape_handle)); + c->set_output(static_cast<int>(i), output_shape_handle); + } + } else { + for (int i = 0; i < c->num_outputs(); ++i) { + c->set_output(i, c->input(i)); + } } return Status::OK(); }); diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index 07f876cb90..55dcc50325 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -549,6 +549,40 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("Pow", PowGrad); +Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"zeros"}, "ZerosLike", {"x"}}, + {{"is_x_zero"}, "NotEqual", {"x", "zeros"}}, + {{"is_zero_cast"}, "Cast", {"is_x_zero"}, + {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}}, + {{"xlogygrad"}, "Xdivy", {"x", "y"}}, + {{"gx"}, "Mul", {"safe_logy", "dz"}}, + {{"gy"}, "Mul", {"xlogygrad", "dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Xlogy", XlogyGrad); + +Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) { + // clang-format off + return GradForBinaryCwise(g, { + {{"zeros"}, "ZerosLike", {"x"}}, + {{"is_x_zero"}, "NotEqual", {"x", "zeros"}}, + {{"is_zero_cast"}, "Cast", {"is_x_zero"}, + {{"SrcT", DT_BOOL}, {"DstT", "$T"}}}, + {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}}, + {{"y2"}, "Square", {"y"}}, + {{"negy2"}, "Neg", {"y2"}}, + {{"xdivygrad"}, "Xdivy", {"x", "negy2"}}, + {{"gx"}, "Mul", {"safe_divy", "dz"}}, + {{"gy"}, "Mul", {"xdivygrad", "dz"}}, + }); + // clang-format on +} +REGISTER_OP_GRADIENT("Xdivy", XdivyGrad); + Status MaximumMinimumGradHelper(const string& comparator, const AttrSlice& attrs, FunctionDef* g) { // clang-format off diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 5ee79809ac..9fc6b34147 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -909,6 +909,46 @@ TEST_F(MathGradTest, ComplexPow) { } #endif // TENSORFLOW_USE_SYCL +TEST_F(MathGradTest, Xlogy) { + auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f}, + TensorShape({2, 3})); + auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1})); + Tensor dx; + Tensor dy; + auto g = [](float x, float y) -> float { return x == 0. ? 0. : std::log(y); }; + auto h = [](float x, float y) -> float { return x == 0. ? 0. : x / y; }; + SymGrad("Xlogy", x, y, &dx, &dy); + test::ExpectClose( + dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f), + g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)}, + TensorShape({2, 3}))); + test::ExpectClose( + dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f), + h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)}, + TensorShape({2, 1}))); +} + +TEST_F(MathGradTest, Xdivy) { + auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f}, + TensorShape({2, 3})); + auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1})); + Tensor dx; + Tensor dy; + auto g = [](float x, float y) -> float { return x == 0. ? 0. : 1 / y; }; + auto h = [](float x, float y) -> float { + return x == 0. ? 0. : -x / (y * y); + }; + SymGrad("Xdivy", x, y, &dx, &dy); + test::ExpectClose( + dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f), + g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)}, + TensorShape({2, 3}))); + test::ExpectClose( + dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f), + h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)}, + TensorShape({2, 1}))); +} + TEST_F(MathGradTest, Maximum) { auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, TensorShape({2, 3})); diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc index 717263a9b0..a9e5e7824d 100644 --- a/tensorflow/core/ops/math_ops.cc +++ b/tensorflow/core/ops/math_ops.cc @@ -429,6 +429,20 @@ Returns (x - y)(x - y) element-wise. [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) )doc"); +REGISTER_OP("Xlogy") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {half, float, double, complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + +REGISTER_OP("Xdivy") + .Input("x: T") + .Input("y: T") + .Output("z: T") + .Attr("T: {half, float, double, complex64, complex128}") + .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn); + #undef BINARY_FEWER #undef BINARY_MORE @@ -1423,7 +1437,24 @@ REGISTER_OP("Bincount") .Attr("T: {int32, int64, float32, float64}") .Output("bins: T") .SetShapeFn([](InferenceContext* c) { - c->set_output(0, c->UnknownShapeOfRank(1)); + ShapeHandle unused; + // The input `size` must be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + + const Tensor* size_tensor = c->input_tensor(1); + if (size_tensor == nullptr) { + // Return unknown shape if size is not known. + c->set_output(0, c->UnknownShapeOfRank(1)); + return Status::OK(); + } + + // Return `[size]` shape if size is known. + int32 size_val = size_tensor->scalar<int32>()(); + if (size_val < 0) { + return errors::InvalidArgument("size (", size_val, + ") must be non-negative"); + } + c->set_output(0, c->MakeShape({size_val})); return Status::OK(); }); diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc index be4c3ed2b6..05379a7d69 100644 --- a/tensorflow/core/ops/math_ops_test.cc +++ b/tensorflow/core/ops/math_ops_test.cc @@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) { INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?"); INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]"); } + +TEST(MathOpsTest, Bincount_ShapeFn) { + ShapeInferenceTestOp op("Bincount"); + + // size should be scalar. + INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?"); + + INFER_OK(op, "?;?;?", "[?]"); + INFER_OK(op, "?;[];?", "[?]"); + INFER_OK(op, "[?];[];?", "[?]"); + INFER_OK(op, "[?];[];[?]", "[?]"); +} } // end namespace tensorflow diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 6191a88e5b..a9ca69ad86 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -178,7 +178,7 @@ REGISTER_OP("FusedBatchNorm") .Output("reserve_space_2: T") .Attr("T: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); @@ -196,7 +196,7 @@ REGISTER_OP("FusedBatchNormV2") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormShape); @@ -213,7 +213,7 @@ REGISTER_OP("FusedBatchNormGrad") .Output("reserve_space_4: T") .Attr("T: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormGradShape); @@ -231,7 +231,7 @@ REGISTER_OP("FusedBatchNormGradV2") .Attr("T: {half, bfloat16, float}") .Attr("U: {float}") .Attr("epsilon: float = 0.0001") - .Attr("data_format: string = 'NHWC'") + .Attr(GetConvnetDataFormatAttrString()) .Attr("is_training: bool = true") .SetShapeFn(shape_inference::FusedBatchNormGradShape); diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 3ae4f1a59e..2048ad26ac 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -10039,6 +10039,421 @@ op { } } op { + name: "ExperimentalAssertNextDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "transformations" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { + name: "ExperimentalCSVDataset" + input_arg { + name: "filenames" + type: DT_STRING + } + input_arg { + name: "compression_type" + type: DT_STRING + } + input_arg { + name: "buffer_size" + type: DT_INT64 + } + input_arg { + name: "header" + type: DT_BOOL + } + input_arg { + name: "field_delim" + type: DT_STRING + } + input_arg { + name: "use_quote_delim" + type: DT_BOOL + } + input_arg { + name: "na_value" + type: DT_STRING + } + input_arg { + name: "select_cols" + type: DT_INT64 + } + input_arg { + name: "record_defaults" + type_list_attr: "output_types" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_STRING + } + } + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalDirectedInterleaveDataset" + input_arg { + name: "selector_input_dataset" + type: DT_VARIANT + } + input_arg { + name: "data_input_datasets" + type: DT_VARIANT + number_attr: "N" + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "N" + type: "int" + has_minimum: true + minimum: 1 + } +} +op { + name: "ExperimentalFunctionBufferingResource" + input_arg { + name: "string_arg" + type: DT_STRING + } + input_arg { + name: "target_device" + type: DT_STRING + } + output_arg { + name: "resource" + type: DT_RESOURCE + } + attr { + name: "shared_name" + type: "string" + } + attr { + name: "container" + type: "string" + } + attr { + name: "f" + type: "func" + } + attr { + name: "buffer_size" + type: "int" + } + attr { + name: "output_types" + type: "list(type)" + } + is_stateful: true +} +op { + name: "ExperimentalFunctionBufferingResourceGetNext" + input_arg { + name: "function_buffer_resource" + type: DT_RESOURCE + } + output_arg { + name: "output" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalFunctionBufferingResourceReset" + input_arg { + name: "function_buffer_resource" + type: DT_RESOURCE + } + is_stateful: true +} +op { + name: "ExperimentalIdentityIndexedDataset" + input_arg { + name: "size" + type: DT_UINT64 + } + output_arg { + name: "handle" + type: DT_VARIANT + } + is_stateful: true +} +op { + name: "ExperimentalIgnoreErrorsDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { + name: "ExperimentalIndexedDatasetGet" + input_arg { + name: "materialized" + type: DT_RESOURCE + } + input_arg { + name: "index" + type: DT_UINT64 + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalIndexedDatasetMaterialize" + input_arg { + name: "dataset" + type: DT_VARIANT + } + input_arg { + name: "materialized" + type: DT_RESOURCE + } + is_stateful: true +} +op { + name: "ExperimentalIteratorGetDevice" + input_arg { + name: "resource" + type: DT_RESOURCE + } + output_arg { + name: "device" + type: DT_STRING + } + is_stateful: true +} +op { + name: "ExperimentalLMDBDataset" + input_arg { + name: "filenames" + type: DT_STRING + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalMaterializedIndexDatasetHandle" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "container" + type: "string" + } + attr { + name: "shared_name" + type: "string" + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalThreadPoolDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "thread_pool" + type: DT_RESOURCE + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + is_stateful: true +} +op { + name: "ExperimentalThreadPoolHandle" + output_arg { + name: "handle" + type: DT_RESOURCE + } + attr { + name: "num_threads" + type: "int" + } + attr { + name: "max_intra_op_parallelism" + type: "int" + default_value { + i: 1 + } + } + attr { + name: "display_name" + type: "string" + } + attr { + name: "container" + type: "string" + default_value { + s: "" + } + } + attr { + name: "shared_name" + type: "string" + default_value { + s: "" + } + } + is_stateful: true +} +op { + name: "ExperimentalUniqueDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + output_arg { + name: "handle" + type: DT_VARIANT + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } +} +op { name: "Expm1" input_arg { name: "x" @@ -11459,6 +11874,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -11532,6 +11953,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -11616,6 +12043,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -11700,6 +12133,12 @@ op { default_value { s: "NHWC" } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } } attr { name: "is_training" @@ -12737,6 +13176,14 @@ op { name: "else_branch" type: "func" } + attr { + name: "output_shapes" + type: "list(shape)" + default_value { + list { + } + } + } is_stateful: true } op { @@ -14883,6 +15330,10 @@ op { name: "arguments" type_list_attr: "Targuments" } + input_arg { + name: "captured_inputs" + type_list_attr: "Tcaptured" + } output_arg { name: "output" type_list_attr: "output_types" @@ -14894,6 +15345,15 @@ op { minimum: 1 } attr { + name: "Tcaptured" + type: "list(type)" + default_value { + list { + } + } + has_minimum: true + } + attr { name: "output_types" type: "list(type)" has_minimum: true @@ -22913,6 +23373,59 @@ op { is_stateful: true } op { + name: "ReduceDataset" + input_arg { + name: "input_dataset" + type: DT_VARIANT + } + input_arg { + name: "initial_state" + type_list_attr: "Tstate" + } + input_arg { + name: "other_arguments" + type_list_attr: "Targuments" + } + output_arg { + name: "components" + type_list_attr: "output_types" + } + attr { + name: "f" + type: "func" + } + attr { + name: "Tstate" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "Targuments" + type: "list(type)" + has_minimum: true + } + attr { + name: "output_types" + type: "list(type)" + has_minimum: true + minimum: 1 + } + attr { + name: "output_shapes" + type: "list(shape)" + has_minimum: true + minimum: 1 + } + attr { + name: "use_inter_op_parallelism" + type: "bool" + default_value { + b: true + } + } +} +op { name: "ReduceJoin" input_arg { name: "inputs" @@ -28171,6 +28684,14 @@ op { name: "stats_aggregator" type: DT_RESOURCE } + input_arg { + name: "tag" + type: DT_STRING + } + input_arg { + name: "counter_prefix" + type: DT_STRING + } output_arg { name: "handle" type: DT_VARIANT @@ -32525,6 +33046,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -32580,6 +33102,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -32613,6 +33136,62 @@ op { } } op { + name: "StatelessRandomUniformInt" + input_arg { + name: "shape" + type_attr: "T" + } + input_arg { + name: "seed" + type_attr: "Tseed" + } + input_arg { + name: "minval" + type_attr: "dtype" + } + input_arg { + name: "maxval" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "Tseed" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } +} +op { name: "StatelessTruncatedNormal" input_arg { name: "shape" @@ -32635,6 +33214,7 @@ op { allowed_values { list { type: DT_HALF + type: DT_BFLOAT16 type: DT_FLOAT type: DT_DOUBLE } @@ -33308,6 +33888,19 @@ op { } } } + attr { + name: "unit" + type: "string" + default_value { + s: "BYTE" + } + allowed_values { + list { + s: "BYTE" + s: "UTF8_CHAR" + } + } + } } op { name: "Sum" @@ -35693,6 +36286,17 @@ op { } } op { + name: "UnicodeScript" + input_arg { + name: "input" + type: DT_INT32 + } + output_arg { + name: "output" + type: DT_INT32 + } +} +op { name: "UniformCandidateSampler" input_arg { name: "true_classes" @@ -36500,6 +37104,14 @@ op { name: "body" type: "func" } + attr { + name: "output_shapes" + type: "list(shape)" + default_value { + list { + } + } + } is_stateful: true } op { @@ -36805,6 +37417,62 @@ op { is_stateful: true } op { + name: "Xdivy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { + name: "Xlogy" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } +} +op { name: "ZerosLike" input_arg { name: "x" diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 26499540f1..adc9cd1486 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -19,6 +19,7 @@ #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/lib/core/errors.h" using ::tensorflow::shape_inference::InferenceContext; using ::tensorflow::shape_inference::ShapeAndType; @@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) { return Status::OK(); } +Status ReadVariablesShapeFn(InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + DataTypeVector value_dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes)); + if (n != value_dtypes.size()) { + return errors::InvalidArgument( + "Mismatched number of arguments to ReadVariablesOp"); + } + for (int i = 0; i < n; ++i) { + ShapeAndType shape_and_type; + auto* handle_data = c->input_handle_shapes_and_types(i); + if (handle_data == nullptr || handle_data->empty()) { + shape_and_type.shape = c->UnknownShape(); + shape_and_type.dtype = DT_INVALID; + } else { + shape_and_type = (*handle_data)[0]; + if (shape_and_type.dtype != value_dtypes[i]) { + return errors::InvalidArgument( + "Trying to read variable with wrong dtype. " + "Expected ", + DataTypeString(shape_and_type.dtype), " got ", + DataTypeString(value_dtypes[i])); + } + } + c->set_output(i, shape_and_type.shape); + } + return Status::OK(); +} + } // namespace REGISTER_OP("VarHandleOp") @@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp") return Status::OK(); }); +REGISTER_OP("_VarHandlesOp") + .Attr("containers: list(string)") + .Attr("shared_names: list(string)") + .Attr("N: int >= 0") + .Attr("dtypes: list(type)") + .Attr("shapes: list(shape)") + .Output("resources: N * resource") + .SetIsStateful() + .SetShapeFn([](InferenceContext* c) { + int n; + TF_RETURN_IF_ERROR(c->GetAttr("N", &n)); + DataTypeVector dtypes; + TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes)); + std::vector<PartialTensorShape> shapes; + TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes)); + if (dtypes.size() != n) { + return errors::InvalidArgument("Mismatched number of dtypes (n=", n, + ", num dtypes=", dtypes.size(), ")"); + } + if (shapes.size() != n) { + return errors::InvalidArgument("Mismatched number of shapes (n=", n, + ", num shapes=", shapes.size(), ")"); + } + for (int i = 0; i < n; ++i) { + c->set_output(i, c->Scalar()); + ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s)); + c->set_output_handle_shapes_and_types( + i, std::vector<ShapeAndType>{{s, dtypes[i]}}); + } + + return Status::OK(); + }); + REGISTER_OP("ReadVariableOp") .Input("resource: resource") .Output("value: dtype") .Attr("dtype: type") .SetShapeFn(ReadVariableShapeFn); +REGISTER_OP("_ReadVariablesOp") + .Attr("N: int >= 0") + .Input("resources: N * resource") + .Output("values: dtypes") + .Attr("dtypes: list(type)") + .SetShapeFn(ReadVariablesShapeFn); + Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FunctionDefHelper::Define( diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc new file mode 100644 index 0000000000..331e1d0152 --- /dev/null +++ b/tensorflow/core/ops/stateless_random_grad.cc @@ -0,0 +1,23 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/function.h" + +namespace tensorflow { +REGISTER_OP_NO_GRADIENT("StatelessRandomUniform"); +REGISTER_OP_NO_GRADIENT("StatelessRandomNormal"); +REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal"); +REGISTER_OP_NO_GRADIENT("StatelessMultinomial"); +} // end namespace tensorflow diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc index 742709fb18..f919a21d60 100644 --- a/tensorflow/core/ops/stateless_random_ops.cc +++ b/tensorflow/core/ops/stateless_random_ops.cc @@ -19,42 +19,55 @@ limitations under the License. namespace tensorflow { using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -static Status StatelessShape(shape_inference::InferenceContext* context) { +static Status StatelessShape(InferenceContext* c) { // Check seed shape ShapeHandle seed; - TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed)); DimensionHandle unused; - TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused)); + TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused)); // Set output shape ShapeHandle out; - TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out)); - context->set_output(0, out); + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out)); + c->set_output(0, out); return Status::OK(); } -#define REGISTER_STATELESS_OP(name) \ - REGISTER_OP(name) \ - .Input("shape: T") \ - .Input("seed: Tseed") \ - .Output("output: dtype") \ - .Attr("dtype: {half,float,double} = DT_FLOAT") \ - .Attr("T: {int32, int64} = DT_INT32") \ - .Attr("Tseed: {int32, int64} = DT_INT64") \ +#define REGISTER_STATELESS_OP(name) \ + REGISTER_OP(name) \ + .Input("shape: T") \ + .Input("seed: Tseed") \ + .Output("output: dtype") \ + .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \ + .Attr("T: {int32, int64} = DT_INT32") \ + .Attr("Tseed: {int32, int64} = DT_INT64") \ .SetShapeFn(StatelessShape) -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomUniform"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessRandomNormal"); - -// This op is exposed through contrib/stateless only. The interface may change. REGISTER_STATELESS_OP("StatelessTruncatedNormal"); -// This op is exposed through contrib/stateless only. The interface may change. +#undef REGISTER_STATELESS_OP + +REGISTER_OP("StatelessRandomUniformInt") + .Input("shape: T") + .Input("seed: Tseed") + .Input("minval: dtype") + .Input("maxval: dtype") + .Output("output: dtype") + .Attr("dtype: {int32, int64}") + .Attr("T: {int32, int64}") + .Attr("Tseed: {int32, int64} = DT_INT64") + .SetShapeFn([](InferenceContext* c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused)); + return StatelessShape(c); + }); + REGISTER_OP("StatelessMultinomial") .Input("logits: T") .Input("num_samples: int32") @@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial") return Status::OK(); }); -#undef REGISTER_STATELESS_OP - } // namespace tensorflow diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc index da1d2a6432..94d71a4113 100644 --- a/tensorflow/core/ops/string_ops.cc +++ b/tensorflow/core/ops/string_ops.cc @@ -223,6 +223,7 @@ REGISTER_OP("Substr") .Input("len: T") .Output("output: string") .Attr("T: {int32, int64}") + .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") .SetShapeFn([](InferenceContext* c) { ShapeHandle pos_shape = c->input(1); ShapeHandle len_shape = c->input(2); @@ -244,4 +245,9 @@ REGISTER_OP("Substr") return shape_inference::BroadcastBinaryOpShapeFn(c); }); +REGISTER_OP("UnicodeScript") + .Input("input: int32") + .Output("output: int32") + .SetShapeFn(shape_inference::UnchangedShape); + } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc index f41b83ac34..affb68ebbb 100644 --- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc @@ -17,7 +17,6 @@ limitations under the License. #include <utility> #include "tensorflow/core/platform/cloud/curl_http_request.h" -#include "tensorflow/core/platform/cloud/retrying_utils.h" namespace tensorflow { @@ -25,21 +24,14 @@ namespace { // The URL to retrieve metadata when running in Google Compute Engine. constexpr char kGceMetadataBaseUrl[] = "http://metadata/computeMetadata/v1/"; -// The default initial delay between retries with exponential backoff. -constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec } // namespace ComputeEngineMetadataClient::ComputeEngineMetadataClient( - std::shared_ptr<HttpRequest::Factory> http_request_factory) - : ComputeEngineMetadataClient(std::move(http_request_factory), - kInitialRetryDelayUsec) {} - -ComputeEngineMetadataClient::ComputeEngineMetadataClient( std::shared_ptr<HttpRequest::Factory> http_request_factory, - int64 initial_retry_delay_usec) + const RetryConfig& config) : http_request_factory_(std::move(http_request_factory)), - initial_retry_delay_usec_(initial_retry_delay_usec) {} + retry_config_(config) {} Status ComputeEngineMetadataClient::GetMetadata( const string& path, std::vector<char>* response_buffer) { @@ -52,8 +44,7 @@ Status ComputeEngineMetadataClient::GetMetadata( return Status::OK(); }; - return RetryingUtils::CallWithRetries(get_metadata_from_gce, - initial_retry_delay_usec_); + return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_); } } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h index 534ccf30b2..7f060327da 100644 --- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/cloud/http_request.h" +#include "tensorflow/core/platform/cloud/retrying_utils.h" namespace tensorflow { @@ -31,10 +32,11 @@ namespace tensorflow { class ComputeEngineMetadataClient { public: explicit ComputeEngineMetadataClient( - std::shared_ptr<HttpRequest::Factory> http_request_factory); - ComputeEngineMetadataClient( std::shared_ptr<HttpRequest::Factory> http_request_factory, - int64 initial_retry_delay_usec); + const RetryConfig& config = RetryConfig( + 10000, /* init_delay_time_us = 1 ms */ + 1000000 /* max_delay_time_us = 1 s */ + )); virtual ~ComputeEngineMetadataClient() {} /// \brief Get the metadata value for a given attribute of the metadata @@ -54,7 +56,7 @@ class ComputeEngineMetadataClient { private: std::shared_ptr<HttpRequest::Factory> http_request_factory_; - const int64 initial_retry_delay_usec_; + const RetryConfig retry_config_; TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineMetadataClient); }; diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc index 4c41ccaa0e..e891b4a5e9 100644 --- a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc +++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc @@ -30,7 +30,8 @@ TEST(ComputeEngineMetadataClientTest, GetMetadata) { std::shared_ptr<HttpRequest::Factory> http_factory = std::make_shared<FakeHttpRequestFactory>(&requests); - ComputeEngineMetadataClient client(http_factory, 0); + ComputeEngineMetadataClient client(http_factory, + RetryConfig(0 /* init_delay_time_us */)); std::vector<char> result; TF_EXPECT_OK( @@ -56,7 +57,8 @@ TEST(ComputeEngineMetadataClientTest, RetryOnFailure) { std::shared_ptr<HttpRequest::Factory> http_factory = std::make_shared<FakeHttpRequestFactory>(&requests); - ComputeEngineMetadataClient client(http_factory, 0); + ComputeEngineMetadataClient client(http_factory, + RetryConfig(0 /* init_delay_time_us */)); std::vector<char> result; TF_EXPECT_OK( diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc index f7477eca23..476e4f9c1f 100644 --- a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc +++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc @@ -34,8 +34,8 @@ TEST_F(ComputeEngineZoneProviderTest, GetZone) { auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); - auto metadata_client = - std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0); + auto metadata_client = std::make_shared<ComputeEngineMetadataClient>( + httpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); ComputeEngineZoneProvider provider(metadata_client); @@ -55,8 +55,8 @@ TEST_F(ComputeEngineZoneProviderTest, InvalidZoneString) { auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); - auto metadata_client = - std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0); + auto metadata_client = std::make_shared<ComputeEngineMetadataClient>( + httpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); ComputeEngineZoneProvider provider(metadata_client); diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc index 83ea8539ed..c61b68aeeb 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system.cc @@ -333,14 +333,14 @@ class GcsWritableFile : public WritableFile { GcsFileSystem* filesystem, GcsFileSystem::TimeoutConfig* timeouts, std::function<void()> file_cache_erase, - int64 initial_retry_delay_usec) + RetryConfig retry_config) : bucket_(bucket), object_(object), filesystem_(filesystem), timeouts_(timeouts), file_cache_erase_(std::move(file_cache_erase)), sync_needed_(true), - initial_retry_delay_usec_(initial_retry_delay_usec) { + retry_config_(retry_config) { // TODO: to make it safer, outfile_ should be constructed from an FD if (GetTmpFilename(&tmp_content_filename_).ok()) { outfile_.open(tmp_content_filename_, @@ -357,14 +357,14 @@ class GcsWritableFile : public WritableFile { GcsFileSystem* filesystem, const string& tmp_content_filename, GcsFileSystem::TimeoutConfig* timeouts, std::function<void()> file_cache_erase, - int64 initial_retry_delay_usec) + RetryConfig retry_config) : bucket_(bucket), object_(object), filesystem_(filesystem), timeouts_(timeouts), file_cache_erase_(std::move(file_cache_erase)), sync_needed_(true), - initial_retry_delay_usec_(initial_retry_delay_usec) { + retry_config_(retry_config) { tmp_content_filename_ = tmp_content_filename; outfile_.open(tmp_content_filename_, std::ofstream::binary | std::ofstream::app); @@ -441,7 +441,7 @@ class GcsWritableFile : public WritableFile { first_attempt = false; return UploadToSession(session_uri, already_uploaded); }, - initial_retry_delay_usec_); + retry_config_); if (upload_status.code() == errors::Code::NOT_FOUND) { // GCS docs recommend retrying the whole upload. We're relying on the // RetryingFileSystem to retry the Sync() call. @@ -586,7 +586,7 @@ class GcsWritableFile : public WritableFile { GcsFileSystem::TimeoutConfig* timeouts_; std::function<void()> file_cache_erase_; bool sync_needed_; // whether there is buffered data that needs to be synced - int64 initial_retry_delay_usec_; + RetryConfig retry_config_; }; class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion { @@ -791,7 +791,7 @@ GcsFileSystem::GcsFileSystem( std::unique_ptr<ZoneProvider> zone_provider, size_t block_size, size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age, size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age, - size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec, + size_t matching_paths_cache_max_entries, RetryConfig retry_config, TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations, std::pair<const string, const string>* additional_header) : auth_provider_(std::move(auth_provider)), @@ -806,7 +806,7 @@ GcsFileSystem::GcsFileSystem( kCacheNeverExpire, kBucketLocationCacheMaxEntries)), allowed_locations_(allowed_locations), timeouts_(timeouts), - initial_retry_delay_usec_(initial_retry_delay_usec), + retry_config_(retry_config), additional_header_(additional_header) {} Status GcsFileSystem::NewRandomAccessFile( @@ -941,7 +941,7 @@ Status GcsFileSystem::NewWritableFile(const string& fname, TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); result->reset(new GcsWritableFile(bucket, object, this, &timeouts_, [this, fname]() { ClearFileCaches(fname); }, - initial_retry_delay_usec_)); + retry_config_)); return Status::OK(); } @@ -981,7 +981,7 @@ Status GcsFileSystem::NewAppendableFile(const string& fname, TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object)); result->reset(new GcsWritableFile( bucket, object, this, old_content_filename, &timeouts_, - [this, fname]() { ClearFileCaches(fname); }, initial_retry_delay_usec_)); + [this, fname]() { ClearFileCaches(fname); }, retry_config_)); return Status::OK(); } @@ -1534,7 +1534,7 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) { // on the server side, we can't just retry the whole RenameFile operation // because the source object is already gone. return RetryingUtils::DeleteWithRetries( - [this, &src]() { return DeleteFile(src); }, initial_retry_delay_usec_); + [this, &src]() { return DeleteFile(src); }, retry_config_); } Status GcsFileSystem::IsDirectory(const string& fname) { @@ -1590,8 +1590,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname, // and therefore RetryingFileSystem won't pay attention to the failures, // we need to make sure these failures are properly retried. const auto& delete_file_status = RetryingUtils::DeleteWithRetries( - [this, &full_path]() { return DeleteFile(full_path); }, - initial_retry_delay_usec_); + [this, &full_path]() { return DeleteFile(full_path); }, retry_config_); if (!delete_file_status.ok()) { if (IsDirectory(full_path).ok()) { // The object is a directory marker. diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h index 71db707687..d0840a3046 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system.h +++ b/tensorflow/core/platform/cloud/gcs_file_system.h @@ -93,7 +93,7 @@ class GcsFileSystem : public FileSystem { uint64 stat_cache_max_age, size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age, size_t matching_paths_cache_max_entries, - int64 initial_retry_delay_usec, TimeoutConfig timeouts, + RetryConfig retry_config, TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations, std::pair<const string, const string>* additional_header); @@ -332,7 +332,7 @@ class GcsFileSystem : public FileSystem { GcsStatsInterface* stats_ = nullptr; // Not owned. /// The initial delay for exponential backoffs when retrying failed calls. - const int64 initial_retry_delay_usec_ = 1000000L; + RetryConfig retry_config_; // Additional header material to be transmitted with all GCS requests std::unique_ptr<std::pair<const string, const string>> additional_header_; @@ -344,7 +344,8 @@ class GcsFileSystem : public FileSystem { class RetryingGcsFileSystem : public RetryingFileSystem<GcsFileSystem> { public: RetryingGcsFileSystem() - : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem)) {} + : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem), + RetryConfig(100000 /* init_delay_time_us */)) {} }; } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc index 14376ad339..702802b185 100644 --- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc +++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc @@ -24,6 +24,8 @@ namespace tensorflow { namespace { static GcsFileSystem::TimeoutConfig kTestTimeoutConfig(5, 1, 10, 20, 30); +static RetryConfig kTestRetryConfig(0 /* init_delay_time_us */); + // Default (empty) constraint config static std::unordered_set<string>* kAllowedLocationsDefault = new std::unordered_set<string>(); @@ -62,16 +64,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) { "Range: 6-11\n" "Timeouts: 5 1 20\n", "6789")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<RandomAccessFile> file; TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); @@ -108,9 +110,9 @@ TEST(GcsFileSystemTest, 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig, - *kAllowedLocationsAuto, nullptr /* gcs additional header */); + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsAuto, + nullptr /* gcs additional header */); std::unique_ptr<RandomAccessFile> file; TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); @@ -150,9 +152,9 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) { 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig, - *kAllowedLocationsAuto, nullptr /* gcs additional header */); + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsAuto, + nullptr /* gcs additional header */); std::unique_ptr<RandomAccessFile> file; @@ -191,9 +193,9 @@ TEST(GcsFileSystemTest, 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, - 0 /* initial retry delay */, kTestTimeoutConfig, - *kAllowedLocationsAuto, nullptr /* gcs additional header */); + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsAuto, + nullptr /* gcs additional header */); std::unique_ptr<RandomAccessFile> file; EXPECT_EQ(tensorflow::errors::FailedPrecondition( @@ -216,16 +218,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) { "Range: 3-12\n" "Timeouts: 5 1 20\n", "3456789")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<RandomAccessFile> file; TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file)); @@ -283,7 +285,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -372,7 +374,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -414,17 +416,17 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) { "Range: 8-15\n" "Timeouts: 5 1 20\n", "89abcdef")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */, - 16 /* max bytes */, 3600 /* max staleness */, - 3600 /* stat cache max age */, 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 8 /* block size */, 16 /* max bytes */, + 3600 /* max staleness */, 3600 /* stat cache max age */, + 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); char scratch[100]; StringPiece result; // There should only be two HTTP requests issued to GCS even though we iterate @@ -492,7 +494,7 @@ TEST(GcsFileSystemTest, std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */, 18 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -513,17 +515,17 @@ TEST(GcsFileSystemTest, TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), - 0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */, - 0 /* stat cache max age */, 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* read ahead bytes */, 0 /* max bytes */, + 0 /* max staleness */, 0 /* stat cache max age */, + 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<RandomAccessFile> file; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -547,16 +549,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) { "012")}); // Set stat_cache_max_age to 1000s so that StatCache could work. - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 1e3 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 1e3 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); // Stat the file first so that the file stats are cached. FileStatistics stat; @@ -621,7 +623,7 @@ TEST(GcsFileSystemTest, NewWritableFile) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */, 8 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -703,16 +705,16 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) { "Timeouts: 5 1 30\n" "Put body: t2\n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<WritableFile> file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -773,17 +775,17 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) { "Range: 0-7\n" "Timeouts: 5 1 20\n", "01234567")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */, - 8 /* max bytes */, 3600 /* max staleness */, - 3600 /* stat cache max age */, 0 /* stat cache max entries */, - 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 8 /* block size */, 8 /* max bytes */, + 3600 /* max staleness */, 3600 /* stat cache max age */, + 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); // Pull the file's first block into the cache. This will trigger the first // HTTP request to GCS. std::unique_ptr<RandomAccessFile> rfile; @@ -867,9 +869,9 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 2 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + 0 /* matching paths cache max entries */, + RetryConfig(2 /* .init_delay_time_us */), kTestTimeoutConfig, + *kAllowedLocationsDefault, nullptr /* gcs additional header */); std::unique_ptr<WritableFile> file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -918,16 +920,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { "Timeouts: 5 1 30\n" "Put body: content1,content2\n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<WritableFile> file; TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file)); @@ -948,16 +950,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) { TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<WritableFile> file; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1013,7 +1015,7 @@ TEST(GcsFileSystemTest, NewAppendableFile) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 32 /* block size */, 32 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -1041,16 +1043,16 @@ TEST(GcsFileSystemTest, NewAppendableFile) { TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<WritableFile> file; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1075,16 +1077,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { "Range: 0-", content.size() - 1, "\n", "Timeouts: 5 1 20\n"), content)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<ReadOnlyMemoryRegion> region; TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile( @@ -1096,16 +1098,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) { TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<ReadOnlyMemoryRegion> region; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1120,16 +1122,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt")); } @@ -1150,16 +1152,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subfolder/\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder")); } @@ -1176,16 +1178,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"size\": \"100\"}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.FileExists("gs://bucket1")); TF_EXPECT_OK(fs.FileExists("gs://bucket1/")); @@ -1206,16 +1208,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": []}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(errors::Code::NOT_FOUND, fs.FileExists("gs://bucket/path/file1.txt").code()); @@ -1233,16 +1235,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(errors::Code::INVALID_ARGUMENT, fs.FileExists("gs://bucket2/").code()); EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1279,7 +1281,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -1306,7 +1308,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -1322,16 +1324,16 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"prefixes\": [\"path/subpath/\"]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1350,16 +1352,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) { " { \"name\": \"path/file1.txt\" }," " { \"name\": \"path/file3.txt\" }]," "\"prefixes\": [\"path/subpath/\"]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1379,16 +1381,16 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) { " { \"name\": \"path/\" }," " { \"name\": \"path/file3.txt\" }]," "\"prefixes\": [\"path/subpath/\"]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1407,16 +1409,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) { " { \"name\": \"path/file1.txt\" }," " { \"name\": \"path/file3.txt\" }]," "\"prefixes\": [\"path/subpath/\"]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children)); @@ -1432,16 +1434,16 @@ TEST(GcsFileSystemTest, GetChildren_Root) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children)); @@ -1457,16 +1459,16 @@ TEST(GcsFileSystemTest, GetChildren_Empty) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children)); @@ -1498,16 +1500,16 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) { " { \"name\": \"path/file4.txt\" }," " { \"name\": \"path/file5.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> children; TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children)); @@ -1525,16 +1527,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> result; TF_EXPECT_OK( @@ -1553,16 +1555,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) { " { \"name\": \"path/file1.txt\" }," " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result)); @@ -1582,16 +1584,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) { " { \"name\": \"path/file1.txt\" }," " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result)); @@ -1608,16 +1610,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) { "{\"items\": [ " " { \"name\": \"path/\" }," " { \"name\": \"path/file3.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result)); @@ -1634,16 +1636,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { " { \"name\": \"path/file1.txt\" }," " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result)); @@ -1652,16 +1654,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) { TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::vector<string> result; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -1686,16 +1688,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) { " { \"name\": \"path/file1.txt\" }," " { \"name\": \"path/subpath/file2.txt\" }," " { \"name\": \"path/file3.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 3600 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 3600 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); // Repeated calls to fs.GetMatchingPaths on these patterns should not lead to // any additional HTTP requests to GCS. @@ -1729,16 +1731,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/subpath/file2.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 3600 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 3600 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); // This loop should trigger the first HTTP request to GCS. for (int i = 0; i < 10; i++) { @@ -1800,7 +1802,7 @@ TEST(GcsFileSystemTest, DeleteFile) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -1821,16 +1823,16 @@ TEST(GcsFileSystemTest, DeleteFile) { TEST(GcsFileSystemTest, DeleteFile_NoObjectName) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(errors::Code::INVALID_ARGUMENT, fs.DeleteFile("gs://bucket/").code()); @@ -1871,7 +1873,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -1894,16 +1896,16 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/")); } @@ -1923,16 +1925,16 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) { "Timeouts: 5 1 10\n" "Delete: yes\n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/")); } @@ -1943,16 +1945,16 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) { "name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.DeleteDir("gs://bucket")); } @@ -1965,16 +1967,16 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"path/file1.txt\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::FAILED_PRECONDITION, fs.DeleteDir("gs://bucket/path/").code()); @@ -1988,16 +1990,16 @@ TEST(GcsFileSystemTest, GetFileSize) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); uint64 size; TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size)); @@ -2006,16 +2008,16 @@ TEST(GcsFileSystemTest, GetFileSize) { TEST(GcsFileSystemTest, GetFileSize_NoObjectName) { std::vector<HttpRequest*> requests; - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); uint64 size; EXPECT_EQ(errors::Code::INVALID_ARGUMENT, @@ -2092,16 +2094,16 @@ TEST(GcsFileSystemTest, RenameFile_Folder) { "Timeouts: 5 1 10\n" "Delete: yes\n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/")); } @@ -2191,7 +2193,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */, 64 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); // Do an initial read of the source and destination files to load their @@ -2272,7 +2274,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); // Do an initial stat of the destination file to load their contents into the @@ -2332,16 +2334,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) { "Timeouts: 5 1 10\n" "Delete: yes\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK( fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt")); @@ -2374,16 +2376,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) { "Post: yes\n" "Timeouts: 5 1 10\n", "{\"done\": false}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ( errors::Code::UNIMPLEMENTED, @@ -2399,16 +2401,16 @@ TEST(GcsFileSystemTest, Stat_Object) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat)); @@ -2433,16 +2435,16 @@ TEST(GcsFileSystemTest, Stat_Folder) { "Timeouts: 5 1 10\n", "{\"items\": [ " " { \"name\": \"subfolder/\" }]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat)); @@ -2466,16 +2468,16 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); FileStatistics stat; EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code()); @@ -2487,16 +2489,16 @@ TEST(GcsFileSystemTest, Stat_Bucket) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat)); @@ -2511,16 +2513,16 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); FileStatistics stat; EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code()); @@ -2556,7 +2558,7 @@ TEST(GcsFileSystemTest, Stat_Cache) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); @@ -2598,7 +2600,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, nullptr /* gcs additional header */); // There should be a single HTTP request to GCS for fs.Stat in this loop. @@ -2628,16 +2630,16 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"5\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); FileStatistics stat; TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat)); @@ -2660,16 +2662,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/file.txt").code()); @@ -2691,16 +2693,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::FAILED_PRECONDITION, fs.IsDirectory("gs://bucket/file.txt").code()); @@ -2722,16 +2724,16 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{\"items\": [{\"name\": \"subfolder/\"}]}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder")); TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/")); @@ -2749,16 +2751,16 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.IsDirectory("gs://bucket")); TF_EXPECT_OK(fs.IsDirectory("gs://bucket/")); @@ -2770,16 +2772,16 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code()); } @@ -2812,16 +2814,16 @@ TEST(GcsFileSystemTest, CreateDir_Folder) { "Timeouts: 5 1 30\n" "Put body: \n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath")); TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/")); @@ -2839,16 +2841,16 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TF_EXPECT_OK(fs.CreateDir("gs://bucket/")); TF_EXPECT_OK(fs.CreateDir("gs://bucket")); @@ -2911,16 +2913,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) { "Timeouts: 5 1 10\n" "Delete: yes\n", "")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); int64 undeleted_files, undeleted_dirs; TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files, @@ -3004,16 +3006,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) { "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); int64 undeleted_files, undeleted_dirs; TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files, @@ -3039,16 +3041,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) { "Auth Token: fake_token\n" "Timeouts: 5 1 10\n", "", errors::NotFound("404"), 404)}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay*/, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); int64 undeleted_files, undeleted_dirs; EXPECT_EQ(error::Code::NOT_FOUND, @@ -3130,7 +3132,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) { std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, + 0 /* matching paths cache max entries */, kTestRetryConfig, kTestTimeoutConfig, *kAllowedLocationsDefault, add_header /* gcs additional header */); @@ -3199,16 +3201,16 @@ TEST(GcsFileSystemTest, CreateHttpRequest) { "Auth Token: fake_token\n" "Header Hello: world\n", "{}")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); std::unique_ptr<HttpRequest> request; TF_EXPECT_OK(fs.CreateHttpRequest(&request)); @@ -3262,16 +3264,16 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) { "Timeouts: 5 1 10\n", strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\"," "\"updated\": \"2016-04-29T23:15:24.896Z\"}"))}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TestGcsStats stats; fs.SetStats(&stats); @@ -3289,16 +3291,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) { "Range: 0-5\n" "Timeouts: 5 1 20\n", "012345")}); - GcsFileSystem fs( - std::unique_ptr<AuthProvider>(new FakeAuthProvider), - std::unique_ptr<HttpRequest::Factory>( - new FakeHttpRequestFactory(&requests)), - std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */, - 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */, - 0 /* stat cache max entries */, 0 /* matching paths cache max age */, - 0 /* matching paths cache max entries */, 0 /* initial retry delay */, - kTestTimeoutConfig, *kAllowedLocationsDefault, - nullptr /* gcs additional header */); + GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider), + std::unique_ptr<HttpRequest::Factory>( + new FakeHttpRequestFactory(&requests)), + std::unique_ptr<ZoneProvider>(new FakeZoneProvider), + 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */, + 0 /* stat cache max age */, 0 /* stat cache max entries */, + 0 /* matching paths cache max age */, + 0 /* matching paths cache max entries */, kTestRetryConfig, + kTestTimeoutConfig, *kAllowedLocationsDefault, + nullptr /* gcs additional header */); TestGcsStats stats; fs.SetStats(&stats); diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc index 07b88a880f..ec31c5ee8c 100644 --- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc +++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc @@ -93,8 +93,8 @@ TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) { std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); - auto metadataClient = - std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); + auto metadataClient = std::make_shared<ComputeEngineMetadataClient>( + fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), metadataClient, &env); oauth_client->return_token = "fake-token"; @@ -129,8 +129,8 @@ TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) { FakeEnv env; std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); - auto metadataClient = - std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); + auto metadataClient = std::make_shared<ComputeEngineMetadataClient>( + fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), metadataClient, &env); @@ -178,8 +178,8 @@ TEST_F(GoogleAuthProviderTest, RunningOnGCE) { FakeEnv env; std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); - auto metadataClient = - std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); + auto metadataClient = std::make_shared<ComputeEngineMetadataClient>( + fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), metadataClient, &env); @@ -206,8 +206,8 @@ TEST_F(GoogleAuthProviderTest, OverrideForTesting) { FakeEnv env; std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&empty_requests); - auto metadataClient = - std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); + auto metadataClient = std::make_shared<ComputeEngineMetadataClient>( + fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), metadataClient, &env); @@ -228,8 +228,8 @@ TEST_F(GoogleAuthProviderTest, NothingAvailable) { FakeEnv env; std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests); - auto metadataClient = - std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0); + auto metadataClient = std::make_shared<ComputeEngineMetadataClient>( + fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */)); GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client), metadataClient, &env); diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h index 941ab7ad65..5ce6670dc7 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system.h +++ b/tensorflow/core/platform/cloud/retrying_file_system.h @@ -34,9 +34,9 @@ template <typename Underlying> class RetryingFileSystem : public FileSystem { public: RetryingFileSystem(std::unique_ptr<Underlying> base_file_system, - int64 delay_microseconds = 1000000) + const RetryConfig& retry_config) : base_file_system_(std::move(base_file_system)), - initial_delay_microseconds_(delay_microseconds) {} + retry_config_(retry_config) {} Status NewRandomAccessFile( const string& filename, @@ -55,7 +55,7 @@ class RetryingFileSystem : public FileSystem { Status FileExists(const string& fname) override { return RetryingUtils::CallWithRetries( [this, &fname]() { return base_file_system_->FileExists(fname); }, - initial_delay_microseconds_); + retry_config_); } Status GetChildren(const string& dir, std::vector<string>* result) override { @@ -63,7 +63,7 @@ class RetryingFileSystem : public FileSystem { [this, &dir, result]() { return base_file_system_->GetChildren(dir, result); }, - initial_delay_microseconds_); + retry_config_); } Status GetMatchingPaths(const string& pattern, @@ -72,31 +72,31 @@ class RetryingFileSystem : public FileSystem { [this, &pattern, result]() { return base_file_system_->GetMatchingPaths(pattern, result); }, - initial_delay_microseconds_); + retry_config_); } Status Stat(const string& fname, FileStatistics* stat) override { return RetryingUtils::CallWithRetries( [this, &fname, stat]() { return base_file_system_->Stat(fname, stat); }, - initial_delay_microseconds_); + retry_config_); } Status DeleteFile(const string& fname) override { return RetryingUtils::DeleteWithRetries( [this, &fname]() { return base_file_system_->DeleteFile(fname); }, - initial_delay_microseconds_); + retry_config_); } Status CreateDir(const string& dirname) override { return RetryingUtils::CallWithRetries( [this, &dirname]() { return base_file_system_->CreateDir(dirname); }, - initial_delay_microseconds_); + retry_config_); } Status DeleteDir(const string& dirname) override { return RetryingUtils::DeleteWithRetries( [this, &dirname]() { return base_file_system_->DeleteDir(dirname); }, - initial_delay_microseconds_); + retry_config_); } Status GetFileSize(const string& fname, uint64* file_size) override { @@ -104,7 +104,7 @@ class RetryingFileSystem : public FileSystem { [this, &fname, file_size]() { return base_file_system_->GetFileSize(fname, file_size); }, - initial_delay_microseconds_); + retry_config_); } Status RenameFile(const string& src, const string& target) override { @@ -112,13 +112,13 @@ class RetryingFileSystem : public FileSystem { [this, &src, &target]() { return base_file_system_->RenameFile(src, target); }, - initial_delay_microseconds_); + retry_config_); } Status IsDirectory(const string& dirname) override { return RetryingUtils::CallWithRetries( [this, &dirname]() { return base_file_system_->IsDirectory(dirname); }, - initial_delay_microseconds_); + retry_config_); } Status DeleteRecursively(const string& dirname, int64* undeleted_files, @@ -128,7 +128,7 @@ class RetryingFileSystem : public FileSystem { return base_file_system_->DeleteRecursively(dirname, undeleted_files, undeleted_dirs); }, - initial_delay_microseconds_); + retry_config_); } void FlushCaches() override { base_file_system_->FlushCaches(); } @@ -137,7 +137,7 @@ class RetryingFileSystem : public FileSystem { private: std::unique_ptr<Underlying> base_file_system_; - const int64 initial_delay_microseconds_; + const RetryConfig retry_config_; TF_DISALLOW_COPY_AND_ASSIGN(RetryingFileSystem); }; @@ -147,9 +147,8 @@ namespace retrying_internals { class RetryingRandomAccessFile : public RandomAccessFile { public: RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file, - int64 delay_microseconds) - : base_file_(std::move(base_file)), - initial_delay_microseconds_(delay_microseconds) {} + const RetryConfig& retry_config) + : base_file_(std::move(base_file)), retry_config_(retry_config) {} Status Read(uint64 offset, size_t n, StringPiece* result, char* scratch) const override { @@ -157,20 +156,19 @@ class RetryingRandomAccessFile : public RandomAccessFile { [this, offset, n, result, scratch]() { return base_file_->Read(offset, n, result, scratch); }, - initial_delay_microseconds_); + retry_config_); } private: std::unique_ptr<RandomAccessFile> base_file_; - const int64 initial_delay_microseconds_; + const RetryConfig retry_config_; }; class RetryingWritableFile : public WritableFile { public: RetryingWritableFile(std::unique_ptr<WritableFile> base_file, - int64 delay_microseconds) - : base_file_(std::move(base_file)), - initial_delay_microseconds_(delay_microseconds) {} + const RetryConfig& retry_config) + : base_file_(std::move(base_file)), retry_config_(retry_config) {} ~RetryingWritableFile() override { // Makes sure the retrying version of Close() is called in the destructor. @@ -179,25 +177,24 @@ class RetryingWritableFile : public WritableFile { Status Append(StringPiece data) override { return RetryingUtils::CallWithRetries( - [this, &data]() { return base_file_->Append(data); }, - initial_delay_microseconds_); + [this, &data]() { return base_file_->Append(data); }, retry_config_); } Status Close() override { return RetryingUtils::CallWithRetries( - [this]() { return base_file_->Close(); }, initial_delay_microseconds_); + [this]() { return base_file_->Close(); }, retry_config_); } Status Flush() override { return RetryingUtils::CallWithRetries( - [this]() { return base_file_->Flush(); }, initial_delay_microseconds_); + [this]() { return base_file_->Flush(); }, retry_config_); } Status Sync() override { return RetryingUtils::CallWithRetries( - [this]() { return base_file_->Sync(); }, initial_delay_microseconds_); + [this]() { return base_file_->Sync(); }, retry_config_); } private: std::unique_ptr<WritableFile> base_file_; - const int64 initial_delay_microseconds_; + const RetryConfig retry_config_; }; } // namespace retrying_internals @@ -210,9 +207,9 @@ Status RetryingFileSystem<Underlying>::NewRandomAccessFile( [this, &filename, &base_file]() { return base_file_system_->NewRandomAccessFile(filename, &base_file); }, - initial_delay_microseconds_)); + retry_config_)); result->reset(new retrying_internals::RetryingRandomAccessFile( - std::move(base_file), initial_delay_microseconds_)); + std::move(base_file), retry_config_)); return Status::OK(); } @@ -224,9 +221,9 @@ Status RetryingFileSystem<Underlying>::NewWritableFile( [this, &filename, &base_file]() { return base_file_system_->NewWritableFile(filename, &base_file); }, - initial_delay_microseconds_)); + retry_config_)); result->reset(new retrying_internals::RetryingWritableFile( - std::move(base_file), initial_delay_microseconds_)); + std::move(base_file), retry_config_)); return Status::OK(); } @@ -238,9 +235,9 @@ Status RetryingFileSystem<Underlying>::NewAppendableFile( [this, &filename, &base_file]() { return base_file_system_->NewAppendableFile(filename, &base_file); }, - initial_delay_microseconds_)); + retry_config_)); result->reset(new retrying_internals::RetryingWritableFile( - std::move(base_file), initial_delay_microseconds_)); + std::move(base_file), retry_config_)); return Status::OK(); } @@ -252,7 +249,7 @@ Status RetryingFileSystem<Underlying>::NewReadOnlyMemoryRegionFromFile( return base_file_system_->NewReadOnlyMemoryRegionFromFile(filename, result); }, - initial_delay_microseconds_); + retry_config_); } } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc index 5910fef1d2..868eea096c 100644 --- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc +++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc @@ -184,7 +184,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->random_access_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped random access file. std::unique_ptr<RandomAccessFile> random_access_file; @@ -211,7 +212,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_SuccessWith3rdTry) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->random_access_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped random access file. std::unique_ptr<RandomAccessFile> random_access_file; @@ -235,7 +237,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->random_access_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped random access file. std::unique_ptr<RandomAccessFile> random_access_file; @@ -265,7 +268,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_NoRetriesForSomeErrors) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->random_access_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped random access file. std::unique_ptr<RandomAccessFile> random_access_file; @@ -291,7 +295,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_ImmediateSuccess) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->writable_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped writable file. std::unique_ptr<WritableFile> writable_file; @@ -317,7 +322,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->writable_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped writable file. std::unique_ptr<WritableFile> writable_file; @@ -343,7 +349,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry_ViaDestructor) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->writable_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped writable file. std::unique_ptr<WritableFile> writable_file; @@ -368,7 +375,8 @@ TEST(RetryingFileSystemTest, NewAppendableFile_SuccessWith3rdTry) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->writable_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped appendable file. std::unique_ptr<WritableFile> writable_file; @@ -391,7 +399,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_AllRetriesFailed) { std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); base_fs->writable_file_to_return = std::move(base_file); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); // Retrieve the wrapped writable file. std::unique_ptr<WritableFile> writable_file; @@ -412,7 +421,8 @@ TEST(RetryingFileSystemTest, std::make_tuple("NewReadOnlyMemoryRegionFromFile", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::unique_ptr<ReadOnlyMemoryRegion> result; TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile("filename.txt", &result)); @@ -423,7 +433,8 @@ TEST(RetryingFileSystemTest, NewReadOnlyMemoryRegionFromFile_AllRetriesFailed) { CreateRetriableErrors("NewReadOnlyMemoryRegionFromFile", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::unique_ptr<ReadOnlyMemoryRegion> result; const auto& status = @@ -440,7 +451,8 @@ TEST(RetryingFileSystemTest, GetChildren_SuccessWith2ndTry) { std::make_tuple("GetChildren", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; TF_EXPECT_OK(fs.GetChildren("gs://path", &result)); @@ -450,7 +462,8 @@ TEST(RetryingFileSystemTest, GetChildren_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetChildren", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; const auto& status = fs.GetChildren("gs://path", &result); @@ -466,7 +479,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_SuccessWith2ndTry) { std::make_tuple("GetMatchingPaths", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; TF_EXPECT_OK(fs.GetMatchingPaths("gs://path/dir", &result)); @@ -477,7 +491,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_AllRetriesFailed) { CreateRetriableErrors("GetMatchingPaths", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; const auto& status = fs.GetMatchingPaths("gs://path/dir", &result); @@ -492,7 +507,8 @@ TEST(RetryingFileSystemTest, DeleteFile_SuccessWith2ndTry) { std::make_tuple("DeleteFile", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; TF_EXPECT_OK(fs.DeleteFile("gs://path/file.txt")); @@ -502,7 +518,8 @@ TEST(RetryingFileSystemTest, DeleteFile_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteFile", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; const auto& status = fs.DeleteFile("gs://path/file.txt"); @@ -517,7 +534,8 @@ TEST(RetryingFileSystemTest, CreateDir_SuccessWith2ndTry) { std::make_tuple("CreateDir", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; TF_EXPECT_OK(fs.CreateDir("gs://path/newdir")); @@ -527,7 +545,8 @@ TEST(RetryingFileSystemTest, CreateDir_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("CreateDir", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; const auto& status = fs.CreateDir("gs://path/newdir"); @@ -542,7 +561,8 @@ TEST(RetryingFileSystemTest, DeleteDir_SuccessWith2ndTry) { std::make_tuple("DeleteDir", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; TF_EXPECT_OK(fs.DeleteDir("gs://path/dir")); @@ -552,7 +572,8 @@ TEST(RetryingFileSystemTest, DeleteDir_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteDir", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); std::vector<string> result; const auto& status = fs.DeleteDir("gs://path/dir"); @@ -568,7 +589,8 @@ TEST(RetryingFileSystemTest, GetFileSize_SuccessWith2ndTry) { std::make_tuple("GetFileSize", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); uint64 size; TF_EXPECT_OK(fs.GetFileSize("gs://path/file.txt", &size)); @@ -578,7 +600,8 @@ TEST(RetryingFileSystemTest, GetFileSize_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetFileSize", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); uint64 size; const auto& status = fs.GetFileSize("gs://path/file.txt", &size); @@ -593,7 +616,8 @@ TEST(RetryingFileSystemTest, RenameFile_SuccessWith2ndTry) { std::make_tuple("RenameFile", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); TF_EXPECT_OK(fs.RenameFile("old_name", "new_name")); } @@ -602,7 +626,8 @@ TEST(RetryingFileSystemTest, RenameFile_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("RenameFile", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); const auto& status = fs.RenameFile("old_name", "new_name"); EXPECT_TRUE( @@ -616,7 +641,8 @@ TEST(RetryingFileSystemTest, Stat_SuccessWith2ndTry) { std::make_tuple("Stat", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); FileStatistics stat; TF_EXPECT_OK(fs.Stat("file_name", &stat)); @@ -626,7 +652,8 @@ TEST(RetryingFileSystemTest, Stat_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("Stat", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); FileStatistics stat; const auto& status = fs.Stat("file_name", &stat); @@ -639,7 +666,8 @@ TEST(RetryingFileSystemTest, FileExists_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("FileExists", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); const auto& status = fs.FileExists("file_name"); EXPECT_TRUE( @@ -653,7 +681,8 @@ TEST(RetryingFileSystemTest, FileExists_SuccessWith2ndTry) { std::make_tuple("FileExists", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); TF_EXPECT_OK(fs.FileExists("gs://path/dir")); } @@ -665,7 +694,8 @@ TEST(RetryingFileSystemTest, IsDirectory_SuccessWith2ndTry) { std::make_tuple("IsDirectory", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); TF_EXPECT_OK(fs.IsDirectory("gs://path/dir")); } @@ -674,7 +704,8 @@ TEST(RetryingFileSystemTest, IsDirectory_AllRetriesFailed) { ExpectedCalls expected_fs_calls = CreateRetriableErrors("IsDirectory", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); const auto& status = fs.IsDirectory("gs://path/dir"); EXPECT_TRUE( @@ -689,7 +720,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_SuccessWith2ndTry) { std::make_tuple("DeleteRecursively", Status::OK())}); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); int64 undeleted_files, undeleted_dirs; TF_EXPECT_OK( @@ -701,7 +733,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) { CreateRetriableErrors("DeleteRecursively", 11); std::unique_ptr<MockFileSystem> base_fs( new MockFileSystem(expected_fs_calls)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); int64 undeleted_files, undeleted_dirs; const auto& status = @@ -715,7 +748,8 @@ TEST(RetryingFileSystemTest, FlushCaches) { ExpectedCalls none; bool flushed = false; std::unique_ptr<MockFileSystem> base_fs(new MockFileSystem(none, &flushed)); - RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0); + RetryingFileSystem<MockFileSystem> fs( + std::move(base_fs), RetryConfig(0 /* init_delay_time_us */)); fs.FlushCaches(); EXPECT_TRUE(flushed); } diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc index d2df422024..cb0aecdd35 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.cc +++ b/tensorflow/core/platform/cloud/retrying_utils.cc @@ -23,11 +23,6 @@ namespace tensorflow { namespace { -// In case of failure, every call will be retried kMaxRetries times. -constexpr int kMaxRetries = 10; -// Maximum backoff time in microseconds. -constexpr int64 kMaximumBackoffMicroseconds = 32000000; // 32 seconds. - bool IsRetriable(error::Code code) { switch (code) { case error::UNAVAILABLE: @@ -43,40 +38,41 @@ bool IsRetriable(error::Code code) { } // namespace Status RetryingUtils::CallWithRetries(const std::function<Status()>& f, - const int64 initial_delay_microseconds) { - return CallWithRetries(f, initial_delay_microseconds, [](int64 micros) { - return Env::Default()->SleepForMicroseconds(micros); - }); + const RetryConfig& config) { + return CallWithRetries( + f, + [](int64 micros) { return Env::Default()->SleepForMicroseconds(micros); }, + config); } Status RetryingUtils::CallWithRetries( - const std::function<Status()>& f, const int64 initial_delay_microseconds, - const std::function<void(int64)>& sleep_usec) { + const std::function<Status()>& f, + const std::function<void(int64)>& sleep_usec, const RetryConfig& config) { int retries = 0; while (true) { auto status = f(); if (!IsRetriable(status.code())) { return status; } - if (retries >= kMaxRetries) { + if (retries >= config.max_retries) { // Return AbortedError, so that it doesn't get retried again somewhere // at a higher level. return Status( error::ABORTED, strings::StrCat( - "All ", kMaxRetries, + "All ", config.max_retries, " retry attempts failed. The last failure: ", status.ToString())); } int64 delay_micros = 0; - if (initial_delay_microseconds > 0) { + if (config.init_delay_time_us > 0) { const int64 random_micros = random::New64() % 1000000; - delay_micros = std::min(initial_delay_microseconds << retries, - kMaximumBackoffMicroseconds) + + delay_micros = std::min(config.init_delay_time_us << retries, + config.max_delay_time_us) + random_micros; } LOG(INFO) << "The operation failed and will be automatically retried in " << (delay_micros / 1000000.0) << " seconds (attempt " - << (retries + 1) << " out of " << kMaxRetries + << (retries + 1) << " out of " << config.max_retries << "), caused by: " << status.ToString(); sleep_usec(delay_micros); retries++; @@ -84,8 +80,7 @@ Status RetryingUtils::CallWithRetries( } Status RetryingUtils::DeleteWithRetries( - const std::function<Status()>& delete_func, - const int64 initial_delay_microseconds) { + const std::function<Status()>& delete_func, const RetryConfig& config) { bool is_retried = false; return RetryingUtils::CallWithRetries( [delete_func, &is_retried]() { @@ -96,7 +91,7 @@ Status RetryingUtils::DeleteWithRetries( is_retried = true; return status; }, - initial_delay_microseconds); + config); } } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h index 546b8d1c4a..1a7ce1b122 100644 --- a/tensorflow/core/platform/cloud/retrying_utils.h +++ b/tensorflow/core/platform/cloud/retrying_utils.h @@ -21,6 +21,26 @@ limitations under the License. namespace tensorflow { +// Default time before reporting failure: ~100 seconds. +struct RetryConfig { + RetryConfig(int64 init_delay_time_us = 100 * 1000, + int64 max_delay_time_us = 32 * 1000 * 1000, + int max_retries = 10) { + this->init_delay_time_us = init_delay_time_us; + this->max_delay_time_us = max_delay_time_us; + this->max_retries = max_retries; + } + + // In case of failure, every call will be retried max_retries times. + int max_retries; + + // Initial backoff time + int64 init_delay_time_us; + + // Maximum backoff time in microseconds. + int64 max_delay_time_us; +}; + class RetryingUtils { public: /// \brief Retries the function in case of failure with exponential backoff. @@ -31,18 +51,19 @@ class RetryingUtils { /// retries. /// If all retries failed, returns the last error status. static Status CallWithRetries(const std::function<Status()>& f, - const int64 initial_delay_microseconds); + const RetryConfig& config); + /// sleep_usec is a function that sleeps for the given number of microseconds. static Status CallWithRetries(const std::function<Status()>& f, - const int64 initial_delay_microseconds, - const std::function<void(int64)>& sleep_usec); + const std::function<void(int64)>& sleep_usec, + const RetryConfig& config); /// \brief A retrying wrapper for a function that deletes a resource. /// /// The function takes care of the scenario when a delete operation /// returns a failure but succeeds under the hood: if a retry returns /// NOT_FOUND, the whole operation is considered a success. static Status DeleteWithRetries(const std::function<Status()>& delete_func, - const int64 initial_delay_microseconds); + const RetryConfig& config); }; } // namespace tensorflow diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc index 1b6527618a..75fe8a98f4 100644 --- a/tensorflow/core/platform/cloud/retrying_utils_test.cc +++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc @@ -30,7 +30,8 @@ TEST(RetryingUtilsTest, CallWithRetries_RetryDelays) { }; std::function<Status()> f = []() { return errors::Unavailable("Failed."); }; - const auto& status = RetryingUtils::CallWithRetries(f, 500000L, sleep); + const auto& status = RetryingUtils::CallWithRetries( + f, sleep, RetryConfig(500000 /* init_delay_time_us */)); EXPECT_EQ(errors::Code::ABORTED, status.code()); EXPECT_TRUE(str_util::StrContains( status.error_message(), @@ -60,8 +61,10 @@ TEST(RetryingUtilsTest, CallWithRetries_NotFoundIsNotRetried) { results.erase(results.begin()); return result; }; - EXPECT_EQ(errors::Code::NOT_FOUND, - RetryingUtils::CallWithRetries(f, 0).code()); + EXPECT_EQ( + errors::Code::NOT_FOUND, + RetryingUtils::CallWithRetries(f, RetryConfig(0 /* init_delay_time_us */)) + .code()); } TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) { @@ -74,7 +77,8 @@ TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) { results.erase(results.begin()); return result; }; - TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 1.0, sleep)); + TF_EXPECT_OK(RetryingUtils::CallWithRetries( + f, sleep, RetryConfig(1L /* init_delay_time_us */))); } TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) { @@ -86,7 +90,8 @@ TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) { results.erase(results.begin()); return result; }; - TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 0)); + TF_EXPECT_OK(RetryingUtils::CallWithRetries( + f, RetryConfig(0 /* init_delay_time_us */))); } TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) { @@ -96,7 +101,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) { delete_results.erase(delete_results.begin()); return result; }; - TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0)); + TF_EXPECT_OK(RetryingUtils::DeleteWithRetries( + delete_func, RetryConfig(0 /* init_delay_time_us */))); } TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) { @@ -106,7 +112,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) { delete_results.erase(delete_results.begin()); return result; }; - TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0)); + TF_EXPECT_OK(RetryingUtils::DeleteWithRetries( + delete_func, RetryConfig(0 /* init_delay_time_us */))); } TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) { @@ -118,7 +125,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) { return result; }; EXPECT_EQ(errors::Code::PERMISSION_DENIED, - RetryingUtils::DeleteWithRetries(delete_func, 0).code()); + RetryingUtils::DeleteWithRetries( + delete_func, RetryConfig(0 /* init_delay_time_us */)) + .code()); } TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) { @@ -129,7 +138,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) { delete_results.erase(delete_results.begin()); return result; }; - TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0)); + TF_EXPECT_OK(RetryingUtils::DeleteWithRetries( + delete_func, RetryConfig(0 /* init_delay_time_us */))); } TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) { @@ -140,7 +150,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) { return result; }; EXPECT_EQ(error::NOT_FOUND, - RetryingUtils::DeleteWithRetries(delete_func, 0).code()); + RetryingUtils::DeleteWithRetries( + delete_func, RetryConfig(0 /* init_delay_time_us */)) + .code()); } } // namespace diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl index bb841aeab7..d884c1aa7c 100644 --- a/tensorflow/core/platform/default/build_config.bzl +++ b/tensorflow/core/platform/default/build_config.bzl @@ -615,11 +615,7 @@ def tf_kernel_tests_linkstatic(): def tf_additional_lib_defines(): """Additional defines needed to build TF libraries.""" - return select({ - "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"], - "//tensorflow:with_jemalloc_linux_ppc64le": ["TENSORFLOW_USE_JEMALLOC"], - "//conditions:default": [], - }) + return [] def tf_additional_lib_deps(): """Additional dependencies needed to build TF libraries.""" @@ -631,64 +627,45 @@ def tf_additional_lib_deps(): ] + if_static( ["@nsync//:nsync_cpp"], ["@nsync//:nsync_headers"], - ) + select({ - "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"], - "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"], - "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"], - "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"], - "//conditions:default": [], - }) + ) def tf_additional_core_deps(): return select({ - "//tensorflow:with_gcp_support_android_override": [], - "//tensorflow:with_gcp_support_ios_override": [], - "//tensorflow:with_gcp_support": [ + "//tensorflow:android": [], + "//tensorflow:windows": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//conditions:default": [ "//tensorflow/core/platform/cloud:gcs_file_system", - ], - "//conditions:default": [], - }) + select({ - "//tensorflow:with_hdfs_support_windows_override": [], - "//tensorflow:with_hdfs_support_android_override": [], - "//tensorflow:with_hdfs_support_ios_override": [], - "//tensorflow:with_hdfs_support": [ - "//tensorflow/core/platform/hadoop:hadoop_file_system", - ], - "//conditions:default": [], - }) + select({ - "//tensorflow:with_aws_support_windows_override": [], - "//tensorflow:with_aws_support_android_override": [], - "//tensorflow:with_aws_support_ios_override": [], - "//tensorflow:with_aws_support": [ "//tensorflow/core/platform/s3:s3_file_system", + "//tensorflow/core/platform/hadoop:hadoop_file_system", ], - "//conditions:default": [], }) # TODO(jart, jhseu): Delete when GCP is default on. def tf_additional_cloud_op_deps(): return select({ - "//tensorflow:with_gcp_support_windows_override": [], - "//tensorflow:with_gcp_support_android_override": [], - "//tensorflow:with_gcp_support_ios_override": [], - "//tensorflow:with_gcp_support": [ + "//tensorflow:android": [], + "//tensorflow:windows": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//conditions:default": [ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib", "//tensorflow/contrib/cloud:gcs_config_ops_op_lib", ], - "//conditions:default": [], }) # TODO(jart, jhseu): Delete when GCP is default on. def tf_additional_cloud_kernel_deps(): return select({ - "//tensorflow:with_gcp_support_windows_override": [], - "//tensorflow:with_gcp_support_android_override": [], - "//tensorflow:with_gcp_support_ios_override": [], - "//tensorflow:with_gcp_support": [ + "//tensorflow:android": [], + "//tensorflow:windows": [], + "//tensorflow:ios": [], + "//tensorflow:linux_s390x": [], + "//conditions:default": [ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", "//tensorflow/contrib/cloud/kernels:gcs_config_ops", ], - "//conditions:default": [], }) def tf_lib_proto_parsing_deps(): @@ -738,11 +715,7 @@ def tf_additional_binary_deps(): "//tensorflow/stream_executor:cuda_platform", "//tensorflow/core/platform/default/build_config:cuda", ], - ) + select({ - "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"], - "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"], - "//conditions:default": [], - }) + [ + ) + [ # TODO(allenl): Split these out into their own shared objects (they are # here because they are shared between contrib/ op shared objects and # core). diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h index 5b237c4736..5732271f15 100644 --- a/tensorflow/core/platform/env.h +++ b/tensorflow/core/platform/env.h @@ -228,6 +228,10 @@ class Env { /// |suffix|. Returns true if success. bool CreateUniqueFileName(string* prefix, const string& suffix); + /// \brief Return the runfiles directory if running under bazel. Returns + /// the directory the executable is located in if not running under bazel. + virtual string GetRunfilesDir() = 0; + // TODO(jeff,sanjay): Add back thread/thread-pool support if needed. // TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or // provide a routine to get the absolute time. @@ -360,6 +364,8 @@ class EnvWrapper : public Env { return target_->FormatLibraryFileName(name, version); } + string GetRunfilesDir() override { return target_->GetRunfilesDir(); } + private: void GetLocalTempDirectories(std::vector<string>* list) override { target_->GetLocalTempDirectories(list); diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc index 418874d340..af95d8201e 100644 --- a/tensorflow/core/platform/posix/env.cc +++ b/tensorflow/core/platform/posix/env.cc @@ -119,6 +119,17 @@ class PosixEnv : public Env { return tensorflow::internal::FormatLibraryFileName(name, version); } + string GetRunfilesDir() override { + string bin_path = this->GetExecutablePath(); + string runfiles_path = bin_path + ".runfiles/org_tensorflow"; + Status s = this->IsDirectory(runfiles_path); + if (!s.ok()) { + return runfiles_path; + } else { + return bin_path.substr(0, bin_path.find_last_of("/\\")); + } + } + private: void GetLocalTempDirectories(std::vector<string>* list) override; }; diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc index b46b9927cd..acdd7798ea 100644 --- a/tensorflow/core/platform/posix/port.cc +++ b/tensorflow/core/platform/posix/port.cc @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef TENSORFLOW_USE_JEMALLOC -#include "jemalloc/jemalloc.h" -#endif - #include "absl/base/internal/sysinfo.h" #include "tensorflow/core/platform/cpu_info.h" @@ -101,11 +97,7 @@ void* AlignedMalloc(size_t size, int minimum_alignment) { // memory aligned to at least the size of a pointer. const int required_alignment = sizeof(void*); if (minimum_alignment < required_alignment) return Malloc(size); -#ifdef TENSORFLOW_USE_JEMALLOC - int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size); -#else int err = posix_memalign(&ptr, minimum_alignment, size); -#endif if (err != 0) { return nullptr; } else { @@ -116,29 +108,11 @@ void* AlignedMalloc(size_t size, int minimum_alignment) { void AlignedFree(void* aligned_memory) { Free(aligned_memory); } -void* Malloc(size_t size) { -#ifdef TENSORFLOW_USE_JEMALLOC - return jemalloc_malloc(size); -#else - return malloc(size); -#endif -} +void* Malloc(size_t size) { return malloc(size); } -void* Realloc(void* ptr, size_t size) { -#ifdef TENSORFLOW_USE_JEMALLOC - return jemalloc_realloc(ptr, size); -#else - return realloc(ptr, size); -#endif -} +void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); } -void Free(void* ptr) { -#ifdef TENSORFLOW_USE_JEMALLOC - jemalloc_free(ptr); -#else - free(ptr); -#endif -} +void Free(void* ptr) { free(ptr); } void* NUMAMalloc(int node, size_t size, int minimum_alignment) { return AlignedMalloc(size, minimum_alignment); @@ -146,9 +120,7 @@ void* NUMAMalloc(int node, size_t size, int minimum_alignment) { void NUMAFree(void* ptr, size_t size) { Free(ptr); } -int NUMAGetMemAffinity(const void* addr) { - return kNUMANoAffinity; -} +int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; } void MallocExtension_ReleaseToSystem(std::size_t num_bytes) { // No-op. diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc index 68ee3595a2..f26ccd1662 100644 --- a/tensorflow/core/platform/windows/env.cc +++ b/tensorflow/core/platform/windows/env.cc @@ -160,6 +160,17 @@ class WindowsEnv : public Env { return filename; } + string GetRunfilesDir() override { + string bin_path = this->GetExecutablePath(); + string runfiles_path = bin_path + ".runfiles\\org_tensorflow"; + Status s = this->IsDirectory(runfiles_path); + if (!s.ok()) { + return runfiles_path; + } else { + return bin_path.substr(0, bin_path.find_last_of("/\\")); + } + } + private: void GetLocalTempDirectories(std::vector<string>* list) override; diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc index 5375f56372..911ea1902f 100644 --- a/tensorflow/core/platform/windows/port.cc +++ b/tensorflow/core/platform/windows/port.cc @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifdef TENSORFLOW_USE_JEMALLOC -#include "jemalloc/jemalloc.h" -#endif - #include <stdio.h> #include <stdlib.h> #include <string.h> @@ -70,55 +66,16 @@ void NUMASetThreadNodeAffinity(int node) {} int NUMAGetThreadNodeAffinity() { return kNUMANoAffinity; } void* AlignedMalloc(size_t size, int minimum_alignment) { -#ifdef TENSORFLOW_USE_JEMALLOC - void* ptr = NULL; - // posix_memalign requires that the requested alignment be at least - // sizeof(void*). In this case, fall back on malloc which should return - // memory aligned to at least the size of a pointer. - const int required_alignment = sizeof(void*); - if (minimum_alignment < required_alignment) return Malloc(size); - int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size); - if (err != 0) { - return NULL; - } else { - return ptr; - } -#else return _aligned_malloc(size, minimum_alignment); -#endif } -void AlignedFree(void* aligned_memory) { -#ifdef TENSORFLOW_USE_JEMALLOC - jemalloc_free(aligned_memory); -#else - _aligned_free(aligned_memory); -#endif -} +void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); } -void* Malloc(size_t size) { -#ifdef TENSORFLOW_USE_JEMALLOC - return jemalloc_malloc(size); -#else - return malloc(size); -#endif -} +void* Malloc(size_t size) { return malloc(size); } -void* Realloc(void* ptr, size_t size) { -#ifdef TENSORFLOW_USE_JEMALLOC - return jemalloc_realloc(ptr, size); -#else - return realloc(ptr, size); -#endif -} +void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); } -void Free(void* ptr) { -#ifdef TENSORFLOW_USE_JEMALLOC - return jemalloc_free(ptr); -#else - return free(ptr); -#endif -} +void Free(void* ptr) { return free(ptr); } void* NUMAMalloc(int node, size_t size, int minimum_alignment) { return AlignedMalloc(size, minimum_alignment); diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD index af034bdd7d..2bf371276e 100644 --- a/tensorflow/core/profiler/BUILD +++ b/tensorflow/core/profiler/BUILD @@ -40,7 +40,6 @@ tf_proto_library( name = "protos_all", srcs = glob(["**/*.proto"]), cc_api_version = 2, - java_api_version = 2, protodeps = tf_additional_all_protos(), visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto index 85cd02350a..104ab039cb 100644 --- a/tensorflow/core/protobuf/config.proto +++ b/tensorflow/core/protobuf/config.proto @@ -453,6 +453,11 @@ message RunOptions { // same group_key value (in a distributed computation where tasks // run disjoint graphs). int64 collective_graph_key = 1; + // If true, then operations (using the inter-op pool) across all + // session::run() calls will be centrally scheduled, optimizing for (median + // and tail) latency. + // Consider using this option for CPU-bound workloads like inference. + bool use_run_handler_pool = 2; }; Experimental experimental = 8; diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto index bb8f88336d..8c31468ff5 100644 --- a/tensorflow/core/protobuf/rewriter_config.proto +++ b/tensorflow/core/protobuf/rewriter_config.proto @@ -75,8 +75,10 @@ message RewriterConfig { // Try to allocate some independent Op outputs contiguously in order to // merge or eliminate downstream Ops (off by default). Toggle scoped_allocator_optimization = 15; - // Force small ops onto the CPU (default is ON). + // Force small ops onto the CPU (default is OFF). Toggle pin_to_host_optimization = 18; + // Disable the entire meta optimizer (off by default). + bool disable_meta_optimizer = 19; // Controls how many times we run the optimizers in meta optimizer (default // is once). @@ -143,8 +145,8 @@ message RewriterConfig { // not configurable (in contrast to memory optimization passes through the // meta-optimizer) and act only on manual op annotations. // - // Custom registered optimizers will be run after the base optimizers, in - // the order that they are specified. + // Custom optimizers (see custom_optimizers) that are not part of this + // schedule will be run after - in the order that they were specified. repeated string optimizers = 100; // Message to describe custom graph optimizer and its parameters diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h index cf7ffd8149..04aaea4f89 100644 --- a/tensorflow/core/util/mkl_util.h +++ b/tensorflow/core/util/mkl_util.h @@ -2039,8 +2039,8 @@ class MklPrimitiveFactory { /// Fuction to check whether primitive memory optimization is enabled static inline bool IsPrimitiveMemOptEnabled() { bool is_primitive_mem_opt_enabled = true; - TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true, - &is_primitive_mem_opt_enabled)); + TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true, + &is_primitive_mem_opt_enabled)); return is_primitive_mem_opt_enabled; } @@ -2095,9 +2095,8 @@ static inline memory::format get_desired_format(int channel, fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c; } else if (port::TestCPUFeature(port::CPUFeature::AVX2) && (channel % 8) == 0) { - fmt_desired = is_2d - ? memory::format::nChw8c - : memory::format::ncdhw; //not support avx2 for 3d yet. + fmt_desired = is_2d ? memory::format::nChw8c + : memory::format::ncdhw; // no avx2 support for 3d yet. } else { fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw; } @@ -2209,7 +2208,8 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) { // utility function to determine if it is conv 1x1 and stride != 1 // for purpose of temporarily disabling primitive reuse -inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) { +inline bool IsConv1x1StrideNot1(memory::dims filter_dims, + memory::dims strides) { if (filter_dims.size() != 4 || strides.size() != 2) return false; return ((filter_dims[2] == 1) && (filter_dims[3] == 1) && diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc index c081ceae57..e01058dff6 100644 --- a/tensorflow/core/util/port.cc +++ b/tensorflow/core/util/port.cc @@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() { } bool IsMklEnabled() { -#ifdef INTEL_MKL +#if defined(INTEL_MKL) && defined(ENABLE_MKL) return true; #else return false; -#endif +#endif // INTEL_MKL && ENABLE_MKL } } // end namespace tensorflow diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD index 648358606c..f40ec9b752 100644 --- a/tensorflow/core/util/tensor_bundle/BUILD +++ b/tensorflow/core/util/tensor_bundle/BUILD @@ -64,6 +64,11 @@ cc_library( tf_cc_test( name = "tensor_bundle_test", srcs = ["tensor_bundle_test.cc"], + data = glob(["testdata/**"]), + tags = [ + "nomsan", + "notsan", + ], deps = [ ":tensor_bundle", "//tensorflow/core:framework", diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc index ea8a259d1a..2dcb57a1f9 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc @@ -64,27 +64,36 @@ namespace { // Reads "num_elements" string elements from file[offset, offset+size) into the // length-N "destination". Discards the original content of "destination". // -// Checksums the string lengths (as restored uint32, not varint32 bytes) and -// string bytes, and stores it into "actual_crc32c". +// Checksums the string lengths (as restored uint32 or uint64, not varint64 +// bytes) and string bytes, and stores it into "actual_crc32c". Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, size_t offset, size_t size, string* destination, uint32* actual_crc32c) { if (size == 0) return Status::OK(); CHECK_GT(size, 0); - // Reads "num_elements" varint32's from "buffered_file". + // Reads "num_elements" varint64's from "buffered_file". TF_RETURN_IF_ERROR(buffered_file->Seek(offset)); - std::vector<uint32> string_lengths(num_elements); + std::vector<uint64> string_lengths(num_elements); for (size_t i = 0; i < num_elements; ++i) { - TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i])); + TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i])); + if (string_lengths[i] <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]); + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32), + sizeof(uint32)); + } else { + *actual_crc32c = crc32c::Extend( + *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]), + sizeof(uint64)); + } } if (offset + size < buffered_file->Tell()) { return errors::DataLoss("String lengths longer than expected offset ", offset + size); } - *actual_crc32c = - crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()), - sizeof(uint32) * num_elements); // Reads the length-checksum. uint32 length_checksum = 0; @@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements, // Reads the actual string bytes. for (size_t i = 0; i < num_elements; ++i) { - const uint32 string_length = string_lengths[i]; + const uint64 string_length = string_lengths[i]; string* buffer = &destination[i]; buffer->resize(string_length); @@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out, Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out, size_t* bytes_written, uint32* crc32c) { // On-disk format: - // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes] - // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes), + // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes] + // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes), // the length-checksum, and all the string bytes. DCHECK_EQ(val.dtype(), DT_STRING); const string* strings = GetStringBackingBuffer(val); @@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out, *crc32c = 0; for (int64 i = 0; i < val.NumElements(); ++i) { const string* elem = &strings[i]; - DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size())); - const uint32 elem_size = static_cast<uint32>(elem->size()); - - core::PutVarint32(&lengths, elem_size); - *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size), - sizeof(uint32)); + DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size())); + const uint64 elem_size = static_cast<uint64>(elem->size()); + + core::PutVarint64(&lengths, elem_size); + if (elem_size <= UINT32_MAX) { + // We need to do this because older checkpoints only used uint32s and we + // should still support them. + const uint32 elem_size_uint32 = static_cast<uint32>(elem_size); + *crc32c = crc32c::Extend(*crc32c, + reinterpret_cast<const char*>(&elem_size_uint32), + sizeof(uint32)); + } else { + *crc32c = crc32c::Extend( + *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64)); + } } TF_RETURN_IF_ERROR(out->Append(lengths)); *bytes_written = lengths.size(); diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc index 59c42baa06..9567e4750b 100644 --- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc +++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc @@ -39,6 +39,11 @@ string Prefix(const string& prefix) { return strings::StrCat(testing::TmpDir(), "/", prefix); } +string TestdataPrefix(const string& prefix) { + return strings::StrCat(testing::TensorFlowSrcRoot(), + "/core/util/tensor_bundle/testdata/", prefix); +} + template <typename T> Tensor Constant(T v, TensorShape shape) { Tensor ret(DataTypeToEnum<T>::value, shape); @@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) { TestNonStandardShapes<qint8>(); } +TEST(TensorBundleTest, StringTensorsOldFormat) { + // Test string tensor bundle made with previous version of code that use + // varint32s to store string lengths (we now use varint64s). + BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo")); + TF_ASSERT_OK(reader.status()); + EXPECT_EQ(AllTensorKeys(&reader), + std::vector<string>({"floats", "scalar", "string_tensor", "strs"})); + + Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1}))); + Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"})); + Expect<string>( + &reader, "strs", + test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')})); + Expect<float>(&reader, "floats", Constant_2x3<float>(16.18)); +} + TEST(TensorBundleTest, StringTensors) { + constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1; + Tensor long_string_tensor(DT_STRING, TensorShape({1})); + { BundleWriter writer(Env::Default(), Prefix("foo")); TF_EXPECT_OK(writer.Add("string_tensor", @@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) { TF_EXPECT_OK(writer.Add( "strs", test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}))); + + // Requires a 64-bit length. + string* backing_string = long_string_tensor.flat<string>().data(); + backing_string->assign(kLongLength, 'd'); + TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor)); + // Mixes in some floats. TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18))); TF_ASSERT_OK(writer.Finish()); @@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) { { BundleReader reader(Env::Default(), Prefix("foo")); TF_ASSERT_OK(reader.status()); - EXPECT_EQ( - AllTensorKeys(&reader), - std::vector<string>({"floats", "scalar", "string_tensor", "strs"})); + EXPECT_EQ(AllTensorKeys(&reader), + std::vector<string>({"floats", "long_scalar", "scalar", + "string_tensor", "strs"})); Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1}))); @@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) { Expect<string>( &reader, "strs", test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})); + Expect<float>(&reader, "floats", Constant_2x3<float>(16.18)); + + // We don't use the Expect function so we can re-use the + // `long_string_tensor` buffer for reading out long_scalar to keep memory + // usage reasonable. + EXPECT_TRUE(reader.Contains("long_scalar")); + DataType dtype; + TensorShape shape; + TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape)); + EXPECT_EQ(DT_STRING, dtype); + EXPECT_EQ(TensorShape({1}), shape); + + // Zero-out the string so that we can be sure the new one is read in. + string* backing_string = long_string_tensor.flat<string>().data(); + backing_string->assign(""); + + // Read long_scalar and check it contains kLongLength 'd's. + TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor)); + ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data()); + EXPECT_EQ(kLongLength, backing_string->length()); + for (char c : *backing_string) { + // Not using ASSERT_EQ('d', c) because this way is twice as fast due to + // compiler optimizations. + if (c != 'd') { + FAIL() << "long_scalar is not full of 'd's as expected."; + break; + } + } } } diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README new file mode 100644 index 0000000000..428d3ef79e --- /dev/null +++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README @@ -0,0 +1,3 @@ +This tensor bundle was generated from cl/214343133, before string tensor +lengths were written as varint64s. This is here to check backwards +compatibility between the new code and old checkpoints. diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 Binary files differnew file mode 100644 index 0000000000..23b488e5fe --- /dev/null +++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index Binary files differnew file mode 100644 index 0000000000..a22a69e6e1 --- /dev/null +++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc index 1e5a9c5712..489999d1e8 100644 --- a/tensorflow/core/util/util.cc +++ b/tensorflow/core/util/util.cc @@ -120,4 +120,20 @@ string SliceDebugString(const TensorShape& shape, const int64 flat) { return result; } +#ifdef INTEL_MKL +bool DisableMKL() { + enum MklStatus { MKL_DEFAULT = 0, MKL_ON = 1, MKL_OFF = 2 }; + static MklStatus status = MKL_DEFAULT; + if (status == MKL_DEFAULT) { + char* tf_disable_mkl = getenv("TF_DISABLE_MKL"); + if ((tf_disable_mkl != NULL) && (std::stoi(tf_disable_mkl) == 1)) { + VLOG(2) << "TF-MKL: Disabling MKL"; + status = MKL_OFF; + } else { + status = MKL_ON; + } + } + return status == MKL_OFF ? true : false; +} +#endif // INTEL_MKL } // namespace tensorflow diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h index 93dfd51ab5..4aa47aa48a 100644 --- a/tensorflow/core/util/util.h +++ b/tensorflow/core/util/util.h @@ -56,6 +56,11 @@ string PrintMemory(const char* ptr, size_t n); // "tensor", "tensor[i]", "tensor[i, j]", etc. string SliceDebugString(const TensorShape& shape, const int64 flat); +// disable MKL in runtime +#ifdef INTEL_MKL +bool DisableMKL(); +#endif // INTEL_MKL + } // namespace tensorflow #endif // TENSORFLOW_CORE_UTIL_UTIL_H_ |