aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD53
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt46
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt6
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc1
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc10
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc82
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc57
-rw-r--r--tensorflow/core/common_runtime/direct_session.h23
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc86
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD7
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h6
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/context.h2
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc67
-rw-r--r--tensorflow/core/common_runtime/eval_const_tensor.cc18
-rw-r--r--tensorflow/core/common_runtime/executor.cc4
-rw-r--r--tensorflow/core/common_runtime/executor.h6
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc52
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op_test.cc4
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h1
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc4
-rw-r--r--tensorflow/core/common_runtime/process_util.cc31
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc75
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc83
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc11
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc107
-rw-r--r--tensorflow/core/framework/common_shape_fns.h3
-rw-r--r--tensorflow/core/framework/dataset.h34
-rw-r--r--tensorflow/core/framework/function.cc8
-rw-r--r--tensorflow/core/framework/function.h5
-rw-r--r--tensorflow/core/framework/function_testlib.cc17
-rw-r--r--tensorflow/core/framework/model.cc83
-rw-r--r--tensorflow/core/framework/model.h42
-rw-r--r--tensorflow/core/framework/node_def_util.h1
-rw-r--r--tensorflow/core/framework/op.h20
-rw-r--r--tensorflow/core/framework/op_def_builder.cc24
-rw-r--r--tensorflow/core/framework/op_def_builder.h14
-rw-r--r--tensorflow/core/framework/resource_mgr.cc9
-rw-r--r--tensorflow/core/framework/resource_mgr.h117
-rw-r--r--tensorflow/core/framework/run_handler.cc249
-rw-r--r--tensorflow/core/framework/run_handler.h95
-rw-r--r--tensorflow/core/framework/run_handler_util.cc57
-rw-r--r--tensorflow/core/framework/run_handler_util.h43
-rw-r--r--tensorflow/core/framework/run_handler_util_test.cc93
-rw-r--r--tensorflow/core/framework/tensor.cc2
-rw-r--r--tensorflow/core/framework/tensor.h2
-rw-r--r--tensorflow/core/framework/tensor_test.cc3
-rw-r--r--tensorflow/core/graph/graph.cc9
-rw-r--r--tensorflow/core/graph/graph.h9
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc28
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc24
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc7
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc4
-rw-r--r--tensorflow/core/graph/node_builder.cc7
-rw-r--r--tensorflow/core/graph/node_builder.h4
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h4
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc6
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt7
-rw-r--r--tensorflow/core/grappler/graph_view.cc35
-rw-r--r--tensorflow/core/grappler/graph_view.h3
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc22
-rw-r--r--tensorflow/core/grappler/grappler_item.cc1
-rw-r--r--tensorflow/core/grappler/grappler_item.h9
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc8
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc23
-rw-r--r--tensorflow/core/grappler/op_types.cc122
-rw-r--r--tensorflow/core/grappler/op_types.h3
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc130
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD78
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc32
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h36
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc45
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h38
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc40
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc289
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h55
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc5
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc21
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc30
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc3
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc13
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc44
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD13
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc33
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc40
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h24
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h15
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h44
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc662
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h35
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc459
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc53
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc156
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc401
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h4
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc118
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc8
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc12
-rw-r--r--tensorflow/core/grappler/utils.cc39
-rw-r--r--tensorflow/core/grappler/utils.h110
-rw-r--r--tensorflow/core/grappler/utils/functions.cc55
-rw-r--r--tensorflow/core/grappler/utils/functions.h5
-rw-r--r--tensorflow/core/grappler/utils_test.cc19
-rw-r--r--tensorflow/core/kernels/BUILD76
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc9
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD5
-rw-r--r--tensorflow/core/kernels/collective_ops.cc21
-rw-r--r--tensorflow/core/kernels/conv_ops.cc321
-rw-r--r--tensorflow/core/kernels/conv_ops.h44
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_xdivy.cc38
-rw-r--r--tensorflow/core/kernels/cwise_op_xlogy.cc41
-rw-r--r--tensorflow/core/kernels/cwise_ops.h45
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/data/BUILD16
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc47
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h20
-rw-r--r--tensorflow/core/kernels/data/dataset_utils_test.cc46
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD139
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc156
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc860
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc281
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc156
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc141
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc375
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h119
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc218
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc482
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc220
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc224
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc162
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc111
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc275
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc62
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc68
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc34
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc16
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc89
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc79
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc136
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h2
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc98
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc11
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc13
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc2
-rw-r--r--tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc32
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/matmul_op.cc8
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_slice_op.cc358
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc19
-rw-r--r--tensorflow/core/kernels/random_op.cc34
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc70
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h10
-rw-r--r--tensorflow/core/kernels/slice_op.cc199
-rw-r--r--tensorflow/core/kernels/stateless_random_ops.cc155
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc2
-rw-r--r--tensorflow/core/kernels/string_util.cc4
-rw-r--r--tensorflow/core/kernels/string_util.h44
-rw-r--r--tensorflow/core/kernels/substr_op.cc162
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc100
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc45
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h37
-rw-r--r--tensorflow/core/kernels/training_ops.cc8
-rw-r--r--tensorflow/core/kernels/transpose_op.cc10
-rw-r--r--tensorflow/core/kernels/unicode_script_op.cc53
-rw-r--r--tensorflow/core/kernels/unique_op.cc15
-rw-r--r--tensorflow/core/ops/array_ops.cc130
-rw-r--r--tensorflow/core/ops/array_ops_test.cc1
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt1272
-rw-r--r--tensorflow/core/ops/dataset_ops.cc26
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc207
-rw-r--r--tensorflow/core/ops/functional_ops.cc44
-rw-r--r--tensorflow/core/ops/math_grad.cc34
-rw-r--r--tensorflow/core/ops/math_grad_test.cc40
-rw-r--r--tensorflow/core/ops/math_ops.cc33
-rw-r--r--tensorflow/core/ops/math_ops_test.cc12
-rw-r--r--tensorflow/core/ops/nn_ops.cc8
-rw-r--r--tensorflow/core/ops/ops.pbtxt668
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc72
-rw-r--r--tensorflow/core/ops/stateless_random_grad.cc23
-rw-r--r--tensorflow/core/ops/stateless_random_ops.cc53
-rw-r--r--tensorflow/core/ops/string_ops.cc6
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.cc15
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.h10
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc6
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc8
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc25
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h7
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc1280
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider_test.cc20
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system.h67
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc102
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.cc35
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.h29
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils_test.cc32
-rw-r--r--tensorflow/core/platform/default/build_config.bzl65
-rw-r--r--tensorflow/core/platform/env.h6
-rw-r--r--tensorflow/core/platform/posix/env.cc11
-rw-r--r--tensorflow/core/platform/posix/port.cc36
-rw-r--r--tensorflow/core/platform/windows/env.cc11
-rw-r--r--tensorflow/core/platform/windows/port.cc51
-rw-r--r--tensorflow/core/profiler/BUILD1
-rw-r--r--tensorflow/core/protobuf/config.proto5
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto8
-rw-r--r--tensorflow/core/util/mkl_util.h12
-rw-r--r--tensorflow/core/util/port.cc4
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD5
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc52
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc64
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README3
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001bin0 -> 1080 bytes
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.indexbin0 -> 211 bytes
-rw-r--r--tensorflow/core/util/util.cc16
-rw-r--r--tensorflow/core/util/util.h5
299 files changed, 14171 insertions, 3334 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index bc0bfb793c..900a0e11c4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -149,6 +149,7 @@ load(
"tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
load(
"//third_party/mkl:build_defs.bzl",
@@ -238,7 +239,6 @@ tf_proto_library(
srcs = [],
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":protos_all_proto",
@@ -271,6 +271,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+java_proto_library(
+ name = "example_java_proto",
+ visibility = ["//visibility:public"],
+ deps = [":example_protos"],
+)
+
closure_proto_library(
name = "example_protos_closure",
visibility = ["//visibility:public"],
@@ -707,14 +713,11 @@ cc_library(
cc_library(
name = "feature_util",
srcs = ["example/feature_util.cc"],
- hdrs = [
- "example/feature_util.h",
- "platform/types.h",
- ],
+ hdrs = ["example/feature_util.h"],
visibility = ["//visibility:public"],
deps = [
":core_stringpiece",
- ":platform_protobuf",
+ ":lib_proto_parsing",
":protos_all_cc",
],
)
@@ -1041,6 +1044,7 @@ tf_gen_op_libs(
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
+ "experimental_dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -1057,7 +1061,6 @@ tf_gen_op_libs(
"random_grad",
"random_ops",
"remote_fused_graph_ops",
- "resource_variable_ops",
"rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
@@ -1099,6 +1102,14 @@ tf_gen_op_libs(
deps = ["//tensorflow/core/kernels:debug_ops"],
)
+tf_gen_op_libs(
+ is_external = False,
+ op_lib_names = [
+ "resource_variable_ops",
+ ],
+ deps = [":lib"],
+)
+
# And one for all user ops
cc_library(
name = "user_ops_op_lib",
@@ -1164,6 +1175,7 @@ cc_library(
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
+ ":experimental_dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -1230,6 +1242,7 @@ cc_library(
srcs = [
"ops/math_grad.cc",
"ops/random_grad.cc",
+ "ops/stateless_random_grad.cc",
],
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
@@ -1363,6 +1376,7 @@ cc_library(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
@@ -2377,7 +2391,6 @@ tf_proto_library(
srcs = ERROR_CODES_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
provide_cc_alias = True,
)
@@ -2398,7 +2411,6 @@ tf_proto_library(
srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":error_codes_proto",
@@ -2478,6 +2490,8 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
"framework/resource_var.h",
+ "framework/run_handler.h",
+ "framework/run_handler_util.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
@@ -2554,6 +2568,7 @@ tf_cuda_library(
"**/*test*",
"**/*main.cc",
"example/example_parser_configuration.*",
+ "example/feature_util.cc",
"util/reporter.cc",
"framework/fake_input.*",
"framework/op_gen_lib.*",
@@ -2583,6 +2598,7 @@ tf_cuda_library(
],
}),
deps = [
+ ":feature_util",
":lib",
":lib_internal",
":protos_all_proto_text",
@@ -2962,6 +2978,7 @@ tf_cuda_library(
":core_cpu_internal",
":device_tracer",
":framework",
+ ":framework_internal",
":graph",
":lib",
":lib_internal",
@@ -2999,7 +3016,7 @@ tf_cuda_library(
"platform/device_tracer.h",
],
copts = tf_copts(),
- cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(),
+ cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()),
visibility = ["//visibility:private"],
deps = [
":core_cpu_internal",
@@ -3734,7 +3751,7 @@ tf_cc_tests_gpu(
tf_cc_tests_gpu(
name = "hierarchical_tree_broadcaster_test",
- size = "small",
+ size = "medium",
srcs = [
"common_runtime/hierarchical_tree_broadcaster_test.cc",
],
@@ -3821,6 +3838,7 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
]),
@@ -4108,6 +4126,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "framework_run_handler_util_test",
+ size = "small",
+ srcs = ["framework/run_handler_util_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ],
+)
+
tf_cuda_cc_test(
name = "common_runtime_direct_session_test",
size = "small",
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000000..fa8fc96bb2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalAssertNextDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000000..5fd88e7a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalCSVDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000000..ac1f9719fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "ExperimentalDirectedInterleaveDataset"
+ in_arg {
+ name: "selector_input_dataset"
+ description: <<END
+A dataset of scalar `DT_INT64` elements that determines which of the
+`N` data inputs should produce the next output element.
+END
+ }
+ in_arg {
+ name: "data_input_datasets"
+ description: <<END
+`N` datasets with the same type that will be interleaved according to
+the values of `selector_input_dataset`.
+END
+ }
+ summary: <<END
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
new file mode 100644
index 0000000000..66511eff60
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResource"
+ in_arg {
+ name: "string_arg"
+ description: <<END
+String argument to the function call.
+END
+ }
+ in_arg {
+ name: "target_device"
+ description: <<END
+Target device to execute the function on.
+END
+ }
+ out_arg {
+ name: "resource"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name across
+multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+Function to be executed.
+END
+ }
+ attr {
+ name: "buffer_size"
+ description: <<END
+Size of the buffer.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Creates a resource that fills up a buffer by making function calls.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
new file mode 100644
index 0000000000..bf4b66b22b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceGetNext"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A list of return values.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Gets the next element from a FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
new file mode 100644
index 0000000000..729718ddb3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceReset"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ summary: <<END
+Resets the FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
new file mode 100644
index 0000000000..fe266c111f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIdentityIndexedDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000000..d42546516d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIgnoreErrorsDataset"
+ summary: <<END
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
new file mode 100644
index 0000000000..e285f87e10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetGet"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
new file mode 100644
index 0000000000..60c32473b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetMaterialize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000000..b72b229e9a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIteratorGetDevice"
+ summary: <<END
+Returns the name of the device on which `resource` has been placed.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000000..b38b23a51d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalLMDBDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
new file mode 100644
index 0000000000..9676b9d284
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000000..d73b5bfda3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolDataset"
+ in_arg {
+ name: "thread_pool"
+ description: <<END
+A resource produced by the ThreadPoolHandle op.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000000..48bf93406c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolHandle"
+ out_arg {
+ name: "handle"
+ description: <<END
+A resource that can be consumed by one or more ExperimentalThreadPoolDataset
+ops.
+END
+ }
+ attr {
+ name: "num_threads"
+ description: <<END
+The number of threads in the thread pool.
+END
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ description: <<END
+The maximum degree of parallelism to use within operations that execute on this
+threadpool.
+END
+ }
+ attr {
+ name: "display_name"
+ description: <<END
+A human-readable name for the threads that may be visible in some
+visualizations.
+threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000000..68ed797a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalUniqueDataset"
+ summary: <<END
+Creates a dataset that contains the unique elements of `input_dataset`.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
index 40d7d371ca..7142a0e3f2 100644
--- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
@@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as:
where
-\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
index 4433693759..d158f4b502 100644
--- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
@@ -4,16 +4,23 @@ op {
in_arg {
name: "arguments"
description: <<END
- A list of tensors whose types are Targuments, corresponding to the inputs the
- function should be mapped over.
+ A list of tensors whose types are `Targuments`, corresponding to the inputs
+ the function should be mapped over.
+END
+ }
+ in_arg {
+ name: "captured_inputs"
+ description: <<END
+ A list of tensors whose types are `Tcaptured`, corresponding to the captured
+ inputs of the defun.
END
}
out_arg {
name: "output"
description: <<END
- A list of output tensors whose types are output_types and whose dimensions 0
- are the same as the dimensions 0 of the tensors in arguments, and whose
- remaining dimensions correspond to those in output_shapes.
+ A list of output tensors whose types are `output_types` and whose dimensions
+ 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
+ remaining dimensions correspond to those in `output_shapes`.
END
}
attr {
@@ -21,6 +28,10 @@ END
description: "A list of types."
}
attr {
+ name: "Tcaptured"
+ description: "A list of types."
+ }
+ attr {
name: "output_types"
description: "A list of types."
}
@@ -29,6 +40,6 @@ END
description: "A list of shapes."
}
summary: <<END
- Maps a function on the list of tensors unpacked from inputs on dimension 0.
+ Maps a function on the list of tensors unpacked from arguments on dimension 0.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
new file mode 100644
index 0000000000..08414b3e68
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
@@ -0,0 +1,26 @@
+op {
+ visibility: HIDDEN
+ graph_op_name: "ReduceDataset"
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "initial_state"
+ description: <<END
+A nested structure of tensors, representing the initial state of the
+transformation.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function that maps `(old_state, input_element)` to `new_state`. It must take
+two arguments and return a nested structures of tensors. The structure of
+`new_state` must match the structure of `initial_state`.
+END
+ }
+ summary: "Reduces the input dataset to a singleton using a reduce function."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
new file mode 100644
index 0000000000..b6a6dbdf54
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StatelessRandomUniformInt.pbtxt
@@ -0,0 +1,46 @@
+op {
+ graph_op_name: "StatelessRandomUniformInt"
+ visibility: HIDDEN
+ in_arg {
+ name: "shape"
+ description: <<END
+The shape of the output tensor.
+END
+ }
+ in_arg {
+ name: "seed"
+ description: <<END
+2 seeds (shape [2]).
+END
+ }
+ in_arg {
+ name: "minval"
+ description: <<END
+Minimum value (inclusive, scalar).
+END
+ }
+ in_arg {
+ name: "maxval"
+ description: <<END
+Maximum value (exclusive, scalar).
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Random values with specified shape.
+END
+ }
+ attr {
+ name: "dtype"
+ description: <<END
+The type of the output.
+END
+ }
+ summary: "Outputs deterministic pseudorandom random integers from a uniform distribution."
+ description: <<END
+The generated values follow a uniform distribution in the range `[minval, maxval)`.
+
+The outputs are a deterministic function of `shape`, `seed`, `minval`, and `maxval`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 5246090ab3..fe0fcc9508 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -18,6 +18,16 @@ END
Scalar defining the number of characters to include in each substring
END
}
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is used to create the substring. One of: `"BYTE"` (for
+defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8
+encoded Unicode code points). The default is `"BYTE"`. Results are undefined if
+`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid
+UTF-8.
+END
+ }
out_arg {
name: "output"
description: <<END
diff --git a/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..7898fe8d6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "UnicodeScript"
+ }
+ in_arg {
+ name: "input"
+ description: <<END
+A Tensor of int32 Unicode code points.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A Tensor of int32 script codes corresponding to each input code point.
+END
+ }
+ summary: <<END
+Determine the script codes of a given tensor of Unicode integer code points.
+END
+ description: <<END
+This operation converts Unicode code points to script codes corresponding to
+each code point. Script codes correspond to International Components for
+Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html.
+Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will
+match input shape.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..ca107abc6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xdivy"
+ summary: "Returns 0 if x == 0, and x / y otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..da625f7836
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xlogy"
+ summary: "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
new file mode 100644
index 0000000000..3d937c745c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBackBatch"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
index 9552fc92e3..e395e333bf 100644
--- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "BatchToSpaceND"
endpoint {
- name: "manip.batch_to_space_nd"
+ name: "batch_to_space_nd"
}
endpoint {
- name: "batch_to_space_nd"
+ name: "manip.batch_to_space_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
new file mode 100644
index 0000000000..44f25b5d93
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EmptyTensorList"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
index 71257c8855..598f23bde3 100644
--- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "GatherNd"
endpoint {
- name: "manip.gather_nd"
+ name: "gather_nd"
}
endpoint {
- name: "gather_nd"
+ name: "manip.gather_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
index b17806b338..5020844204 100644
--- a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "RegexReplace"
- endpoint {
- name: "strings.regex_replace"
- }
- endpoint {
- name: "regex_replace"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
index c469665b66..b3d596de7a 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Reshape"
endpoint {
- name: "manip.reshape"
+ name: "reshape"
}
endpoint {
- name: "reshape"
+ name: "manip.reshape"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index 77f595927b..51478b7c34 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ReverseV2"
endpoint {
- name: "manip.reverse"
+ name: "reverse"
}
endpoint {
- name: "reverse"
+ name: "manip.reverse"
deprecated: true
}
endpoint {
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
index a65a19b542..85888da45a 100644
--- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ScatterNd"
endpoint {
- name: "manip.scatter_nd"
+ name: "scatter_nd"
}
endpoint {
- name: "scatter_nd"
+ name: "manip.scatter_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
index af323a6cf3..146b97f444 100644
--- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "SpaceToBatchND"
endpoint {
- name: "manip.space_to_batch_nd"
+ name: "space_to_batch_nd"
}
endpoint {
- name: "space_to_batch_nd"
+ name: "manip.space_to_batch_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
new file mode 100644
index 0000000000..d3c70190dd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessMultinomial"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
new file mode 100644
index 0000000000..e294325fb8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
new file mode 100644
index 0000000000..95d414c54a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomUniform"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
new file mode 100644
index 0000000000..c72bdda94a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessTruncatedNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
index 4778d7927c..4fb9ee56e9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "Substr"
- endpoint {
- name: "strings.substr"
- }
- endpoint {
- name: "substr"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
new file mode 100644
index 0000000000..45fc55e71e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListConcatLists"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
new file mode 100644
index 0000000000..e1ad713e7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListElementShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
new file mode 100644
index 0000000000..4aaefba3c5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListFromTensor"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..aaf607d70e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGather"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
new file mode 100644
index 0000000000..3bb5f39cbc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
new file mode 100644
index 0000000000..a04c20bb8a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListLength"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
new file mode 100644
index 0000000000..9287162f22
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPopBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
new file mode 100644
index 0000000000..da2bc11721
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
new file mode 100644
index 0000000000..77e63747d5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListReserve"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..0015189d7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListScatter"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
new file mode 100644
index 0000000000..4999ee7ad9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListSetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
new file mode 100644
index 0000000000..2dc7b2784b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListStack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
index c34061c941..1d8695f1fd 100644
--- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Tile"
endpoint {
- name: "manip.tile"
+ name: "tile"
}
endpoint {
- name: "tile"
+ name: "manip.tile"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..a884a46143
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "strings.unicode_script"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..984442ba2b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xdivy"
+ endpoint {
+ name: "math.xdivy"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..b4a5299256
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xlogy"
+ endpoint {
+ name: "math.xlogy"
+ }
+}
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 3b2dc6a050..7cb90de3c7 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -522,7 +522,6 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
InitInstanceSharedParams(
gr, cp, ir,
[this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
- DCHECK(!ir->out_mu.try_lock());
DCHECK(ir->out_mu_available);
ir->status.Update(s);
ir->out_mu.unlock();
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index 99cb9ac6a0..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -470,19 +470,19 @@ bool ReplaceTensorWithConstant(
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
- // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
+ // 1) Do not replace another constant.
+ // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
// constraint, do not replace it.
- // 3) If the constant op created does not have a kernel implementation
- // for the device, do not use it.
// 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
+ // 5) If the constant op created does not have a kernel implementation
+ // for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) Do not replace another constant.
if (tensor.first->IsConstant()) {
return false;
}
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index d800a86199..6e2eb66b94 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [dst, recv_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Host->Device Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
+ edge_name](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
- wrapped_done_);
+ CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
+ dst, to, recv_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Host->Device Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [edge_name, src, send_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Device->Host Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [edge_name, src, send_dev_context, out_allocator, status_cb,
+ cpu_allocator](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
- wrapped_done_);
+ CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
+ src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Device->Host Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index af5d5b17e7..458e133b68 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/run_handler.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
#endif // __ANDROID__
}
+static RunHandlerPool* GetOrCreateRunHandlerPool(
+ const SessionOptions& options) {
+ static RunHandlerPool* pool =
+ new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
+ return pool;
+}
+
+bool DirectSession::ShouldUseRunHandlerPool() const {
+ if (options_.config.session_inter_op_thread_pool_size() > 0 ||
+ options_.config.use_per_session_threads()) {
+ return false;
+ }
+ return true;
+}
+
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
@@ -363,7 +379,7 @@ Status DirectSession::MaybeInitializeExecutionState(
Status DirectSession::Create(const GraphDef& graph) {
TF_RETURN_IF_ERROR(init_error_);
if (graph.node_size() > 0) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (graph_created_) {
return errors::AlreadyExists(
"A Graph has already been created for this session.");
@@ -375,7 +391,7 @@ Status DirectSession::Create(const GraphDef& graph) {
Status DirectSession::Extend(const GraphDef& graph) {
TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
return ExtendLocked(graph);
}
@@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- Executor::Args::Runner default_runner = [this,
- pool](Executor::Args::Closure c) {
- SchedClosure(pool, std::move(c));
- };
+ std::unique_ptr<RunHandler> handler;
+ if (ShouldUseRunHandlerPool() &&
+ run_options.experimental().use_run_handler_pool()) {
+ // Non-null only when a global inter-op pool is used.
+ VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
+ handler = GetOrCreateRunHandlerPool(options_)->Get();
+ }
+ auto* handler_ptr = handler.get();
+
+ Executor::Args::Runner default_runner = nullptr;
+
+ if (pool == nullptr) {
+ default_runner = [](Executor::Args::Closure c) { c(); };
+ } else if (handler_ptr != nullptr) {
+ default_runner = [handler_ptr](Executor::Args::Closure c) {
+ handler_ptr->ScheduleInterOpClosure(std::move(c));
+ };
+ } else {
+ default_runner = [this, pool](Executor::Args::Closure c) {
+ SchedClosure(pool, std::move(c));
+ };
+ }
+
for (const auto& item : executors_and_keys->items) {
- // TODO(zhengxq): support partial run.
- // TODO(zhengxq): if the device picks its own threadpool, we need to assign
+ // TODO(azaks): support partial run.
+ // TODO(azaks): if the device picks its own threadpool, we need to assign
// less threads to the main compute pool by default.
thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
+ // TODO(crk): Investigate usage of RunHandlerPool when using device specific
+ // thread pool(s).
if (!device_thread_pool) {
args.runner = default_runner;
} else {
@@ -1172,7 +1209,7 @@ Status DirectSession::CreateExecutors(
int graph_def_version;
{
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
graph_def_version =
execution_state_->original_graph_def().versions().producer();
}
@@ -1400,7 +1437,7 @@ Status DirectSession::CreateGraphs(
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
DataTypeVector* output_types, int64* collective_graph_key) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
std::unique_ptr<ClientGraph> client_graph;
std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index c2cf3c7fd7..3a168bbe3f 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -215,7 +215,7 @@ class DirectSession : public Session {
// if not already initialized.
Status MaybeInitializeExecutionState(const GraphDef& graph,
bool* out_already_initialized)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
// Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use.
@@ -247,8 +247,11 @@ class DirectSession : public Session {
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata);
+ // Returns whether inter-op execution uses a global pool.
+ bool ShouldUseRunHandlerPool() const;
+
::tensorflow::Status ExtendLocked(const GraphDef& graph)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
::tensorflow::Status ResourceHandleToInputTensor(
const Tensor& resource_tensor, Tensor* retrieved_tensor);
@@ -289,7 +292,7 @@ class DirectSession : public Session {
}
::tensorflow::Status CheckGraphCreated(const char* method) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before ", method, "!");
@@ -313,10 +316,8 @@ class DirectSession : public Session {
DeviceSet device_set_;
string session_handle_;
- bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
-
- mutex graph_def_lock_;
- GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
+ mutex graph_state_lock_;
+ bool graph_created_ GUARDED_BY(graph_state_lock_) = false;
// The thread-pools to use for running ops, with a bool indicating if the pool
// is owned.
@@ -367,11 +368,11 @@ class DirectSession : public Session {
// nodes can not be moved to a different device. Maps node names to
// device names.
std::unordered_map<string, string> stateful_placements_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// Execution_state; used when placing the entire graph.
std::unique_ptr<GraphExecutionState> execution_state_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// The function library, before any rewrites or optimizations have been
// performed. In particular, CreateGraphs() may need to modify the function
@@ -386,7 +387,7 @@ class DirectSession : public Session {
std::atomic<int64> edge_name_counter_ = {0};
std::atomic<int64> handle_name_counter_ = {0};
- // For generating step ids that are unique across all sessions.
+ // For generating step ids that are unique across this sessions.
static std::atomic_int_fast64_t step_id_counter_;
// Global timeout for all blocking operations in this session.
@@ -395,8 +396,6 @@ class DirectSession : public Session {
// Manages all the cost models for the graphs executed in this session.
CostModelManager cost_model_manager_;
- Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
-
// For testing collective graph key generation.
mutex collective_graph_key_lock_;
int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 65e816c202..a6440c55ad 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
+TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
+ Initialize({3, 2, -1, 0});
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+
+ // Prepares RunOptions and RunMetadata
+ RunOptions run_options;
+ run_options.mutable_experimental()->set_use_run_handler_pool(true);
+
+ Status s = session->Run(run_options, inputs, output_names, target_nodes,
+ &outputs, nullptr);
+ TF_ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initialized and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
@@ -2234,8 +2262,8 @@ class DirectSessionCollectiveTest : public ::testing::Test {
TF_RETURN_IF_ERROR(session->Create(g));
std::vector<Tensor> outputs;
TF_RETURN_IF_ERROR(
- session->Run({{"input1:0", t1}, {"input2:0", t2}}, {},
- {"collective_call1:0", "collective_call2:0"}, &outputs));
+ session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
+ {"collective_call0:0", "collective_call1:0"}, &outputs));
DirectSession* direct_session = static_cast<DirectSession*>(session.get());
{
mutex_lock l(direct_session->collective_graph_key_lock_);
@@ -2273,6 +2301,26 @@ class DirectSessionCollectiveTest : public ::testing::Test {
}});
}
+ NodeDef Input(int id) {
+ AttrValue dtype_attr;
+ SetAttrValue(DT_FLOAT, &dtype_attr);
+ NodeDef input;
+ input.set_name(strings::StrCat("input", id));
+ input.set_op("Placeholder");
+ input.mutable_attr()->insert({"dtype", dtype_attr});
+ return input;
+ }
+
+ NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
+ NodeDef collective_call;
+ collective_call.set_name(strings::StrCat("collective_call", cpu_id));
+ collective_call.set_op(op);
+ collective_call.add_input(input);
+ collective_call.set_device(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
+ return collective_call;
+ }
+
// Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
// CPU1, with instance_key 1, and appropriate placeholder inputs. If
// `add_unused_function` is true, adds another CollectiveFunction with
@@ -2289,42 +2337,17 @@ class DirectSessionCollectiveTest : public ::testing::Test {
*lib->add_function() = unused_function;
}
- // Inputs.
- AttrValue dtype_attr;
- SetAttrValue(DT_FLOAT, &dtype_attr);
- NodeDef input1;
- input1.set_name("input1");
- input1.set_op("Placeholder");
- input1.mutable_attr()->insert({"dtype", dtype_attr});
- NodeDef input2;
- input2.set_name("input2");
- input2.set_op("Placeholder");
- input2.mutable_attr()->insert({"dtype", dtype_attr});
-
+ *g.add_node() = Input(0);
+ *g.add_node() = Input(1);
// CollectiveReduce on CPU0 with instance_key 1.
- NodeDef collective_call1;
- collective_call1.set_name("collective_call1");
- collective_call1.set_op("CollectiveFunction1");
- collective_call1.add_input("input1");
- collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:0");
+ *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
// CollectiveReduce on CPU1 with instance_key 1.
- NodeDef collective_call2;
- collective_call2.set_name("collective_call2");
- collective_call2.set_op("CollectiveFunction1");
- collective_call2.add_input("input2");
- collective_call1.set_device("/job:localhost/replica:0/task:0/device:CPU:1");
-
- *g.add_node() = input1;
- *g.add_node() = input2;
- *g.add_node() = collective_call1;
- *g.add_node() = collective_call2;
+ *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
return g;
}
};
-#ifndef GOOGLE_CUDA
-// TODO(ayushd): enable this test for GPU builds.
TEST_F(DirectSessionCollectiveTest,
TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
int64 key1;
@@ -2333,6 +2356,5 @@ TEST_F(DirectSessionCollectiveTest,
TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
ASSERT_EQ(key1, key2);
}
-#endif
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 2ed4f69f90..2c63b8704e 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -108,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(2, shape.dim(0).size());
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// if MKL is used, it goes through various additional
// graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
@@ -117,16 +117,16 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
// which increments the value of AllocationId.
// Thus AllocationId becomes more than TF if MKL
// is used. Now IDs for MKL are 8 more than TF.
- EXPECT_EQ(29, cm->AllocationId(node, 0));
-#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
- } else {
-#ifdef INTEL_MKL
- EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
+ EXPECT_EQ(13, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
+ } else {
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#else
+ EXPECT_EQ(14, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index be5f3bae3a..7b74c67c85 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -147,10 +147,11 @@ tf_cuda_library(
"kernel_and_device.h",
],
visibility = ["//tensorflow:internal"],
- deps = select({
+ deps = [
+ "@farmhash_archive//:farmhash",
+ ] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
- "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
"//tensorflow/core:core_cpu_lib",
@@ -219,13 +220,13 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":kernel_and_device",
+ "@farmhash_archive//:farmhash",
# Only the TF_AttrType enum is required, so pull in just the C headers.
# TODO(b/113535673): Break this dependency and avoid the C header completely.
"//tensorflow/c:c_api_headers",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
- "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index cf1cd4134e..5c8369de87 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -136,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
m->insert(*it);
}
}
+ // For any attr-value pairs that exist in the op def (from op registry) but
+ // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
+ // specify all the default attr values (e.g. for matmul, the `transpose_a`
+ // attr defaults to false).
+ const OpDef* op_def = nullptr;
+ Status s = OpDefForOp(op_name_.c_str(), &op_def);
+ // This is expected, if this op is a custom function, and is therefore not
+ // present in the op registry.
+ if (!s.ok()) return;
+
+ DCHECK(op_def);
+ for (const auto& attr_def : op_def->attr()) {
+ if (attr_def.has_default_value() && !m->count(attr_def.name())) {
+ SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
+ }
+ }
}
const NodeDef& AttrBuilder::BuildNodeDef() {
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index cbe6a1cb50..c114ea4ba0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -110,6 +110,12 @@ class AttrBuilder {
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
+ // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
+ // well as any default attr-value pairs from the associated op_def, if there
+ // is one.
+ //
+ // If `include_those_in_node_def` is true, also include any attr-value pairs
+ // from `node_def_`.
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
template <class T>
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 18420b60fd..f23cefb33d 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -70,7 +70,9 @@ EagerContext::EagerContext(const SessionOptions& opts,
async_default_(async),
log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
- use_send_tensor_rpc_(false) {
+ use_send_tensor_rpc_(false),
+ pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
+ "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", true)) {
if (device_mgr_owned) {
local_device_manager_.reset(device_mgr);
local_unowned_device_manager_ = nullptr;
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 5ed6057ec6..15eeaa8066 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -202,6 +202,7 @@ class EagerContext {
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs).
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
+ bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
private:
void InitDeviceMapAndAsync();
@@ -293,6 +294,7 @@ class EagerContext {
#endif
bool use_send_tensor_rpc_;
+ const bool pin_small_ops_to_cpu_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 1bc63616d0..a52f933d75 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -579,19 +579,23 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
return Status::OK();
#endif
}
-} // namespace
-Status EagerExecute(EagerOperation* op,
- gtl::InlinedVector<TensorHandle*, 2>* retvals,
- int* num_retvals) {
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
+// The Op device may be updated if:
+// - A resource touching input is specified: all resource-touching ops run in
+// the device the resource is, regardless of anything else that has been
+// specified. This is identical to the graph mode behavior.
+//
+// - All op inputs are on the CPU, small (<64 elements) and integers
+// (int32/int64). This can be disabled by setting the environment variable
+// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
+Status MaybeUpdateOpDevice(EagerOperation* op) {
EagerContext* ctx = op->EagerContext();
+ bool device_set_for_resource_variable = false;
+ bool all_inputs_eligible_for_cpu_pinning = ctx->PinSmallOpsToCPU();
+
for (int i = 0; i < op->Inputs().size(); ++i) {
Device* input_op_device = nullptr;
- auto status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
+ TF_RETURN_IF_ERROR(op->Inputs()[i]->OpDevice(&input_op_device));
VLOG(2) << "for op " << op->Name() << " input " << i << " "
<< DataTypeString(op->Inputs()[i]->dtype) << " "
<< (input_op_device == nullptr ? "cpu" : input_op_device->name())
@@ -603,8 +607,53 @@ Status EagerExecute(EagerOperation* op,
<< d->name() << " because input #" << i
<< " is a resource in this device.";
op->SetDevice(d);
+
+ device_set_for_resource_variable = true;
+ all_inputs_eligible_for_cpu_pinning = false;
+ } else if (all_inputs_eligible_for_cpu_pinning) {
+ TensorHandle* handle = op->Inputs()[i];
+
+ // Input is on CPU.
+ if (input_op_device != nullptr && input_op_device != ctx->HostCPU()) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ if (handle->dtype != DataType::DT_INT32 &&
+ handle->dtype != DataType::DT_INT64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ int64 num_elements;
+ TF_RETURN_IF_ERROR(handle->NumElements(&num_elements));
+ if (num_elements > 64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ }
}
}
+
+ // Ops without inputs are usually ops that generate a tensor in some way and
+ // usually require being present on whatever device they are scheduled on
+ // - for e.g. VarHandleOp or _Recv).
+ // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
+ // an op, but there is a GPU kernel?
+ if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
+ VLOG(1) << "Forcing op " << op->Name()
+ << " to be on the CPU since all input tensors have an "
+ "int32/int64 dtype, and are small (less than 64 elements).";
+ op->SetDevice(ctx->HostCPU());
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status EagerExecute(EagerOperation* op,
+ gtl::InlinedVector<TensorHandle*, 2>* retvals,
+ int* num_retvals) {
+ TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
+
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc
index c1542f1f57..87749da7af 100644
--- a/tensorflow/core/common_runtime/eval_const_tensor.cc
+++ b/tensorflow/core/common_runtime/eval_const_tensor.cc
@@ -113,6 +113,13 @@ Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
return Status::OK();
}
+// Returns true if 'node' has a registered CPU kernel.
+bool HasCpuKernel(const Node& node) {
+ return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
+ /*kernel_class_name=*/nullptr)
+ .ok();
+}
+
// Extracts the subgraph ending at 'target_node' that is statically computable
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
// will be set to true.
@@ -136,6 +143,12 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ // Since constant-folding runs on the CPU, do not attempt to constant-fold
+ // operators that have no CPU kernel.
+ if (!HasCpuKernel(target_node)) {
+ return Status::OK();
+ }
+
// TODO(skyewm): should more of the filtering applied in input nodes below be
// applied to target_node here?
@@ -201,6 +214,11 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ if (!HasCpuKernel(*current_node)) {
+ *is_constant_graph = false;
+ return Status::OK();
+ }
+
// If there is nothing more to recurse down, see if
// the generator node is a constant.
if (current_node->num_inputs() == 0) {
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index 2c48084cab..40ec1502da 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -54,6 +54,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -1240,6 +1241,7 @@ class ExecutorState {
StepStatsCollectorInterface* const stats_collector_;
const tracing::TraceCollector* const trace_collector_;
const tracing::EventCollector* const event_collector_;
+ Context context_;
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
@@ -1367,6 +1369,7 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
trace_collector_(tracing::GetTraceCollector()),
event_collector_(
tracing::GetEventCollector(tracing::EventCategory::kCompute)),
+ context_(ContextKind::kThread),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
@@ -1586,6 +1589,7 @@ bool MightTrace(const NodeItem& item,
}
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
+ WithContext wc(context_);
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
TaggedNodeReadyQueue inline_ready;
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index 6cd4fd22ea..34bf73972f 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -97,12 +97,6 @@ class Executor {
typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
Runner runner = nullptr;
-
- // A callback that is invoked each time a node has finished executing.
- typedef std::function<Status(const string& node_name, const int output_slot,
- const Tensor* tensor, const bool is_ref,
- OpKernelContext* ctx)>
- NodeOutputsCallback;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 96ecfb41d4..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -38,7 +38,8 @@ void GraphOptimizer::Optimize(
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn) {
+ const std::function<bool(const Node*)>& cse_consider_fn,
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -62,6 +63,7 @@ void GraphOptimizer::Optimize(
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
+ cf_opts.consider = cf_consider_fn;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 80246281cd..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -45,12 +45,15 @@ class GraphOptimizer {
//
// If cse_consider_fn is not null then only nodes for which cse_consider_fn
// returns true will be considered for CSE.
+ // If cf_consider_fn is not null then only nodes for which cf_consider_fn
+ // returns true will be considered for CF.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index dfce7c23e7..9306386117 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -38,11 +38,12 @@ class CondBuilder {
public:
enum Branch { kElseBranch = 0, kThenBranch = 1 };
- // Create a CondBuilder to create the lowering of If op. that has then and
+ // Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions named `then_fn_name` and `else_fn_name` respectively in the
- // given graph.
+ // `graph`. The functions should be available in `flib`.
CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph);
+ const string& else_fn_name, const FunctionLibraryDefinition& flib,
+ Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
@@ -89,6 +90,7 @@ class CondBuilder {
Node* then_call_node_;
Node* else_call_node_;
Graph* graph_;
+ const FunctionLibraryDefinition& flib_;
string name_;
NodeBuilder then_call_builder_;
@@ -96,13 +98,17 @@ class CondBuilder {
};
CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph)
+ const string& else_fn_name,
+ const FunctionLibraryDefinition& flib, Graph* graph)
: if_op_(if_op),
graph_(graph),
+ flib_(flib),
name_(if_op->name()),
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
TF_CHECK_OK(if_op_->input_node(0, &pred_));
+ then_call_builder_.Device(if_op_->requested_device());
+ else_call_builder_.Device(if_op_->requested_device());
}
Status CondBuilder::CreatePivotNodes() {
@@ -113,15 +119,18 @@ Status CondBuilder::CreatePivotNodes() {
NodeBuilder(NewName("switch_pred"), "Switch", graph_->op_registry())
.Input(NodeOut(pred_, 0))
.Input(NodeOut(pred_, 0))
+ .Device(if_op_->requested_device())
.Finalize(graph_, &switch_pred));
control_predecessor_ = switch_pred;
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_f"), "Identity", graph_->op_registry())
.Input(switch_pred, kElseBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_f_));
TF_RETURN_IF_ERROR(
NodeBuilder(NewName("pivot_t"), "Identity", graph_->op_registry())
.Input(switch_pred, kThenBranch)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &pivot_t_));
return Status::OK();
}
@@ -136,6 +145,7 @@ Status CondBuilder::AddInput(Node* src, int src_output) {
NodeBuilder(NewName(src->name()), "Switch", graph_->op_registry())
.Input(src, src_output)
.Input(pred_, 0)
+ .Device(if_op_->requested_device())
.Finalize(graph_, &input));
then_call_builder_.Input(input, kThenBranch);
else_call_builder_.Input(input, kElseBranch);
@@ -174,6 +184,7 @@ Status CondBuilder::AddOutputs() {
TF_RETURN_IF_ERROR(
NodeBuilder(graph_->NewName("merge"), "Merge", graph_->op_registry())
.Input({NodeOut(then_call_node_, i), NodeOut(else_call_node_, i)})
+ .Device(if_op_->requested_device())
.Finalize(graph_, &merges[i]));
outputs_[i] = NodeOut(merges[i], 0);
}
@@ -193,15 +204,15 @@ Status CondBuilder::AddOutputs() {
return Status::OK();
}
-Status InlineCallInGraph(Node* n, Graph* g) {
- const auto& lib = g->flib_def();
- const FunctionDef* fdef = lib.Find(n->type_string());
+Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
+ Graph* g) {
+ const FunctionDef* fdef = flib.Find(n->type_string());
CHECK(fdef != nullptr);
FunctionBody* fbody;
TF_RETURN_IF_ERROR(
- FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
- [&lib](const string& op, const OpDef** sig) {
- return lib.LookUpOpDef(op, sig);
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &flib,
+ [&flib](const string& op, const OpDef** sig) {
+ return flib.LookUpOpDef(op, sig);
},
&fbody));
// TODO(jpienaar): Improve this interface to make the need to delete it
@@ -214,13 +225,13 @@ Status InlineCallInGraph(Node* n, Graph* g) {
Status CondBuilder::BuildLoweredIfOutput() {
// Build the identity node output.
NodeBuilder ib(name_, "IdentityN");
- ib.Input(outputs_);
+ ib.Input(outputs_).Device(if_op_->requested_device());
return ib.Finalize(graph_, &lowered_if_output_);
}
Status CondBuilder::InlineCallNodes() {
- TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
- TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_));
return Status::OK();
}
@@ -240,6 +251,12 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
return errors::Internal("Lowering If op requires a graph to be available.");
}
+ FunctionLibraryDefinition* flib = options.flib_def;
+ if (flib == nullptr) {
+ return errors::Internal(
+ "Lowering If op requires a FunctionLibraryDefinition to be available.");
+ }
+
// Match all the nodes that need to be rewritten.
gtl::InlinedVector<Node*, 2> matches;
for (Node* n : g->op_nodes()) {
@@ -251,12 +268,14 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
}
}
for (Node* n : matches) {
- TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g));
}
return Status::OK();
}
-Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
+Status LowerIfOpPass::RewriteNode(Node* n,
+ const FunctionLibraryDefinition& flib,
+ Graph* g) {
const AttrValue* then_attr = n->attrs().Find("then_branch");
if (then_attr == nullptr) {
return errors::InvalidArgument("Then branch function missing");
@@ -266,7 +285,8 @@ Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
return errors::InvalidArgument("Else branch function missing");
}
- CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
+ CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
+ g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
index a9ef39ae5c..5ab1123e3f 100644
--- a/tensorflow/core/common_runtime/lower_if_op.h
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -29,8 +29,9 @@ class LowerIfOpPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
private:
- // Rewrite the given If node `n` in graph `g` to use the switch-merge form.
- Status RewriteNode(Node* n, Graph* g);
+ // Rewrite the given If node `n` in graph `g` to use the switch-merge
+ // form. `flib` should contain the branch functions referenced by `n`.
+ Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index 319a617b32..044a355d06 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -36,9 +36,7 @@ namespace tensorflow {
namespace {
Status Rewrite(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
-
+ FunctionLibraryDefinition flib_def((*graph)->flib_def());
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = &flib_def;
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 538a70668a..429b19599b 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -251,6 +251,7 @@ class MklCPUAllocator : public Allocator {
// max_alloc_size from large_size_allocator would be the maximum
// size allocated by MklCPUAllocator.
stats->max_alloc_size = l_stats.max_alloc_size;
+ stats->bytes_limit = std::max(s_stats.bytes_limit, l_stats.bytes_limit);
}
void ClearStats() override {
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
index a67411cd2e..e08ab57638 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
@@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) {
} // namespace tensorflow
-#endif // INTEL_MKL
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc
index a5d31b75c7..e1dc08d645 100644
--- a/tensorflow/core/common_runtime/process_util.cc
+++ b/tensorflow/core/common_runtime/process_util.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/util.h"
namespace tensorflow {
@@ -56,24 +57,26 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 inter_op = options.config.inter_op_parallelism_threads();
if (inter_op != 0) return inter_op;
#ifdef INTEL_MKL
- // MKL library executes ops in parallel using OMP threads
- // Set inter_op conservatively to avoid thread oversubscription that could
- // lead to severe perf degradations and OMP resource exhaustion
- int mkl_intra_op = 1;
+ if (!DisableMKL()) {
+ // MKL library executes ops in parallel using OMP threads
+ // Set inter_op conservatively to avoid thread oversubscription that could
+ // lead to severe perf degradations and OMP resource exhaustion
+ int mkl_intra_op = 1;
#ifdef _OPENMP
- mkl_intra_op = omp_get_max_threads();
+ mkl_intra_op = omp_get_max_threads();
#endif // _OPENMP
- CHECK_GE(mkl_intra_op, 1);
- const int32 mkl_inter_op = std::max(
- (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
- VLOG(0) << "Creating new thread pool with default inter op setting: "
- << mkl_inter_op
- << ". Tune using inter_op_parallelism_threads for best performance.";
- return mkl_inter_op;
-#else
+ DCHECK_GE(mkl_intra_op, 1);
+ const int32 mkl_inter_op = std::max(
+ (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
+ VLOG(0)
+ << "Creating new thread pool with default inter op setting: "
+ << mkl_inter_op
+ << ". Tune using inter_op_parallelism_threads for best performance.";
+ return mkl_inter_op;
+ }
+#endif // INTEL_MKL
// Default to using the number of cores available in the process.
return port::NumSchedulableCPUs();
-#endif // INTEL_MKL
}
thread::ThreadPool* NewThreadPoolFromSessionOptions(
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index a81f8650bf..b1fe928ba7 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -41,6 +41,16 @@ limitations under the License.
// Set true for greater intelligibility of debug mode log messages.
#define READABLE_KEYS false
+// RingReduce algorithm exchanges chunks of tensor between devices. The chunk
+// size depends on the number of subdivisions specified in the algorithm. If
+// the user does not specify the number of subdivisions, we infer the number
+// dynamically so that the resulting chunk size does not exceed
+// kMaxChunkSizeBytes, empirically set at 4 MiB.
+constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
+// kMaxSubdivsPerDev is used to give an upper bound on the number of
+// subdivisions dynamically generated. A reasonable value would be a small
+// multiple of the number of NICs adjacent to each device.
+constexpr int kMaxSubdivsPerDevice = 2;
namespace tensorflow {
namespace {
@@ -92,7 +102,62 @@ RingReducer::RingReducer()
RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
+Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
+ if (col_params->instance.shape.num_elements() == 0) {
+ return errors::Internal("shape in CollectiveParams should be non-empty");
+ }
+ const int kAvgDevPerTask =
+ col_params->group.group_size / col_params->group.num_tasks;
+ const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
+ if (kMaxNumSubdivs <= 0) {
+ return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
+ " in RingReducer");
+ }
+ // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
+ // as many offsets as needed so that the size of tensor chunks <=
+ // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
+ // lead to worse performance.
+ int num_subdivs = 0;
+ const size_t tensor_size = col_params->instance.shape.num_elements() *
+ DataTypeSize(col_params->instance.data_type);
+ size_t chunk_size;
+ do {
+ ++num_subdivs;
+ int num_chunks = col_params->group.group_size * num_subdivs;
+ chunk_size = tensor_size / num_chunks;
+ VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
+ << " chunk_size " << chunk_size;
+ } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
+ if (num_subdivs <= 0) {
+ return errors::Internal("Unexpected num_subdivs ", num_subdivs,
+ " in RingReducer");
+ }
+
+ int subdiv_stride = kAvgDevPerTask / num_subdivs;
+ if (subdiv_stride == 0) subdiv_stride = 1;
+ col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
+ for (int sdi = 0; sdi < num_subdivs; ++sdi) {
+ int subdiv_offset = subdiv_stride * sdi;
+ if (sdi % 2 == 1) subdiv_offset *= -1;
+ col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
+ }
+
+ if (VLOG_IS_ON(2)) {
+ string subdiv_buf;
+ for (const int subdiv_offset :
+ col_params->instance.impl_details.subdiv_offsets) {
+ strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
+ }
+ VLOG(2) << "Dynamically generated " << num_subdivs
+ << " subdiv_offsets:" << subdiv_buf << " tensor_size "
+ << tensor_size << " chunk_size " << chunk_size;
+ }
+
+ return Status::OK();
+}
+
Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
+ // TODO(b/113171733): change CHECKs to return errors.
CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
const string& device_name =
@@ -123,12 +188,11 @@ Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
dev_per_task.push_back(dev_count);
CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
- // Generate a ring permutation for each requested offset.
if (col_params->instance.impl_details.subdiv_offsets.empty()) {
- return errors::Internal(
- "Subdiv offsets should be non-empty for ring reducer, size=",
- col_params->instance.impl_details.subdiv_offsets.size());
+ TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
}
+
+ // Generate a ring permutation for requested offset.
VLOG(2) << "Setting up perms for col_params " << col_params
<< " subdiv_permutations "
<< &col_params->instance.impl_details.subdiv_permutations;
@@ -646,7 +710,8 @@ bool RingReducer::RunAsyncParts() {
case RF_SEND:
--send_pending_count;
break;
- default: {} // Ignore any other actions
+ default: {
+ } // Ignore any other actions
}
}
}
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index 28df85399e..75aba43572 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -549,37 +549,38 @@ class RingReducerTest : public ::testing::Test {
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
};
-TEST_F(RingReducerTest, InitializeParams) {
- static const int kNumDevsPerTask = 8;
- static const int kNumTasks = 3;
- static const int kNumDevs = kNumDevsPerTask * kNumTasks;
+CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
+ const int num_tasks) {
CollectiveParams cp;
- std::vector<string> device_names;
- std::vector<string> task_names;
+ const int kNumDevs = num_devs_per_task * num_tasks;
cp.group.group_key = 1;
cp.group.group_size = kNumDevs;
cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = kNumTasks;
+ cp.group.num_tasks = num_tasks;
cp.instance.instance_key = 3;
cp.instance.type = REDUCTION_COLLECTIVE;
cp.instance.data_type = DataType(DT_FLOAT);
- cp.instance.shape = TensorShape({5});
+ cp.instance.shape = TensorShape({kNumDevs});
cp.instance.impl_details.collective_name = "RingReduce";
cp.instance.impl_details.subdiv_offsets.push_back(0);
cp.is_source = false;
for (int i = 0; i < kNumDevs; ++i) {
- int task_id = i / kNumDevsPerTask;
- int dev_id = i % kNumDevsPerTask;
+ int task_id = i / num_devs_per_task;
+ int dev_id = i % num_devs_per_task;
string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
- task_names.push_back(task_name);
string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
- device_names.push_back(device_name);
cp.instance.task_names.push_back(task_name);
cp.instance.device_names.push_back(device_name);
}
+ return cp;
+}
- int test_rank = 0;
- cp.default_rank = test_rank;
+TEST_F(RingReducerTest, InitializeParams) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
cp.instance.impl_details.subdiv_offsets = {0, 4};
RunSubdivPermsTest(&cp,
{{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
@@ -588,8 +589,15 @@ TEST_F(RingReducerTest, InitializeParams) {
8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
{0, 4});
- test_rank = 3;
- cp.default_rank = test_rank;
+ cp.instance.impl_details.subdiv_offsets = {0, -4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+
+ cp.default_rank = 3;
cp.instance.impl_details.subdiv_offsets = {3, -3};
RunSubdivPermsTest(&cp,
{{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
@@ -599,6 +607,49 @@ TEST_F(RingReducerTest, InitializeParams) {
{0, 1});
}
+TEST_F(RingReducerTest, AutomaticSubdivs) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ // Test automatic generation of subdiv offsets.
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+ {0});
+
+ // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2
+ // offsets, {0, -4}, to be generated.
+ {
+ int num_subdivs = 2;
+ int num_chunks = kNumDevs * num_subdivs;
+ size_t chunk_size = 3 * 1048576; // 3 MB
+ size_t tensor_size = chunk_size * num_chunks;
+ cp.instance.shape =
+ TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
+ }
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+}
+
+TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
+ const int kNumDevsPerTask = 1;
+ const int kNumTasks = 4;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
+}
+
// TODO(b/113171733): change to use TEST_P.
#define DEF_TEST(B, T, W, D, S, L, A) \
TEST_F(RingReducerTest, \
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 0fbc20b34b..6404d8bc6a 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/util.h"
#ifdef INTEL_MKL
#ifdef _OPENMP
@@ -49,6 +50,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
allocator_(allocator),
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
#ifdef INTEL_MKL
+ // Early return when MKL is disabled
+ if (DisableMKL()) return;
#ifdef _OPENMP
const char* user_omp_threads = getenv("OMP_NUM_THREADS");
if (user_omp_threads == nullptr) {
@@ -113,8 +116,12 @@ class MklCPUAllocatorFactory : public AllocatorFactory {
}
};
-REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+#ifdef ENABLE_MKL
+REGISTER_MEM_ALLOCATOR("MklCPUAllocator", (DisableMKL() ? 50 : 200),
+ MklCPUAllocatorFactory);
+#endif // ENABLE_MKL
+
} // namespace
-#endif
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 20a07d86a2..50403b4004 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+namespace {
+
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
+template <typename T>
+Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
+ const Tensor* sizes_value,
+ std::vector<DimensionHandle>* dims) {
+ auto sizes_vec = sizes_value->vec<T>();
+ for (int i = 0; i < sizes_value->NumElements(); ++i) {
+ DimensionHandle dim = c->Dim(c->input(0), i);
+ if (sizes_vec(i) != -1) {
+ auto dim_val = c->Value(dim);
+ if (sizes_vec(i) < 0) {
+ return errors::InvalidArgument(
+ "Out of bounds slicing on dimension ", i, " of length ", dim_val,
+ ": sizes vector cannot be < -1, but was ", sizes_vec(i));
+ }
+
+ dims->emplace_back(c->MakeDim(sizes_vec(i)));
+ } else {
+ DimensionHandle result;
+ TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
+ dims->emplace_back(result);
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status SliceShape(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
+ ShapeHandle sizes_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+
+ // Merge to check compatibility of begin and sizes tensors.
+ TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
+ if (c->ValueKnown(ndims)) {
+ TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
+ }
+
+ // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
+ // values, even though the `begin` value does not represent a shape.
+ ShapeHandle begin_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
+
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes` might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown").
+ const Tensor* sizes_value = c->input_tensor(2);
+
+ if (sizes_value != nullptr) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
+ std::vector<DimensionHandle> dims;
+ // If the begin and sizes tensors are available, then
+ // we can be precise about the shape of the output.
+ if (sizes_value->dtype() == DT_INT64) {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int64>(c, begin_value, sizes_value, &dims));
+ } else {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int32>(c, begin_value, sizes_value, &dims));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ } else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ dims.reserve(c->Rank(sizes_value));
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+ // We might know the rank of the input.
+ if (c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
+ }
+
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index e6f9f935f9..3a496e06ae 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
+// Shape function for Slice opertaions.
+Status SliceShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 697e0604bf..964a7d5f8c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -278,15 +278,8 @@ class IteratorContext {
// Function call support.
std::function<void(std::function<void()>)> runner = nullptr;
- // A function that returns the current `StatsAggregator` instance to be
- // used when recording statistics about the iterator.
- //
- // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator`
- // is a property of the `IteratorResource` (which this class does not know
- // about), and (ii) it can change after the `IteratorContext` has been
- // created. Better suggestions are welcome!
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter =
- nullptr;
+ // The `StatsAggregator` object to record statistics about the iterator.
+ std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
// The FunctionLibraryRuntime object to be used to make function calls.
FunctionLibraryRuntime* lib = nullptr;
@@ -320,13 +313,6 @@ class IteratorContext {
return &params_.runner;
}
- std::shared_ptr<StatsAggregator> stats_aggregator() {
- if (params_.stats_aggregator_getter) {
- return params_.stats_aggregator_getter();
- } else {
- return nullptr;
- }
- }
std::shared_ptr<const FunctionLibraryDefinition> function_library() {
return params_.function_library;
@@ -344,8 +330,8 @@ class IteratorContext {
return params_.allocator_getter;
}
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() {
- return params_.stats_aggregator_getter;
+ std::shared_ptr<StatsAggregator> stats_aggregator() {
+ return params_.stats_aggregator;
}
std::shared_ptr<model::Model> model() { return params_.model; }
@@ -657,15 +643,15 @@ class DatasetBaseIterator : public IteratorBase {
// When performance modeling is enabled, this method adds a tunable parameter
// to the model node corresponding to this iterator.
//
- // The performance modeling logic may use `value` to set the value of the
+ // The performance modeling logic may use `state` to set the value of the
// tunable parameter at any point during the lifetime of this iterator. When
- // it does, it notifies `cond_var`.
+ // it does, it acquires `state->mu` and notifies `state->cond_var`.
void AddTunableParameter(IteratorContext* ctx, const string& name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) {
+ std::shared_ptr<model::SharedState> state, int64 min,
+ int64 max) {
if (ctx->model()) {
- ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
- cond_var);
+ ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min,
+ max);
}
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index a17959a448..20f957190b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
+ TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e01eb7503d..4d6d68e214 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// a non-OK status if "func" was not found in the library, OK otherwise.
Status ReplaceFunction(const string& func, const FunctionDef& fdef);
+ // Replaces the gradient corresponding to `grad.function_name()`. Returns
+ // a non-OK status if "grad.function_name()" was not found in the library, OK
+ // otherwise.
+ Status ReplaceGradient(const GradientDef& grad);
+
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index d5c203d276..0445c242e9 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -93,7 +93,6 @@ FunctionDef IsZero() {
FunctionDef RandomUniform() {
const Tensor kZero = test::AsScalar<int64>(0);
- const Tensor kTen = test::AsScalar<int64>(10);
return FDH::Define(
// Name
@@ -108,19 +107,11 @@ FunctionDef RandomUniform() {
"Const",
{},
{{"value", kZero}, {"dtype", DT_INT64}}},
- {{"random_uniform/min"},
- "Const",
- {},
- {{"value", kZero}, {"dtype", DT_INT64}}},
- {{"random_uniform/max"},
- "Const",
- {},
- {{"value", kTen}, {"dtype", DT_INT64}}},
{{"random_uniform"},
- "RandomUniformInt",
- {},
- {{"T", DT_INT64},
- {"Tout", DT_INT64},
+ "RandomUniform",
+ {"random_uniform/shape"},
+ {{"T", DT_INT32},
+ {"Tout", DT_FLOAT},
{"seed", 87654321},
{"seed2", 42}}}});
}
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index b0330ec990..bfdb3a6658 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -296,12 +296,12 @@ void Model::AddProcessingTime(const string& name, int64 delta) {
void Model::AddTunableParameter(const string& node_name,
const string& parameter_name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) {
+ std::shared_ptr<SharedState> state, int64 min,
+ int64 max) {
tf_shared_lock l(mu_);
auto node = *gtl::FindOrNull(lookup_table_, node_name);
DCHECK(node);
- node->add_tunable_param(parameter_name, value, min, max, cond_var);
+ node->add_tunable_param(parameter_name, std::move(state), min, max);
}
// The optimization algorithm starts by setting all tunable parallelism
@@ -311,54 +311,55 @@ void Model::AddTunableParameter(const string& node_name,
// is less than or equal to the processing time needed to produce an element
// divided by CPU budget.
void Model::Optimize(int64 cpu_budget) {
- tf_shared_lock lock(mu_);
std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
- const int64 processing_time = ProcessingTime();
- tunables = CollectTunables();
- for (auto tunable : tunables) {
- tunable->value = 1;
- }
- while (true) {
- const int64 output_time = OutputTime();
- bool all_tunables = true;
- for (auto& tunable : tunables) {
- if (tunable->value < tunable->max) {
- all_tunables = false;
+ {
+ tf_shared_lock lock(mu_);
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
break;
}
- }
- if (output_time < processing_time / cpu_budget || all_tunables) {
- break;
- }
- int64 best_delta = -1;
- Model::Node::Tunable* best_tunable = nullptr;
- for (auto& tunable : tunables) {
- if (tunable->value == tunable->max) {
- continue;
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
}
- tunable->value++;
- int64 delta = output_time - OutputTime();
- if (delta > best_delta) {
- best_delta = delta;
- best_tunable = tunable.get();
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
}
- tunable->value--;
+ best_tunable->value++;
}
- if (!best_tunable) {
- // NOTE: This can happen because we are performing the optimization
- // while the model data is changing. If this becomes an issue, we should
- // look into performing the optimization using a model snapshot.
- break;
- }
- best_tunable->value++;
}
VLOG(2) << "Number of knobs: " << tunables.size();
for (auto& tunable : tunables) {
VLOG(2) << "Setting tunable parameter: " << tunable->value;
- tunable->value_ptr->store(tunable->value);
- if (tunable->cond_var) {
- tunable->cond_var->notify_all();
- }
+ mutex_lock l(*tunable->state->mu);
+ tunable->state->value = tunable->value;
+ tunable->state->cond_var->notify_all();
}
}
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 26402f5cd3..eae0fa70e8 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -33,6 +33,19 @@ namespace tensorflow {
namespace data {
namespace model {
+// Represents thread-safe state that can be shared between an input pipeline and
+// the performance model.
+struct SharedState {
+ public:
+ explicit SharedState(int64 value, std::shared_ptr<mutex> mu,
+ std::shared_ptr<condition_variable> cond_var)
+ : value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {}
+
+ std::shared_ptr<mutex> mu;
+ std::shared_ptr<condition_variable> cond_var;
+ int64 value;
+};
+
// Abstract representation of a TensorFlow input pipeline that can be used
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
@@ -62,8 +75,8 @@ class Model {
// Adds a tunable parameter for the given node.
void AddTunableParameter(const string& node_name,
const string& parameter_name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
+ std::shared_ptr<SharedState> value, int64 min,
+ int64 max) LOCKS_EXCLUDED(mu_);
// Runs optimization.
void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
@@ -109,13 +122,8 @@ class Model {
public:
// Represents a tunable parameter.
struct Tunable {
- Tunable(std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var)
- : value(*value),
- min(min),
- max(max),
- value_ptr(value),
- cond_var(cond_var) {}
+ Tunable(std::shared_ptr<SharedState> state, int64 min, int64 max)
+ : value(state->value), min(min), max(max), state(std::move(state)) {}
// Identifies the model value of the parameter. This can be different from
// the actual value (e.g. during optimization search).
@@ -127,12 +135,8 @@ class Model {
// Identifies the maximum value of the parameter.
int64 max;
- // Points to the actual value of the parameter. Not owned.
- std::atomic<int64>* value_ptr;
-
- // If non-null, this condition variable is notified when the model updates
- // the actual value of the parameter (via `value_ptr`). Not owned.
- condition_variable* cond_var;
+ // Shared state of the parameter.
+ std::shared_ptr<SharedState> state;
};
Node(int64 id, const string& name, std::shared_ptr<Node> output)
@@ -158,12 +162,12 @@ class Model {
}
// Adds a tunable parameter.
- void add_tunable_param(const string& name, std::atomic<int64>* value,
- int64 min, int64 max, condition_variable* cond_var)
- LOCKS_EXCLUDED(mu_) {
+ void add_tunable_param(const string& name,
+ std::shared_ptr<SharedState> state, int64 min,
+ int64 max) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
tunable_params_[name] =
- std::make_shared<Tunable>(value, min, max, cond_var);
+ std::make_shared<Tunable>(std::move(state), min, max);
}
// Returns the unique node ID.
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 187bfa2c88..0ff67554eb 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
-#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 25f8de8dcc..81ed5f95f0 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -209,16 +209,16 @@ template <>
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
- OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
- builder_.Attr(spec);
+ OpDefBuilderWrapper<true>& Attr(string spec) {
+ builder_.Attr(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Input(StringPiece spec) {
- builder_.Input(spec);
+ OpDefBuilderWrapper<true>& Input(string spec) {
+ builder_.Input(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Output(StringPiece spec) {
- builder_.Output(spec);
+ OpDefBuilderWrapper<true>& Output(string spec) {
+ builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
@@ -237,12 +237,12 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
- OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
- builder_.Deprecated(version, explanation);
+ OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
+ builder_.Deprecated(version, std::move(explanation));
return *this;
}
- OpDefBuilderWrapper<true>& Doc(StringPiece text) {
- builder_.Doc(text);
+ OpDefBuilderWrapper<true>& Doc(string text) {
+ builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 34a7a43d38..8a9bb63182 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -526,32 +526,32 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
-OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(string(op_name)); // NOLINT
+OpDefBuilder::OpDefBuilder(string op_name) {
+ op_def()->set_name(std::move(op_name));
}
-OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
- attrs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Attr(string spec) {
+ attrs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
- inputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Input(string spec) {
+ inputs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
- outputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Output(string spec) {
+ outputs_.push_back(std::move(spec));
return *this;
}
#ifndef TF_LEAN_BINARY
-OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+OpDefBuilder& OpDefBuilder::Doc(string text) {
if (!doc_.empty()) {
errors_.push_back(
strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
} else {
- doc_.assign(text.data(), text.size());
+ doc_ = std::move(text);
}
return *this;
}
@@ -577,14 +577,14 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
-OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
+OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
if (op_def()->has_deprecation()) {
errors_.push_back(
strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(string(explanation));
+ deprecation->set_explanation(std::move(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index 0b39d6e848..8077b20598 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -51,7 +51,7 @@ struct OpRegistrationData {
class OpDefBuilder {
public:
// Constructs an OpDef with just the name field set.
- explicit OpDefBuilder(StringPiece op_name);
+ explicit OpDefBuilder(string op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
@@ -84,7 +84,7 @@ class OpDefBuilder {
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
- OpDefBuilder& Attr(StringPiece spec);
+ OpDefBuilder& Attr(string spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
@@ -101,8 +101,8 @@ class OpDefBuilder {
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
- OpDefBuilder& Input(StringPiece spec);
- OpDefBuilder& Output(StringPiece spec);
+ OpDefBuilder& Input(string spec);
+ OpDefBuilder& Output(string spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
@@ -112,7 +112,7 @@ class OpDefBuilder {
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
- OpDefBuilder& Deprecated(int version, StringPiece explanation);
+ OpDefBuilder& Deprecated(int version, string explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
@@ -128,9 +128,9 @@ class OpDefBuilder {
// to suppress the automatically-generated type documentation in
// generated output.
#ifndef TF_LEAN_BINARY
- OpDefBuilder& Doc(StringPiece text);
+ OpDefBuilder& Doc(string text);
#else
- OpDefBuilder& Doc(StringPiece text) { return *this; }
+ OpDefBuilder& Doc(string text) { return *this; }
#endif
// Sets the shape function to be used for shape inference.
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index ebdaaec153..508a8d3149 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -288,4 +288,13 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
return ctx->resource_manager()->Delete(p);
}
+Status ResourceHandlesShape(shape_inference::InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index d58deaa3fc..4a531648d9 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
+#include <memory>
#include <string>
#include <typeindex>
#include <typeinfo>
@@ -127,6 +128,14 @@ class ResourceMgr {
Status Lookup(const string& container, const string& name,
T** resource) const TF_MUST_USE_RESULT;
+ // Similar to Lookup, but looks up multiple resources at once, with only a
+ // single lock acquisition.
+ template <typename T>
+ Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
+ resource) const TF_MUST_USE_RESULT;
+
// If "container" has a resource "name", returns it in
// "*resource". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*resource".
@@ -239,14 +248,31 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
+//
+// If successful, the caller transfers the ownership of one ref on `resource` to
+// `ctx->resource_mgr()`.
template <typename T>
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
// Looks up a resource pointed by a given resource handle.
+//
+// If the lookup is successful, the caller takes the ownership of one ref on
+// `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
+// Looks up multiple resources pointed by a sequence of resource handles.
+template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
+
// Looks up or creates a resource.
+//
+// If successful, the caller takes the ownership of one ref on `*value`, and
+// must call its `Unref()` method when it has finished using it. If the
+// `creator` is invoked, its reference on the created resource is transferred
+// to `ctx->resource_mgr()`.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator);
@@ -358,6 +384,26 @@ class ResourceHandleOp : public OpKernel {
std::atomic<bool> initialized_{false};
};
+// Utility op kernel to produce a handle to a resource of type T.
+template <typename T>
+class ResourceHandlesOp : public OpKernel {
+ public:
+ explicit ResourceHandlesOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ std::vector<string> containers_;
+ std::vector<string> names_;
+ mutex mutex_;
+ std::vector<Tensor> resources_;
+ std::atomic<bool> initialized_{false};
+};
+
+Status ResourceHandlesShape(shape_inference::InferenceContext* c);
+
// Registers a kernel for an op which produces a handle to a resource of the
// specified type.
#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
@@ -390,6 +436,24 @@ Status ResourceMgr::Lookup(const string& container, const string& name,
}
template <typename T>
+Status ResourceMgr::LookupMany(
+ absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
+ CheckDeriveFromResourceBase<T>();
+ tf_shared_lock l(mu_);
+ resources->resize(containers_and_names.size());
+ for (size_t i = 0; i < containers_and_names.size(); ++i) {
+ T* resource;
+ TF_RETURN_IF_ERROR(LookupInternal(*containers_and_names[i].first,
+ *containers_and_names[i].second,
+ &resource));
+ (*resources)[i].reset(resource);
+ }
+ return Status::OK();
+}
+
+template <typename T>
Status ResourceMgr::LookupInternal(const string& container, const string& name,
T** resource) const {
ResourceBase* found = nullptr;
@@ -499,6 +563,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
}
template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
+ std::vector<std::pair<const string*, const string*>> containers_and_names(
+ p.size());
+ for (size_t i = 0; i < p.size(); ++i) {
+ TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
+ containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
+ }
+ return ctx->resource_manager()->LookupMany(containers_and_names, values);
+}
+
+template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
@@ -555,6 +632,46 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
ctx->set_output(0, resource_);
}
+template <typename T>
+ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ int n;
+ OP_REQUIRES_OK(context, context->GetAttr("N", &n));
+ OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
+ OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
+ OP_REQUIRES(
+ context, containers_.size() == n,
+ errors::InvalidArgument("Number of containers (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ OP_REQUIRES(context, names_.size() == n,
+ errors::InvalidArgument("Number of names (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ resources_.resize(n);
+}
+
+template <typename T>
+void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
+ if (!initialized_.load()) {
+ mutex_lock ml(mutex_);
+ // Checking again to see if another thread has initialized the resource.
+ if (!initialized_.load()) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
+ &resources_[i], attr));
+ ResourceHandle h =
+ MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
+ resources_[i].template scalar<ResourceHandle>()() = h;
+ }
+ initialized_.store(true);
+ }
+ }
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ ctx->set_output(i, resources_[i]);
+ }
+}
+
} // end namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc
new file mode 100644
index 0000000000..0c4007eafc
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.cc
@@ -0,0 +1,249 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/run_handler.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/run_handler_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+// Contains the concrete implementation of the RunHandler.
+// Externally visible RunHandler class simply forwards the work to this one.
+class RunHandler::Impl {
+ public:
+ explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
+ Reset();
+ }
+
+ ~Impl() {}
+
+ void set_inter_op_scheduling_range(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ inter_op_scheduling_range_.store(EncodePartition(start, limit),
+ std::memory_order_release);
+ }
+
+ std::uint_fast32_t inter_op_scheduling_range() const {
+ return inter_op_scheduling_range_.load(std::memory_order_acquire);
+ }
+
+ // Stores now time (in microseconds) since unix epoch when the handler is
+ // requested via RunHandlerPool::Get().
+ uint64 start_time_us() const { return start_time_us_; }
+
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ void Reset();
+
+ RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
+
+ private:
+ // Encoding/decoding logic for storing [start, limit) into a single
+ // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
+ const int kMaxPartitionBits = 16;
+ const int kMaxThreads = 1 << kMaxPartitionBits;
+
+ std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ return (start << kMaxPartitionBits) | limit;
+ }
+
+ void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
+ std::uint_fast32_t* limit) {
+ *limit = val & (kMaxThreads - 1);
+ val >>= kMaxPartitionBits;
+ *start = val;
+ }
+
+ std::atomic_uint_fast32_t inter_op_scheduling_range_;
+ RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
+ uint64 start_time_us_;
+};
+
+// Contains shared state across all run handlers present in the pool. Also
+// responsible for pool management decisions.
+// This class is thread safe.
+class RunHandlerPool::Impl {
+ public:
+ explicit Impl(int num_inter_op_threads)
+ : max_handlers_(128),
+ inter_op_thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
+ iterations_(0) {
+ VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
+ for (int i = 0; i < max_handlers_; ++i) {
+ handlers_.emplace_back(new RunHandler::Impl(this));
+ free_handlers_.push_back(handlers_.back().get());
+ }
+ }
+
+ ~Impl() {
+ // Sanity check that all handlers have been returned back to the pool before
+ // destruction.
+ DCHECK_EQ(handlers_.size(), max_handlers_);
+ DCHECK_EQ(free_handlers_.size(), handlers_.size());
+ DCHECK_EQ(sorted_active_handlers_.size(), 0);
+ }
+
+ thread::ThreadPool* inter_op_thread_pool() const {
+ return inter_op_thread_pool_.get();
+ }
+
+ std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ while (free_handlers_.empty()) {
+ one_handler_free_.wait(l);
+ }
+ // Remove the last entry from free_handlers_ and add to the end of
+ // sorted_active_handlers_.
+ auto* handler_impl = free_handlers_.back();
+ handler_impl->Reset();
+ // Sortedness isn't violated if we simply add at the end of the list, since
+ // handlers are expected to be obtained in increasing order of time.
+ sorted_active_handlers_.push_back(handler_impl);
+ DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
+ free_handlers_.pop_back();
+
+ RecomputePoolStatsLocked();
+ return WrapUnique<RunHandler>(new RunHandler(handler_impl));
+ }
+
+ void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ DCHECK_GT(sorted_active_handlers_.size(), 0);
+
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ double elapsed = (now - handler->start_time_us()) / 1000.0;
+ time_hist_.Add(elapsed);
+
+ // Erase from and update sorted_active_handlers_. Add it to the end of
+ // free_handlers_.
+ auto iter = std::find(sorted_active_handlers_.begin(),
+ sorted_active_handlers_.end(), handler);
+ DCHECK(iter != sorted_active_handlers_.end())
+ << "Unexpected handler: " << handler
+ << " is being requested for release";
+
+ // Remove this handler from this list and add it to the list of free
+ // handlers.
+ sorted_active_handlers_.erase(iter);
+ free_handlers_.push_back(handler);
+ DCHECK_LE(free_handlers_.size(), max_handlers_);
+
+ RecomputePoolStatsLocked();
+ }
+ one_handler_free_.notify_one();
+ }
+
+ private:
+ void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Maximum number of handlers pre-created during pool construction time. The
+ // number has been chosen expecting each handler might at least want 1
+ // inter-op thread for execution (during compute intensive workloads like
+ // inference).
+ const int max_handlers_;
+
+ // Thread safe part.
+ const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
+
+ // Thread compatible part used only by lock under RunHandlerPool.
+ // Handlers are sorted by start time.
+ std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
+ std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
+ // Histogram of elapsed runtime of every handler (in ms).
+ histogram::Histogram time_hist_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
+ int64 iterations_ GUARDED_BY(mu_);
+ condition_variable one_handler_free_;
+ mutex mu_;
+};
+
+void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
+ int num_active_requests = sorted_active_handlers_.size();
+ if (num_active_requests == 0) return;
+
+ int num_threads = inter_op_thread_pool_->NumThreads();
+
+ inter_op_start_.resize(num_active_requests);
+ inter_op_limit_.resize(num_active_requests);
+
+ const int kMinThreadsPerRequest = 3;
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ kMinThreadsPerRequest, &inter_op_start_,
+ &inter_op_limit_);
+
+ for (int i = 0; i < num_active_requests; ++i) {
+ sorted_active_handlers_[i]->set_inter_op_scheduling_range(
+ inter_op_start_[i], inter_op_limit_[i]);
+ }
+
+ if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
+ VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
+ VLOG(1) << "Active session runs: " << num_active_requests;
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ string ranges_str = "";
+ string times_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) {
+ times_str += " ";
+ ranges_str += " ";
+ }
+
+ times_str += strings::StrCat(
+ (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
+ ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
+ inter_op_limit_[i], ")");
+ }
+ VLOG(1) << "Elapsed times are: " << times_str;
+ VLOG(1) << "Ranges are: " << ranges_str;
+ }
+}
+
+void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
+ std::uint_fast32_t start = 0, limit = 0;
+ DecodePartition(inter_op_scheduling_range(), &start, &limit);
+ pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn));
+}
+
+void RunHandler::Impl::Reset() {
+ set_inter_op_scheduling_range(
+ 0, pool_impl_->inter_op_thread_pool()->NumThreads());
+ start_time_us_ = tensorflow::Env::Default()->NowMicros();
+}
+
+RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
+ : impl_(new Impl(num_inter_op_threads)) {}
+
+RunHandlerPool::~RunHandlerPool() {}
+
+std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
+
+RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
+
+void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
+ impl_->ScheduleInterOpClosure(std::move(fn));
+}
+
+RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h
new file mode 100644
index 0000000000..72fa6301b4
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.h
@@ -0,0 +1,95 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+class RunHandler;
+
+// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
+// that can be used for tracking inter-op work for a given Session::Run().
+// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
+// 'active' when its unique_ptr is returned by Get() and is being used by a
+// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
+//
+// Expected usage:
+//
+// * Create a single RunHandlerPool (say run_handler_pool_).
+//
+// * When a Session::Run() is invoked, obtain a handler by:
+// auto handler = run_handler_pool_->Get();
+//
+// * Use handler for scheduling all inter-op work by:
+// handler->ScheduleInterOpClosure(closure);
+//
+// This class is thread safe.
+class RunHandlerPool {
+ public:
+ explicit RunHandlerPool(int num_inter_op_threads);
+ ~RunHandlerPool();
+
+ // Returns an inactive RunHandler from the pool.
+ //
+ // RunHandlers in RunHandlerPool are initially 'inactive'.
+ // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
+ // and is being used by a client. It becomes 'inactive' once more when the
+ // unique_ptr is destroyed.
+ //
+ // Will block unless there is an inactive handler.
+ std::unique_ptr<RunHandler> Get();
+
+ private:
+ class Impl;
+ friend class RunHandler;
+
+ std::unique_ptr<Impl> impl_;
+};
+
+// RunHandler can be used to schedule inter-op closures to run on a global pool
+// shared across all Session::Run(s).
+//
+// It can only be created via RunHandlerPool::Get().
+//
+// This class can be used instead of directly scheduling closures on a global
+// pool since it maintains a global view across all sessions and optimizes pool
+// scheduling to improve (median and tail) latency.
+//
+// This class is thread safe.
+class RunHandler {
+ public:
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ ~RunHandler();
+
+ private:
+ class Impl;
+ friend class RunHandlerPool::Impl;
+
+ explicit RunHandler(Impl* impl);
+
+ Impl* impl_; // NOT OWNED.
+};
+
+} // end namespace tensorflow.
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc
new file mode 100644
index 0000000000..3087998c69
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec) {
+ // Each request is expected to have weight W[i] = num_active_requests - i.
+ // Therefore, total_weight = sum of all request weights.
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float demand_factor = static_cast<float>(num_threads) / total_weight;
+ float last_cumulative_weight = 0.0;
+ min_threads_per_request = std::max(1, min_threads_per_request);
+ for (int i = 0; i != num_active_requests; i++) {
+ float cumulative_weight =
+ static_cast<float>(i + 1) *
+ (num_active_requests - static_cast<float>(i) * 0.5f);
+ float weight = cumulative_weight - last_cumulative_weight;
+ // Quantize thread_demand by rounding up, and also satisfying
+ // `min_threads_per_request` constraint.
+ // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from
+ // rounding weights like 4.0 to 5.
+ int demand =
+ std::max(min_threads_per_request,
+ static_cast<int>(ceil(weight * demand_factor - 0.00001f)));
+ // For the quantized range [start, end); compute the floor of real start,
+ // and expand downwards from there with length `demand` and adjust for
+ // boundary conditions.
+ int start = last_cumulative_weight * demand_factor;
+ int end = std::min(num_threads, start + demand);
+ start = std::max(0, std::min(start, end - demand));
+ start_vec->at(i) = start;
+ end_vec->at(i) = end;
+ last_cumulative_weight = cumulative_weight;
+ }
+}
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h
new file mode 100644
index 0000000000..c0c36aeccb
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.h
@@ -0,0 +1,43 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+
+#include <cstdint>
+#include <vector>
+
+namespace tensorflow {
+
+// Assign thread ranges to requests.
+// Requests are numbered 0...num_active_requests-1, and
+// threads are numbered 0...num_threads-1.
+// On return, the range start_vec->at(i)...end_vec->at(i)-1
+// indicates the subrange of the threads available to request i.
+// The ranges given to different requests may overlap.
+// Lower numbered requests will tend to be assigned more threads.
+// Thus, a client might associate older requests with lower
+// array indices so they receive access to more threads.
+// However, the routine ensures that each request is given access
+// to at least min(min_threads_per_request, num_threads) threads.
+// Every thread will be assigned to at least one request range,
+// assuming there is at least one request.
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec);
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc
new file mode 100644
index 0000000000..a1928c132b
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <vector>
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+namespace tensorflow {
+namespace {
+
+void VerifyFunction(int num_active_requests, int num_threads,
+ int min_threads_per_request, bool print_stats = false) {
+ if (print_stats) {
+ LOG(INFO) << "Test case# num_active_requests: " << num_active_requests
+ << " num_threads: " << num_threads
+ << " min_threads: " << min_threads_per_request;
+ }
+ std::vector<std::uint_fast32_t> start(num_active_requests);
+ std::vector<std::uint_fast32_t> end(num_active_requests);
+
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ min_threads_per_request, &start, &end);
+ string range_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) range_str += " ";
+ range_str += strings::StrCat("[", start[i], ", ", end[i], ")");
+
+ ASSERT_GE(start[i], 0) << range_str;
+ ASSERT_LE(end[i], num_threads) << range_str;
+ if (i > 0) {
+ // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i)
+ ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str;
+ // No missing threads.
+ ASSERT_GE(end[i - 1], start[i]) << range_str;
+ }
+ // Each interval is at least of size 'min_threads_per_request'.
+ ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str;
+ // Verify that assigned (quantized) threads is not overly estimated
+ // from real demand, when the demand is high (>=
+ // min_threads_per_request).
+ float entry_weight = num_active_requests - i;
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float thread_demand = (entry_weight * num_threads) / total_weight;
+ if (thread_demand > min_threads_per_request) {
+ // We expect some over-estimation of threads due to quantization,
+ // but we hope it's not more than 1 extra thread.
+ ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0)
+ << "Ranges: " << range_str << " thread_demand: " << thread_demand
+ << " i: " << i;
+ }
+ }
+ ASSERT_EQ(end[num_active_requests - 1], num_threads);
+ ASSERT_EQ(start[0], 0);
+ if (print_stats) {
+ LOG(INFO) << "Assigned ranges: " << range_str;
+ }
+}
+
+TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) {
+ const int kMinThreadsPerRequestBound = 12;
+ const int kMaxActiveRequests = 128;
+ const int kMaxThreads = 128;
+
+ for (int min_threads_per_request = 1;
+ min_threads_per_request <= kMinThreadsPerRequestBound;
+ ++min_threads_per_request) {
+ for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests;
+ ++num_active_requests) {
+ for (int num_threads = min_threads_per_request;
+ num_threads <= kMaxThreads; ++num_threads) {
+ VerifyFunction(num_active_requests, num_threads,
+ min_threads_per_request);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index 3df677675e..1dea6da911 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -813,7 +813,7 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
}
Tensor Tensor::SubSlice(int64 index) const {
- CHECK_GE(dims(), 2); // Crash ok.
+ CHECK_GE(dims(), 1); // Crash ok.
CHECK_LE(0, index); // Crash ok.
int64 dim0_size = shape_.dim_size(0);
CHECK_LE(index, dim0_size); // Crash ok.
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 8a0c70fef2..d0f9eb56e2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -219,7 +219,7 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
- /// REQUIRES: `dims()` >= 2
+ /// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start < dim_size(0)`
Tensor SubSlice(int64 index) const;
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 0bfa53e6c5..c596604143 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -1246,6 +1246,9 @@ TEST(Tensor, SubSlice_Basic) {
EXPECT_EQ(&tx(5, j, k), &ty(j, k));
}
}
+ Tensor z = y.SubSlice(3).SubSlice(31);
+ auto tz = z.unaligned_flat<float>();
+ EXPECT_EQ(*tz.data(), 5.0);
}
{
// Test unaligned access via a SubSlice.
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 1630ab7a15..7a4a0096fa 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -192,6 +192,11 @@ void Node::ClearAttr(const string& name) {
(*props_->node_def.mutable_attr()).erase(name);
}
+void Node::set_name(string name) {
+ MaybeCopyOnWrite();
+ props_->node_def.set_name(std::move(name));
+}
+
void Node::set_requested_device(const string& device) {
MaybeCopyOnWrite();
props_->node_def.set_device(device);
@@ -643,7 +648,7 @@ Status Graph::IsValidNode(const Node* node) const {
Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
- if (idx >= node->num_outputs()) {
+ if (idx >= node->num_outputs() || idx < 0) {
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of outputs: ", node->num_outputs(),
@@ -654,7 +659,7 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
Status Graph::IsValidInputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
- if (idx >= node->num_inputs()) {
+ if (idx >= node->num_inputs() || idx < 0) {
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of inputs: ", node->num_inputs(),
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 52e9f23a76..2944951f82 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -72,6 +72,7 @@ class Node {
int id() const { return id_; }
int cost_id() const { return cost_id_; }
const string& name() const;
+ void set_name(string name);
const string& type_string() const;
// def() provides the NodeDef the user supplied, but the specifics
@@ -590,12 +591,12 @@ class Graph {
// Returns OK if `node` is non-null and belongs to this graph
Status IsValidNode(const Node* node) const;
- // Returns OK if IsValidNode(`node`) and `idx` is less than
- // node->num_outputs()
+ // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
+ // accept control outputs.
Status IsValidOutputTensor(const Node* node, int idx) const;
- // Returns OK if IsValidNode(`node`) and `idx` is less than
- // node->num_inputs()
+ // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
+ // control inputs.
Status IsValidInputTensor(const Node* node, int idx) const;
// Create and return a new WhileContext owned by this graph. This is called
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index f5b0105862..7394b1cddf 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/util.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
@@ -977,7 +978,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -2448,6 +2451,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.tanh = "Tanh";
csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
+ csinfo_.slice = "Slice";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
// Element-wise ops. Ensure you also add any new ops to IsOpElementWise
@@ -2555,6 +2559,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.slice,
+ mkl_op_registry::GetMklOpName(csinfo_.slice),
+ CopyAttrsSlice, AlwaysRewrite});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
@@ -2674,6 +2681,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string tanh;
string tanh_grad;
string reshape;
+ string slice;
string softmax;
string split;
string squared_difference;
@@ -3132,6 +3140,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
@@ -3150,7 +3159,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -3735,6 +3746,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
nb->Attr("Tshape", Tshape);
}
+void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Index;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Index", Index);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@@ -4488,6 +4512,10 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
+ if (DisableMKL()) {
+ VLOG(2) << "TF-MKL: Disabling MKL";
+ return Status::OK();
+ }
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of a graph
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index e8bac847e5..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Int32Input'}"
+ "node { name: 'D' op: 'Slice'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Index' value { type: DT_INT32 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Int32Input);"
+ "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
+ "_2(Const);E(Zeta)|A->D;A->E;"
+ "A:control->DMT/_0:control;A:control->DMT/"
+ "_1:control;A:control->DMT/_2:control;"
+ "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
@@ -3586,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index b67a321fc1..6804ab84ce 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/util.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
@@ -133,7 +134,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
@@ -422,6 +425,10 @@ Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
+ if (DisableMKL()) {
+ VLOG(2) << "TF-MKL: Disabling MKL";
+ return Status::OK();
+ }
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of graph
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551..319437a801 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index a446e0d136..d92874909f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
return *this;
}
+NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
+ assigned_device_ = string(device);
+ return *this;
+}
+
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;
@@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
Node* node = graph->AddNode(node_def, &status);
if (!status.ok()) return status;
+ node->set_assigned_device_name(assigned_device_);
+
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs_[i].node != nullptr) { // Skip back edges.
graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 4727ee7b56..d576985a23 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -100,6 +100,9 @@ class NodeBuilder {
// "assigned device" in the Node).
NodeBuilder& Device(StringPiece device_spec);
+ // Sets the device name in the "assigned device" field in tensorflow::Node.
+ NodeBuilder& AssignedDevice(StringPiece device);
+
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
@@ -141,6 +144,7 @@ class NodeBuilder {
std::vector<NodeOut> inputs_;
std::vector<Node*> control_inputs_;
std::vector<string> errors_;
+ string assigned_device_;
};
// IMPLEMENTATION -------------------------------------------------------------
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index f716cd72c9..28fd7565cc 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -74,6 +74,10 @@ class GraphProperties {
// shape information.
void ClearInputProperties(const string& node_name);
void ClearOutputProperties(const string& node_name);
+ // Returns true if we have *any* properties.
+ bool has_properties() const {
+ return input_properties_.size() > 0 || output_properties_.size() > 0;
+ }
private:
// Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 362092a6cf..db10f586bc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -1340,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
Output g = ops::Shape(s.WithOpName("g"), c);
Output h = ops::Fill(s.WithOpName("h"), g, zero);
+ Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
+ Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -1382,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
ASSERT_EQ(2, shape_f.dim_size());
EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
+
+ const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
+ ASSERT_EQ(1, shape_j.dim_size());
+ EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
}
TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
index c94ee2f227..0ec95dd684 100644
--- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
@@ -88,6 +88,13 @@ library {
}
}
}
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ }
+ }
+ }
}
ret {
key: "while"
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 2619a9a8f3..de0a63fc4e 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -20,23 +20,25 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
- for (int output_arg_id = 0; output_arg_id < op.output_arg_size();
- ++output_arg_id) {
+namespace {
+int OpPortIdToArgId(const NodeDef& node,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ int port_id) {
+ for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
if (port_id < 0) {
return -1;
} else if (port_id == 0) {
- return output_arg_id;
+ return arg_id;
}
- // Default is 1 port per output arg.
+ // Default is 1 port per arg.
int n = 1;
- const auto& output_arg = op.output_arg(output_arg_id);
- if (!output_arg.number_attr().empty()) {
- n = node.attr().at(output_arg.number_attr()).i();
- } else if (!output_arg.type_list_attr().empty()) {
- n = node.attr().at(output_arg.type_list_attr()).list().type_size();
+ const auto& arg = args.Get(arg_id);
+ if (!arg.number_attr().empty()) {
+ n = node.attr().at(arg.number_attr()).i();
+ } else if (!arg.type_list_attr().empty()) {
+ n = node.attr().at(arg.type_list_attr()).list().type_size();
}
if (n < 0) {
@@ -44,13 +46,22 @@ int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
DCHECK_GE(n, 0);
return -1;
} else if (port_id < n) {
- return output_arg_id;
+ return arg_id;
}
port_id -= n;
}
return -1;
}
+} // end namespace
+
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.output_arg(), port_id);
+}
+
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.input_arg(), port_id);
+}
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
@@ -72,7 +83,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
- string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
+ const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ec946ca3b5..09c36a1368 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -26,7 +26,7 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Map a node/op's output port_id to arg_id.
+// Map a node/op's input/output port_id to arg_id.
//
// The port_id refers to the n-th tensor of the node, while the arg_id refers to
// the n-th arg of the op. These two can be different if an op's arg is a list
@@ -34,6 +34,7 @@ namespace grappler {
//
// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 3d7d2faf7c..f90e2c8cfc 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -26,7 +26,7 @@ namespace {
class GraphViewTest : public ::testing::Test {};
-TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
+TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
ops::ShapeN b(s.WithOpName("b"), {a, a, a});
@@ -45,9 +45,16 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
- EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *a_op_def, 0));
- EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *a_op_def, 1));
+ // Const has 0 inputs, 1 output.
+ EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1));
+ // ShapeN has N=3 inputs and outputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
@@ -55,7 +62,7 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdShapeN) {
EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
}
-TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
+TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) {
for (int num_splits : {1, 2}) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
@@ -70,6 +77,13 @@ TEST_F(GraphViewTest, OpOutputPortIdToArgIdSparseSplit) {
EXPECT_TRUE(
OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+ // We have 4 inputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4));
+
for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
int arg_id = -1;
if (port_id < num_splits * 3) {
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index bbc0fedd22..2c490f3966 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
+ allowed_optimizations = other.allowed_optimizations;
graph.Swap(graph_def);
}
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 939e5fa046..a0748abfe6 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -77,6 +77,15 @@ struct GrapplerItem {
// Return a set of node names that must be preserved. This includes feed and
// fetch nodes, keep_ops, init_ops.
std::unordered_set<string> NodesToPreserve() const;
+
+ // Restrict types of optimizations that are allowed for this GrapplerItem.
+ struct AllowedOptimizations {
+ // Is it allowed to add nodes to the graph that do not have registered
+ // gradient function.
+ bool non_differentiable_rewrites = true;
+ };
+
+ AllowedOptimizations allowed_optimizations;
};
// Return the transitive fanin of a set of terminal nodes.
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 029515ad3c..369046666d 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -192,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
+ for (const auto& fetch_node : cfg.fetch_nodes) {
+ new_item->fetch.emplace_back(NodeName(fetch_node));
+ }
- // Attempt to detect the fetch node(s).
- if (meta_graph.collection_def().count("train_op") > 0) {
+ // Attempt to detect the fetch node(s) if they were not set explicitly.
+ if (new_item->fetch.empty() &&
+ meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index aafd2fdcda..1698587f8c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -49,6 +49,8 @@ struct ItemConfig {
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
+ // Override fetch nodes list.
+ std::set<string> fetch_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index 4b90bf3038..d00981f174 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
}
+TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ ItemConfig config;
+ config.feed_nodes.insert("x");
+ config.fetch_nodes.insert("z");
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, config);
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 3521669b63..cbf5c8e038 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unordered_set>
-
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -102,6 +101,22 @@ bool IsConjugateTranspose(const NodeDef& node) {
return node.op() == "ConjugateTranspose";
}
+bool IsControlFlow(const NodeDef& node) {
+ // TODO(williamchan): Add a microbenchmark to compare FlatSet vs. iterative
+ // string comparison.
+ static const gtl::FlatSet<string>* const kControFlowOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "ControlTrigger",
+ "Enter",
+ "Exit",
+ "LoopCond",
+ "Merge",
+ "NextIteration",
+ "Switch",
+ }));
+ return kControFlowOps->count(node.op()) > 0;
+}
+
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
@@ -140,26 +155,26 @@ bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
// e.g. inv.
bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
- static const std::unordered_set<string>* monotonic_non_decreasing_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
"Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
"Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
}));
- static const std::unordered_set<string>* monotonic_non_increasing_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Inv",
"Reciprocal",
"Erfc",
"Rsqrt",
"Neg",
}));
- if (monotonic_non_decreasing_ops->count(node.op()) > 0) {
+ if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = true;
}
return true;
- } else if (monotonic_non_increasing_ops->count(node.op()) > 0) {
+ } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
if (is_non_decreasing) {
*is_non_decreasing = false;
}
@@ -425,8 +440,44 @@ bool IsSwitch(const NodeDef& node) {
return op == "Switch" || op == "RefSwitch";
}
+bool IsSymbolicGradient(const NodeDef& node) {
+ return node.op() == "SymbolicGradient";
+}
+
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
+bool IsTensorArray(const NodeDef& node) {
+ static const gtl::FlatSet<string>* const kTensorArrayOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "TensorArray",
+ "TensorArrayV2",
+ "TensorArrayV3",
+ "TensorArrayGrad",
+ "TensorArrayGradV2",
+ "TensorArrayGradV3",
+ "TensorArrayGradWithShape",
+ "TensorArrayWrite",
+ "TensorArrayWriteV2",
+ "TensorArrayWriteV3",
+ "TensorArrayRead",
+ "TensorArrayReadV2",
+ "TensorArrayReadV3",
+ "TensorArrayConcat",
+ "TensorArrayConcatV2",
+ "TensorArrayConcatV3",
+ "TensorArraySplit",
+ "TensorArraySplitV2",
+ "TensorArraySplitV3",
+ "TensorArraySize",
+ "TensorArraySizeV2",
+ "TensorArraySizeV3",
+ "TensorArrayClose",
+ "TensorArrayCloseV2",
+ "TensorArrayCloseV3",
+ }));
+ return kTensorArrayOps->count(node.op()) > 0;
+}
+
bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
@@ -538,30 +589,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
OPDEF_PROPERTY_HELPER(Commutative, commutative)
bool IsInvolution(const NodeDef& node) {
- static const std::unordered_set<string>* involution_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
- "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"}));
- return involution_ops->count(node.op()) > 0;
+ static const gtl::FlatSet<string>* const kInvolutionOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
+ "Neg", "LogicalNot"}));
+ return kInvolutionOps->count(node.op()) > 0;
}
bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>*
- value_and_order_and_shape_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
- "PreventGradient",
- "Print",
- "Snapshot",
- "StopGradient",
- }));
- return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 ||
+ static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "PreventGradient",
+ "Print",
+ "Snapshot",
+ "StopGradient",
+ }));
+ return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
IsIdentity(node);
}
@@ -569,31 +619,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>* value_and_order_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
"ExpandDims",
"Reshape",
"Squeeze",
}));
- return value_and_order_preserving_ops->count(node.op()) > 0 ||
+ return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
bool IsValuePreserving(const NodeDef& node) {
- static const std::unordered_set<string>* value_preserving_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValuePreservingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"InvertPermutation",
"Reverse",
"Roll",
"Transpose",
}));
return IsValueAndOrderPreserving(node) ||
- value_preserving_ops->count(node.op()) > 0;
+ kValuePreservingOps->count(node.op()) > 0;
}
bool IsUnaryElementWise(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kElementWiseOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Abs",
"Acos",
"Acosh",
@@ -642,7 +692,7 @@ bool IsUnaryElementWise(const NodeDef& node) {
"Tan"
"Tanh",
}));
- return element_wise_ops->count(node.op()) > 0 ||
+ return kElementWiseOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 25ab6b65ac..d4e0159e81 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
+bool IsControlFlow(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
bool IsConv2DBackpropFilter(const NodeDef& node);
bool IsConv2DBackpropInput(const NodeDef& node);
@@ -149,7 +150,9 @@ bool IsStridedSliceGrad(const NodeDef& node);
bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
+bool IsSymbolicGradient(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
+bool IsTensorArray(const NodeDef& node);
bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 960d1addb3..c708f84948 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -525,6 +525,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
@@ -541,6 +542,7 @@ tf_cuda_cc_test(
":custom_graph_optimizer_registry",
":meta_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 75ed12635e..7d5014ee0a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -276,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
- string name = ParseNodeName(input, &position);
+ StringPiece name = ParseNodeNameAsStringPiece(input, &position);
if (name == node.name() && /*control input*/ position < 0) {
return true;
}
@@ -1568,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
for (NodeDef* output : outputs) {
if (IsControlInput(output->input(0))) continue;
int port;
- const string node_name = ParseNodeName(output->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(output->input(0), &port);
if (node_name == node.name()) {
tails->insert(ChainLink(output, port));
} else {
@@ -1618,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
} else {
for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
int port;
- const string node_name = ParseNodeName(new_tail->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(new_tail->input(0), &port);
if (node_name != tail->name()) {
return Status::OK();
}
@@ -2929,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
for (const auto& input : node.input()) {
int pos;
- string node_name = ParseNodeName(input, &pos);
- h = Hash64CombineUnordered(Hash64(node_name), h);
+ const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos);
+ h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h);
h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
@@ -3247,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ // Disable restricted graph rewrites.
+ options_.unary_ops_composition &=
+ item.allowed_optimizations.non_differentiable_rewrites;
+
if (options_.dedup_computations) {
DedupComputations();
}
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ca5d3a6dfd..3d0d95bba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -616,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices(
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
- const int rank = input_prop.shape().dim_size();
- if (rank == 0) {
+ const int input_rank = input_prop.shape().dim_size();
+ if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
+ const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
+ DataType dtype = reduction_indices_prop.dtype();
+ if (dtype != DT_INT32 && dtype != DT_INT64) {
+ return Status::OK();
+ }
+ PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
+ const int num_reduction_indices = reduction_indices_shape.num_elements();
+
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
- const bool keep_dims =
- node->attr().count("keep_dims") && node->attr().at("keep_dims").b();
const OpInfo::TensorProperties& output_prop = output_props[0];
- PartialTensorShape output_shape(output_prop.shape());
- if (output_shape.num_elements() != 1) {
- bool full_reduction = false;
+ const int output_rank =
+ output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
+
+ bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
+ if (!full_reduction) {
+ // A full reduction will generate a tensor of one of the shapes
+ // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
+ // elements in the output of the reduction, we may deduce it from reshape
+ // nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
- if (!IsReshape(*fanout) && !keep_dims) {
- // Depending on how it's setup, a full reduction will generate a tensor
- // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we
- // rely on the existence of a reshape node following the reduction to
- // ensure that the fanout is fed a scalar of the right shape.
+ full_reduction = false;
+ if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
@@ -658,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices(
}
}
- const OpInfo::TensorProperties& reduction_prop = input_props[1];
- DataType dtype = reduction_prop.dtype();
- if (dtype != DT_INT32 && dtype != DT_INT64) {
- return Status::OK();
- }
- // We know it's a full reduction. We can generate the set of indices to
- // reduce.
+ // We know it's a full reduction. We can generate the full set of indices to
+ // reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
- Tensor value(dtype, TensorShape({rank}));
- for (int i = 0; i < rank; ++i) {
+ Tensor value(dtype, TensorShape({input_rank}));
+ for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
@@ -680,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices(
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
+
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b09360a2c2..fab01edfed 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2591,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output input =
- ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
- Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
- Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
- Output size = ops::Const(s.WithOpName("size"), 1, {1});
- Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ for (bool use_reshape : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ // If use_reshape is false, we need to now the number of indices to apply
+ // the rewrite.
+ Output indices = ops::Placeholder(
+ s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+ if (use_reshape) {
+ Output size = ops::Const(s.WithOpName("size"), 1, {1});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch.push_back("reshape");
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back(use_reshape ? "reshape" : "sum");
- auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
- Tensor indices_t(DT_INT32, TensorShape({2}));
- indices_t.flat<int>()(0) = 0;
- indices_t.flat<int>()(1) = 1;
- auto tensors_expected = EvaluateNodes(
- item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors_expected.size());
+ auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ Tensor indices_t(DT_INT32, TensorShape({2}));
+ indices_t.flat<int>()(0) = 0;
+ indices_t.flat<int>()(1) = 1;
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors_expected.size());
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- // Run a second time to make sure the optimization is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "ConstantFolding/sum-reduction_indices") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^indices", node.input(0));
- EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
- } else if (node.name() == "sum") {
- ++found;
- EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
- } else if (node.name() == "indices") {
- ++found;
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "ConstantFolding/sum-reduction_indices") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^indices", node.input(0));
+ EXPECT_EQ(2,
+ TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "sum") {
+ ++found;
+ EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
+ } else if (node.name() == "indices") {
+ ++found;
+ }
}
+ EXPECT_EQ(3, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch,
+ {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
- EXPECT_EQ(3, found);
+}
- auto tensors = EvaluateNodes(output, item.fetch,
- {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
+TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
+ for (bool input_rank_known : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(
+ PartialTensorShape({-1, -1})))
+ : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
+ Output indices =
+ ops::Placeholder(s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(
+ PartialTensorShape({input_rank_known ? 1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("sum");
+
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ CompareGraphs(item.graph, output);
+ }
}
TEST_F(ConstantFoldingTest, LargeConstant) {
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index cf305cebe1..ee7c14e3ab 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -22,6 +22,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -31,6 +32,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":graph_test_utils",
":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
@@ -87,11 +89,12 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -121,11 +124,12 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -135,6 +139,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -146,6 +151,62 @@ tf_cc_test(
)
cc_library(
+ name = "graph_test_utils",
+ testonly = 1,
+ srcs = ["graph_test_utils.cc"],
+ hdrs = [
+ "graph_test_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "hoist_random_uniform",
+ srcs = ["hoist_random_uniform.cc"],
+ hdrs = [
+ "hoist_random_uniform.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "hoist_random_uniform_test",
+ srcs = ["hoist_random_uniform_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":hoist_random_uniform",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "latency_all_edges",
srcs = ["latency_all_edges.cc"],
hdrs = [
@@ -256,7 +317,7 @@ cc_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core:ptr_util",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -265,6 +326,7 @@ tf_cc_test(
srcs = ["map_and_filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_and_filter_fusion",
"//tensorflow/core:framework",
@@ -294,6 +356,7 @@ cc_library(
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -302,6 +365,7 @@ tf_cc_test(
srcs = ["map_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_fusion",
"//tensorflow/core:framework",
@@ -339,6 +403,7 @@ tf_cc_test(
srcs = ["map_parallelization_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_parallelization",
"//tensorflow/core:framework",
@@ -422,6 +487,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":filter_fusion",
+ ":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
@@ -458,7 +524,9 @@ cc_library(
deps = [
":function_utils",
":graph_utils",
+ "//tensorflow/cc:ops",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -474,6 +542,7 @@ tf_cc_test(
srcs = ["vectorization_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
":function_utils",
":vectorization_utils",
"//tensorflow/core:framework",
@@ -483,7 +552,10 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
index c71aa6e804..1ad495bbad 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -43,19 +43,14 @@ NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
fused_node.set_op("FilterDataset");
fused_node.add_input(first_filter_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = first_filter_node.attr().at("predicate");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["predicate"] = std::move(attr);
- copy_attribute("Targuments", first_filter_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, second_filter_node, &fused_node);
+ graph_utils::CopyAttribute(key, second_filter_node, &fused_node);
return fused_node;
}
@@ -120,8 +115,8 @@ Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// functions, or make sure that optimization passes run after filter
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(first_filter_node->name());
nodes_to_delete.insert(second_filter_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
index 12b1924efd..c8becc5cc0 100644
--- a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,14 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
index e95ea1a4c1..311df15bc2 100644
--- a/tensorflow/core/grappler/optimizers/data/function_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -14,31 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace grappler {
namespace function_utils {
-namespace {
-
-template <typename Predicate, typename Collection>
-std::vector<int> GetElementIndicesWithPredicate(const Predicate& predicate,
- const Collection& collection) {
- std::vector<int> indices = {};
- unsigned idx = 0;
- for (auto&& element : collection) {
- if (predicate(element)) {
- indices.push_back(idx);
- }
- idx++;
- }
- return indices;
-}
-
-} // namespace
FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
const string& output, int position)
@@ -152,32 +137,27 @@ bool ContainsFunctionOutputWithName(StringPiece name,
}
int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().input_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
function.signature().output_arg());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
function.node_def());
- return indices.empty() ? -1 : indices.front();
}
int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return graph_utils::GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; },
function.node_def());
-
- return indices.empty() ? -1 : indices.front();
}
void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
new file mode 100644
index 0000000000..b2eec7220e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}});
+}
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
new file mode 100644
index 0000000000..ca0fde997d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "XTimesTwo");
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "IsZero");
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 2dd9ee822e..b863a25dc5 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -201,25 +202,22 @@ bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const FunctionDef& function) {
return function.signature().name() == name;
},
library.function());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
- return indices.empty() ? -1 : indices.front();
}
int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
- return indices.empty() ? -1 : indices.front();
}
std::vector<int> FindAllGraphNodesWithOp(const string& op,
@@ -260,6 +258,41 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
}
function->mutable_signature()->set_name(std::move(name));
}
+
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node) {
+ (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+}
+
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node) {
+ CopyAttribute(attribute_name, first, to_node);
+ (*to_node->mutable_attr())
+ .at(attribute_name)
+ .mutable_list()
+ ->MergeFrom(second.attr().at(attribute_name).list());
+}
+
+Status EnsureNodeNamesUnique(Graph* g) {
+ // Modeled after Scope::Impl::GetUniqueName
+ std::unordered_map<string, int> name_map;
+
+ for (auto node : g->op_nodes()) {
+ const string& prefix = node->name();
+ if (auto entry = gtl::FindOrNull(name_map, prefix)) {
+ string unique_name;
+ do {
+ unique_name = strings::StrCat(prefix, "_", ++(*entry));
+ } while (name_map.find(unique_name) != name_map.end());
+ name_map.insert({unique_name, 0});
+ node->set_name(std::move(unique_name));
+ } else {
+ name_map.insert({node->name(), 0});
+ }
+ }
+
+ return Status::OK();
+}
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index b117482db2..d130fee204 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -31,6 +32,21 @@ namespace tensorflow {
namespace grappler {
namespace graph_utils {
+// Returns the index of the first element in collection that fulfills predicate.
+// If no such element exists, returns -1.
+template <typename Predicate, typename Collection>
+int GetFirstElementIndexWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ return idx;
+ }
+ idx++;
+ }
+ return -1;
+}
+
// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
@@ -101,11 +117,29 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
+// Copies attribute having name `attribute_name` from node `from` to node
+// `to_node`.
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node);
+
+// Concatenates list attribute having name `attribute_name` from `first` and
+// `second` node, setting it to `to_node`.
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node);
+
+// Checks that all nodes in the graphs have unique names, and sets their names
+// to be unique if they are not already. This is necessary as Graph does not
+// have the provisions to deduplicate names, and name deduplication elsewhere
+// in tensorflow happens in other layers (for example, in the Scope class of the
+// C++ API). Note that the nodes in the graph are identified by their id,
+// and renaming nodes does not mutate any edges.
+Status EnsureNodeNamesUnique(Graph* g);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 6877c207c4..4ab6d71532 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -24,6 +25,18 @@ namespace grappler {
namespace graph_utils {
namespace {
+TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
+ std::vector<int> vec({1, 2, 3, 4, 5, 6});
+ auto result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 3 == 0; }, vec);
+
+ EXPECT_EQ(result, 2);
+
+ result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 7 == 0; }, vec);
+ EXPECT_EQ(result, -1);
+}
+
TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -217,6 +230,33 @@ TEST(GraphUtilsTest, GetInputNode) {
EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
}
+TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
+ Graph g(OpRegistry::Global());
+
+ Node *const_0, *const_1, *const_2;
+
+ // Arbitrary const
+ Tensor tensor(DT_INT32, {});
+ tensor.scalar<int32>()() = 5;
+
+ for (auto node : {&const_0, &const_1}) {
+ TF_EXPECT_OK(NodeBuilder("Const", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, node));
+ }
+ // Make sure generated name doesn't clash with existing name either
+ TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, &const_2));
+
+ TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
+ EXPECT_NE(const_0->name(), const_1->name());
+ EXPECT_NE(const_1->name(), const_2->name());
+ EXPECT_NE(const_0->name(), const_2->name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
new file mode 100644
index 0000000000..ce0b2db039
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
+ const FunctionDef& stateless_function,
+ MutableGraphView* graph) {
+ NodeDef stateless_map;
+ graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
+ &stateless_map);
+
+ stateless_map.set_op("MapDataset");
+ stateless_map.add_input(zip_node.name());
+ // Add placeholders.
+ for (int i = 1; i < map_node.input_size(); i++)
+ stateless_map.add_input(map_node.input(i));
+
+ auto attr = map_node.attr().at("f");
+ *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
+ *attr.mutable_func()->mutable_attr() = stateless_function.attr();
+ (*stateless_map.mutable_attr())["f"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, map_node, &stateless_map);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
+
+ return stateless_map;
+}
+
+NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
+ MutableGraphView* graph) {
+ NodeDef random_dataset;
+ random_dataset.set_op("RandomDataset");
+ graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
+ &random_dataset);
+
+ const auto* seed = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed").i(), graph);
+ const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed2").i(), graph);
+
+ random_dataset.add_input(seed->name());
+ random_dataset.add_input(seed2->name());
+
+ (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
+ (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return random_dataset;
+}
+
+NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
+ NodeDef batch_dataset;
+ batch_dataset.set_op("BatchDatasetV2");
+ graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
+ &batch_dataset);
+ const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
+ const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
+ batch_dataset.add_input(random_dataset.name());
+ batch_dataset.add_input(batch_size->name());
+ batch_dataset.add_input(drop_reminder->name());
+
+ (*batch_dataset.mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape()
+ ->mutable_dim()
+ ->Add()
+ ->set_size(-1);
+ (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return batch_dataset;
+}
+
+NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
+ MutableGraphView* graph) {
+ NodeDef zip_node;
+ graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
+ &zip_node);
+
+ zip_node.set_op("ZipDataset");
+ zip_node.add_input(first_node.name());
+ zip_node.add_input(second_node.name());
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
+
+ (*zip_node.mutable_attr())["N"].set_i(2);
+
+ return zip_node;
+}
+
+// We need to insert our argument before the placeholders, which are the last
+// arguments.
+OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
+ int new_argument_idx = signature->input_arg_size() - num_placeholders;
+ signature->add_input_arg();
+ for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
+ signature->mutable_input_arg()->SwapElements(i - 1, i);
+ }
+ auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
+ seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
+ seed_arg->set_type(DT_INT64);
+
+ return seed_arg;
+}
+
+// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
+// to make it less statefull. The function can still be stateful, but in when
+// other stateful ops are e.g. `Assert`, then it will be parallelizable.
+const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
+ bool is_stateful,
+ int num_placeholders,
+ FunctionDefLibrary* library) {
+ FunctionDef* stateless_function = library->add_function();
+ *stateless_function = map_function;
+ if (is_stateful)
+ stateless_function->mutable_signature()->set_is_stateful(is_stateful);
+ graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
+ stateless_function);
+
+ auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
+ num_placeholders);
+
+ auto* const random_uniform = stateless_function->mutable_node_def(
+ function_utils::FindFunctionNodeWithOp("RandomUniform",
+ *stateless_function));
+
+ // Replace RandomUniform node with StatelessRandomUniform.
+ random_uniform->set_op("StatelessRandomUniform");
+ random_uniform->add_input(seed_arg->name());
+ (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
+ random_uniform->mutable_attr()->erase("seed");
+ random_uniform->mutable_attr()->erase("seed2");
+
+ return stateless_function;
+}
+// This function returns true if function is stateful and has single
+// RandomUniform op and no other stateful ops except Assert.
+// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
+// stateful op and hoisting can be performed.
+bool CanHoistRandomUniform(const FunctionDef& map_function,
+ const FunctionLibraryDefinition& library,
+ bool* is_stateful_after_hoisting,
+ const NodeDef** random_uniform_op) {
+ if (!map_function.signature().is_stateful()) return false;
+ *is_stateful_after_hoisting = true;
+
+ bool have_other_stateful_ops = false;
+
+ for (const auto& node : map_function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Skip stateless nodes and assert, as it does not actually have a state.
+ if (!op_def->is_stateful()) continue;
+
+ if (op_def->name() == "Assert") {
+ have_other_stateful_ops = true;
+ continue;
+ }
+
+ // TODO(prazek): For now we only handle RandomUniform, we should handle
+ // RandomUniformInt as well.
+ if (op_def->name() != "RandomUniform") return false;
+
+ // TODO(prazek): For now we can only hoist single RandomUniform.
+ if (*random_uniform_op != nullptr) return false;
+
+ *random_uniform_op = &node;
+ }
+
+ if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
+
+ // Have we found single RandomUniform?
+ return *random_uniform_op != nullptr;
+}
+
+int NumberOfPlaceholders(const NodeDef& map_node) {
+ // First input of MapDataset is the argument to the function. Rest of the
+ // inputs are placeholders.
+ return map_node.input_size() - 1;
+}
+
+} // namespace
+
+Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ // TODO(prazek): we could also handle ParallelMapDataset and
+ // MapAndBatchDataset.
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ const auto& fun = map_node->attr().at("f");
+ const FunctionDef* func = function_library.Find(fun.func().name());
+
+ const NodeDef* random_uniform_op = nullptr;
+ bool is_stateful_after_hoisting = true;
+ if (!CanHoistRandomUniform(*func, function_library,
+ &is_stateful_after_hoisting, &random_uniform_op))
+ continue;
+ const auto* random_seed_dataset =
+ graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
+
+ const auto* batch_dataset =
+ graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
+
+ const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
+
+ const auto* zip_node =
+ graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
+
+ const auto* stateless_func = MakeLessStatefulFunction(
+ *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
+ output->mutable_library());
+
+ const auto* stateless_map = graph.AddNode(
+ MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
+
+ graph.ReplaceInput(*map_node, *stateless_map);
+
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
new file mode 100644
index 0000000000..d1bcf6782d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization hoists instances of `random_uniform` out of a function
+// with the aim of making it stateless. It creates a new function that takes a
+// random seed as an extra argument and uses `stateless_random_uniform` instead
+// of `random_uniform` to make it stateless.
+// It also creates RandomDataset(seed).batch(2), which is zipped with old input
+// to the map. The batching in RandomDataset is because we need 2 seeds for
+// `stateless_random_uniform`.
+// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
+// `RandomUniformInt` similarly.
+class HoistRandomUniform : public CustomGraphOptimizer {
+ public:
+ HoistRandomUniform() = default;
+ ~HoistRandomUniform() override = default;
+
+ string name() const override { return "hoist_random_uniform"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
new file mode 100644
index 0000000000..455459e3f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(HoistRandomUniform, SimpleHoisting) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}}),
+ graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
+ NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::RandomUniform(),
+ });
+
+ HoistRandomUniform optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
+ const int zip_dataset_id =
+ graph_utils::FindGraphNodeWithOp("ZipDataset", output);
+ const int random_dataset_id =
+ graph_utils::FindGraphNodeWithOp("RandomDataset", output);
+ const int batch_random_id =
+ graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
+ ASSERT_NE(random_dataset_id, -1);
+ ASSERT_NE(zip_dataset_id, -1);
+ ASSERT_NE(new_map_id, -1);
+ ASSERT_NE(batch_random_id, -1);
+
+ const auto& new_map = output.node(new_map_id);
+ const auto& zip = output.node(zip_dataset_id);
+ const auto& random = output.node(random_dataset_id);
+ const auto& batch = output.node(batch_random_id);
+
+ ASSERT_EQ(new_map.input_size(), 1);
+ EXPECT_EQ(new_map.input(0), zip.name());
+
+ ASSERT_EQ(zip.input_size(), 2);
+ EXPECT_EQ(zip.input(0), "range");
+ EXPECT_EQ(zip.input(1), batch.name());
+
+ ASSERT_EQ(batch.input_size(), 3);
+ EXPECT_EQ(batch.input(0), random.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 63945b8b9e..e66766eb23 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -80,11 +80,12 @@ NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
// Set `f` and `Targuments` attributes.
for (auto key : {"f", "Targuments"}) {
- (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ graph_utils::CopyAttribute(key, map_node, &new_node);
}
+
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ graph_utils::CopyAttribute(key, batch_node, &new_node);
}
return new_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index f1844a141c..c4868eacbb 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = map_node.attr().at("f");
attr.mutable_func()->set_name(fused_function.signature().name());
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", map_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr;
// Add the predicate output attributes.
(*fused_node.mutable_attr())["output_types"]
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index f029a093fa..6e6da37d7c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,24 +28,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
-
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {string(input_node_name)},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
+using graph_tests_utils::MakeMapNode;
TEST(MapAndFilterFusionTest, FuseMapAndFilter) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index a78ecb09f7..bd943342e8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
-
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = parent_map_node.attr().at("f");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", parent_map_node, &fused_node);
-
+ graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+ auto value_or_false = [](const AttrValue* attr) {
+ if (!attr) return false;
+ return attr->b();
+ };
+
+ const auto* first_parallelism =
+ gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism");
+ const auto* second_parallelism =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism");
+ // Some graphs cannot execute with use_inter_op_parallelism=False, so we need
+ // to set it to true if one of the ops have it set to true.
+ if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) {
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+ }
return fused_node;
}
@@ -123,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(parent_map_node->name());
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index b25dfbd0b8..8889f9dddd 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -28,14 +29,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeMapNode;
TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
index 305325e434..782c9f48b7 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -84,9 +84,6 @@ Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
graph.ReplaceInput(*map_node, *parallel_map);
-
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
index b2a5d9b6af..9fdfe8af30 100644
--- a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -28,16 +28,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
- StringPiece function_name) {
- return test::function::NDef(
- name, "MapDataset", {string(input_node_name)},
- {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
+using graph_tests_utils::MakeMapNode;
const char stateless_fun_name[] = "XTimesTwo";
const char stateful_fun_name[] = "RandomUniform";
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 7a2f1910da..a9254ed58b 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -35,10 +35,6 @@ namespace tensorflow {
namespace grappler {
namespace {
-void CopyAttribute(const string& attr_name, const NodeDef& from, NodeDef* to) {
- (*to->mutable_attr())[attr_name] = from.attr().at(attr_name);
-}
-
// Returns a FunctionDef containing a MapDefun op that wraps the original
// function.
FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
@@ -48,7 +44,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
// Function inputs and outputs are the same as original, just
// with different shapes.
*vectorized_func->mutable_signature() = orig_func.signature();
- graph_utils::SetUniqueGraphFunctionName("vectorized_function", library,
+ graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library,
vectorized_func);
// Add MapDefun node
@@ -61,7 +57,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
for (const string& k : {"f", "output_types", "output_shapes"}) {
// Function, output types and (unbatched) shapes are the same as the
// original map node.
- CopyAttribute(k, map_node, map_defun_node);
+ graph_utils::CopyAttribute(k, map_node, map_defun_node);
}
// Get types of input arguments from original map function
@@ -71,6 +67,7 @@ FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
map_defun_node->add_input(input.name());
}
(*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
// Set return values to match output names
string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
@@ -90,21 +87,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
- NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
- DCHECK_EQ(map_defun_node->op(), "MapDefun");
-
- // Create a copy of the original function so that we can mutate it, and
- // attach that to the map defun node.
- FunctionDef* map_defun_fn = library->add_function();
- *map_defun_fn = orig_func;
- graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
- map_defun_fn);
- (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
- map_defun_fn->signature().name());
-
- vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
- map_defun_node);
- return vectorized_func;
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
@@ -195,13 +190,16 @@ NodeDef MakeNewMapNode(const NodeDef& old_map_node,
}
// Set attrs
- CopyAttribute("Targuments", old_map_node, &map_node);
+ graph_utils::CopyAttribute("Targuments", old_map_node, &map_node);
auto& func_attr = (*map_node.mutable_attr())["f"];
func_attr.mutable_func()->set_name(vectorized_func.signature().name());
for (auto key : {"output_shapes", "output_types"}) {
- CopyAttribute(key, old_batch_node, &map_node);
+ graph_utils::CopyAttribute(key, old_batch_node, &map_node);
}
+
+ (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+
return map_node;
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index ed1bd6bc97..f4faf41549 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -30,72 +30,51 @@ namespace {
using test::function::GDef;
using test::function::NDef;
-void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
- TensorShapeProto* t) {
- for (size_t i = 0; i < dims.size(); ++i) {
- auto* d = t->add_dim();
- d->set_size(dims[i]);
- }
-}
-
-AttrValue MakeShapeListAttr(
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
- AttrValue shapes_attr;
- for (size_t i = 0; i < shapes.size(); ++i) {
- MakeTensorShapeProtoHelper(shapes[i],
- shapes_attr.mutable_list()->add_shape());
- }
-
- return shapes_attr;
-}
-
-NodeDef MakeMapNodeHelper(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- StringPiece map_op_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return test::function::NDef(
name, map_op_name, {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
- {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_shapes", output_shapes},
{"output_types", output_types}});
}
-NodeDef MakeMapNode(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
-NodeDef MakeBatchNode(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDataset",
- {string(input_node_name), string(input_batch_size_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeBatchV2Node(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDatasetV2",
- {string(input_node_name), string(input_batch_size_name),
- string(input_drop_remainder_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
return NDef(name, "RangeDataset", inputs,
- {{"output_shapes", MakeShapeListAttr({{}})},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
@@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
- {{"output_shapes", MakeShapeListAttr({{}})}}),
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
@@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index cb0ff670e8..99c4afa634 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
+ graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 1462cb234d..985d6c6c3a 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,13 +9,24 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/core/grappler/optimizers/data:function_utils",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
] + tf_protos_all()
cc_library(
+ name = "wrapped_tensor",
+ hdrs = ["wrapped_tensor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ],
+)
+
+cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ ":wrapped_tensor",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index c1739737a0..f445157531 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -14,41 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class CastVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
}
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index, new_cast_node,
+ 0);
+ // Add output mappings
+ outputs->push_back({new_cast_node, 0, true});
return Status::OK();
}
};
REGISTER_VECTORIZER("Cast", CastVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 776d3179c5..f1ba741821 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -14,40 +14,38 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
+namespace {
class UnpackVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
+ Status s;
+ if (node.num_inputs() != 1 || inputs.size() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_unpack_node);
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
// Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
+ outer_scope->AddEdge(inputs[0].node, inputs[0].output_index,
+ new_unpack_node, 0);
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
+ outputs->push_back({new_unpack_node, i, true});
}
return Status::OK();
@@ -56,6 +54,6 @@ class UnpackVectorizer : public Vectorizer {
REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
-} // namespace vectorization_utils
+} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index d341dbba7d..8d4676aae0 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -17,13 +17,13 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
@@ -31,19 +31,19 @@ class Vectorizer {
public:
virtual ~Vectorizer() {}
- // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs, and adding mappings to `conversion_map`
- // from old output tensor names to new (vectorized) output tensor names.
- // The new node(s) collectively have the same number of inputs and outputs as
- // the node being converted, and use the tensor names in `inputs` as their
- // inputs.
- virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) = 0;
+ // on elements of `inputs`. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds edges between the newly created nodes and nodes in `inputs`, and adds
+ // mappings to the new nodes' output ports to `outputs`, where the i'th
+ // value in `outputs` corresponds to the i'th output port of the node
+ // to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) = 0;
};
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
index a6551e36ac..e1cf77a7d5 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
VectorizerRegistry* VectorizerRegistry::Global() {
static VectorizerRegistry* registry = new VectorizerRegistry;
@@ -42,6 +41,5 @@ void VectorizerRegistry::Register(const string& op_type,
vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
op_type, std::move(vectorizer)));
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
index 16159d47ca..ad54c74933 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -23,7 +23,6 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
// A global VectorizerRegistry is used to hold all the vectorizers.
class VectorizerRegistry {
@@ -59,16 +58,12 @@ class VectorizerRegistration {
#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
-#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
- static ::tensorflow::grappler::vectorization_utils:: \
- vectorizer_registration::VectorizerRegistration \
- vectorizer_registration_##ctr( \
- op_type, \
- ::std::unique_ptr< \
- ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
- new vectorizer()))
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorizer_registration:: \
+ VectorizerRegistration vectorizer_registration_##ctr( \
+ op_type, ::std::unique_ptr<::tensorflow::grappler::Vectorizer>( \
+ new vectorizer()))
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 86e303564b..054aeb9a8f 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -20,13 +20,12 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<WrappedTensor>&& inputs,
+ std::vector<WrappedTensor>* outputs) override {
return Status::OK();
}
};
@@ -39,12 +38,14 @@ TEST(TestVectorizer, TestTestVectorizer) {
auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
EXPECT_NE(vectorizer, nullptr);
- FunctionDef function;
- NodeDef node;
- std::map<string, string> conversion_map;
- EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<WrappedTensor> inputs, outputs;
+ EXPECT_TRUE(
+ vectorizer->Vectorize(*node, &g, std::move(inputs), &outputs).ok());
}
-} // namespace vectorization_utils
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
new file mode 100644
index 0000000000..4439b4ab4e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/wrapped_tensor.h
@@ -0,0 +1,44 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
+
+#include "tensorflow/core/graph/graph.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// Represents a tensor that has been vectorized.
+struct WrappedTensor {
+ Node* const node;
+ const int output_index;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(Node* node, int output_index, bool stacked)
+ : node(node), output_index(output_index), stacked(stacked) {}
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_WRAPPED_TENSOR_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..ba857ab5d9 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -17,274 +17,588 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/lib/gtl/map_util.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
+
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs and adds mappings from
+ // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
+ // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
+ // inputs to `map_defun_node_`. This stacked tensor will be compatible with
+ // the expected output shape of `map_defun_node_`.
+ // This is equivalent to the _stack function in python Pfor.
+ Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
+
+ // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
+ // doing a depth-first search from the ret nodes. Lifts nodes that are
+ // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
+ // and add mappings to `conversion_map_`.
+ Status AddUnstackedNodeMappings();
+
+ // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
+ // is unstacked.
+ bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);
+
+ // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
+ // nodes to `conversion_map_`.
+ Status AddArgNodeMappings();
+
+ // Maps a tensor to the corresponding WrappedTensor. For example,
+ // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
+ std::map<TensorDesc, WrappedTensor> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+
+ // Caches the loop_len_node_ needed for tiling unstacked output. This
+ // corresponds to a vector with one element.
+ Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ op_node->type_string());
+ }
+ std::vector<WrappedTensor> inputs, outputs;
+ inputs.reserve(op_node->num_inputs());
+ outputs.reserve(op_node->num_outputs());
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ // The inputs for the node to be converted may already have been converted
+ // themselves. For those that are not, we promote them to MapDefun outputs.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ inputs.push_back(*found);
+ } else {
+ // TODO(rachelim): Handle the case where unconverted inputs are unstacked.
+ // We assume that all unconverted inputs will be stacked, since we
+ // converted all unstacked nodes in `Initialize`. However, it's actually
+ // possible that yet-unconverted nodes may produce unstacked outputs after
+ // they are vectorized. (For example, see the "Shape" converter in
+ // tensorflow/python/ops/parallel_for/pfor.py). If a vectorizer expects
+ // an unstacked input but receives a stacked one, vectorizer->Vectorize
+ // will return an error.
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ int output_index = map_defun_fn_->ret_nodes.size() - 1;
+ inputs.push_back({map_defun_node_, output_index, true});
+ }
+ }
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ std::move(inputs), &outputs));
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
+ if (op_node->num_outputs() != outputs.size()) {
+ return errors::Internal(
+ "Number of vectorizer outputs does not match. Expected: ",
+ op_node->num_outputs(), " Actual: ", outputs.size());
}
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
- if (vectorizer == nullptr) {
- return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, outputs[i]});
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ auto found = gtl::FindOrNull(conversion_map_, output);
+ if (!found) {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ found = &conversion_map_.at(output);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ if (found->stacked) {
+ converted_output = {found->node, found->output_index};
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
}
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
+ } else {
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
+ }
+}
+
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
+
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
}
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ TF_RETURN_IF_ERROR(AddArgNodeMappings());
+
+ TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
+ loop_len_node_ = nullptr;
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+// TODO(rachelim): It might be profitable to use the C++ API for this instead of
+// NodeBuilder
+Status Vectorization::StackTensor(WrappedTensor* unstacked,
+ TensorDesc* result) {
+ // Note that all these nodes are necessary as the size of the batch may not be
+ // constant.
+ if (unstacked->stacked) {
+ return errors::Internal("Can only stack unstacked tensor.");
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
- } else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ Graph* g = outer_scope_.get();
+ auto node_builder = [](StringPiece op) {
+ return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
+ };
+
+ auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
+ Node** result) {
+ TF_RETURN_IF_ERROR(val.status);
+ return node_builder("Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype())
+ .Finalize(graph, result);
+ };
+
+ // If loop_len_node_ hasn't been created yet, add the node and cache it.
+ if (loop_len_node_ == nullptr) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
+
+ Node* shape_node;
+ TF_RETURN_IF_ERROR(
+ node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
+
+ Node* const_vec_0;
+ TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
+ Node* const_vec_1;
+ TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
+
+ Node* strided_slice_node;
+ TF_RETURN_IF_ERROR(node_builder("StridedSlice")
+ .Input(shape_node) // input
+ .Input(const_vec_0) // begin
+ .Input(const_vec_1) // end
+ .Input(const_vec_1) // strides
+ .Finalize(g, &strided_slice_node));
+
+ // Produces a vector of length 1
+ TF_RETURN_IF_ERROR(node_builder("Reshape")
+ .Input(strided_slice_node) // tensor
+ .Input(const_vec_1) // shape
+ .Finalize(g, &loop_len_node_));
}
- *converted = conversion_map_.at(output_desc.full_str);
+
+ Node* ones_shape;
+ TF_RETURN_IF_ERROR(node_builder("Shape")
+ .Input(unstacked->node) // input
+ .Finalize(g, &ones_shape));
+
+ Node* ones;
+ TF_RETURN_IF_ERROR(
+ node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
+
+ Node* const_0;
+ TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
+
+ Node* multiples;
+ TF_RETURN_IF_ERROR(node_builder("Concat")
+ .Input(const_0) // concat_dim
+ .Input({{loop_len_node_, 0}, {ones, 0}}) // values
+ .Finalize(g, &multiples));
+
+ Node* expand_dims;
+ TF_RETURN_IF_ERROR(node_builder("ExpandDims")
+ .Input(unstacked->node) // input
+ .Input(const_0) // dim
+ .Finalize(g, &expand_dims));
+
+ TF_RETURN_IF_ERROR(node_builder("Tile")
+ .Input(expand_dims) // input
+ .Input(multiples) // multiples
+ .Finalize(g, &result->first));
+ result->second = 0;
return Status::OK();
}
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+Status Vectorization::AddArgNodeMappings() {
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0, true}});
+ // Control inputs
+ conversion_map_.insert({{arg_node, Graph::kControlSlot},
+ {input_node, Graph::kControlSlot, true}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
+ Status* status) {
+ if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
+ return !found->stacked;
+ }
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
+ if (tensor.first->op_def().is_stateful()) {
+ // We don't lift stateful nodes directly out of the MapDefun, since they may
+ // have to be executed N times.
+ return false;
+ }
+
+ bool is_unstacked = true;
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ // A node is unstacked if all of its inputs are unstacked
+ is_unstacked &= AddUnstackedNodeMappingsHelper(
+ {edge->src(), edge->src_output()}, status);
+ }
+
+ if (!is_unstacked) {
+ return false;
+ }
+
+ // If the node is unstacked, we copy it into outer_scope_ and
+ // add it to the map. Note that we don't clean up the nodes that are copied
+ // in map_defun_fn_, and rely on them being pruned out later.
+ Node* node = outer_scope_->AddNode(tensor.first->def(), status);
+ if (!status->ok()) return true;
+
+ // Add input edges to nodes that should already have been lifted.
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ outer_scope_->AddEdge(found->node, found->output_index, node,
+ edge->dst_input());
+ } else {
+ status->Update(errors::Internal(
+ "Could not find input conversion even though we did depth first "
+ "conversion."));
}
}
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ // Add output mappings
+ for (int i = 0; i < tensor.first->num_outputs(); ++i) {
+ conversion_map_.insert({{tensor.first, i}, WrappedTensor(node, i, false)});
+ }
+ conversion_map_.insert({{tensor.first, Graph::kControlSlot},
+ WrappedTensor(node, Graph::kControlSlot, false)});
+
+ return true;
+}
+
+Status Vectorization::AddUnstackedNodeMappings() {
+ SetVector<Node*> unstacked_nodes;
+ Status s;
+ for (const auto& ret_node : map_defun_fn_->ret_nodes) {
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
+ AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
+ TF_RETURN_IF_ERROR(s);
}
+ return Status::OK();
+}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
+
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
+
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
+
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
index bb405faa77..bd7d390900 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -24,22 +24,28 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-// Given a function, `map_defun_fn`, that is mapped across some input vector
-// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
-// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
-// `outer_scope`; that is, replacing `map_defun_fn` operations with new
-// `outer_scope` operations that produce the same vector output(s) as executing
-// the `map_defun_fn` operations on elements of vector input(s) would. If all
-// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
-// eliminated from `outer_scope` altogether. However, if some operations cannot
-// be lifted, and this vectorization only succeeds partially, `map_defun_node`
-// remains to be used for operations that were not lifted.
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
//
// Example:
// If the input to the `VectorizeMapDefun` function is a MapDefun
// whose `map_defun_fn` performs the Cast operation, the vectorization will
// eliminate the MapDefun. This is because the Cast operation supports
-// any tensor shape and can thus be lifted to the `outer_scope`.
+// any tensor shape and can thus be lifted to `result`.
//
// Before:
//
@@ -68,7 +74,7 @@ namespace vectorization_utils {
//
// After:
//
-// outer_scope +------+
+// result +------+
// +---------------+ Arg0 +---------+
// | +---+--+ |
// | | |
@@ -80,8 +86,9 @@ namespace vectorization_utils {
// +---------------+ Ret0 +---------+
// +------+
//
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node);
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
} // end namespace vectorization_utils
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..a6020e36bb 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -54,12 +55,18 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
func.set_name(function_name);
NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
graph_transforms::SetNodeAttr("output_types", output_types, node);
graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
graph_transforms::SetNodeAttr("f", func, node);
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +92,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +139,17 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +162,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +178,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +233,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +307,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +385,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +471,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +551,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +570,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +631,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +650,278 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeConst) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Const:output:0"}});
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Const", *vectorized));
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node.name());
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | +------+ +------+ | |
+// | | |Const | |Const | | |
+// | | +---+--+ +---+--+ | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | |
+// | +------+ |
+// | +------+ |Const | |
+// | |Const | +---+--+ |
+// | +---+--+ | |
+// | : +---v--+ |
+// | ::::::> Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +Stack*+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ FunctionDefHelper::Const("ConstDep", 3)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64,
+ false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto find_const = [vectorized](int val) -> const NodeDef* {
+ for (const auto& n : vectorized->node_def()) {
+ if (n.attr().at("value").tensor().int_val(0) == val) {
+ return &n;
+ }
+ }
+ return nullptr;
+ };
+
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = find_const(2);
+ auto const_dep_node = find_const(3);
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node->name());
+ EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
}
// TODO(rachelim): More test cases when we get around to implementing them:
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index c59645e5f2..3f33b16ba8 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -106,7 +107,8 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
- MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
+ MK_OPT("pin_to_host",
+ new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -115,6 +117,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (cfg_.disable_meta_optimizer()) {
+ return Status::OK();
+ }
if (!cfg_.disable_model_pruning()) {
optimizers->push_back(MakeUnique<ModelPruner>());
}
@@ -172,11 +177,12 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ std::set<string> initialized_custom_optimizers;
for (const string& optimizer_name : cfg_.optimizers()) {
auto optimizer = MakeNewOptimizer(optimizer_name);
if (optimizer) {
@@ -190,18 +196,26 @@ Status MetaOptimizer::InitializeOptimizersByName(
if (custom_optimizer) {
VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
- TF_RETURN_IF_ERROR(custom_optimizer->Init());
+ TF_RETURN_IF_ERROR(custom_optimizer->Init(
+ GetCustomGraphOptimizerConfig(optimizer_name)));
optimizers->push_back(std::move(custom_optimizer));
+ initialized_custom_optimizers.insert(optimizer_name);
} else {
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
- return InitializeCustomGraphOptimizers(optimizers);
+ return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
+ optimizers);
}
Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
+ if (pre_initialized_optimizers.find(optimizer_config.name()) !=
+ pre_initialized_optimizers.end()) {
+ continue;
+ }
// Initialize the ExperimentalImplementationSelector here instead of
// CustomizeOptimizer registry, due the static link issue in TensorRT for
// double registry.
@@ -237,6 +251,16 @@ Status MetaOptimizer::InitializeCustomGraphOptimizers(
return Status::OK();
}
+const RewriterConfig::CustomGraphOptimizer*
+MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
+ for (const auto& config : cfg_.custom_optimizers()) {
+ if (config.name() == name) {
+ return &config;
+ }
+ }
+ return nullptr;
+}
+
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
@@ -391,6 +415,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
+ // Find functions for which we might need to compute a gradient at runtime.
+ gtl::FlatSet<string> differentiable_functions;
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (IsSymbolicGradient(node)) {
+ const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
+ if (f_attr) differentiable_functions.insert(f_attr->func().name());
+ }
+ }
+
// Optimize each function only once.
std::unordered_set<string> optimized_funcs;
bool optimize_function_library = true;
@@ -406,6 +439,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Skip parametrized functions (function type or body is defined only at
// function call time by caller node attributes).
+ // They should be specialized to their instantiation type parameters by
+ // the function optimizer, before we can optimize function body.
if (IsParametrized(func)) continue;
VLOG(3) << "Optimize function: function=" << func_name;
@@ -420,6 +455,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
func, flib, item.graph.versions().producer(), &func_item));
+ // If we need to compute the gradient of optimized function at runtime, we
+ // can't perform non-differentiable rewrites.
+ if (differentiable_functions.find(func_name) !=
+ differentiable_functions.end()) {
+ func_item.allowed_optimizations.non_differentiable_rewrites = false;
+ }
+
// Optimize function body graph.
GraphDef optimized_func_graph;
TF_RETURN_IF_ERROR(
@@ -470,6 +512,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
+ if (cfg.disable_meta_optimizer()) {
+ return false;
+ }
return !cfg.disable_model_pruning() ||
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 831c5e37c0..99a0a33ffa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -54,7 +54,11 @@ class MetaOptimizer : public GraphOptimizer {
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
// Initialize active optimizers from RewriterConfig.custom_optimizers.
Status InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Returns the config for a custom graph optimizer. Null if none was found.
+ const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig(
+ const string& name) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index e74e0f7501..3f3f43382f 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -71,6 +72,59 @@ class TestGraphOptimizer : public TestOptimizer {
REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+class TestOptimizerWithParams : public TestOptimizer {
+ public:
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ CHECK(config != nullptr);
+ return Status::OK();
+ }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
+
+// Record various properties of the GrapplerItems passed for optimization.
+class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
+ public:
+ static void SetAllowedOptimizations(
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations) {
+ allowed_optimizations_ = allowed_optimizations;
+ }
+ static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; }
+
+ GrapplerItemPropertiesAccumulator() {}
+ string name() const override {
+ return "grappler_item_properties_accumulator";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ if (allowed_optimizations_) {
+ allowed_optimizations_->insert({item.id, item.allowed_optimizations});
+ }
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations_;
+};
+
+gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ GrapplerItemPropertiesAccumulator::allowed_optimizations_;
+
+REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -90,6 +144,25 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizerWithParams");
+ auto* custom_config = rewriter_config.add_custom_optimizers();
+ custom_config->set_name("TestOptimizerWithParams");
+ (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -305,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
}
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
+ using test::function::NDef;
+ using FDH = FunctionDefHelper;
+
+ // We will record what type of optimizations meta optimizer allows for each
+ // GrapplerItem (main graph and graphs for each function).
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>
+ allowed_optimizations;
+ GrapplerItemPropertiesAccumulator::SetAllowedOptimizations(
+ &allowed_optimizations);
+
+ // Just record properties of optimized Grappler items.
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+
+ // Define simple function library with two identical mul functions.
+ FunctionDef mul_func_1 = FunctionDefHelper::Create(
+ "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ FunctionDef mul_func_2 = FunctionDefHelper::Create(
+ "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ // Tensorflow graph:
+ //
+ // x0 = tf.Placeholder(tf.float);
+ // x1 = tf.Placeholder(tf.float);
+ // dy = tf.Placeholder(tf.float);
+ //
+ // mul_1 = MyMul1(x0, x1);
+ // mul_2 = MyMul2(x0, x1);
+ // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
+ GrapplerItem item;
+ item.id = "main";
+ item.graph = test::function::GDef(
+ {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ // Calls into function library
+ NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
+ NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
+ // Symbolic gradient of a MyMul2
+ NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
+ {{"f", FDH::FunctionRef("MyMul2", {})},
+ {"Tin", DataTypeSlice{DT_FLOAT}},
+ {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
+ kDevice)},
+ // FunctionLib
+ {mul_func_1, mul_func_2});
+ item.fetch = {"mul_1", "mul_2", "dx"};
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Our custom optimizer must be called for the main graph and for the two
+ // functions.
+ ASSERT_EQ(allowed_optimizations.size(), 3);
+
+ auto allowed_optimizations_main =
+ gtl::FindOrNull(allowed_optimizations, "main");
+ ASSERT_NE(allowed_optimizations_main, nullptr);
+ EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_1 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul1");
+ ASSERT_NE(allowed_optimizations_my_mul_1, nullptr);
+ EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_2 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul2");
+ ASSERT_NE(allowed_optimizations_my_mul_2, nullptr);
+ EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 2190d38937..29a3b2b74c 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -25,23 +25,67 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
namespace grappler {
+
namespace internal {
+namespace {
// TODO(williamchan): Change this constant to be something smarter, maybe
// dynamically determined.
constexpr int64 kTensorMaxSize = 64;
-// Find KernelDef for `node`.
-Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
- // Try find KernelDef for node.device, else GPU or CPU.
- for (const DeviceType& device :
- {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}) {
- Status s = FindKernelDef(device, node, kdef, nullptr);
+struct OpDevicePortHasher {
+ std::size_t operator()(const std::tuple<string, string, int>& x) const {
+ uint64 code = Hash64Combine(Hash64(std::get<0>(x)), Hash64(std::get<1>(x)));
+
+ return Hash64Combine(code, hash<int>()(std::get<2>(x)));
+ }
+};
+using OpDevicePortOnHostMap =
+ gtl::FlatMap<std::tuple<string, string, int>, bool, OpDevicePortHasher>;
+
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) {
+ return
+ // Collective ops should not be swapped.
+ IsCollective(node) ||
+ // ControlFlow ops should not be swapped.
+ IsControlFlow(node) ||
+ // NoOp ops should not be swapped (due to group dependencies).
+ IsNoOp(node);
+}
+
+// Check if Tensor is integer and small size.
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+// Find KernelDef for `node`, greedily return first found from `devices`.
+Status TryFindKernelDef(const std::vector<DeviceType>& devices,
+ const NodeDef& node, const KernelDef** kdef) {
+ for (const DeviceType& device : devices) {
+ const KernelDef* kernel = nullptr;
+ Status s = FindKernelDef(device, node, &kernel, nullptr);
if (s.ok()) {
+ if (kdef) {
+ *kdef = kernel;
+ }
return Status::OK();
}
}
@@ -49,96 +93,239 @@ Status TryFindKernelDef(const NodeDef& node, const KernelDef** kdef) {
return errors::NotFound("Could not find KernelDef for op: ", node.op());
}
-// Check if all node's inputs are pinned to CPU memory.
-bool AreAllNodeInputsPinnedToHost(const GraphView& graph, const NodeDef& node) {
- // Loop through all the inputs excluding the controlling nodes.
- for (const GraphView::OutputPort& fanin : graph.GetFanins(node, false)) {
- // Check if (the fanin) op's device is on CPU.
- if (str_util::StrContains(fanin.node->device(), DEVICE_CPU)) {
- continue;
- }
-
- // Check if (the fanin) op's output port is pinned to HostMemory.
- const OpDef* fanin_odef = nullptr;
- Status s = OpRegistry::Global()->LookUpOpDef(fanin.node->op(), &fanin_odef);
- if (!s.ok()) {
- LOG(INFO) << "Could not find OpDef for : " << fanin.node->op();
- return false;
- }
+// Checks if a node's output port is host friendly.
+// Roughly this means checking if the output port is on Host memory.
+Status IsNodeOutputPortHostFriendly(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ int port_id, OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
+ *is_candidate = false;
- const int output_arg_id =
- OpOutputPortIdToArgId(*fanin.node, *fanin_odef, fanin.port_id);
- if (output_arg_id < 0) {
- LOG(WARNING) << "Invalid port: " << fanin.port_id << "!\n"
- << node.DebugString() << "\n"
- << fanin.node->DebugString() << "\n"
- << fanin_odef->DebugString();
- return false;
- }
+ // Make sure we are not a blacklisted op.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
- const KernelDef* fanin_kdef = nullptr;
- s = TryFindKernelDef(*fanin.node, &fanin_kdef);
- if (!s.ok()) {
- LOG(INFO) << "Could not find KernelDef for : " << fanin.node->op();
- return false;
- }
+ // Check to make sure we have the right properties (i.e., statically shaped).
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ const auto& output_properties = properties->GetOutputProperties(node.name());
+ if (port_id >= output_properties.size()) {
+ LOG(WARNING) << "port_id=" << port_id
+ << " but output_properties.size()=" << output_properties.size()
+ << "\n"
+ << node.DebugString();
+ return Status::OK();
+ }
+ if (!IsTensorIntegerAndSmall(output_properties[port_id])) {
+ return Status::OK();
+ }
- bool fanin_pinned = false;
- for (const string& host_memory_arg : fanin_kdef->host_memory_arg()) {
- if (fanin_odef->output_arg(output_arg_id).name() == host_memory_arg) {
- fanin_pinned = true;
- break;
+ // These nodes may be optimized away downstream (even if pinned to Host), we
+ // should (recusively) check their source.
+ if (IsIdentity(node)) {
+ for (const auto& fanin : graph.GetFanins(node, false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
}
}
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check if op's device is on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check `op_device_outport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ const std::tuple<string, string, int> cache_key(node.op(), node.device(),
+ port_id);
+ auto it = op_device_outport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_outport_pinned_to_host_cache->end()) {
+ *is_candidate = it->second;
+ return Status::OK();
+ }
+
+ // Check if op's output port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
+ return Status::OK();
+ }
+
+ // Map the port_id to output_arg_id.
+ const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << port_id << "!\n"
+ << node.DebugString() << "\n"
+ << op->DebugString();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
+ return Status::OK();
+ }
- if (!fanin_pinned) {
- return false;
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
+ &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, false);
+ return Status::OK();
+ }
+
+ // Check if the output_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->output_arg(output_arg_id).name() == host_memory_arg) {
+ *is_candidate = true;
+ break;
}
}
- return true;
+ op_device_outport_pinned_to_host_cache->emplace(cache_key, *is_candidate);
+
+ return Status::OK();
}
-bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
- // Check if Tensor is integer and small size.
+// Checks if a node's input port is Host friendly.
+// Roughly this means checking if the input port is on Host memory.
+bool IsNodeInputPortHostFriendly(
+ const NodeDef& node, int port_id,
+ OpDevicePortOnHostMap* op_device_inport_pinned_to_host_cache) {
+ // If node is on Host, assume its inputs are Host friendly.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ return true;
+ }
- // Check type to be int32 or int64.
- if (prop.dtype() != DataType::DT_INT32 &&
- prop.dtype() != DataType::DT_INT64) {
- return false;
+ // Check `op_device_inport_pinned_to_host_cache` for our
+ // {op, device, port_id} combo to see if the arg is pinned on Host.
+ std::tuple<string, string, int> cache_key(node.op(), node.device(), port_id);
+ auto it = op_device_inport_pinned_to_host_cache->find(cache_key);
+ if (it != op_device_inport_pinned_to_host_cache->end()) {
+ return it->second;
}
- // Check size known and small.
- const int64 size = NumCoefficients(prop.shape());
- if (size < 0 || size > kTensorMaxSize) {
+ // Check if op's input port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+ return false;
+ }
+ const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
+
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = internal::TryFindKernelDef(
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
return false;
}
- return true;
+ // Check if the input_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, true);
+ return true;
+ }
+ }
+
+ op_device_inport_pinned_to_host_cache->emplace(cache_key, false);
+
+ return false;
}
-bool AreAllNodeInputsAndOutputsIntsAndSmall(const GraphProperties& properties,
- const NodeDef& node) {
- for (const auto& prop : properties.GetInputProperties(node.name())) {
+// Checks if a node is a candidate to pin to Host.
+// The rough algorithm is as follows:
+// 1] Check if node is blacklisted.
+// 2] Check if node can run on Host.
+// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
+// ints, and pinned to Host).
+Status IsNodeHostCandidate(
+ const GraphView& graph, GraphProperties* properties, const NodeDef& node,
+ OpDevicePortOnHostMap* op_device_outport_pinned_to_host_cache,
+ bool* is_candidate) {
+ *is_candidate = false;
+
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
+
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check the node can be run on CPU.
+ Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
+ if (!s.ok()) {
+ return Status::OK();
+ }
+
+ // Check all outputs are Host friendly.
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ for (const auto& prop : properties->GetOutputProperties(node.name())) {
if (!IsTensorIntegerAndSmall(prop)) {
- return false;
+ return Status::OK();
}
}
- for (const auto& prop : properties.GetOutputProperties(node.name())) {
- if (!IsTensorIntegerAndSmall(prop)) {
- return false;
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id,
+ op_device_outport_pinned_to_host_cache, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
}
}
- return true;
+
+ *is_candidate = true;
+ return Status::OK();
}
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device) {
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace
+
+// Tries to swap `device` to a Host device from `devices`. Returns true iff
+// there was a swap.
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device) {
// Force this node onto the CPU.
- if (device.empty() && has_device_cpu) {
- return "/device:CPU:0";
- } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ if (device->empty() && has_device_cpu) {
+ *device = "/device:CPU:0";
+ return true;
+ } else if (str_util::StrContains(*device, DEVICE_GPU)) {
// Sometimes the cluster can have:
// devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
// and we need to handle them properly.
@@ -146,30 +333,19 @@ string TryFindHostDevice(const gtl::FlatSet<string>& devices,
{std::pair<string, string>("GPU", "CPU:0"),
std::pair<string, string>("/device", "/device:CPU:0")}) {
const string device_host =
- strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ strings::StrCat(device->substr(0, device->rfind(device_match.first)),
device_match.second);
if (devices.find(device_host) != devices.end()) {
- return device_host;
+ *device = device_host;
+ return true;
}
}
}
- // We couldn't find an appropriate Host device, return original device.
- return device;
-}
-
-bool IsTPUGraphDef(const GraphDef& def) {
- for (const auto& node : def.node()) {
- if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
- node.op() == "TPUPartitionedCall") {
- return true;
- }
- }
+ // We couldn't find an appropriate Host device, return false.
return false;
}
-// All the nodes that should be blacklisted and not swapped.
-bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -182,7 +358,6 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
GraphProperties properties(item);
- bool has_properties = false;
GraphView graph(optimized_graph);
gtl::FlatSet<string> devices;
@@ -202,45 +377,26 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// All the Const nodes, and their original devices in topological order.
std::vector<std::pair<NodeDef*, string>> const_nodes;
- for (auto& node : *optimized_graph->mutable_node()) {
- // Check if node already on CPU.
- if (str_util::StrContains(node.device(), DEVICE_CPU)) {
- continue;
- }
-
- // Skip these node types.
- if (internal::IsBlacklisted(node)) {
- continue;
- }
+ // Cache to map {op, device, port} -> bool on whether it is pinned to host.
+ internal::OpDevicePortOnHostMap op_device_outport_pinned_to_host_cache;
+ internal::OpDevicePortOnHostMap op_device_inport_pinned_to_host_cache;
- // Check the node can be run on CPU.
- Status s = FindKernelDef(DEVICE_CPU, node, nullptr, nullptr);
- if (!s.ok()) {
- continue;
- }
-
- // Check all input's are pinned to CPU.
- if (!internal::AreAllNodeInputsPinnedToHost(graph, node)) {
- continue;
- }
-
- if (!has_properties) {
- // This is an expensive call, call it lazily.
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
- has_properties = true;
- }
-
- // Check all inputs and outputs are integers and small.
- if (!internal::AreAllNodeInputsAndOutputsIntsAndSmall(properties, node)) {
+ for (auto& node : *optimized_graph->mutable_node()) {
+ bool is_candidate = false;
+ TF_RETURN_IF_ERROR(internal::IsNodeHostCandidate(
+ graph, &properties, node, &op_device_outport_pinned_to_host_cache,
+ &is_candidate));
+ if (!is_candidate) {
continue;
}
- if (IsConstant(node)) {
- const_nodes.emplace_back(&node, node.device());
+ const string original_device = node.device();
+ const bool swapped = internal::TrySwapToHostDevice(devices, has_device_cpu,
+ node.mutable_device());
+ // Keep track of all Const nodes that we swapped.
+ if (swapped && IsConstant(node)) {
+ const_nodes.emplace_back(&node, original_device);
}
- // Try and swap the device to Host.
- node.set_device(
- internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
}
// Traverse all `const_nodes`, and map them back to GPU greedily.
@@ -248,10 +404,13 @@ Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* node = it.first;
const string& device = it.second;
- // Check all the consumers of this node, if any of them are on the original
- // device, swap this node back onto the original device.
+ // Check all the consumers of this node, if any of them are not on CPU, swap
+ // this node back onto the original device.
for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
- if (fanout.node->device() == device) {
+ // The consumer is not Host friendly, swap it back to the original device.
+ if (!internal::IsNodeInputPortHostFriendly(
+ *fanout.node, fanout.port_id,
+ &op_device_inport_pinned_to_host_cache)) {
node->set_device(device);
break;
}
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
index d557a03463..bed4a9ef95 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -26,8 +26,8 @@ namespace tensorflow {
namespace grappler {
namespace internal {
// Try and find an appropriate Host device in `devices` given `device`.
-string TryFindHostDevice(const gtl::FlatSet<string>& devices,
- bool has_device_cpu, const string& device);
+bool TrySwapToHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, string* device);
} // end namespace internal
// Optimize TensorFlow ops that should be swapped into the CPU to avoid
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
index 173cb3fe3c..9bb030b220 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -28,30 +28,60 @@ namespace {
class PinToHostOptimizerTest : public GrapplerTest {};
-TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceNoDevices) {
gtl::FlatSet<string> devices = {};
- EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
-
- devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
- "/device:CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
- "/device:CPU:0");
-
- devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_CPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_CPU:0");
-
- devices = {"/device:XLA_GPU:0"};
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
- "/device:XLA_GPU:0");
- EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
- "/device:XLA_GPU:*");
+
+ string device = "ABC";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "ABC");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, true, &device));
+ EXPECT_EQ(device, "/device:CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaCpuXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_TRUE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_CPU:0");
+}
+
+TEST_F(PinToHostOptimizerTest, TrySwapToHostDeviceXlaGpu) {
+ gtl::FlatSet<string> devices = {"/device:XLA_GPU:0"};
+
+ string device = "";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_TRUE(device.empty());
+
+ device = "/device:XLA_GPU:0";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:0");
+
+ device = "/device:XLA_GPU:*";
+ EXPECT_FALSE(internal::TrySwapToHostDevice(devices, false, &device));
+ EXPECT_EQ(device, "/device:XLA_GPU:*");
}
TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
@@ -160,6 +190,48 @@ TEST_F(PinToHostOptimizerTest, NoSwap) {
EXPECT_EQ(found, 3);
}
+TEST_F(PinToHostOptimizerTest, Identity) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped.
+ // `b` should be placed onto Host since `c` pins the input to Host memory.
+ Output a =
+ ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64});
+ Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2});
+ Output c =
+ ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b);
+ Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c);
+ Output e = ops::Multiply(s.WithOpName("e"), d, d);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_EQ(node.device(), "/device:GPU:0");
+ } else if (node.name() == "b") {
+ // If CUDA, then there is a GPU kernel registration that is pinned to Host
+ // memory. Consequently, `b` will be mapped to Host correct if there is
+ // a GPU kernel registered.
+#if GOOGLE_CUDA
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+#else
+ EXPECT_TRUE(node.device().empty());
+#endif
+ } else if (node.name() == "d") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ } else if (node.name() == "e") {
+ EXPECT_TRUE(node.device().empty());
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 5);
+}
+
TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 008a289cfd..9ada8b7ff9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
+ optimized_graph->mutable_node()->Reserve(item.graph.node_size());
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
@@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 4542d17ccc..6ccb1cd783 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = item.graph;
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
@@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop =
properties.GetOutputProperties(reduce_indices.node->name());
if (prop.size() < reduce_indices.port_id) {
@@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop1 = properties.GetInputProperties(input1.node->name());
const auto& prop2 = properties.GetInputProperties(input2.node->name());
if (prop1.size() != 1 || prop2.size() != 1) {
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index db6e4e6852..5867d01324 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -156,45 +156,6 @@ bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
-string NodeName(const string& name) {
- int position;
- return ParseNodeName(name, &position);
-}
-
-int NodePosition(const string& name) {
- int position;
- ParseNodeNameAsStringPiece(name, &position);
- return position;
-}
-
-int NodePositionIfSameNode(const string& input_name, const string& node_name) {
- const bool is_ctrl = input_name[0] == '^';
- auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
- auto node_it = node_name.begin();
- if (node_name.empty() ||
- std::distance(input_it, input_name.end()) < node_name.size()) {
- return -2;
- }
- while (node_it != node_name.end()) {
- if (*input_it++ != *node_it++) {
- return -2;
- }
- }
- if (input_it == input_name.end()) {
- return is_ctrl ? -1 : 0;
- } else if (*input_it++ == ':') {
- StringPiece remaining(&(*input_it),
- std::distance(input_it, input_name.end()));
- int position;
- if (!strings::safe_strto32(remaining, &position)) {
- return -2;
- }
- return is_ctrl ? -1 : position;
- } else {
- return -2;
- }
-}
-
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index 296ee1678e..95126d470c 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -102,40 +101,92 @@ bool IsControlInput(const string& name);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
+// Returns the trailing position number (or zero if no number is present) if
+// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
+// Returns -2 if NodeName(input_name) is not equal to node_name.
+// Note: This function is used very heavily, and this hand-optimized
+// version is 3-4x faster than the version using Scanner, which it replaced.
+// This is worth the reduction in readability.
+inline int NodePositionIfSameNode(const string& input_name,
+ const string& node_name) {
+ if (input_name.empty()) return -2;
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
-string NodeName(const string& name);
+inline StringPiece NodeNameAsStringPiece(const string& name) {
+ static const string empty;
+ if (name.empty()) return StringPiece(empty);
+ const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ if (end_it != name.end() && *end_it != ':') {
+ return StringPiece(empty);
+ }
+ return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
+}
-// Get the trailing position number ":{digits}" (if any) of a node name.
-// Returns -1 for control inputs.
-int NodePosition(const string& name);
+// Return the node name corresponding to 'name' if name is valid, or the empty
+// string otherwise.
+inline string NodeName(const string& name) {
+ return string(NodeNameAsStringPiece(name));
+}
+// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
int* position) {
- // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
- // to get a node name.
- strings::Scanner scan(name);
- scan.ZeroOrOneLiteral("^")
- .RestartCapture()
- .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
- .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
- StringPiece capture;
- StringPiece remaining;
- if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ static const string empty;
+ if (name.empty()) {
*position = 0;
- static const string empty;
return StringPiece(empty);
- } else {
- if (name[0] == '^') {
- *position = -1;
- } else if (remaining.empty()) {
- *position = 0;
- } else {
- // Skip the first ':' character.
- CHECK(strings::safe_strto32(remaining.substr(1), position));
+ }
+ const bool is_ctrl = name[0] == '^';
+ const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
+ *position = is_ctrl ? -1 : 0;
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
+ if (end_it != name.end()) {
+ if (*end_it != ':') {
+ return StringPiece(empty);
+ } else if (!is_ctrl) {
+ ++end_it;
+ StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
+ if (!strings::safe_strto32(remaining, position)) {
+ return StringPiece(empty);
+ }
}
- return capture;
}
+ return node_name;
}
// Returns the node name and position in a single call.
@@ -143,10 +194,11 @@ inline string ParseNodeName(const string& name, int* position) {
return string(ParseNodeNameAsStringPiece(name, position));
}
-// Returns NodePosition(input_name) if NodeName(input_name) == node_name.
-// Otherwise returns -2;
-// REQUIRES: inputs_name.size() > 0 && node_name.size() > 0.
-int NodePositionIfSameNode(const string& input_name, const string& node_name);
+inline int NodePosition(const string& name) {
+ int position;
+ ParseNodeNameAsStringPiece(name, &position);
+ return position;
+}
// Add a prefix to a node name with a custom delimiter.
string AddPrefixToNodeName(const string& name, const string& prefix,
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a428aea7f5..6861fb423c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
- connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ connectivity->RegisterFunctionBodyOutputs(node.name(),
+ std::move(outputs_range_map));
return Status::OK();
}
@@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
- const InputArgExpansion& input_arg_expansion) {
- const auto& input_name = input_arg_expansion.input_name;
+ InputArgExpansion input_arg_expansion) {
+ string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
- input_arg_expansions_.emplace(input_name, input_arg_expansion);
+
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
- input_arg_placeholders_.emplace(
- placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ input_arg_placeholders_.insert(
+ {placeholder, InputArgPlaceholder{input_name, /*position=*/i}});
}
+ input_arg_expansions_.insert(
+ {std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
- const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_[node_name] = outputs;
+ const string& node_name, tensorflow::NameRangeMap&& outputs) {
+ function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
@@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const auto& output_range = output->second;
if (position == -1) {
+ graph_def_inputs->reserve(graph_def_inputs->size() +
+ output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
- i == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", i));
+ graph_def_inputs->push_back(
+ i == 0 ? node_name : strings::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
- pos == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", pos));
+ graph_def_inputs->push_back(
+ pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
}
return Status::OK();
@@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
}
function_body_node->clear_input();
- for (const string& expanded_input : expanded_inputs)
- function_body_node->add_input(expanded_input);
+ for (string& expanded_input : expanded_inputs)
+ function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
@@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem(
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
- feed.emplace_back(placeholder, Tensor());
+ feed.push_back({placeholder, Tensor()});
input_arg_placeholders_.insert(placeholder);
}
}
@@ -460,7 +463,7 @@ Status InstantiationBodyParameters(
auto it = func_instantiation_attr.find(placeholder);
if (it != func_instantiation_attr.end()) {
- body_parameters->emplace(placeholder, it->second);
+ body_parameters->insert({placeholder, it->second});
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
@@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
- std::vector<InputArgExpansion> inputs;
- std::vector<OutputArgExpansion> outputs;
- std::vector<string> keep_nodes;
-
// Function body shares the library with the graph that instantiated it.
GraphDef function_body;
*function_body.mutable_library() = flib.ToProto();
@@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
+ std::vector<InputArgExpansion> inputs;
+ inputs.reserve(signature.input_arg_size());
+
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
@@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
/*is_ref*/ input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
- inputs.push_back(input_expansion);
+ inputs.push_back(std::move(input_expansion));
}
+ std::vector<string> keep_nodes;
// Add all function nodes to the function body
for (const NodeDef& func_def_node : func.node_def()) {
NodeDef* new_node = function_body.add_node();
@@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
+ std::vector<OutputArgExpansion> outputs;
+ outputs.reserve(signature.output_arg_size());
// Add function outputs
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
@@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
- /*output_tensors=*/output_tensors};
- outputs.push_back(output);
+ /*output_tensors=*/std::move(output_tensors)};
+ outputs.push_back(std::move(output));
}
bool is_stateful = signature.is_stateful();
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 733caf325f..ef944ced09 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -70,9 +71,9 @@ struct OutputArgExpansion {
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
- void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion);
+ void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
- const tensorflow::NameRangeMap& outputs);
+ tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index 6b787a6910..9b6c1f690b 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -371,6 +371,25 @@ BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+#define BM_ParseNodeNameAsStringPiece(I, NAME) \
+ static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
+ string input = I; \
+ for (int i = 0; i < iters; ++i) { \
+ int position; \
+ const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \
+ CHECK_GE(position, -1); \
+ CHECK(!name.empty()); \
+ } \
+ } \
+ BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME)
+
+BM_ParseNodeNameAsStringPiece("foo", foo);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl);
+BM_ParseNodeNameAsStringPiece("foo:123", foo123);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl);
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 1a3db2c7cd..3a920f26f3 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1197,8 +1197,10 @@ tf_cc_test(
tf_cc_test(
name = "example_parsing_ops_test",
- size = "large",
+ size = "medium",
srcs = ["example_parsing_ops_test.cc"],
+ shard_count = 4,
+ tags = ["optonly"],
deps = [
":example_parsing_ops",
":ops_testutil",
@@ -2028,8 +2030,8 @@ tf_kernel_library(
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:resource_variable_ops_op_lib",
- "//third_party/eigen3",
],
)
@@ -4049,11 +4051,6 @@ cc_library(
)
SPARSE_DEPS = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:sparse_ops_op_lib",
@@ -4086,7 +4083,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_cross_op",
prefix = "sparse_cross_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4098,13 +4097,19 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_dense_binary_op_shared",
prefix = "sparse_dense_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_sparse_binary_op_shared",
prefix = "sparse_sparse_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4136,7 +4141,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_softmax",
prefix = "sparse_softmax",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4148,25 +4155,37 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_tensor_dense_add_op",
prefix = "sparse_tensor_dense_add_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":scatter_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_tensor_dense_matmul_op",
prefix = "sparse_tensor_dense_matmul_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ ":fill_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_to_dense_op",
prefix = "sparse_to_dense_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_xent_op",
prefix = "sparse_xent_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4431,6 +4450,7 @@ cc_library(
":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
+ ":unicode_script_op",
],
)
@@ -4438,7 +4458,12 @@ cc_library(
name = "string_util",
srcs = ["string_util.cc"],
hdrs = ["string_util.h"],
- deps = ["//tensorflow/core:lib"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@icu//:common",
+ ],
)
STRING_DEPS = [
@@ -5254,6 +5279,8 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
+ "cwise_op_xlogy.cc",
+ "cwise_op_xdivy.cc",
"data_format_ops.cc",
"decode_wav_op.cc",
"deep_conv2d.cc",
@@ -5469,6 +5496,7 @@ filegroup(
"batch_kernels.*",
"regex_full_match_op.cc",
"regex_replace_op.cc",
+ "unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@@ -6414,6 +6442,12 @@ tf_mkl_kernel_library(
)
tf_mkl_kernel_library(
+ name = "mkl_slice_op",
+ prefix = "mkl_slice_op",
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
+tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
deps = ARRAY_DEPS + mkl_deps(),
@@ -6557,6 +6591,16 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "unicode_script_op",
+ srcs = ["unicode_script_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:string_ops_op_lib",
+ "@icu//:common",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe63..f48bd0c318 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c70..25ae795d8e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 792eb74e31..0d53240330 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -1,7 +1,7 @@
# Description: Utilities.
package(
- default_visibility = ["//tensorflow:internal"],
+ default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
@@ -12,7 +12,6 @@ cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
@@ -21,7 +20,6 @@ cc_library(
cc_library(
name = "periodic_function",
- visibility = ["//visibility:public"],
deps = [
":periodic_function_dynamic",
"//tensorflow/core:lib",
@@ -190,7 +188,6 @@ cc_library(
testonly = 1,
srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index e0da91125b..82e2913b64 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -143,6 +143,7 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
+ col_params_.instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_exec, done](const Status& s) {
@@ -171,7 +172,7 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -195,13 +196,14 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
if (c->mutable_output(0) == nullptr) {
// Allocate the output tensor, trying to reuse the input.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- c, c->forward_input_or_allocate_output({0}, 0, shape_, &output),
- done);
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, col_params_.instance.shape, &output),
+ done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
- c, shape_.IsSameSize(c->input(0).shape()),
+ c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
@@ -214,8 +216,6 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
};
@@ -234,7 +234,7 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -258,7 +258,8 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
if (c->mutable_output(0) == nullptr) {
// No input, so must allocate output.
Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
+ OP_REQUIRES_OK_ASYNC(
+ c, c->allocate_output(0, col_params_.instance.shape, &output), done);
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
@@ -270,8 +271,6 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
};
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 717a9f40a9..78856c4a99 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> {
};
#endif
+#define TF_REQUIRES(EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
+ } while (false)
+
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params) {
+ TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
+ TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
+ TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
+ string data_format_string;
+ TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
+ TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
+ errors::InvalidArgument("Invalid data format"));
+
+ const auto& strides = params->strides;
+ const auto& dilations = params->dilations;
+ const auto& data_format = params->data_format;
+
+ TF_REQUIRES(dilations.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ TF_REQUIRES(strides.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides, data_format, 'N');
+ const int64 stride_c = GetTensorDim(strides, data_format, 'C');
+ const int64 stride_h = GetTensorDim(strides, data_format, 'H');
+ const int64 stride_w = GetTensorDim(strides, data_format, 'W');
+ TF_REQUIRES(
+ stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ TF_REQUIRES(stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
+ const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
+ const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
+ const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
+ TF_REQUIRES(
+ dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ TF_REQUIRES(
+ dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ return Status::OK();
+}
+
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions) {
+ // Check that 2D convolution input and filter have exactly 4 dimensions.
+ TF_REQUIRES(input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ TF_REQUIRES(filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+ for (int i = 0; i < 3; i++) {
+ TF_REQUIRES(
+ FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. Check that it is the same as the
+ // filter's in_depth or it is evenly divisible by filter's in_depth.
+ const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
+ const int64 patch_depth_raw = filter.dim_size(2);
+ TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input depth too large"));
+ TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Patch depth too large"));
+ const int in_depth = static_cast<int>(in_depth_raw);
+ const int patch_depth = static_cast<int>(patch_depth_raw);
+ TF_REQUIRES(in_depth % patch_depth == 0,
+ errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
+ TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int input_rows = static_cast<int>(input_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
+ TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int input_cols = static_cast<int>(input_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
+ TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // Take the stride and dilation from the second and third dimensions only (we
+ // do not support striding or dilation on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
+ const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
+ const int dilation_rows =
+ GetTensorDim(params.dilations, params.data_format, 'H');
+ const int dilation_cols =
+ GetTensorDim(params.dilations, params.data_format, 'W');
+
+ // Compute windowed output sizes for rows and columns.
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
+ &out_rows, &pad_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
+ &out_cols, &pad_cols));
+
+ dimensions->batch = batch;
+ dimensions->input_rows = input_rows;
+ dimensions->input_cols = input_cols;
+ dimensions->in_depth = in_depth;
+ dimensions->filter_rows = filter_rows;
+ dimensions->filter_cols = filter_cols;
+ dimensions->patch_depth = patch_depth;
+ dimensions->out_depth = out_depth;
+ dimensions->stride_rows = stride_rows;
+ dimensions->stride_cols = stride_cols;
+ dimensions->dilation_rows = dilation_rows;
+ dimensions->dilation_cols = dilation_cols;
+ dimensions->out_rows = out_rows;
+ dimensions->out_cols = out_cols;
+ dimensions->pad_rows = pad_rows;
+ dimensions->pad_cols = pad_cols;
+
+ return Status::OK();
+}
+
+#undef TF_REQUIRES
+
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
+
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- OP_REQUIRES(context, strides_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
- const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
- const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
- const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
- const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
- OP_REQUIRES(
- context, stride_n == 1 && stride_c == 1,
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
- errors::InvalidArgument(
- "Row and column strides should be larger than 0."));
-
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
-
const Tensor& input = context->input(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
const Tensor& filter = context->input(1);
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
- }
+ Conv2DDimensions dimensions;
+ OP_REQUIRES_OK(context,
+ ComputeConv2DDimension(params_, input, filter, &dimensions));
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth or be evenly divisible by filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- const int64 patch_depth = filter.dim_size(2);
- OP_REQUIRES(context, in_depth % patch_depth == 0,
- errors::InvalidArgument(
- "input depth must be evenly divisible by filter depth: ",
- in_depth, " vs ", patch_depth));
-
- // The last dimension for filter is out_depth.
- const int out_depth = static_cast<int>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
- const int input_rows = static_cast<int>(input_rows_raw);
- const int filter_rows = static_cast<int>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
- const int input_cols = static_cast<int>(input_cols_raw);
- const int filter_cols = static_cast<int>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(context,
- FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
- const int batch = static_cast<int>(batch_raw);
-
- // For now we take the stride and dilation from the second and third
- // dimensions only (we do not support striding or dilation on the batch or
- // depth dimension).
- const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
-
- const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
- const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_rows, filter_rows, dilation_rows,
- stride_rows, padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_cols, filter_cols, dilation_cols,
- stride_cols, padding_, &out_cols, &pad_cols));
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(
+ params_.data_format, dimensions.batch, dimensions.out_rows,
+ dimensions.out_cols, dimensions.out_depth);
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "Conv2D: in_depth = " << in_depth
- << ", patch_depth = " << patch_depth
- << ", input_cols = " << input_cols
- << ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
- << ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", dilation_rows = " << dilation_rows
- << ", dilation_cols = " << dilation_cols
- << ", out_depth = " << out_depth;
+ VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
+ << ", patch_depth = " << dimensions.patch_depth
+ << ", input_cols = " << dimensions.input_cols
+ << ", filter_cols = " << dimensions.filter_cols
+ << ", input_rows = " << dimensions.input_rows
+ << ", filter_rows = " << dimensions.filter_rows
+ << ", stride_rows = " << dimensions.stride_rows
+ << ", stride_cols = " << dimensions.stride_cols
+ << ", dilation_rows = " << dimensions.dilation_rows
+ << ", dilation_cols = " << dimensions.dilation_cols
+ << ", out_depth = " << dimensions.out_depth;
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
@@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> {
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
if (LaunchXsmmConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
#endif
if (LaunchDeepConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
- output, data_format_);
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, params_.padding,
+ output, params_.data_format);
}
private:
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
+ Conv2DParameters params_;
bool use_cudnn_;
- Padding padding_;
- TensorFormat data_format_;
- LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
+ LaunchConv2DOp<Device, T> launcher_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
};
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index adf4601b43..7ec878e0b2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
+// Convolution parameters specified by Op attributes.
+struct Conv2DParameters {
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Convolution dimensions inferred from parameters, input and filter tensors.
+struct Conv2DDimensions {
+ int batch;
+ int input_rows;
+ int input_cols;
+ int in_depth;
+
+ int filter_rows;
+ int filter_cols;
+ int patch_depth;
+ int out_depth;
+
+ int stride_rows;
+ int stride_cols;
+
+ int dilation_rows;
+ int dilation_cols;
+
+ int64 out_rows;
+ int64 out_cols;
+ int64 pad_rows;
+ int64 pad_cols;
+};
+
+// Initializes and validates Conv2D parameters configured by OpKernel
+// attributes.
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params);
+
+// Computes and validates convolutions dimensions from Conv2D parameters. If
+// parameters are valid, dimensions will be updated with derived convolution
+// dimensions, otherwise error will be returned.
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
new file mode 100644
index 0000000000..e4b21a66c6
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
new file mode 100644
index 0000000000..1e1b5a426e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc
new file mode 100644
index 0000000000..6a6aec5e86
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xdivy.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xdivy<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc
new file mode 100644
index 0000000000..e71a9109b2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xlogy.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xlogy<TYPE>>);
+REGISTER_SYCL_KERNEL(Eigen::half);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+REGISTER_SYCL_KERNEL(complex64);
+REGISTER_SYCL_KERNEL(complex128);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 22eb66e979..66ba827a90 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -471,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> {
enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
};
+// TODO(srvasude): Add packet versions of this operation.
+template <typename Scalar>
+struct xlogy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x * numext::log(y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xlogy_op<Scalar>> {
+ enum {
+ Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost,
+ PacketAccess = false
+ };
+};
+
+template <typename Scalar>
+// TODO(srvasude): Add packet versions of this operation.
+struct xdivy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x / y;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xdivy_op<Scalar>> {
+ enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -830,6 +869,12 @@ struct squared_difference
Eigen::internal::scalar_difference_op<T>>> {};
template <typename T>
+struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
+
+template <typename T>
+struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
+
+template <typename T>
struct less : base<T, Eigen::internal::less<T>, bool> {};
template <typename T>
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb..8ad3b4d1fc 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif // !INTEL_MKL || !ENABLE_MKL
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 87efdff789..37c1c54786 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -45,6 +45,16 @@ cc_library(
],
)
+tf_cc_test(
+ name = "dataset_utils_test",
+ srcs = ["dataset_utils_test.cc"],
+ deps = [
+ ":dataset_utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
cc_library(
name = "captured_function",
srcs = ["captured_function.cc"],
@@ -205,6 +215,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -232,6 +243,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -245,6 +257,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
@@ -285,6 +298,7 @@ tf_kernel_library(
deps = [
":captured_function",
":dataset",
+ ":dataset_utils",
":parallel_map_iterator",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:dataset_ops_op_lib",
@@ -458,6 +472,7 @@ tf_kernel_library(
srcs = ["stats_aggregator_dataset_op.cc"],
deps = [
":dataset",
+ "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
],
@@ -765,6 +780,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index a04f150e71..9607e9444c 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
- return output_tensorshape;
+ return PartialTensorShape();
+ PartialTensorShape output_tensorshape({});
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
+ output_tensorshape.AddDim(dims1[d]);
else
- output_tensorshape.Concatenate(-1);
+ output_tensorshape.AddDim(-1);
}
return output_tensorshape;
}
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index e10833f525..a40f7f2146 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -15,10 +15,57 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
namespace tensorflow {
namespace data {
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices) {
+ FunctionLibraryRuntime::Handle fn_handle;
+ TF_RETURN_IF_ERROR(ctx->function_library()->Instantiate(
+ func.name(), AttrSlice(&func.attr()), &fn_handle));
+ auto cleanup = gtl::MakeCleanup([ctx, fn_handle]() {
+ Status s = ctx->function_library()->ReleaseHandle(fn_handle);
+ if (!s.ok()) {
+ LOG(WARNING) << "Failed to release handle: " << s.error_message();
+ }
+ });
+
+ const FunctionBody* fn_body =
+ ctx->function_library()->GetFunctionBody(fn_handle);
+ indices->resize(fn_body->ret_nodes.size());
+ for (size_t i = 0; i < fn_body->ret_nodes.size(); ++i) {
+ Node* ret_node = fn_body->ret_nodes[i];
+ Node* ret_input_node;
+ TF_RETURN_IF_ERROR(ret_node->input_node(0, &ret_input_node));
+ if (ret_input_node->def().op() == FunctionLibraryDefinition::kArgOp) {
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ret_input_node->def(), "index", &((*indices)[i])));
+ } else {
+ indices->clear();
+ break;
+ }
+ }
+ return Status::OK();
+}
+
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices) {
+ std::map<int, int> last_use;
+ for (size_t i = 0; i < indices.size(); ++i) {
+ last_use[indices[i]] = i;
+ }
+ std::vector<bool> can_move;
+ can_move.resize(indices.size());
+ for (size_t i = 0; i < indices.size(); ++i) {
+ can_move[i] = last_use[indices[i]] == i;
+ }
+ return can_move;
+}
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6ec1350cd4..d777062293 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -22,6 +22,26 @@ limitations under the License.
namespace tensorflow {
namespace data {
+// This method is used to determine whether we can short-circuit the evaluation
+// of the user-defined function `func`. Short-circuting is possible if every
+// function output corresponds to one of its inputs (e.g. `f(x) = x`, `f(x,y) =
+// (y,x)`, or `f(x) = (x,x)`).
+//
+// If short-circuiting is possible, the method stores the mapping from output
+// indices to input indices in `indices`. Otherwise, `indices` will be empty.
+//
+// Returns non-ok status if analysis of the function fails.
+//
+// TODO(jsimsa): Extend this to support constants as well.
+Status ComputeShortCircuitIndices(OpKernelContext* ctx,
+ const NameAttrList& func,
+ std::vector<int>* indices);
+
+// Given a vector that maps output indices to input indices, return a vector
+// that identifies for which output indices can we move the input (assuming
+// output indices are processed left to right).
+std::vector<bool> ComputeMoveVector(const std::vector<int>& indices);
+
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
diff --git a/tensorflow/core/kernels/data/dataset_utils_test.cc b/tensorflow/core/kernels/data/dataset_utils_test.cc
new file mode 100644
index 0000000000..43295b8ebb
--- /dev/null
+++ b/tensorflow/core/kernels/data/dataset_utils_test.cc
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+TEST(DatasetUtils, ComputeMoveVector) {
+ struct TestCase {
+ std::vector<int> indices;
+ std::vector<bool> expected;
+ };
+
+ TestCase test_cases[] = {
+ TestCase{{}, {}},
+ TestCase{{1}, {true}},
+ TestCase{{1, 1}, {false, true}},
+ TestCase{{1, 2}, {true, true}},
+ TestCase{{1, 1, 2}, {false, true, true}},
+ TestCase{{1, 2, 2}, {true, false, true}},
+ };
+
+ for (auto& test_case : test_cases) {
+ EXPECT_EQ(test_case.expected, ComputeMoveVector(test_case.indices));
+ }
+}
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
new file mode 100644
index 0000000000..43406db3ed
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -0,0 +1,139 @@
+# Description:
+# Contains experimental kernels for datasets and iterators.
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
+cc_library(
+ name = "indexed_dataset_headers",
+ hdrs = ["indexed_dataset.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "indexed_dataset",
+ srcs = [
+ "identity_indexed_dataset.cc",
+ "indexed_dataset.cc",
+ ],
+ deps = [
+ ":indexed_dataset_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "prefetching_kernels",
+ srcs = ["prefetching_kernels.cc"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "directed_interleave_dataset_op",
+ srcs = ["directed_interleave_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "csv_dataset_op",
+ srcs = ["csv_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "ignore_errors_dataset_op",
+ srcs = ["ignore_errors_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "lmdb_dataset_op",
+ srcs = ["lmdb_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ "@lmdb",
+ ],
+)
+
+tf_kernel_library(
+ name = "threadpool_dataset_op",
+ srcs = ["threadpool_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "unique_dataset_op",
+ srcs = ["unique_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "assert_next_dataset_op",
+ srcs = ["assert_next_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "dataset_kernels",
+ deps = [
+ ":assert_next_dataset_op",
+ ":csv_dataset_op",
+ ":directed_interleave_dataset_op",
+ ":ignore_errors_dataset_op",
+ ":indexed_dataset",
+ ":lmdb_dataset_op",
+ ":prefetching_kernels",
+ ":threadpool_dataset_op",
+ ":unique_dataset_op",
+ ],
+)
diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
new file mode 100644
index 0000000000..3511cca0f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -0,0 +1,156 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class AssertNextDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::vector<string> transformations;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
+ &transformations));
+ *output =
+ new Dataset(ctx, input, transformations, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const std::vector<string>& transformations,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ transformations_(transformations),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Assert")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "AssertNextDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* transformations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, transformations_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ std::vector<string> tokens =
+ str_util::Split(prefix(), ':', str_util::SkipEmpty());
+ if (dataset()->transformations_.size() > tokens.size() - 2) {
+ return errors::InvalidArgument(
+ "Asserted next ", dataset()->transformations_.size(),
+ " transformations but encountered only ", tokens.size() - 2, ".");
+ }
+ int n = tokens.size();
+ for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
+ if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
+ return errors::InvalidArgument(
+ "Asserted ", dataset()->transformations_[i],
+ " transformation at offset ", i, " but encountered ",
+ tokens[n - 2 - i], " transformation instead.");
+ }
+ }
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* input_;
+ const std::vector<string> transformations_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
new file mode 100644
index 0000000000..7451ca4cb1
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -0,0 +1,860 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/parsing_ops.cc.
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/io/inputstream_interface.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_inputstream.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class CSVDatasetOp : public DatasetOpKernel {
+ public:
+ explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ string compression_type;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
+ &compression_type));
+
+ OpInputList record_defaults_list;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("record_defaults", &record_defaults_list));
+ for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
+ OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
+ errors::InvalidArgument(
+ "There should only be 1 default per field but field ", i,
+ " has ", record_defaults_list[i].NumElements()));
+ }
+
+ const Tensor* select_cols_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor));
+ OP_REQUIRES(ctx, select_cols_tensor->dims() == 1,
+ errors::InvalidArgument("`select_cols` must be a vector."));
+
+ int64 buffer_size;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
+ OP_REQUIRES(ctx, buffer_size > 0,
+ errors::InvalidArgument("buffer_size should be positive"));
+
+ string delim;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "field_delim", &delim));
+ OP_REQUIRES(ctx, delim.size() == 1,
+ errors::InvalidArgument("field_delim should be only 1 char"));
+
+ bool header;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header));
+
+ bool use_quote_delim;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim",
+ &use_quote_delim));
+ string na_value;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "na_value", &na_value));
+
+ std::vector<Tensor> record_defaults;
+ record_defaults.reserve(record_defaults_list.size());
+ for (const Tensor& t : record_defaults_list) {
+ record_defaults.push_back(t);
+ }
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ io::ZlibCompressionOptions zlib_compression_options =
+ io::ZlibCompressionOptions::DEFAULT();
+ if (compression_type == "ZLIB") {
+ zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
+ } else if (compression_type == "GZIP") {
+ zlib_compression_options = io::ZlibCompressionOptions::GZIP();
+ } else {
+ OP_REQUIRES(ctx, compression_type.empty(),
+ errors::InvalidArgument(
+ "Unsupported compression_type: ", compression_type, "."));
+ }
+ zlib_compression_options.input_buffer_size = buffer_size;
+
+ std::vector<int64> select_cols;
+ select_cols.reserve(select_cols_tensor->NumElements());
+ for (int i = 0; i < select_cols_tensor->NumElements(); ++i) {
+ select_cols.push_back(select_cols_tensor->flat<int64>()(i));
+ }
+ OP_REQUIRES(
+ ctx, output_types_.size() == select_cols.size() || select_cols.empty(),
+ errors::InvalidArgument("select_cols should match output size"));
+ for (int i = 1; i < select_cols.size(); i++) {
+ OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i],
+ errors::InvalidArgument(
+ "select_cols should be strictly increasing indices"));
+ }
+ OP_REQUIRES(
+ ctx, select_cols.empty() || select_cols.front() >= 0,
+ errors::InvalidArgument("select_cols should be non-negative indices"));
+
+ *output = new Dataset(ctx, std::move(filenames), header,
+ std::move(compression_type), zlib_compression_options,
+ output_types_, output_shapes_,
+ std::move(record_defaults), std::move(select_cols),
+ use_quote_delim, delim[0], std::move(na_value));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
+ string compression_type, io::ZlibCompressionOptions options,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
+ bool use_quote_delim, char delim, string na_value)
+ : DatasetBase(DatasetContext(ctx)),
+ filenames_(std::move(filenames)),
+ header_(header),
+ out_type_(output_types),
+ output_shapes_(output_shapes),
+ record_defaults_(std::move(record_defaults)),
+ select_cols_(std::move(select_cols)),
+ use_quote_delim_(use_quote_delim),
+ delim_(delim),
+ na_value_(std::move(na_value)),
+ use_compression_(!compression_type.empty()),
+ compression_type_(std::move(compression_type)),
+ options_(options) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::CSV")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override { return out_type_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override { return "CSVDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ Node* compression_type = nullptr;
+ Node* buffer_size = nullptr;
+ Node* header = nullptr;
+ Node* delim = nullptr;
+ Node* use_quote_delim = nullptr;
+ Node* na_value = nullptr;
+ Node* select_cols = nullptr;
+
+ std::vector<Node*> record_defaults;
+ record_defaults.reserve(record_defaults_.size());
+ for (const Tensor& t : record_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ record_defaults.emplace_back(node);
+ }
+
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(options_.input_buffer_size, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
+
+ string delim_string(1, delim_);
+ TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
+ TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {std::make_pair(0, filenames), std::make_pair(1, compression_type),
+ std::make_pair(2, buffer_size), std::make_pair(3, header),
+ std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
+ std::make_pair(6, na_value),
+ std::make_pair(7, select_cols)}, // Single tensor inputs
+ {std::make_pair(8, record_defaults)}, // Tensor list inputs
+ {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ bool select_all = dataset()->select_cols_.empty();
+ do {
+ // We are currently processing a file, so try to read the next record
+ if (input_stream_) {
+ Status s = ReadRecord(ctx, out_tensors, select_all,
+ dataset()->select_cols_);
+ if (s.ok()) {
+ // Validate output
+ if (out_tensors->size() != dataset()->out_type_.size()) {
+ return errors::InvalidArgument(
+ "Expect ", dataset()->out_type_.size(), " fields but have ",
+ out_tensors->size(), " in record");
+ }
+
+ *end_of_sequence = false;
+ return s;
+ }
+ if (!errors::IsOutOfRange(s)) {
+ // Not at the end of file, return OK or non-EOF errors to caller.
+ *end_of_sequence = false;
+ return s;
+ }
+ // We have reached the end of the current file, so maybe
+ // move on to next file.
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ // Iteration ends when there are no more files to process.
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+ current_file_index_));
+ // `input_stream_` is empty if
+ // 1. GetNext has not been called even once.
+ // 2. All files have been read and the iterator has been exhausted.
+ if (input_stream_ && num_buffer_reads_ > 0) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
+ // If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
+ num_buffer_reads_));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ ResetStreamsLocked();
+ int64 current_file_index;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+ &current_file_index));
+ current_file_index_ = size_t(current_file_index);
+ // The keys "pos" and "num_buffer_reads" are written only if
+ // the iterator was saved with an open, partially read file.
+ if (reader->Contains(full_name("pos"))) {
+ int64 pos, num_buffer_reads;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
+ &num_buffer_reads));
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+
+ num_buffer_reads_ = size_t(num_buffer_reads - 1);
+
+ // Restores the most recently held buffer
+ Status s = input_stream_->SkipNBytes(
+ num_buffer_reads_ * dataset()->options_.input_buffer_size);
+ if (!s.ok() && !errors::IsOutOfRange(s)) {
+ // We might get out of range error here if the size of the file
+ // is not an exact multiple of the buffer size, and the last buffer
+ // read is < buffer_size. This is valid and we do not surface the
+ // error.
+ return s;
+ }
+
+ Status s2 = FillBuffer(&buffer_);
+ if (!s2.ok() && !errors::IsOutOfRange(s2)) {
+ return s2;
+ }
+ pos_ = size_t(pos);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Reads an entire CSV row from the input stream, either from the
+ // existing buffer or by filling the buffer as needed. Converts extracted
+ // fields to output tensors as we go.
+ //
+ // When this function is called, pos_ should be the index of the first
+ // character of the record in buffer_, or past the end of the buffer.
+ // Note: ctx and out_tensors are only used in this function
+ // when fields are included in the record.
+ Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool select_all, const std::vector<int64>& selected)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // At the end of the file, this will return errors::OutOfRange
+ TF_RETURN_IF_ERROR(FillBuffer(&buffer_));
+ pos_ = 0;
+ }
+
+ // The first character may be \n if this is the continuation of a
+ // \r\n linebreak between this and the previous record. If so, skip it.
+
+ bool end_of_record = false; // Keep track of when we find \n, \r or EOF
+ size_t num_parsed = 0;
+ size_t num_selected_parsed = 0;
+
+ Status result;
+
+ while (!end_of_record) { // Read till we reach \n, \r or EOF
+ bool include =
+ select_all || (num_selected_parsed < selected.size() &&
+ selected[num_selected_parsed] == num_parsed);
+
+ // Don't fail fast, so that the next call to GetNext may still return
+ // a valid record
+ result.Update(
+ ParseOneField(ctx, out_tensors, &end_of_record, include));
+
+ num_parsed++;
+ if (include) num_selected_parsed++;
+ }
+
+ return result;
+ }
+
+ // Parses one field from position pos_ in the buffer. Fields are
+ // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of
+ // the next field.
+ Status ParseOneField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // If we get here, this means the previous field's end coincided
+ // with the end of the buffer. We can fill the buffer without abandon.
+ Status s = FillBuffer(&buffer_);
+
+ if (errors::IsOutOfRange(s)) {
+ // Reached EOF, and last field is empty
+ *end_of_record = true;
+ if (include) {
+ return FieldToOutput(ctx, StringPiece(), out_tensors);
+ } else {
+ return Status::OK();
+ }
+ } else if (!s.ok()) {
+ return s; // Surface other errors back to caller
+ }
+
+ pos_ = 0;
+ }
+
+ if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') {
+ return ParseQuotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ return ParseUnquotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ // For keeping track of relevant parts of a field from a previous buffer
+ struct Piece {
+ size_t start;
+ size_t len;
+ string buffer;
+
+ Piece(string buffer, size_t start, size_t len)
+ : start(start), len(len), buffer(std::move(buffer)) {}
+ };
+
+ // Given that pos_ exceeds the buffer, saves the relevant part of the
+ // current buffer (if necessary), fills the buffer, and resets indices to
+ // 0.
+ Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
+ size_t* start, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ string temp_buffer;
+
+ buffer_.swap(temp_buffer);
+ if (include && pos_ > *start) {
+ earlier_pieces->push_back(
+ Piece(std::move(temp_buffer), *start, pos_ - *start));
+ }
+ pos_ = 0;
+ *start = 0;
+ return FillBuffer(&buffer_);
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseQuotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ pos_++; // Starting quotation mark
+
+ Status parse_result;
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ return errors::InvalidArgument(
+ "Reached end of file without closing quoted field in "
+ "record");
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+ if (ch == '"') {
+ // When we encounter a quote, we look ahead to the next character to
+ // decide what to do
+ pos_++;
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ // This was the last field. We are done
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(), out_tensors, earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s;
+ }
+ }
+
+ char next = buffer_[pos_];
+ pos_++;
+ if (next == dataset()->delim_) {
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ return parse_result;
+
+ } else if (next == '\n' || next == '\r') {
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ if (next == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ } else if (next != '"') {
+ // Take note of the error, but keep going to end of field.
+ include = false; // So we don't get funky errors when trying to
+ // unescape the quotes.
+ parse_result.Update(errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another quote"));
+ }
+
+ } else {
+ pos_++;
+ }
+ }
+ }
+
+ // Converts quoted field to an output tensor, removing the starting
+ // and ending quotes from it and unescaping double quotations if
+ // necessary.
+ Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ if (field.find('\"', 1) == field.size() - 1) {
+ // `field` contains no escaped quotation marks.
+ // Exclude framing quotation marks
+ field.remove_prefix(1);
+ field.remove_suffix(1);
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+ }
+ string field_complete;
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ field_complete.reserve(str_len);
+
+ // This bool flips every time we see a quote, so that we skip the second
+ // quote of every pair of adjacent quotes in the field. We need to track
+ // this across iterations of the for loop because adjacent double quotes
+ // may be in different buffers. Initialize to true because we also skip
+ // the opening quotation mark of the quoted field.
+ bool skip_next_quote = true;
+ for (const Piece& p : earlier_pieces) {
+ AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len),
+ &field_complete, &skip_next_quote);
+ }
+ AppendUnescapedPiece(field, &field_complete, &skip_next_quote);
+ StringPiece result = StringPiece(field_complete);
+ result.remove_suffix(1); // Skip final quote
+
+ return FieldToOutput(ctx, result, out_tensors);
+ }
+
+ void AppendUnescapedPiece(StringPiece piece, string* field_complete,
+ bool* skip_next_quote) {
+ size_t from = 0;
+ size_t found = piece.find('\"', from);
+ while (found != string::npos) {
+ if (!*skip_next_quote) {
+ // This is the first quote in a pair of adjacent double quotes
+ field_complete->append(piece.data() + from, found + 1 - from);
+ }
+ *skip_next_quote = !*skip_next_quote;
+ from = found + 1;
+ found = piece.find('\"', from);
+ }
+ // Include the chunk after the last quotation mark in the string
+ if (from < piece.size()) {
+ field_complete->append(piece.data() + from, piece.size() - from);
+ }
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseUnquotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ Status parse_result;
+
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ // Handle errors
+ if (errors::IsOutOfRange(s)) {
+ // Whatever we have is the last field of the last record
+ *end_of_record = true;
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+
+ if (ch == dataset()->delim_) {
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ pos_++;
+ return parse_result;
+ }
+ if (ch == '\n' || ch == '\r') {
+ // need special case to skip over first \n of record if the line
+ // breaks are \r\n
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ *end_of_record = true;
+ pos_++;
+ if (ch == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ }
+ if (dataset()->use_quote_delim_ && ch == '"') {
+ // Take note of the error, but keep going to end of field.
+ parse_result.Update(errors::InvalidArgument(
+ "Unquoted fields cannot have quotes inside"));
+ }
+ // Otherwise, go to next character
+ pos_++;
+ }
+ }
+
+ Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ result->clear();
+ ++num_buffer_reads_;
+ Status s = input_stream_->ReadNBytes(
+ dataset()->options_.input_buffer_size, result);
+
+ if (errors::IsOutOfRange(s) && !result->empty()) {
+ // Ignore OutOfRange error when ReadNBytes read < N bytes.
+ return Status::OK();
+ }
+ return s;
+ }
+
+ // Given a field, converts it to the right output tensor type
+ Status FieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors) {
+ size_t output_idx = out_tensors->size();
+ if (output_idx >= dataset()->out_type_.size()) {
+ // We can get here if we're selecting all columns, but the number of
+ // fields exceeds the number of defaults provided
+ return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
+ " fields but have more in record");
+ }
+ const DataType& dtype = dataset()->out_type_[output_idx];
+ Tensor component(ctx->allocator({}), dtype, {});
+ if ((field.empty() || field == dataset()->na_value_) &&
+ dataset()->record_defaults_[output_idx].NumElements() != 1) {
+ // If the field is empty or NA value, and default is not given,
+ // report error.
+ return errors::InvalidArgument("Field ", output_idx,
+ " is required but missing in record!");
+ }
+
+ switch (dtype) {
+ // For each case, if the field is empty, we use the default.
+ // Otherwise, we convert it to the right type.
+ case DT_INT32: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int32>()() =
+ dataset()->record_defaults_[output_idx].flat<int32>()(0);
+ } else {
+ int32 value;
+ if (!strings::safe_strto32(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int32: ", field);
+ }
+ component.scalar<int32>()() = value;
+ }
+ break;
+ }
+ case DT_INT64: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int64>()() =
+ dataset()->record_defaults_[output_idx].flat<int64>()(0);
+ } else {
+ int64 value;
+ if (!strings::safe_strto64(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int64: ", field);
+ }
+ component.scalar<int64>()() = value;
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<float>()() =
+ dataset()->record_defaults_[output_idx].flat<float>()(0);
+ } else {
+ float value;
+ if (!strings::safe_strtof(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid float: ", field);
+ }
+ component.scalar<float>()() = value;
+ }
+ break;
+ }
+ case DT_DOUBLE: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<double>()() =
+ dataset()->record_defaults_[output_idx].flat<double>()(0);
+ } else {
+ double value;
+ if (!strings::safe_strtod(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid double: ", field);
+ }
+ component.scalar<double>()() = value;
+ }
+ break;
+ }
+ case DT_STRING: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<string>()() =
+ dataset()->record_defaults_[output_idx].flat<string>()(0);
+ } else {
+ component.scalar<string>()() = string(field);
+ }
+ break;
+ }
+ default:
+ return errors::InvalidArgument("csv: data type ", dtype,
+ " not supported in field ",
+ output_idx);
+ }
+ out_tensors->push_back(std::move(component));
+ return Status::OK();
+ }
+
+ // Records can be delimited by "\r\n" line breaks. When we encounter a
+ // '\r', we have to check the next character to see if it is part of the
+ // linebreak, and ignore it if so.
+ void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ Status s = FillBuffer(&buffer_);
+ pos_ = 0;
+ // If we failed to fill buffer, it doesn't matter because we're done
+ // with the record
+ if (!s.ok()) return;
+ }
+ if (buffer_[pos_] == '\n') {
+ pos_++;
+ }
+ }
+
+ // Given a string field, and its index in the output,
+ // converts it to a Tensor of the right type and adds it to the
+ // out_tensors vector.
+ Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ string field_complete;
+ field_complete.reserve(str_len);
+
+ for (const Piece& p : earlier_pieces) {
+ field_complete.append(p.buffer, p.start, p.len);
+ }
+
+ field_complete.append(field.data(), field.size());
+ return FieldToOutput(ctx, field_complete, out_tensors);
+ }
+
+ // Sets up reader streams to read from the file at `current_file_index_`.
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+
+ // Actually move on to next file.
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
+ dataset()->filenames_[current_file_index_], &file_));
+ random_access_input_stream_ =
+ std::make_shared<io::RandomAccessInputStream>(file_.get(), false);
+
+ if (dataset()->use_compression_) {
+ input_stream_ = std::make_shared<io::ZlibInputStream>(
+ random_access_input_stream_.get(),
+ dataset()->options_.input_buffer_size,
+ dataset()->options_.input_buffer_size, dataset()->options_);
+ } else {
+ input_stream_ = random_access_input_stream_;
+ }
+ buffer_.clear();
+ pos_ = 0;
+ num_buffer_reads_ = 0;
+ if (dataset()->header_) {
+ // Read one line, but don't include it. Pass nullptrs as dummy
+ // pointers to objects that shouldn't be invoked anyway
+ // We need to process this as a record here instead of just finding
+ // the first newline because it might contain quoted fields with
+ // newlines in the header as well
+ std::vector<int64> empty;
+ Status s = ReadRecord(nullptr, nullptr, false, empty);
+ if (!s.ok()) {
+ return errors::InvalidArgument("Can't read header of file");
+ }
+ }
+ return Status::OK();
+ }
+
+ // Resets all reader streams.
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ input_stream_.reset();
+ file_.reset();
+ }
+
+ mutex mu_;
+ string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
+ size_t pos_ GUARDED_BY(
+ mu_); // Index into the buffer must be maintained between iters
+ size_t num_buffer_reads_ GUARDED_BY(mu_);
+ std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
+ GUARDED_BY(mu_);
+ std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<RandomAccessFile> file_
+ GUARDED_BY(mu_); // must outlive input_stream_
+ }; // class Iterator
+
+ const std::vector<string> filenames_;
+ const bool header_;
+ const DataTypeVector out_type_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<Tensor> record_defaults_;
+ const std::vector<int64> select_cols_;
+ const bool use_quote_delim_;
+ const char delim_;
+ const string na_value_;
+ const bool use_compression_;
+ const string compression_type_;
+ const io::ZlibCompressionOptions options_;
+ }; // class Dataset
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+}; // class CSVDatasetOp
+
+// Register the kernel implementation for CSVDataset.
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
new file mode 100644
index 0000000000..c47a9099c4
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class DirectedInterleaveDatasetOp : public DatasetOpKernel {
+ public:
+ explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ DatasetBase* selector_input;
+ OP_REQUIRES_OK(ctx,
+ GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
+
+ OP_REQUIRES(
+ ctx,
+ selector_input->output_dtypes().size() == 1 &&
+ selector_input->output_dtypes()[0] == DT_INT64 &&
+ selector_input->output_shapes().size() == 1 &&
+ selector_input->output_shapes()[0].IsCompatibleWith(
+ PartialTensorShape({})),
+ errors::InvalidArgument(
+ "The selector input must be a dataset of scalar int64 elements."));
+
+ std::vector<DatasetBase*> data_inputs;
+ for (size_t i = 1; i < ctx->num_inputs(); ++i) {
+ DatasetBase* input;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
+ data_inputs.push_back(input);
+
+ OP_REQUIRES(
+ ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
+ errors::InvalidArgument(
+ "All inputs must have the same output_dtypes. First input "
+ "has types ",
+ DataTypeVectorString(data_inputs[0]->output_dtypes()),
+ ", and input ", i - 1, " has types ",
+ DataTypeVectorString(input->output_dtypes())));
+ }
+ *output = new Dataset(ctx, selector_input, std::move(data_inputs));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
+ std::vector<DatasetBase*> data_inputs)
+ : DatasetBase(DatasetContext(ctx)),
+ selector_input_(selector_input),
+ data_inputs_(std::move(data_inputs)) {
+ selector_input_->Ref();
+
+ output_shapes_ = data_inputs_[0]->output_shapes();
+ data_inputs_[0]->Ref();
+ for (size_t i = 1; i < data_inputs_.size(); ++i) {
+ const DatasetBase* data_input = data_inputs_[i];
+ data_input->Ref();
+ for (size_t j = 0; j < output_shapes_.size(); ++j) {
+ output_shapes_[j] = MostSpecificCompatibleShape(
+ output_shapes_[j], data_input->output_shapes()[j]);
+ }
+ }
+ }
+
+ ~Dataset() override {
+ selector_input_->Unref();
+ for (DatasetBase* data_input : data_inputs_) {
+ data_input->Unref();
+ }
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::DirectedInterleave")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return data_inputs_[0]->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* selector_input_node;
+ TF_RETURN_IF_ERROR(
+ b->AddInputDataset(ctx, selector_input_, &selector_input_node));
+ std::vector<Node*> data_input_nodes(data_inputs_.size());
+ for (size_t i = 0; i < data_inputs_.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
+ }
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
+ {{1, data_input_nodes}}, {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ num_active_inputs_(params.dataset->data_inputs_.size()) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), ".selector"),
+ &selector_input_impl_));
+ data_input_impls_.resize(dataset()->data_inputs_.size());
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const DatasetBase* data_input = dataset()->data_inputs_[i];
+ TF_RETURN_IF_ERROR(data_input->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"),
+ &data_input_impls_[i]));
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (!selector_input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ while (true) {
+ std::vector<Tensor> selector_result;
+ *end_of_sequence = false;
+ TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
+ ctx, &selector_result, end_of_sequence));
+ if (*end_of_sequence) {
+ selector_input_impl_.reset();
+ for (auto& data_input_impl : data_input_impls_) {
+ data_input_impl.reset();
+ }
+ return Status::OK();
+ }
+
+ int64 selected_input = selector_result[0].scalar<int64>()();
+ if (selected_input < 0 || selected_input > data_input_impls_.size()) {
+ return errors::InvalidArgument(
+ "Selector index out of range: ", selected_input,
+ " >= ", data_input_impls_.size());
+ }
+
+ if (data_input_impls_[selected_input]) {
+ bool end_of_selected_input = false;
+ TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
+ ctx, out_tensors, &end_of_selected_input));
+
+ if (!end_of_selected_input) {
+ return Status::OK();
+ }
+
+ data_input_impls_[selected_input].reset();
+ --num_active_inputs_;
+
+ if (num_active_inputs_ == 0) {
+ selector_input_impl_.reset();
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+
+ LOG(WARNING) << "DirectedInterleave selected an exhausted input: "
+ << selected_input;
+ }
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (selector_input_impl_) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const auto& data_input_impl = data_input_impls_[i];
+ if (data_input_impl) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
+ ""));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("selector_input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
+ } else {
+ selector_input_impl_.reset();
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ if (!reader->Contains(full_name(
+ strings::StrCat("data_input_impl_empty[", i, "]")))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
+ } else {
+ data_input_impls_[i].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
+ GUARDED_BY(mu_);
+ int64 num_active_inputs_ GUARDED_BY(mu_);
+ };
+
+ static PartialTensorShape MostSpecificCompatibleShape(
+ const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
+ PartialTensorShape output_tensorshape;
+ if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
+ return output_tensorshape;
+ auto dims1 = ts1.dim_sizes();
+ auto dims2 = ts2.dim_sizes();
+ for (int d = 0; d < ts1.dims(); d++) {
+ if (dims1[d] == dims2[d])
+ output_tensorshape.Concatenate(dims1[d]);
+ else
+ output_tensorshape.Concatenate(-1);
+ }
+ return output_tensorshape;
+ }
+
+ const DatasetBase* const selector_input_;
+ const std::vector<DatasetBase*> data_inputs_;
+ std::vector<PartialTensorShape> output_shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
new file mode 100644
index 0000000000..2141f118ca
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -0,0 +1,156 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
+ public:
+ using IndexedDatasetOpKernel::IndexedDatasetOpKernel;
+
+ void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) override {
+ uint64 size = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size));
+ OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0"));
+ *output = new Dataset(ctx, size);
+ }
+
+ class Dataset : public IndexedDataset {
+ public:
+ Dataset(OpKernelContext* ctx, uint64 size)
+ : IndexedDataset(DatasetContext(ctx)), size_(size) {}
+
+ Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) override {
+ materialized->reset(new Materialized(this));
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::IdentityIndexedDataset")}));
+ }
+
+ string DebugString() const override {
+ return "IdentityIndexedDataset::Dataset";
+ }
+
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** node) const override {
+ return errors::Unimplemented(
+ "identity_indexed_dataset.AsGraphDefInternal");
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (cur_ < dataset()->size_) {
+ Tensor result_tensor(ctx->allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = cur_++;
+ out_tensors->emplace_back(std::move(result_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ uint64 cur_ GUARDED_BY(mu_);
+ };
+
+ class Materialized : public MaterializedIndexedDataset {
+ public:
+ explicit Materialized(Dataset* dataset) : dataset_(dataset) {
+ dataset->Ref();
+ }
+
+ ~Materialized() override {
+ // TODO(saeta): Pull this into MaterializedIndexedDataset
+ dataset_->Unref();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return dataset_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return dataset_->output_shapes();
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const override {
+ LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index
+ << ")";
+ if (index >= dataset_->size_) {
+ // Note: use InvalidArgument instead of OutOfRange error because many
+ // things consider OutOfRange to be a "clean termination" error.
+ return errors::InvalidArgument(
+ "Index ", index,
+ " is out of range for this dataset. (Size is: ", dataset_->size_,
+ ".)");
+ }
+ Tensor result_tensor(ctx.allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = index;
+ out_tensors->emplace_back(std::move(result_tensor));
+ return Status::OK();
+ }
+
+ Status Size(uint64* size) const override {
+ *size = dataset_->size_;
+ return Status::OK();
+ }
+
+ private:
+ const Dataset* const dataset_; // Not owned.
+ };
+
+ const uint64 size_;
+ std::shared_ptr<Materialized> materialized_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
new file mode 100644
index 0000000000..b34377c642
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "IgnoreErrorsDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ {
+ tf_shared_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ while (!s.ok()) {
+ out_tensors->clear();
+ s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+ }
+ if (*end_of_sequence) {
+ mutex_lock l(mu_);
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_)
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ else
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impls_empty"), ""));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("input_impls_empty")))
+ input_impl_.reset();
+ else
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
new file mode 100644
index 0000000000..75ea462f40
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -0,0 +1,375 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
+class MaterializedDatasetResource : public ResourceBase {
+ public:
+ MaterializedDatasetResource(
+ const DataTypeVector& output_dtypes,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
+
+ string DebugString() override {
+ return "Materialized IndexedDataset resource";
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) {
+ std::shared_ptr<MaterializedIndexedDataset> captured(materialized_);
+ if (captured) {
+ return captured->Get(std::move(ctx), index, out_tensors);
+ } else {
+ return errors::FailedPrecondition(
+ "Get() failed because the MaterializedIndexedDataset has not been "
+ "initialized. Ensure that you have run the materialization operation "
+ "for this MaterializedIndexedDataset before retrieving elements.");
+ }
+ }
+
+ // TODO(saeta): Implement Save and Restore
+
+ const DataTypeVector& output_dtypes() const { return output_dtypes_; }
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ Status set_materialized_dataset(
+ const std::shared_ptr<MaterializedIndexedDataset>& dataset) {
+ if (dataset) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, dataset->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, dataset->output_shapes()));
+ }
+ materialized_ = dataset;
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<MaterializedIndexedDataset> materialized_;
+ const DataTypeVector output_dtypes_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT
+// tensor. Objects of the wrapper class own a reference on an instance of an
+// `IndexedTensor` and the wrapper's copy constructor and desctructor take care
+// of managing the reference count.
+//
+// NOTE: This is not a feature-complete implementation of the DT_VARIANT
+// specification. In particular, we cannot currently serialize an arbitrary
+// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not
+// implemented.
+//
+// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just
+// use `tensorflow::DatasetVariantWrapper`.
+class IndexedDatasetVariantWrapper {
+ public:
+ IndexedDatasetVariantWrapper() : dataset_(nullptr) {}
+
+ // Transfers ownership of `dataset` to `*this`.
+ explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset)
+ : dataset_(dataset) {}
+
+ IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other)
+ : dataset_(other.dataset_) {
+ if (dataset_) dataset_->Ref();
+ }
+
+ ~IndexedDatasetVariantWrapper() {
+ if (dataset_) dataset_->Unref();
+ }
+
+ IndexedDataset* get() const { return dataset_; }
+
+ string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; }
+ string DebugString() const {
+ if (dataset_) {
+ return dataset_->DebugString();
+ } else {
+ return "<Uninitialized IndexedDatasetVariantWrapper>";
+ }
+ }
+
+ void Encode(VariantTensorData* data) const {
+ LOG(ERROR) << "The Encode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ LOG(ERROR) << "The Decode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ return false;
+ }
+
+ private:
+ IndexedDataset* const dataset_; // Owns one reference.
+};
+
+} // namespace
+
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset) {
+ if (!(tensor.dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor.shape()))) {
+ return errors::InvalidArgument(
+ "IndexedDataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ const Variant& variant = tensor.scalar<Variant>()();
+ const IndexedDatasetVariantWrapper* wrapper =
+ variant.get<IndexedDatasetVariantWrapper>();
+ if (wrapper == nullptr) {
+ return errors::InvalidArgument("Tensor must be an IndexedDataset object.");
+ }
+ *out_dataset = wrapper->get();
+ if (*out_dataset == nullptr) {
+ return errors::Internal("Read uninitialized IndexedDataset variant.");
+ }
+ return Status::OK();
+}
+
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor) {
+ if (!(tensor->dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor->shape()))) {
+ return errors::InvalidArgument(
+ "Dataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset);
+ return Status::OK();
+}
+
+void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) {
+ IndexedDataset* dataset = nullptr;
+ MakeIndexedDataset(ctx, &dataset);
+
+ if (ctx->status().ok()) {
+ OP_REQUIRES(ctx, dataset != nullptr,
+ errors::Internal("MakeIndexedDataset did not correctly "
+ "construct the IndexedDataset"));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output));
+ }
+}
+
+namespace {
+
+class MaterializedHandleOp : public OpKernel {
+ public:
+ explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ ~MaterializedHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MaterializedDatasetResource* resource;
+ OP_REQUIRES_OK(context,
+ mgr->LookupOrCreate<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this](MaterializedDatasetResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MaterializedDatasetResource(
+ output_dtypes_, output_shapes_);
+ return Status::OK();
+ }));
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MaterializedDatasetResource>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MaterializedDatasetResource* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+// TODO(saeta): Make async.
+class MaterializeDatasetOp : public OpKernel {
+ public:
+ explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IndexedDataset* dataset;
+ OP_REQUIRES_OK(ctx,
+ GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset));
+
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &materialized_resource));
+ core::ScopedUnref unref(materialized_resource);
+ std::shared_ptr<MaterializedIndexedDataset> materialized;
+ OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized));
+ OP_REQUIRES_OK(
+ ctx, materialized_resource->set_materialized_dataset(materialized));
+ }
+};
+
+// TODO(saeta): Make async
+class IndexedDatasetGet : public OpKernel {
+ public:
+ explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0),
+ &materialized_resource));
+ auto cleanup = gtl::MakeCleanup([materialized_resource] {
+ materialized_resource->Unref(); // Note: can't use core::ScopedUnref.
+ });
+
+ const Tensor* index_t;
+ OP_REQUIRES_OK(ctx, ctx->input("index", &index_t));
+ // TODO(saeta): Support batch reads (indexes should be non-scalar!)
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()),
+ errors::InvalidArgument("index must be a scalar"));
+ const uint64 index = index_t->scalar<uint64>()();
+
+ std::vector<Tensor> out_tensors;
+ Status s =
+ materialized_resource->Get(IteratorContext(ctx), index, &out_tensors);
+
+ // Note: Unref materialized_resource to avoid destruction races. (Important
+ // in a [future] async op implementation.)
+ cleanup.release()();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else {
+ auto expected_shapes = materialized_resource->output_shapes();
+ auto expected_types = materialized_resource->output_dtypes();
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ OP_REQUIRES(
+ ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()),
+ errors::Internal(
+ "Materialized dataset output at index ", i,
+ " is incompatible with the expected shape. (Expected: ",
+ expected_shapes[i], ", got: ", out_tensors[i].shape(), ")"));
+ OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i],
+ errors::Internal("Materialized dataset output at index ", i,
+ " was not the expected dtype. (Expected: ",
+ expected_types[i],
+ ", got: ", out_tensors[i].dtype(), ")"));
+ ctx->set_output(i, out_tensors[i]);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ MaterializedHandleOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
new file mode 100644
index 0000000000..27a8360cbc
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace data {
+
+// TODO(saeta): Urgh, this is ugly.
+class MaterializedIndexedDataset {
+ public:
+ virtual ~MaterializedIndexedDataset() = default;
+
+ // Retrieve the element at a given index. The output tensors are stored in
+ // out_tensors.
+ //
+ // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is
+ // returned.
+ //
+ // Get is thread-safe.
+ virtual Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const = 0;
+
+ // Size determines the number of elements in this IndexedDataset.
+ //
+ // Size is thread-safe.
+ virtual Status Size(uint64* size) const = 0;
+
+ // Returns a vector of DataType values, representing the respective
+ // element types of each tuple component in the outputs of this dataset.
+ virtual const DataTypeVector& output_dtypes() const = 0;
+
+ // Returns a vector of tensor shapes, representing the respective
+ // (and possibly partially defined) shapes of each tuple component
+ // in the outputs of this dataset.
+ virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+};
+
+// IndexedDataset represents a dataset that supports random access in addition
+// to iterator-based sequential access.
+//
+// Note: IndexedDatasets are HIGHLY experimental at this time. Expect
+// significant (backwards incompatible) changes!
+class IndexedDataset : public DatasetBase {
+ public:
+ IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {}
+
+ // Materialize (if necessary) the dataset, and return a pointer.
+ // TODO(saeta): Add in `IteratorContext* ctx` when materializing.
+ virtual Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0;
+};
+
+// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the
+// rest of the TensorFlow runtime.
+//
+// Most IndexedDataset's will be private members of classes inheriting from this
+// class.
+class IndexedDatasetOpKernel : public OpKernel {
+ public:
+ IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) final;
+
+ protected:
+ // Subclasses should implement this method. It will be called during Compute
+ // execution.
+ virtual void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) = 0;
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+};
+
+// Validates and extracts an `IndexedDataset` object from `tensor`.
+//
+// `tensor` must have been written by a call to
+// `StoreIndexedDatasetInVariantTensor`
+//
+// The retrieved pointer isa borrowed reference to the dataset, which is owned
+// by the tensor. The consumer must either acquire its own reference to the
+// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
+// destroyed or mutated while the retrieved pointer is in use.
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset);
+
+// Stores an `IndexedDataset` object in `tensor.`
+//
+// The ownership of `dataset` is transferred to `tensor`.
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
new file mode 100644
index 0000000000..8a88d32f0c
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -0,0 +1,218 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sys/stat.h>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+#include "lmdb.h" // NOLINT(build/include)
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class LMDBDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
+ : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ if (mdb_cursor_) {
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = string(
+ static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() =
+ string(static_cast<const char*>(mdb_value_.mv_data),
+ mdb_value_.mv_size);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ int val;
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ private:
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+ const string& filename = dataset()->filenames_[current_file_index_];
+
+ int val = mdb_env_create(&mdb_env_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
+
+ struct stat source_stat;
+ if (stat(filename.c_str(), &source_stat) == 0 &&
+ (source_stat.st_mode & S_IFREG)) {
+ flags |= MDB_NOSUBDIR;
+ }
+ val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ }
+ return Status::OK();
+ }
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (mdb_env_ != nullptr) {
+ if (mdb_cursor_) {
+ mdb_cursor_close(mdb_cursor_);
+ mdb_cursor_ = nullptr;
+ }
+ mdb_dbi_close(mdb_env_, mdb_dbi_);
+ mdb_txn_abort(mdb_txn_);
+ mdb_env_close(mdb_env_);
+ mdb_txn_ = nullptr;
+ mdb_dbi_ = 0;
+ mdb_env_ = nullptr;
+ }
+ }
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
+ MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
+ MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
+ MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
+
+ MDB_val mdb_key_ GUARDED_BY(mu_);
+ MDB_val mdb_value_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
new file mode 100644
index 0000000000..2c6179d9f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -0,0 +1,482 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct BufferElement {
+ // The producer sets `status` if getting the input element fails.
+ Status status;
+ // The buffered data element.
+ std::vector<Tensor> value;
+};
+
+using FunctionBufferCallback = std::function<void(const BufferElement&)>;
+
+class FunctionBufferingResource : public ResourceBase {
+ public:
+ FunctionBufferingResource(FunctionLibraryRuntime* lib,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ const NameAttrList& func, int64 buffer_size,
+ const string& source_device,
+ const string& target_device,
+ const std::vector<Tensor>& func_args,
+ const DataTypeVector& output_types)
+ : lib_(lib),
+ pflr_(std::move(pflr)),
+ func_(func),
+ buffer_size_(buffer_size),
+ source_device_(source_device),
+ target_device_(target_device),
+ func_args_(func_args),
+ output_types_(output_types),
+ handle_(kInvalidHandle),
+ is_buffering_(false),
+ end_of_sequence_(false),
+ cancelled_(false) {}
+
+ ~FunctionBufferingResource() override {
+ Cancel();
+ }
+
+ string DebugString() override {
+ return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_,
+ "; target_device: ", target_device_);
+ }
+
+ // Instantiates the function the first time it's called. After that it caches
+ // the handle.
+ Status Instantiate() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ // Re-use existing handle if it's been set, effectively caching it.
+ if (handle_ != kInvalidHandle) {
+ return Status::OK();
+ }
+ AttrValueMap attr_values = func_.attr();
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ opts.target = target_device_;
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts,
+ &handle_);
+ }
+
+ // Returns true if we've got to the end of the sequence and exhausted the
+ // buffer.
+ bool Finished() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return end_of_sequence_ && buffer_.empty();
+ }
+
+ // Cancels any buffering / prefetching going on.
+ void Cancel() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ while (is_buffering_) {
+ cond_var_.wait(l);
+ }
+ }
+
+ // Cancels all pending operations and then clears out the state.
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ Cancel();
+ mutex_lock l(mu_);
+ buffer_.clear();
+ requests_.clear();
+ is_buffering_ = false;
+ end_of_sequence_ = false;
+ cancelled_ = false;
+ }
+
+ // If the buffer has anything, runs `callback` on the first element in the
+ // buffer, else schedules the `callback` to be called. Requires `args` and
+ // `lib` in case more function calls need to be scheduled.
+ void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) {
+ bool start_buffering = false;
+ bool produced_output = false;
+ BufferElement buffer_element;
+ {
+ mutex_lock l(mu_);
+ if (!is_buffering_ && !end_of_sequence_) {
+ start_buffering = true;
+ }
+ if (!buffer_.empty()) {
+ produced_output = true;
+ std::swap(buffer_element, buffer_.front());
+ buffer_.pop_front();
+ } else {
+ produced_output = false;
+ requests_.push_back(std::move(callback));
+ }
+ }
+ if (produced_output) {
+ callback(buffer_element);
+ }
+ if (start_buffering) {
+ FillBuffer();
+ }
+ }
+
+ private:
+ void FillBuffer() LOCKS_EXCLUDED(mu_) {
+ FunctionLibraryRuntime::Handle handle;
+ std::vector<FunctionBufferCallback> cancellation_callbacks;
+ std::vector<BufferElement> cancellation_buffer_elements;
+ bool cancelled = false;
+ {
+ mutex_lock l(mu_);
+ handle = handle_;
+ if (cancelled_) {
+ cancelled = true;
+ // Run through and fulfill all pending requests, if possible.
+ while (!requests_.empty()) {
+ if (!buffer_.empty()) {
+ cancellation_buffer_elements.push_back(std::move(buffer_.front()));
+ buffer_.pop_front();
+ cancellation_callbacks.push_back(std::move(requests_.front()));
+ requests_.pop_front();
+ } else {
+ LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: "
+ << requests_.size() << " requests";
+ break;
+ }
+ }
+ is_buffering_ = false;
+ } else {
+ is_buffering_ = true;
+ }
+ }
+ if (cancelled) {
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_buffer_elements[i]);
+ }
+ cond_var_.notify_all();
+ return;
+ }
+ FunctionLibraryRuntime::Options opts;
+ // Copied from CapturedFunction::generate_step_id();
+ opts.step_id = -std::abs(static_cast<int64>(random::New64()));
+ opts.source_device = source_device_;
+ AllocatorAttributes arg_alloc_attr;
+ arg_alloc_attr.set_on_host(true);
+ opts.args_alloc_attrs.push_back(arg_alloc_attr);
+ for (const auto& dtype : output_types_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
+ if (opts.source_device != target_device_) {
+ opts.remote_execution = true;
+ }
+ opts.create_rendezvous = true;
+ auto* rets = new std::vector<Tensor>;
+ lib_->Run(opts, handle, func_args_, rets,
+ [this, rets](const Status& status) {
+ FunctionBufferCallback callback = nullptr;
+ BufferElement buffer_front;
+ bool restart_buffering = false;
+ {
+ mutex_lock l(mu_);
+ BufferElement buffer_element;
+ buffer_element.status = status;
+ if (status.ok()) {
+ buffer_element.value.swap(*rets);
+ } else {
+ end_of_sequence_ = true;
+ is_buffering_ = false;
+ }
+ buffer_.push_back(std::move(buffer_element));
+ if (!requests_.empty()) {
+ buffer_front = std::move(buffer_.front());
+ buffer_.pop_front();
+ callback = std::move(requests_.front());
+ requests_.pop_front();
+ }
+ if (buffer_.size() < buffer_size_ && !end_of_sequence_) {
+ restart_buffering = true;
+ } else {
+ // When the buffer is full, we don't want to call
+ // FillBuffer() unless we're in cancellation phase in which
+ // case FillBuffer() will do the final cleanup post
+ // cancellation.
+ if (cancelled_) {
+ restart_buffering = true;
+ }
+ is_buffering_ = false;
+ }
+ }
+ if (callback != nullptr) {
+ callback(buffer_front);
+ }
+ if (restart_buffering) {
+ FillBuffer();
+ }
+ });
+ }
+
+ mutex mu_;
+ FunctionLibraryRuntime* lib_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ NameAttrList func_;
+ const int64 buffer_size_;
+ const string source_device_;
+ const string target_device_;
+ const std::vector<Tensor> func_args_;
+ const DataTypeVector output_types_;
+ FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
+ std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
+ std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
+ bool is_buffering_ GUARDED_BY(mu_);
+ bool end_of_sequence_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_);
+ condition_variable cond_var_;
+};
+
+class FunctionBufferResourceHandleOp : public OpKernel {
+ public:
+ explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), flib_def_(nullptr) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ }
+
+ ~FunctionBufferResourceHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<FunctionBufferingResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* string_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg));
+ std::vector<Tensor> func_args;
+ func_args.push_back(*string_arg);
+
+ const string& source_device = ctx->device()->name();
+
+ // Obtain and canonicalize target_device.
+ const Tensor* target_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
+ string target_device;
+ OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
+ target_arg->scalar<string>()(), source_device,
+ &target_device));
+
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."));
+
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
+ FunctionLibraryRuntime* clone_lib;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr;
+ OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib));
+ // Create the resource.
+ FunctionBufferingResource* buffer;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>(
+ cinfo_.container(), cinfo_.name(), &buffer,
+ [clone_lib, &pflr, &source_device, &target_device, func_args,
+ this](FunctionBufferingResource** ptr) {
+ *ptr = new FunctionBufferingResource(
+ clone_lib, std::move(pflr), func_, buffer_size_,
+ source_device, target_device, func_args, output_types_);
+ return Status::OK();
+ }));
+ core::ScopedUnref s(buffer);
+ OP_REQUIRES_OK(ctx, buffer->Instantiate());
+ initialized_ = true;
+ }
+
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<FunctionBufferingResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ NameAttrList func_;
+ int64 buffer_size_;
+ string container_;
+ string name_;
+ DataTypeVector output_types_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_CPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_SYCL)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Prefetches and fills up a buffer by calling a function that provides the
+// elements to buffer.
+class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
+ public:
+ explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx) {}
+
+ ~FunctionBufferingResourceGetNextOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done);
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
+ done);
+
+ if (buffer->Finished()) {
+ buffer->Unref();
+ ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
+ done();
+ return;
+ }
+
+ FunctionBufferCallback callback =
+ [ctx, buffer, done](const BufferElement& buffer_element) {
+ Status s = buffer_element.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ buffer->Unref();
+ done();
+ return;
+ }
+ for (size_t i = 0; i < buffer_element.value.size(); ++i) {
+ ctx->set_output(i, buffer_element.value[i]);
+ }
+ buffer->Unref();
+ done();
+ };
+ buffer->MaybeGet(std::move(callback));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Resets the FunctionBufferingResource, cancelling all pending requests and
+// clearing out the buffer.
+class FunctionBufferingResourceResetOp : public OpKernel {
+ public:
+ explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ ~FunctionBufferingResourceResetOp() override {}
+
+ void Compute(OpKernelContext* ctx) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(ctx,
+ HandleFromInput(ctx, "function_buffer_resource", &handle));
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer));
+ core::ScopedUnref s(buffer);
+
+ buffer->Reset();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#endif // TENSORFLOW_USE_SYCL
+
+class IteratorGetDeviceOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* ctx) override {
+ // NOTE(mrry): We do not currently Validate that the handle
+ // corresponds to a real IteratorResource, because that symbol is
+ // not exposed from the framework library.
+ Tensor* device_name_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &device_name_t));
+ // NOTE(mrry): Since the operation's input is a resource, we must be
+ // colocated with it, and so we can simply return the current device's
+ // name without looking at the input.
+ device_name_t->scalar<string>()() = ctx->device()->name();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
new file mode 100644
index 0000000000..8d561ca0e3
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ThreadPoolResource : public ResourceBase {
+ public:
+ ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
+ const string& name, int num_threads, bool low_latency_hint,
+ int max_intra_op_parallelism)
+ : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
+ max_intra_op_parallelism_(max_intra_op_parallelism) {}
+
+ // Schedules fn() for execution in the pool of threads.
+ void Schedule(std::function<void()> fn) {
+ if (max_intra_op_parallelism_ < 0) {
+ thread_pool_.Schedule(std::move(fn));
+ } else {
+ thread_pool_.Schedule(std::bind(
+ [this](std::function<void()> bound_fn) {
+ // TODO(mrry): Consider moving this thread-local configuration to
+ // the threads themselves.
+ ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
+ bound_fn();
+ },
+ std::move(fn)));
+ }
+ }
+
+ string DebugString() override { return "ThreadPoolResource"; }
+
+ private:
+ thread::ThreadPool thread_pool_;
+ const int max_intra_op_parallelism_;
+};
+
+// Creates a handle to a ThreadPool resource. Note that we don't use
+// ResourceOpKernel here because the ThreadPoolResource constructor requires
+// access to `OpKernelContext::env()`, which isn't provided by
+// `ResourceOpKernel<T>::CreateResource()`.
+class ThreadPoolHandleOp : public OpKernel {
+ public:
+ explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
+ &max_intra_op_parallelism_));
+ OP_REQUIRES(
+ ctx, num_threads_ > 0,
+ errors::InvalidArgument("`num_threads` must be greater than zero."));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel. Ideally the resource should be deleted when it is no longer held
+ // by anyone, but it would break backward compatibility.
+ ~ThreadPoolHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ ThreadPoolResource* resource;
+ OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](ThreadPoolResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new ThreadPoolResource(
+ ctx->env(), {}, display_name_,
+ num_threads_, max_intra_op_parallelism_,
+ false /* low_latency_hint */);
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<ThreadPoolResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ string display_name_;
+ int num_threads_;
+ int max_intra_op_parallelism_;
+};
+
+class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ ThreadPoolResource* threadpool_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &threadpool_resource));
+ core::ScopedUnref unref_iterator(threadpool_resource);
+
+ *output = new Dataset(ctx, input, threadpool_resource);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ ThreadPoolResource* threadpool)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ threadpool_(threadpool) {
+ input_->Ref();
+ threadpool_->Ref();
+ }
+
+ ~Dataset() override {
+ input_->Unref();
+ threadpool_->Unref();
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "ThreadPoolDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ ThreadPoolResource* pool = dataset()->threadpool_;
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = [pool](std::function<void()> c) {
+ pool->Schedule(std::move(c));
+ };
+ params.stats_aggregator = ctx->stats_aggregator();
+ params.lib = ctx->lib();
+ params.function_library = ctx->function_library();
+ params.allocator_getter = ctx->allocator_getter();
+ IteratorContext threadpool_ctx(params);
+ return input_impl_->GetNext(&threadpool_ctx, out_tensors,
+ end_of_sequence);
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* const input_;
+ ThreadPoolResource* const threadpool_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
+ ThreadPoolHandleOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
new file mode 100644
index 0000000000..cd612e0eb2
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -0,0 +1,224 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class UniqueDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit UniqueDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
+ errors::InvalidArgument("UniqueDataset only supports "
+ "inputs with a single component."));
+
+ DataType input_dtype = input->output_dtypes()[0];
+ OP_REQUIRES(ctx,
+ input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
+ input_dtype == DT_STRING,
+ errors::InvalidArgument(
+ "UniqueDataset only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component."));
+
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Unique")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("UniqueDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const typename Iterator::Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ bool saw_new_value;
+ do {
+ saw_new_value = false;
+ out_tensors->clear();
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ break;
+ }
+ DCHECK_EQ(1, out_tensors->size());
+ saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
+ } while (!saw_new_value);
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("unique_elements_size"), unique_elements_.size()));
+ size_t i = 0;
+ for (const Tensor& t : unique_elements_) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("unique_elements[", i++, "]")), t));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ int64 num_unique_elements;
+ unique_elements_.clear();
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
+ &num_unique_elements));
+ for (int64 i = 0; i < num_unique_elements; ++i) {
+ Tensor unique_element;
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("unique_elements[", i, "]")),
+ &unique_element));
+ auto insert_result = unique_elements_.insert(unique_element);
+ if (!insert_result.second) {
+ return errors::InvalidArgument(
+ "Checkpoint contained two unique elements with the same "
+ "value.");
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ struct TensorHash {
+ size_t operator()(const Tensor& t) const {
+ if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
+ return Hash64(t.tensor_data().data(), t.tensor_data().size());
+ } else {
+ DCHECK_EQ(DT_STRING, t.dtype());
+ auto flat_t = t.flat<string>();
+ uint64 hash = 0;
+ for (int64 i = 0; i < t.NumElements(); ++i) {
+ hash = Hash64Combine(hash, Hash64(flat_t(i)));
+ }
+ return static_cast<size_t>(hash);
+ }
+ }
+ };
+
+ struct TensorKeyEqual {
+ bool operator()(const Tensor& lhs, const Tensor& rhs) const {
+ if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
+ return false;
+ }
+ switch (lhs.dtype()) {
+#define HANDLE_TYPE(T) \
+ case T: \
+ do { \
+ auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
+ auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
+ for (int64 i = 0; i < lhs.NumElements(); ++i) { \
+ if (lhs_flat(i) != rhs_flat(i)) { \
+ return false; \
+ } \
+ } \
+ return true; \
+ } while (0)
+
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_INT64);
+ HANDLE_TYPE(DT_STRING);
+ default:
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
+ }
+ }
+ };
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
+ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
+ UniqueDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index 00884314a9..be7d182a1f 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -18,9 +18,11 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -31,67 +33,84 @@ namespace {
class FilterDatasetOp : public UnaryDatasetOpKernel {
public:
+ using FilterIteratorPredicate =
+ std::function<Status(IteratorContext*, std::vector<Tensor>, bool*)>;
+
explicit FilterDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("predicate", &func_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- FunctionLibraryRuntime::Handle pred_handle;
- OP_REQUIRES_OK(ctx,
- ctx->function_library()->Instantiate(
- func_.name(), AttrSlice(&func_.attr()), &pred_handle));
- auto cleanup = gtl::MakeCleanup([ctx, pred_handle]() {
- OP_REQUIRES_OK(ctx, ctx->function_library()->ReleaseHandle(pred_handle));
- });
-
- const FunctionBody* pred_body =
- ctx->function_library()->GetFunctionBody(pred_handle);
- OP_REQUIRES(ctx, pred_body->ret_nodes.size() == 1,
- errors::InvalidArgument(
- "predicate function must have a single return value."));
- Node* ret_node = pred_body->ret_nodes[0];
- Node* ret_input_node;
- OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
-
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- if (ret_input_node->def().op() == "_Arg") {
- int32 index = -1;
- OP_REQUIRES_OK(ctx, GetNodeAttr(ret_input_node->def(), "index", &index));
- *output = new FilterTensorDataset(ctx, input, func_,
- std::move(captured_func), index);
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+ OP_REQUIRES(ctx, indices.size() <= 1,
+ errors::InvalidArgument(
+ "predicate function has more than one return value."));
+
+ FilterIteratorPredicate filter_pred;
+ if (indices.empty()) {
+ CapturedFunction* raw_captured_func = captured_func.get();
+ filter_pred = [raw_captured_func](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ std::vector<Tensor> result;
+ TF_RETURN_IF_ERROR(
+ raw_captured_func->RunWithBorrowedArgs(ctx, args, &result));
+
+ if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
+ result[0].NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = result[0].scalar<bool>()();
+ return Status::OK();
+ };
} else {
- *output = new FilterFunctionDataset(ctx, input, func_,
- std::move(captured_func));
+ filter_pred = [indices](IteratorContext* ctx,
+ const std::vector<Tensor>& args,
+ bool* out_matched) {
+ const Tensor& predicate = args[indices[0]];
+ if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
+ return errors::InvalidArgument(
+ "Filter predicate `f` must return a scalar bool.");
+ }
+ *out_matched = predicate.scalar<bool>()();
+ return Status::OK();
+ };
}
+
+ *output = new Dataset(ctx, input, func_, std::move(captured_func),
+ std::move(filter_pred));
}
private:
- const int graph_def_version_;
-
- class FilterDatasetBase : public DatasetBase {
+ class Dataset : public DatasetBase {
public:
- FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func)
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func,
+ FilterIteratorPredicate filter_pred)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ filter_pred_(std::move(filter_pred)) {
input_->Ref();
}
- ~FilterDatasetBase() override { input_->Unref(); }
+ ~Dataset() override { input_->Unref(); }
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Filter")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Filter")},
+ filter_pred_);
}
const DataTypeVector& output_dtypes() const override {
@@ -133,17 +152,15 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- virtual Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const = 0;
-
private:
- class Iterator : public DatasetIterator<FilterDatasetBase> {
+ class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params),
+ explicit Iterator(const Params& params,
+ FilterIteratorPredicate filter_pred)
+ : DatasetIterator<Dataset>(params),
filtered_elements_(0),
- dropped_elements_(0) {
+ dropped_elements_(0),
+ filter_pred_(std::move(filter_pred)) {
std::vector<string> components =
str_util::Split(params.prefix, "::", str_util::SkipEmpty());
prefix_end_ = components.back();
@@ -180,8 +197,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- TF_RETURN_IF_ERROR(
- dataset()->EvaluatePredicate(ctx, *out_tensors, &matched));
+ TF_RETURN_IF_ERROR(filter_pred_(ctx, *out_tensors, &matched));
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
@@ -251,64 +267,14 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
int64 filtered_elements_ GUARDED_BY(mu_);
int64 dropped_elements_ GUARDED_BY(mu_);
+ const FilterIteratorPredicate filter_pred_;
string prefix_end_;
};
const DatasetBase* const input_;
const NameAttrList func_;
-
- protected:
const std::unique_ptr<CapturedFunction> captured_func_;
- };
-
- class FilterFunctionDataset : public FilterDatasetBase {
- public:
- using FilterDatasetBase::FilterDatasetBase;
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- std::vector<Tensor> result;
- TF_RETURN_IF_ERROR(
- captured_func_->RunWithBorrowedArgs(ctx, element, &result));
-
- if (result.size() != 1 || result[0].dtype() != DT_BOOL ||
- result[0].NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = result[0].scalar<bool>()();
- return Status::OK();
- }
- };
-
- class FilterTensorDataset : public FilterDatasetBase {
- public:
- FilterTensorDataset(OpKernelContext* ctx, const DatasetBase* input,
- const NameAttrList& func,
- std::unique_ptr<CapturedFunction> captured_func,
- int32 index)
- : FilterDatasetBase(ctx, input, func, std::move(captured_func)),
- index_(index) {}
-
- protected:
- Status EvaluatePredicate(IteratorContext* ctx,
- const std::vector<Tensor>& element,
- bool* out_matched) const override {
- const Tensor& predicate = element[index_];
- if (predicate.dtype() != DT_BOOL || predicate.NumElements() != 1) {
- return errors::InvalidArgument(
- "Filter predicate `f` must return a scalar bool.");
- }
- *out_matched = predicate.scalar<bool>()();
- return Status::OK();
- }
-
- private:
- const int32 index_;
+ const FilterIteratorPredicate filter_pred_;
};
private:
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index 71a36314a0..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -86,8 +86,6 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
- TF_RETURN_IF_ERROR(
- dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
return Status::OK();
}
@@ -96,6 +94,12 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
bool* end_of_sequence) override {
mutex_lock l(mu_);
+ if (!initialized_) {
+ TF_RETURN_IF_ERROR(
+ dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
+ initialized_ = true;
+ }
+
if (finalized_) {
*end_of_sequence = true;
return Status::OK();
@@ -123,6 +127,7 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
private:
mutex mu_;
+ bool initialized_ GUARDED_BY(mu_) = false;
bool finalized_ GUARDED_BY(mu_) = false;
std::vector<Tensor> state_ GUARDED_BY(mu_);
};
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index d6ee42a7c6..e7244ee208 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -30,8 +30,7 @@ namespace {
class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
@@ -421,7 +420,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 8b417bb1c2..14aefe5d54 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -31,8 +31,7 @@ namespace {
class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
@@ -507,7 +506,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index c0bc507ec0..7a833668ac 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -659,6 +659,115 @@ class ToSingleElementOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class ReduceDatasetOp : public AsyncOpKernel {
+ public:
+ explicit ReduceDatasetOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule([this, ctx, done]() {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
+ OpInputList inputs;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
+ done);
+ std::vector<Tensor> state(inputs.begin(), inputs.end());
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
+ use_inter_op_parallelism_, &captured_func),
+ done);
+
+ IteratorContext iter_ctx(ctx);
+ OP_REQUIRES_OK_ASYNC(ctx, captured_func->Instantiate(&iter_ctx), done);
+
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
+ done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
+
+ // Iterate through the input dataset.
+ Status status;
+ while (true) {
+ std::vector<Tensor> next_input_element;
+ bool end_of_input;
+ status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
+ &end_of_input);
+ if (!status.ok() || end_of_input) {
+ break;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(state.size() + next_input_element.size());
+ std::copy(state.begin(), state.end(), std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ status =
+ captured_func->Run(&iter_ctx, std::move(args), &reduce_func_output);
+ if (!status.ok()) {
+ break;
+ }
+ std::swap(reduce_func_output, state);
+ }
+
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ for (int i = 0; i < state.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, state[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The result does not match the expected type for component ", i,
+ ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(state[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
+ errors::InvalidArgument(
+ "The result does not match the expected shape for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", state[i].shape().DebugString(), "."),
+ done);
+ ctx->set_output(i, state[i]);
+ }
+ });
+ }
+
+ private:
+ NameAttrList reduce_func_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
+ BackgroundWorker background_worker_;
+};
+
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -1146,6 +1255,8 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
+REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
+ ReduceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 2bbf4af664..f45a239793 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/inplace_ops_functor.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
@@ -29,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -37,8 +39,14 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapAndBatchIteratorFunction =
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
+ std::shared_ptr<std::vector<Tensor>>, StatusCallback)>;
+
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
@@ -89,31 +97,73 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
&captured_func));
- *output = new Dataset(ctx, input, batch_size, num_parallel_calls,
- drop_remainder, output_types_, output_shapes_, func_,
- std::move(captured_func), &ctx->eigen_cpu_device());
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapAndBatchIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors.get(),
+ std::move(done), prefix);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::shared_ptr<std::vector<Tensor>> out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
+ *output = new Dataset(ctx, input, func_, batch_size, num_parallel_calls,
+ drop_remainder, output_types_, output_shapes_,
+ std::move(captured_func), &ctx->eigen_cpu_device(),
+ std::move(map_func));
}
private:
class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
- const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
- const Eigen::ThreadPoolDevice* device)
+ const Eigen::ThreadPoolDevice* device,
+ MapAndBatchIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
+ func_(func),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
drop_remainder_(drop_remainder),
output_types_(output_types),
output_shapes_(output_shapes),
- map_fn_(func),
captured_func_(std::move(captured_func)),
- device_(device) {
+ device_(device),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -121,8 +171,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::MapAndBatch")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::MapAndBatch")},
+ map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -141,7 +192,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -163,7 +214,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
other_arguments_types.emplace_back(t.dtype());
}
AttrValue f;
- b->BuildAttrValue(map_fn_, &f);
+ b->BuildAttrValue(func_, &f);
AttrValue other_arguments_types_attr;
b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
@@ -183,31 +234,35 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
+ explicit Iterator(const Params& params,
+ MapAndBatchIteratorFunction map_func)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
+ map_func_(std::move(map_func)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- port::NumSchedulableCPUs() /* max */, &cond_var_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
@@ -219,27 +274,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
std::shared_ptr<BatchResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -255,7 +310,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
@@ -293,55 +348,17 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_calls; // access guarded by owner's mutex
};
- void Callback(const std::shared_ptr<IteratorContext>& ctx,
- const std::shared_ptr<BatchResult>& result,
- const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
- result->UpdateStatus(status);
- if (status.ok()) {
- EnsureOutputAllocated(ctx, result, return_values);
- for (size_t i = 0; i < return_values->size(); ++i) {
- const Tensor& tensor = return_values->at(i);
- Tensor* batch = &(result->output)[i];
- if (tensor.NumElements() !=
- (batch->NumElements() / batch->dim_size(0))) {
- TensorShape batch_shape = batch->shape();
- batch_shape.RemoveDim(0);
- result->UpdateStatus(errors::InvalidArgument(
- "Cannot add tensor to the batch: number of elements does not "
- "match. Shapes are: [tensor]: ",
- tensor.shape().DebugString(),
- ", [batch]: ", batch_shape.DebugString()));
- break;
- }
- // TODO(mrry): Add a version of DoParallelConcat that allows us to
- // move `tensor` where possible, to speed up string tensor batching.
- Status copy_status = ::tensorflow::functor::DoParallelConcat(
- *dataset()->device_, tensor, offset, batch);
- if (!copy_status.ok()) {
- result->UpdateStatus(copy_status);
- break;
- }
- }
- {
- mutex_lock l(result->mu);
- result->num_elements++;
- }
- }
- CallCompleted(result);
- }
-
void CallCompleted(const std::shared_ptr<BatchResult>& result)
- LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ LOCKS_EXCLUDED(*mu_) {
+ mutex_lock l(*mu_);
num_calls_--;
result->num_calls--;
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result,
- int64 offset) LOCKS_EXCLUDED(mu_) {
+ int64 offset) LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
@@ -359,21 +376,48 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
- // Call `captured_func_(input_element)`, using `Callback` to store the
- // result in `result`.
- (*ctx->runner())(std::bind(
- [this, result, offset](std::shared_ptr<IteratorContext> ctx,
- std::vector<Tensor> input_element) {
- std::shared_ptr<std::vector<Tensor>> return_values(
- new std::vector<Tensor>());
- dataset()->captured_func_->RunAsync(
- ctx.get(), std::move(input_element), return_values.get(),
- [this, ctx, result, return_values, offset](Status status) {
- Callback(ctx, result, return_values, offset, status);
- },
- prefix());
- },
- ctx, std::move(input_element)));
+ std::shared_ptr<std::vector<Tensor>> return_values =
+ std::make_shared<std::vector<Tensor>>();
+ auto done = [this, ctx, result, return_values, offset](Status status) {
+ result->UpdateStatus(status);
+ if (status.ok()) {
+ EnsureOutputAllocated(ctx, result, return_values);
+ for (size_t i = 0; i < return_values->size(); ++i) {
+ const Tensor& tensor = return_values->at(i);
+ Tensor* batch = &(result->output)[i];
+ if (tensor.NumElements() !=
+ (batch->NumElements() / batch->dim_size(0))) {
+ TensorShape batch_shape = batch->shape();
+ batch_shape.RemoveDim(0);
+ result->UpdateStatus(errors::InvalidArgument(
+ "Cannot add tensor to the batch: number of elements does "
+ "not match. Shapes are: [tensor]: ",
+ tensor.shape().DebugString(),
+ ", [batch]: ", batch_shape.DebugString()));
+ break;
+ }
+ // TODO(mrry): Add a version of DoParallelConcat that allows us to
+ // move `tensor` where possible, to speed up string tensor
+ // batching.
+ Status copy_status = ::tensorflow::functor::DoParallelConcat(
+ *dataset()->device_, tensor, offset, batch);
+ if (!copy_status.ok()) {
+ result->UpdateStatus(copy_status);
+ break;
+ }
+ }
+ {
+ mutex_lock l(result->mu);
+ result->num_elements++;
+ }
+ }
+ CallCompleted(result);
+ };
+
+ // Apply the map function on `input_element`, storing the result in
+ // `return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ std::move(return_values), std::move(done));
}
Status CopyPartialBatch(Tensor* output, const Tensor& value,
@@ -398,9 +442,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&Iterator::RunnerThread, this, ctx_copy)));
@@ -474,14 +518,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
RecordStart(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
- new_calls.reserve(num_parallel_calls_);
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
- int64 num_parallel_calls = num_parallel_calls_;
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
int64 max_batch_results =
(num_parallel_calls + dataset()->batch_size_ - 1) /
dataset()->batch_size_;
@@ -492,10 +536,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
};
while (true) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -505,8 +549,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
- batch_results_.emplace_back(
- new BatchResult(dataset()->batch_size_));
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
}
int64 offset = call_counter_++ % dataset()->batch_size_;
new_calls.emplace_back(batch_results_.back(), offset);
@@ -522,8 +566,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
- size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
+ size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ batch_results_.push_back(
+ std::make_shared<BatchResult>(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -567,7 +612,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadStatus(IteratorStateReader* reader, const string& prefix,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
@@ -585,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -626,7 +671,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
static_cast<int64>(status.code())));
@@ -640,24 +685,26 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the callback threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
- // schedule new calls when the number of in-flight calls is less than the
- // user specified level of parallelism and there are slots available in
- // the `batch_results_` buffer.
- condition_variable cond_var_;
+ // schedule new calls when the number of in-flight calls is less than
+ // `num_parallel_calls_->value` and there are slots available in the
+ // `batch_results_` buffer.
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
+ const MapAndBatchIteratorFunction map_func_;
+
// Counts the number of outstanding calls for this batch.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
- int64 call_counter_ GUARDED_BY(mu_) = 0;
+ int64 call_counter_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
- std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
@@ -667,9 +714,9 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const bool drop_remainder_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
- const NameAttrList map_fn_;
const std::unique_ptr<CapturedFunction> captured_func_;
const Eigen::ThreadPoolDevice* device_; // not owned
+ const MapAndBatchIteratorFunction map_func_;
};
const int op_version_;
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index f112e1dc43..6b6ffabf4f 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
@@ -28,6 +30,9 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
+ using MapIteratorFunction = std::function<Status(
+ IteratorContext*, std::vector<Tensor>, std::vector<Tensor>*)>;
+
explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -43,8 +48,42 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ MapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ return raw_captured_func->Run(ctx, std::move(args), out_tensors);
+ };
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ return Status::OK();
+ };
+ }
+
*output = new Dataset(ctx, input, func_, std::move(captured_func),
- output_types_, output_shapes_);
+ output_types_, output_shapes_, std::move(map_func));
}
private:
@@ -54,13 +93,15 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
- const std::vector<PartialTensorShape>& output_shapes)
+ const std::vector<PartialTensorShape>& output_shapes,
+ MapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
output_types_(output_types),
- output_shapes_(output_shapes) {
+ output_shapes_(output_shapes),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -68,8 +109,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(
- new Iterator({this, strings::StrCat(prefix, "::Map")}));
+ return MakeUnique<Iterator>(
+ Iterator::Params{this, strings::StrCat(prefix, "::Map")}, map_func_);
}
const DataTypeVector& output_dtypes() const override {
@@ -116,8 +157,8 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
class Iterator : public DatasetIterator<Dataset> {
public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ explicit Iterator(const Params& params, MapIteratorFunction map_func)
+ : DatasetIterator<Dataset>(params), map_func_(std::move(map_func)) {}
Status Initialize(IteratorContext* ctx) override {
TF_RETURN_IF_ERROR(
@@ -139,10 +180,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
- // TODO(mrry): Avoid blocking a threadpool thread. We will need to
- // stack-rip the iterators and use async kernels.
- Status s =
- dataset()->captured_func_->Run(ctx, std::move(args), out_tensors);
+ Status s = map_func_(ctx, args, out_tensors);
if (errors::IsOutOfRange(s)) {
// `f` may deliberately raise `errors::OutOfRange` to indicate
// that we should terminate the iteration early.
@@ -167,6 +205,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
private:
std::unique_ptr<IteratorBase> input_impl_;
+ const MapIteratorFunction map_func_;
};
const DatasetBase* const input_;
@@ -174,6 +213,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::unique_ptr<CapturedFunction> captured_func_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const MapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index 6657f2b2b3..705b0393de 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -62,24 +62,6 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
- Status GetInputBatchSize(OpKernelContext* ctx, int64* batch_size) {
- // Validates inputs and gets the size of their leading dimension.
- *batch_size = ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
- return errors::InvalidArgument(
- "All inputs must have rank at least 1. Input ", i,
- " has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != *batch_size) {
- return errors::InvalidArgument(
- "All inputs must have the same dimension 0. Input ", i,
- " has leading dimension ", ctx->input(i).dim_size(0),
- ", while all previous inputs have leading dimension ", batch_size);
- }
- }
- return Status::OK();
- }
-
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
ComputeOptions* compute_opts = nullptr;
@@ -150,8 +132,9 @@ class MapDefunOp : public AsyncOpKernel {
// all calls to the function are complete. This struct also encapsulates
// all the components that need to be passed to each MapFunctionCallFrame.
- const std::vector<Tensor> args;
+ OpInputList args;
const std::vector<TensorShape> arg_shapes;
+ OpInputList captured_inputs;
const int64 batch_size;
// Output of a compute call
@@ -161,26 +144,31 @@ class MapDefunOp : public AsyncOpKernel {
// Create a copy of output_shapes because every `Compute` may expect a
// different output shape.
- ComputeOptions(std::vector<Tensor> args,
+ ComputeOptions(OpInputList args, OpInputList captured_inputs,
std::vector<TensorShape> arg_shapes, int64 batch_size,
const std::vector<PartialTensorShape>& output_shapes_attr)
- : args(std::move(args)),
+ : args(args),
arg_shapes(std::move(arg_shapes)),
+ captured_inputs(captured_inputs),
batch_size(batch_size),
output_shapes(output_shapes_attr) {}
};
// Get inputs to Compute and check that they are valid.
Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
- int64 batch_size =
- ctx->input(0).dims() > 0 ? ctx->input(0).dim_size(0) : -1;
+ OpInputList arguments;
+ TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
+ OpInputList captured_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
+
+ int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- if (ctx->input(i).dims() == 0) {
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ if (arguments[i].dims() == 0) {
return errors::InvalidArgument(
"All inputs must have rank at least 1. Input ", i,
" has a rank of 0.");
- } else if (ctx->input(i).dim_size(0) != batch_size) {
+ } else if (arguments[i].dim_size(0) != batch_size) {
return errors::InvalidArgument(
"All inputs must have the same dimension 0. Input ", i,
" has leading dimension ", ctx->input(i).dim_size(0),
@@ -188,19 +176,17 @@ class MapDefunOp : public AsyncOpKernel {
}
}
- std::vector<Tensor> args;
std::vector<TensorShape> arg_shapes;
- args.reserve(ctx->num_inputs());
- arg_shapes.reserve(ctx->num_inputs());
+ arg_shapes.reserve(arguments.size());
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args.push_back(ctx->input(i));
- arg_shapes.push_back(ctx->input(i).shape());
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ arg_shapes.push_back(arguments[i].shape());
arg_shapes.at(i).RemoveDim(0);
}
- *compute_opts = new ComputeOptions(std::move(args), std::move(arg_shapes),
- batch_size, output_shapes_);
+ *compute_opts =
+ new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
+ batch_size, output_shapes_);
return Status::OK();
}
@@ -235,12 +221,21 @@ class MapDefunOp : public AsyncOpKernel {
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= compute_opts_->args.size()) {
+ if (index < 0 || index >= compute_opts_->args.size() +
+ compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
+
+ if (index >= compute_opts_->args.size()) {
+ // The function is calling for a captured input
+ *val =
+ compute_opts_->captured_inputs[index - compute_opts_->args.size()];
+ return Status::OK();
+ }
+
bool result =
- val->CopyFrom(compute_opts_->args.at(index).Slice(iter_, iter_ + 1),
+ val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
@@ -248,7 +243,6 @@ class MapDefunOp : public AsyncOpKernel {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
-
return Status::OK();
}
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 5f143967d9..d909b9e9d3 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase {
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
}
}
RunPendingCallbacks();
@@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase {
buffer_[shard_num].cond_var.notify_all();
}
} else {
- if (background_thread_finished_) {
+ if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
@@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
@@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
+ end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
@@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase {
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index d5b725eac9..1cb7caa738 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -154,12 +154,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ IteratorContext::Params params = ctx->params();
params.lib = dataset()->lib_;
- params.allocator_getter = ctx->allocator_getter();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(params), prefix(), &input_impl_);
}
@@ -167,14 +163,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ IteratorContext::Params params = ctx->params();
params.lib = dataset()->lib_;
- params.allocator_getter = ctx->allocator_getter();
- IteratorContext iter_ctx(params);
- return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ return input_impl_->GetNext(IteratorContext(params), out_tensors,
+ end_of_sequence);
}
protected:
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 2e6e0465f7..6b6b3d6ab9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
@@ -1214,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1224,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
false /* low_latency_hint */)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- dataset()->cycle_length_ /* max */, &cond_var_);
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ dataset()->cycle_length_);
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
@@ -1256,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
std::shared_ptr<InvocationResult> result;
do {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
if (!invocation_results_.empty()) {
@@ -1271,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -1287,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -1328,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
@@ -1381,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -1398,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void FetchOutputs(
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
@@ -1421,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
if (end_of_input) {
current_elements_[cycle_index].reset();
}
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
element_in_use_[cycle_index] = false;
num_calls_--;
if (end_of_input) {
args_list_[cycle_index].clear();
num_open_--;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
// Method responsible for 1) creating iterators out of input elements, 2)
@@ -1439,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return element_in_use_[cycle_index_] ||
- num_calls_ >= num_parallel_calls_ ||
+ num_calls_ >= num_parallel_calls_->value ||
invocation_results_.size() >=
dataset()->cycle_length_ * dataset()->block_length_;
};
while (true) {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -1506,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -1523,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -1550,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status WriteCurrentElements(IteratorStateWriter* writer)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
@@ -1569,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
@@ -1597,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the worker threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the worker threads. In particular, the runner thread should only
@@ -1605,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// user specified level of parallelism, there are slots available in the
// `invocation_results_` buffer, the current cycle element is not in use,
// and there are elements left to be fetched.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Iterator for input elements.
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
// Identifies current cycle element.
int64 cycle_index_ = 0;
// Arguments for creating an iterator for cycle elements.
- std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_);
// Iterators for the current cycle elements. Concurrent access is
// protected by `element_in_use_`.
std::vector<std::unique_ptr<IteratorBase>> current_elements_;
// Identifies cycle elements that are in use by worker threads.
- std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+ std::vector<bool> element_in_use_ GUARDED_BY(*mu_);
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
+ GUARDED_BY(*mu_);
// Identifies whether end of input has been reached.
- bool end_of_input_ GUARDED_BY(mu_) = false;
+ bool end_of_input_ GUARDED_BY(*mu_) = false;
// Identifies the number of open iterators.
- int64 num_open_ GUARDED_BY(mu_) = 0;
+ int64 num_open_ GUARDED_BY(*mu_) = 0;
// Identifies the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
- bool cancelled_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index 6abe6c8338..3a14924fba 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
#include "tensorflow/core/lib/random/random.h"
@@ -56,9 +57,55 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
use_inter_op_parallelism_,
&captured_func));
+ std::vector<int> indices;
+ OP_REQUIRES_OK(ctx, ComputeShortCircuitIndices(ctx, func_, &indices));
+
+ ParallelMapIteratorFunction map_func;
+ CapturedFunction* raw_captured_func = captured_func.get();
+ if (indices.empty()) {
+ map_func = [raw_captured_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ raw_captured_func->RunAsync(ctx, std::move(args), out_tensors,
+ std::move(done), prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args,
+ std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, prefix, std::move(args),
+ out_tensors, std::move(done)));
+ };
+ }
+ } else {
+ std::vector<bool> can_move = ComputeMoveVector(indices);
+ map_func = [raw_captured_func, indices, can_move](
+ IteratorContext* ctx, const string& prefix,
+ std::vector<Tensor> args, std::vector<Tensor>* out_tensors,
+ StatusCallback done) {
+ const std::vector<Tensor>& captured_inputs =
+ raw_captured_func->captured_inputs();
+ size_t num_args = args.size();
+ for (size_t i = 0; i < indices.size(); ++i) {
+ if (indices[i] < num_args) {
+ if (can_move[i]) {
+ out_tensors->push_back(std::move(args[indices[i]]));
+ } else {
+ out_tensors->push_back(args[indices[i]]);
+ }
+ } else {
+ out_tensors->push_back(captured_inputs[indices[i] - num_args]);
+ }
+ }
+ done(Status::OK());
+ };
+ }
+
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
output_shapes_, use_inter_op_parallelism_,
- std::move(captured_func));
+ std::move(captured_func), std::move(map_func));
}
private:
@@ -69,7 +116,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
bool use_inter_op_parallelism,
- std::unique_ptr<CapturedFunction> captured_func)
+ std::unique_ptr<CapturedFunction> captured_func,
+ ParallelMapIteratorFunction map_func)
: DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
@@ -77,7 +125,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
output_types_(output_types),
output_shapes_(output_shapes),
use_inter_op_parallelism_(use_inter_op_parallelism),
- captured_func_(std::move(captured_func)) {
+ captured_func_(std::move(captured_func)),
+ map_func_(std::move(map_func)) {
input_->Ref();
}
@@ -89,26 +138,9 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
return captured_func_->Instantiate(ctx);
};
- const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
- ParallelMapIteratorFunction map_func =
- [this, new_prefix](IteratorContext* ctx,
- std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done), new_prefix);
- };
- if (!use_inter_op_parallelism_) {
- map_func = [map_func](
- IteratorContext* ctx, std::vector<Tensor> input_element,
- std::vector<Tensor>* result, StatusCallback done) {
- (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
- result, std::move(done)));
- };
- }
-
- return NewParallelMapIterator({this, new_prefix}, input_,
- std::move(init_func), std::move(map_func),
- num_parallel_calls_);
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
+ std::move(init_func), map_func_, num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -176,6 +208,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
+ const ParallelMapIteratorFunction map_func_;
};
DataTypeVector output_types_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ee20249bfe..ebf41925c9 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -22,11 +22,14 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace data {
namespace {
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
@@ -38,30 +41,32 @@ class ParallelMapIterator : public DatasetBaseIterator {
input_dataset_(input_dataset),
init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
- num_parallel_calls_(num_parallel_calls) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ num_parallel_calls, mu_, cond_var_)) {}
~ParallelMapIterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
- AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
- 1 /* min */, port::NumSchedulableCPUs() /* max */,
- &cond_var_);
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
@@ -75,16 +80,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -94,28 +99,27 @@ class ParallelMapIterator : public DatasetBaseIterator {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ const auto& result = *(invocation_results_[i]);
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(
- writer->WriteTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
+ result.return_values.size()));
+ for (size_t j = 0; j < result.return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result.return_values[j]));
}
- if (result->end_of_input) {
+ if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
@@ -127,15 +131,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
+ auto& result = *invocation_results_.back();
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
@@ -151,17 +155,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
": ", size, " is not a valid value of type size_t."));
}
}
- result->return_values.reserve(num_return_values);
+ result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ result.return_values.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ &result.return_values.back()));
}
- result->end_of_input = reader->Contains(full_name(
+ result.end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
+ result.notification.Notify();
}
return Status::OK();
}
@@ -175,9 +178,9 @@ class ParallelMapIterator : public DatasetBaseIterator {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
- std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
+ auto ctx_copy = std::make_shared<IteratorContext>(*ctx);
runner_thread_.reset(ctx->env()->StartThread(
{}, "runner_thread",
std::bind(&ParallelMapIterator::RunnerThread, this, ctx_copy)));
@@ -185,18 +188,18 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void CallCompleted(const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
num_calls_--;
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
result->notification.Notify();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
@@ -206,15 +209,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in `result->return_values`,
- // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
};
- map_func_(ctx.get(), std::move(input_element), &result->return_values,
- std::move(done));
+ // Apply the map function on `input_element`, storing the result in
+ // `result->return_values`, and invoking `done` when finished.
+ map_func_(ctx.get(), prefix(), std::move(input_element),
+ &result->return_values, std::move(done));
}
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
@@ -239,29 +242,29 @@ class ParallelMapIterator : public DatasetBaseIterator {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_);
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
- int64 num_parallel_calls = num_parallel_calls_;
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
- invocation_results_.emplace_back(new InvocationResult());
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
@@ -271,7 +274,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -282,7 +286,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -312,23 +316,23 @@ class ParallelMapIterator : public DatasetBaseIterator {
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
// Used for coordination between the main thread and the runner thread.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
} // namespace
@@ -346,9 +350,9 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBase* input_dataset,
std::function<Status(IteratorContext*)> init_func,
ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(
- new ParallelMapIterator(params, input_dataset, std::move(init_func),
- std::move(map_func), num_parallel_calls));
+ return MakeUnique<ParallelMapIterator>(
+ params, input_dataset, std::move(init_func), std::move(map_func),
+ num_parallel_calls);
}
} // namespace data
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index dc26c5cf25..813f13c9e4 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -30,7 +30,7 @@ namespace data {
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
using ParallelMapIteratorFunction =
- std::function<void(IteratorContext*, std::vector<Tensor>,
+ std::function<void(IteratorContext*, const string&, std::vector<Tensor>,
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
index c28c06da62..7de5ea8860 100644
--- a/tensorflow/core/kernels/data/parse_example_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -182,7 +182,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_fn = [this](IteratorContext* ctx,
+ auto map_fn = [this](IteratorContext* ctx, const string& prefix,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
(*ctx->runner())([this, ctx, input_element, result, done]() {
@@ -253,7 +253,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
for (example::PerExampleFeatureStats feature_stats :
example_result.feature_stats) {
stats_aggregator->AddToHistogram(
- strings::StrCat("record_stats", ":features"),
+ "features",
{static_cast<double>(feature_stats.features_count)});
stats_aggregator->IncrementCounter(
"features_count", "trainer", feature_stats.features_count);
@@ -261,7 +261,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
"feature_values_count", "trainer",
feature_stats.feature_values_count);
stats_aggregator->AddToHistogram(
- strings::StrCat("record_stats", ":feature-values"),
+ "feature-values",
{static_cast<double>(feature_stats.feature_values_count)});
}
}
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index dbe31f37b8..2a911aa368 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -32,8 +32,7 @@ namespace {
class ScanDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ScanDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -258,7 +257,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector state_types_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index f5314f7a75..c09a73fff1 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -12,9 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <memory>
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
@@ -22,6 +24,52 @@ namespace tensorflow {
namespace data {
namespace {
+class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
+ public:
+ StatsAggregatorWithTagAndPrefix(
+ std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag,
+ const string& prefix)
+ : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {}
+
+ void AddToHistogram(const string& name,
+ gtl::ArraySlice<double> values) override {
+ if (!tag_.empty()) {
+ wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values);
+ } else {
+ wrapped_->AddToHistogram(name, values);
+ }
+ }
+
+ void AddScalar(const string& name, float value) override {
+ if (!tag_.empty()) {
+ wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value);
+ } else {
+ wrapped_->AddScalar(name, value);
+ }
+ }
+
+ void EncodeToProto(Summary* out_summary) override {
+ wrapped_->EncodeToProto(out_summary);
+ }
+
+ void IncrementCounter(const string& name, const string& label,
+ int64 val) override {
+ if (!prefix_.empty()) {
+ wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label,
+ val);
+ } else {
+ wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label,
+ val);
+ }
+ }
+
+ private:
+ std::shared_ptr<StatsAggregator> wrapped_;
+ string tag_;
+ string prefix_;
+ TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix);
+};
+
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx)
@@ -33,18 +81,28 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
+ string tag;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
- *output = new Dataset(ctx, input, stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
+ tag, prefix);
}
private:
class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
- StatsAggregatorResource* stats_aggregator_resource)
+ const Tensor& resource_handle,
+ StatsAggregatorResource* stats_aggregator_resource,
+ const string& tag, const string& prefix)
: DatasetBase(DatasetContext(ctx)),
input_(input),
- stats_aggregator_resource_(stats_aggregator_resource) {
+ resource_handle_(resource_handle),
+ stats_aggregator_resource_(stats_aggregator_resource),
+ tag_(tag),
+ prefix_(prefix) {
input_->Ref();
stats_aggregator_resource_->Ref();
}
@@ -75,8 +133,18 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* resource_handle_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ Node* tag_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ Node* prefix_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, resource_handle_node, tag_node, prefix_node},
+ output));
+ return Status::OK();
}
private:
@@ -98,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
- params.stats_aggregator_getter = [stats_aggregator_resource]() {
- return stats_aggregator_resource->stats_aggregator();
- };
+ params.stats_aggregator = std::shared_ptr<StatsAggregator>(
+ new StatsAggregatorWithTagAndPrefix(
+ stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
+ dataset()->prefix_));
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
@@ -111,16 +180,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
private:
@@ -129,7 +196,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
+ string tag_;
+ string prefix_;
};
};
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index a7ded67876..2d51467616 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -82,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator {
auto counters_map = get_counters_map();
if (counters_map->find(name) == counters_map->end()) {
counters_map->emplace(
- name, monitoring::Counter<1>::New(
- /*streamz name*/ "/tensorflow/" + name,
- /*streamz description*/
- name + " generated or consumed by the component.",
- /*streamz label name*/ "component_descriptor"));
+ name,
+ monitoring::Counter<1>::New(
+ /*streamz name*/ name,
+ /*streamz description*/
+ strings::StrCat(name, " generated or consumed by the component."),
+ /*streamz label name*/ "component_descriptor"));
}
counters_map->at(name)->GetCell(label)->IncrementBy(val);
}
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 81c432b938..74908994b4 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -41,11 +41,16 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
for (const PartialTensorShape& shape : input->output_shapes()) {
- gtl::InlinedVector<int64, 4> partial_dim_sizes;
- for (int i = 1; i < shape.dims(); ++i) {
- partial_dim_sizes.push_back(shape.dim_size(i));
+ if (!shape.unknown_rank()) {
+ gtl::InlinedVector<int64, 4> partial_dim_sizes;
+ for (int i = 1; i < shape.dims(); ++i) {
+ partial_dim_sizes.push_back(shape.dim_size(i));
+ }
+ shapes_.emplace_back(std::move(partial_dim_sizes));
+ } else {
+ // If the input shape is unknown, the output shape will be unknown.
+ shapes_.emplace_back();
}
- shapes_.emplace_back(std::move(partial_dim_sizes));
}
}
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 42fbf95cd3..28940e0849 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -96,8 +96,6 @@ class DequantizeOp : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and
- // QuantizeAndDequantizeV3 to match this SCALED mode again.
const float scale_factor =
std::numeric_limits<T>::min() == 0
? (max_range / std::numeric_limits<T>::max())
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index c90ad2cfeb..ada1235449 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession {
}
void FuzzImpl(const uint8_t* data, size_t size) final {
+ // We need to be sure that we don't request too many elements (i.e., we
+ // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large
+ // number of elements, up to the limit of the memory available to the OS.
+ // However, due to the tracing done in ASAN, after 2^32 bytes of requested
+ // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's
+ // try parsing the proto here, check that the size (if valid) is below a
+ // maximum threshold (using 2^20 for convenience), and then run the
+ // remainder of the fuzzer testing. Of course, this duplicates some work
+ // but it's better than repeating the investigation whenever Autofuzz
+ // detects another similar OOM.
+ string as_string = string(reinterpret_cast<const char*>(data), size);
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, as_string)) {
+ LOG(WARNING) << "Unable to parse proto of tensor\n";
+ return;
+ }
+ if (!TensorShape::IsValid(proto.tensor_shape())) {
+ LOG(WARNING) << "Invalid tensor shape\n";
+ return;
+ }
+ TensorShape shape(proto.tensor_shape());
+ const int64 num_elements = shape.num_elements();
+ const int64 max_num_elements = 1 << 20;
+ if (num_elements > max_num_elements) {
+ LOG(WARNING) << "Requiring a tensor with too many elements\n";
+ return;
+ }
+
+ // Now we can do the actual fuzz implementation
Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
- input_tensor.scalar<string>()() =
- string(reinterpret_cast<const char*>(data), size);
+ input_tensor.scalar<string>()() = as_string;
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
RunOneInput(input_tensor).IgnoreError();
}
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index 277ee2be02..1c78de253e 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -114,7 +114,7 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// Eigen implementation below is not highly performant. gather_nd_generator
// does not seem to be called in parallel, leading to very poor performance.
// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
@@ -126,12 +126,12 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
const Eigen::array<Eigen::DenseIndex, 1> loc{i};
gather_nd_generator(loc);
}
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab38..4ad390a411 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// MKL does not support half, bfloat16 and int32 types for
// matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif // INTEL_MKL_DNN_ONLY
-#else // INTEL MKL
+#else // INTEL_MKL && ENABLE_MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc3..bc135de11e 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
+#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif // ENABLE_MKL
} // end namespace tensorflow
#endif
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 52157ed5fb..f406ad2ab5 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -853,7 +853,7 @@ class MklConvCustomBackpropFilterOp
// MKL DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env
- // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
convBwdFilterDims, do_not_cache);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index c38c9cc27c..a501ce2c93 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -713,7 +713,7 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
TFPaddingToMklDnnPadding(this->padding_));
// We don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
// includes potentialy large buffers. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 184e0cb003..b332edad0a 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -901,7 +901,7 @@ class MklConvOp : public OpKernel {
// In some cases, primitve descriptor includes potentialy large buffers,
// we don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
// 2. 1x1 convolution with stride != 1
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce32..f4788f4851 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
+#ifdef ENABLE_MKL
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
+#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
new file mode 100644
index 0000000000..d63e14adf6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/array_ops.cc.
+
+#ifdef INTEL_MKL
+#ifndef INTEL_MKL_ML_ONLY
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/util/mkl_util.h"
+
+using mkldnn::stream;
+using mkldnn::view;
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ // tensor must be either int32 or int64
+ DCHECK(false);
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// A version of SharedValidation (slice_op.h) written for input that is in
+// either Mkl layout or Tensorflow layout.
+// A shared code to validate input shapes and check for identity, which is not dependent on the type of T.
+// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.)
+static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const int kInputTensorIndex = 0;
+ const int kInputBeginIndex = 1;
+ const int kInputSizeIndex = 2;
+ const Tensor& input = MklGetInput(context, kInputTensorIndex);
+ const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
+ const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
+
+ MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
+ GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
+ GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
+ GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
+
+ // Begin and size tensors cannot be in MklDnn layout.
+ DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
+ DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
+
+ TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetTfShape()
+ : input.shape();
+ const int input_dims = input_tf_shape.dims();
+
+ OP_REQUIRES(
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input_dims &&
+ size_tensor.NumElements() == input_dims,
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
+ " and ", size_tensor.shape().DebugString(), " instead."));
+
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_tf_shape.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_tf_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_tf_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+// A version of SharedSliceCommonCases function written for input tensor
+// that may be in MklDnn layout or in Tensorflow layout.
+template <typename T>
+static void CheckCommonCasesForMklInputs(OpKernelContext* context,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ bool* done) {
+ bool is_identity = true;
+ *done = false;
+
+ ValidateMklInputs(context, &is_identity, begin, size);
+ if (!context->status().ok()) return;
+
+ const Tensor& input = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ // Mkl metadata tensor in this case can just be forwarded from input to
+ // output.
+ AllocateOutputSetMklShape(context, 0, input_mkl_shape);
+ *done = true;
+ }
+}
+
+// MKL-DNN implementation of Slice
+template <typename Device, typename T>
+class MklDnnSliceOp : public OpKernel {
+ public:
+ explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ ~MklDnnSliceOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ bool done = false;
+
+ CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
+ if (!context->status().ok() || done == true) return;
+
+ // Though MKL-DNN supports more than 8 dimension and
+ // less than 12 dimension tensor.
+ // But we are mimicking functionality of Eigen Slice op for CPU.
+ if (begin.size() >= 8) {
+ OP_REQUIRES(
+ context, false,
+ errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions"));
+ }
+
+ ComputeMklDnnSlice(context, begin, size);
+ }
+
+ private:
+ // Slice op implemented using MKL-DNN APIs.
+ void ComputeMklDnnSlice(OpKernelContext* context,
+ const gtl::InlinedVector<int64, 4>& begin,
+ const gtl::InlinedVector<int64, 4>& size) {
+ try {
+ // MKL-DNN API usage below is guided by description at:
+ // https://github.com/01org/mkl-dnn/issues/69
+ //
+ // Relevant part of the description is copied below:
+ //
+ // Let's say you want to copy a part of memory into another buffer (and
+ // probably change the format). Then your steps are:
+ //
+ // 1. create memory primitive descriptor in_mem_pd and memory primitive
+ // in_mem_p for the entire source data.
+ // 2. create view primitive descriptor in_submem_pd based on in_mem_pd,
+ // initial offsets, and sub-sizes
+ // 3. create memory primitive descriptor out_mem_pd and memory primitive
+ // out_mem_p for the output (the logical sizes should match sub-sizes
+ // used in step 2, but the format might be arbitrary)
+ // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd
+ // and out_mem_pd
+ // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and
+ // out_mem_p.
+ //
+ // Please notice that there is no view primitive. There is only view
+ // primitive descriptor. And the reorder uses source memory as input but
+ // traverses it according to a view in_submem_pd.
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Populate offsets and sizes in memory::dims format based on vector.
+ memory::dims begin_dims = {};
+ begin_dims.resize(begin.size());
+ for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
+ memory::dims size_dims = {};
+ bool empty = false;
+ size_dims.resize(size.size());
+ for (size_t i = 0; i < size.size(); ++i) {
+ size_dims[i] = size[i];
+ if (size_dims[i] == 0) empty = true;
+ }
+
+ Tensor* output_tensor = nullptr;
+ MklDnnShape output_mkl_shape;
+
+ // If no dimension is selected in slice, the result should be empty.
+ // Just return an empty output tensor, and a dummy Mkl-shape tensor.
+ if (empty) { // for empty dims
+ auto shape_to = MklDnnDimsToTFShape(size_dims);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ output_mkl_shape);
+ return;
+ }
+
+ // Step 1 (as per above description) - Create memory for user data.
+ // We use blocked format here to describe input tensor.
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (input_mkl_shape.IsMklTensor()) {
+ auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
+ auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
+ begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
+ size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
+ auto input_md = input_mkl_shape.GetMklLayout();
+ src.SetUsrMem(input_md, &input_tensor);
+ } else {
+ // Initialize input dimensions and strides to be used when input is not
+ // in MklDnn layout.
+ memory::dims input_dims, input_strides;
+ input_dims = TFShapeToMklDnnDims(input_tensor.shape());
+ input_strides = CalculateTFStrides(input_dims);
+ // Create input memory descriptor.
+ auto input_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ src.SetUsrMem(input_md, &input_tensor);
+ }
+
+ // Step 2 - create view primitive descriptor
+ auto view_pd =
+ view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims)
+ .dst_primitive_desc();
+ auto output_strides = CalculateTFStrides(size_dims);
+ auto output_md =
+ MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
+ auto output_pd = memory::primitive_desc(output_md, cpu_engine);
+
+ // Step 3 - Create memory for output. If input is in MklDnn layout, then
+ // output is also in MklDnn layout. Otherwise, output is in Tensorflow
+ // layout.
+ AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
+ &output_tensor, &output_mkl_shape);
+ DCHECK(output_tensor);
+ DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
+ output.SetUsrMem(output_md, output_tensor);
+
+ std::vector<primitive> net;
+ // Step 4 - create reorder primitive desc between view_pd and output_pd.
+ auto reorder_pd =
+ reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc());
+ // Step 5 - create reorder primitive itself.
+ net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem()));
+ // Execute the reorder primitive.
+ stream(stream::kind::eager).submit(net).wait();
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ void AllocateOutputTensor(OpKernelContext* context,
+ const MklDnnShape& input_mkl_shape,
+ memory::primitive_desc* output_pd,
+ const memory::dims& output_dims,
+ Tensor** output_tensor,
+ MklDnnShape* output_mkl_shape) {
+ DCHECK(output_tensor);
+ DCHECK(output_mkl_shape);
+
+ TensorShape output_tf_shape;
+
+ if (input_mkl_shape.IsMklTensor()) {
+ // Since input tensor is in Mkl layout, output tensor will be in Mkl
+ // layout.
+
+ // Allocate shape of Mkl tensor.
+ output_mkl_shape->SetMklTensor(true);
+ output_mkl_shape->SetMklLayout(output_pd);
+ output_mkl_shape->SetElemType(MklDnnType<T>());
+ output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
+ input_mkl_shape.GetTfDataFormat());
+
+ output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1);
+ } else {
+ // If input is not in Mkl layout, then output won't be in Mkl layout.
+ output_mkl_shape->SetMklTensor(false);
+ output_tf_shape = MklDnnDimsToTFShape(output_dims);
+ }
+
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ *output_mkl_shape);
+ }
+};
+
+// MKL-DNN Slice registration
+#define REGISTER_MKL_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDnnSliceOp<CPUDevice, type>);
+
+TF_CALL_float(REGISTER_MKL_SLICE);
+#undef REGISTER_MKL_SLICE
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index fc1c9003aa..3979e4b53a 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -97,7 +97,20 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ OP_REQUIRES_ASYNC(
+ ctx, args.size() == fbody->arg_nodes.size(),
+ errors::InvalidArgument(
+ "Wrong number of arguments to the op; function expects ",
+ fbody->arg_nodes.size(), " but PartitionedCall received ",
+ args.size()),
+ done);
+ // We need to pass global op_registry as default_registry when creating
+ // graph. So that graph optimization passes can lookup all possible ops
+ // by name.
auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ graph.get()->AddFunctionLibrary(global_flib.ToProto()));
CopyGraph(*fbody->graph, graph.get());
OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
@@ -250,9 +263,11 @@ class PartitionedCallOp : public AsyncOpKernel {
VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
<< partitions.size() << " shards.";
- const FunctionLibraryDefinition* flib_def = &graph->flib_def();
for (const auto& partition : partitions) {
- std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ subgraph.get()->AddFunctionLibrary(global_flib.ToProto()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = true;
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index 04a53697c0..3810d817ca 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -489,13 +489,15 @@ class RandomGammaOp : public OpKernel {
Name("RandomGamma").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
RandomGammaOp<TYPE>)
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ CPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_CPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<CPUDevice, IntType>);
TF_CALL_half(REGISTER);
@@ -538,14 +540,16 @@ TF_CALL_int64(REGISTER_INT);
random::TruncatedNormalDistribution< \
random::SingleSampleAdapter<random::PhiloxRandom>, TYPE>>);
-#define REGISTER_INT(IntType) \
- REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("minval") \
- .HostMemory("maxval") \
- .TypeConstraint<int32>("T") \
- .TypeConstraint<IntType>("Tout"), \
+#define REGISTER_INT(IntType) \
+ template struct functor::FillPhiloxRandom< \
+ GPUDevice, random::UniformDistribution<random::PhiloxRandom, IntType>>; \
+ REGISTER_KERNEL_BUILDER(Name("RandomUniformInt") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("shape") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<int32>("T") \
+ .TypeConstraint<IntType>("Tout"), \
RandomUniformIntOp<GPUDevice, IntType>);
TF_CALL_half(REGISTER);
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 26705a8d34..678d675c4a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -51,7 +51,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif
-#include "tensorflow/core/kernels/resource_variable_ops.h"
+#include <memory>
+#include <vector>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -60,10 +62,12 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -72,6 +76,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
+ ResourceHandlesOp<Var>);
ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
@@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, t);
}
+ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
+ int n;
+ OP_REQUIRES_OK(c, c->GetAttr("N", &n));
+ OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
+ OP_REQUIRES(c, n == dtypes_.size(),
+ errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp (", n,
+ " vs. ", dtypes_.size(), ")"));
+}
+
+void ReadVariablesOp::Compute(OpKernelContext* ctx) {
+ std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
+ dtypes_.size());
+ std::vector<const ResourceHandle*> handles(dtypes_.size());
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ handles[i] = &HandleFromInput(ctx, i);
+ }
+ const auto status = LookupResources(ctx, handles, &variables);
+ OP_REQUIRES(ctx, status.ok(),
+ errors::FailedPrecondition(
+ "Error while reading resource variable. This could mean that "
+ "the variable was uninitialized. ",
+ status.ToString()));
+
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ // We're acquiring a reference to the underlying buffer while
+ // holding a shared lock to guarantee ordering of reads and
+ // writes.
+ tf_shared_lock ml(*variables[i]->mu());
+ const Tensor& t = *variables[i]->tensor();
+ OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
+ errors::InvalidArgument(
+ "Trying to read variable ", handles[i]->name(),
+ " from Container: ", handles[i]->container(),
+ " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
+ " got ", DataTypeString(t.dtype())));
+ ctx->set_output(i, t);
+ }
+}
+
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
+ ReadVariablesOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
+ ReadVariablesOp);
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
@@ -122,11 +173,20 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
ResourceHandleOp<Var>)
-
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_variant(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
+ .Device(DEVICE_GPU)
+ .HostMemory("resources")
+ .TypeConstraint("dtypes",
+ {DT_INT64, DT_COMPLEX64,
+ DT_COMPLEX128, DT_HALF, DT_FLOAT,
+ DT_DOUBLE, DT_BOOL, DT_VARIANT}),
+ ResourceHandlesOp<Var>);
+
#endif // GOOGLE_CUDA
template <typename T>
@@ -366,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel {
// ADD if value's refcount was 1.
mutex_lock ml(*variable->mu());
Tensor* var_tensor = variable->tensor();
+ OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
+ errors::InvalidArgument("Cannot update variable with shape ",
+ var_tensor->shape().DebugString(),
+ " using a Tensor with shape ",
+ value.shape().DebugString(),
+ ", shapes must be equal."));
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, var_tensor));
functor::DenseUpdate<Device, T, Op> update_functor;
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 9b60106f13..cffb732c38 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel {
DataType dtype_;
};
+class ReadVariablesOp : public OpKernel {
+ public:
+ explicit ReadVariablesOp(OpKernelConstruction* c);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
+ private:
+ DataTypeVector dtypes_;
+};
+
class DestroyResourceOp : public OpKernel {
public:
explicit DestroyResourceOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 77594479cb..a006c69297 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -228,191 +228,6 @@ class SliceOp : public OpKernel {
}
};
-#ifdef INTEL_MKL
-template <typename Device, typename T>
-class MklSliceOp : public OpKernel {
- public:
- explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- TensorShape output_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> size;
- Tensor* result = nullptr;
- bool done = false;
- SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
- &done);
- if (!context->status().ok() || done == true) return;
-
- const Tensor& input = context->input(0);
- const int input_dims = input.dims();
-
- if (output_shape.num_elements() > 0) {
- if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
- DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- auto input = context->input(0).tensor<T, 2>();
- auto output = result->tensor<T, 2>();
- // TODO(agarwal): Consider multi-threading this loop for cases where
- // size[0] is very large.
- for (int i = 0; i < size[0]; ++i) {
- const int64 row = begin[0] + i;
- if (i + 1 < size[0]) {
- port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
- }
- memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
- }
- return;
- }
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
- }
-
- HANDLE_DIM(1);
- HANDLE_DIM(2);
- HANDLE_DIM(3);
- HANDLE_DIM(4);
- HANDLE_DIM(5);
- HANDLE_DIM(6);
- HANDLE_DIM(7);
-
-#undef HANDLE_DIM
-
- OP_REQUIRES(
- context, false,
- errors::Unimplemented("SliceOp : Unhandled input dimensions"));
- }
- }
-
- private:
- // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
- // criteria matches for slice_dim: if indices for slice are 0 in all dims
- // except slice_dim and if sizes of all the dimensions of the slice are same
- // as the sizes of all the dimensions of the input except slice_dim, then
- // returns True. Otherwise, returns False.
- bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (dim != slice_dim &&
- (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
- return false;
- }
- }
- return true;
- }
-
- // Is 'input' tensor being sliced over a single dimension out of 4?
- //
- // This check is applicable in the context of Slice of a 4-D tensor in
- // NHWC or NCHW format over channel dimension.
- //
- // If indices for slice are 0 in all dims except one dimension and if sizes of
- // all dimensions of slice are same as sizes of all dimensions of inputs
- // except that dimension, then we are slicing over a single dimension.
- //
- // Returns True if Slicing over a single dimension, and sets slice_dim
- // to the number of the dimension that satisfies criteria.
- bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int* slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
- *slice_dim = dim;
- return true;
- }
- }
- return false;
- }
-
- template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size, Tensor* result) {
- int slice_dim = -1;
- TensorShape in_shape = context->input(0).shape();
- // Special case for handling 4-D tensor slice when shape of the slice
- // differs from the input tensor in only 1 out of 4 dimensions.
- // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
- // format over channel dimension.
- if (NDIM == 4 &&
- DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
- size_t in_strides[4] = {
- (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) *
- in_shape.dim_size(3),
- (size_t)in_shape.dim_size(2) * in_shape.dim_size(3),
- (size_t)in_shape.dim_size(3), (size_t)1};
-
- size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3],
- (size_t)size[2] * size[3], (size_t)size[3],
- (size_t)1};
-
- T* in_buf = const_cast<T*>(
- const_cast<const T*>(context->input(0).flat<T>().data()));
- T* op_buf = result->flat<T>().data();
-
- if (slice_dim == 1) {
- /* data format = NCHW */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
- // For NCHW, H and W will be contiguous. So we can copy
- // both with one memcpy.
- memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
- sizeof(T) * in_strides[1]);
- }
- }
- return;
- } else if (slice_dim == 3) {
- /* data_format = NHWC */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
-#pragma omp parallel for
- for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
- T* ip2 = ip1 + (d2 * in_strides[2]);
- T* ip3 = ip2 + begin[3];
- T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
- T* op3 = op2;
- memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
- sizeof(T) * size[3]);
- }
- }
- }
- return;
- }
- // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
- }
-
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
- }
-
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
- }
-};
-#endif
-
// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
@@ -440,7 +255,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#ifndef INTEL_MKL
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -452,19 +266,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
#undef REGISTER_SLICE
-#else
-#define REGISTER_SLICE(type) \
- REGISTER_KERNEL_BUILDER(Name("Slice") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .HostMemory("begin") \
- .HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
-
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-#undef REGISTER_SLICE
-#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/stateless_random_ops.cc b/tensorflow/core/kernels/stateless_random_ops.cc
index eab176c7fb..925f5291a6 100644
--- a/tensorflow/core/kernels/stateless_random_ops.cc
+++ b/tensorflow/core/kernels/stateless_random_ops.cc
@@ -113,74 +113,109 @@ class StatelessRandomOp : public StatelessRandomOpBase {
}
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<CPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_CPU) \
- .HostMemory("shape") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- CPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+template <typename Device, typename IntType>
+class StatelessRandomUniformIntOp : public StatelessRandomOpBase {
+ public:
+ using StatelessRandomOpBase::StatelessRandomOpBase;
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+ void Fill(OpKernelContext* context, random::PhiloxRandom random,
+ Tensor* output) override {
+ const Tensor& minval = context->input(2);
+ const Tensor& maxval = context->input(3);
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(minval.shape()),
+ errors::InvalidArgument("minval must be 0-D, got shape ",
+ minval.shape().DebugString()));
+ OP_REQUIRES(context, TensorShapeUtils::IsScalar(maxval.shape()),
+ errors::InvalidArgument("maxval must be 0-D, got shape ",
+ maxval.shape().DebugString()));
+
+ // Verify that minval < maxval. Note that we'll never reach this point for
+ // empty output. Zero impossible things are fine.
+ const auto lo = minval.scalar<IntType>()();
+ const auto hi = maxval.scalar<IntType>()();
+ OP_REQUIRES(
+ context, lo < hi,
+ errors::InvalidArgument("Need minval < maxval, got ", lo, " >= ", hi));
+
+ // Build distribution
+ typedef random::UniformDistribution<random::PhiloxRandom, IntType>
+ Distribution;
+ Distribution dist(lo, hi);
+
+ auto flat = output->flat<IntType>();
+ // Reuse the compute kernels from the stateful random ops
+ functor::FillPhiloxRandom<Device, Distribution>()(
+ context, context->eigen_device<Device>(), random, flat.data(),
+ flat.size(), dist);
+ }
+};
-#undef REGISTER
+#define REGISTER(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomUniform") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::UniformDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessRandomNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp<DEVICE##Device, random::NormalDistribution< \
+ random::PhiloxRandom, TYPE> >); \
+ REGISTER_KERNEL_BUILDER( \
+ Name("StatelessTruncatedNormal") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomOp< \
+ DEVICE##Device, \
+ random::TruncatedNormalDistribution< \
+ random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+
+#define REGISTER_INT(DEVICE, TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("StatelessRandomUniformInt") \
+ .Device(DEVICE_##DEVICE) \
+ .HostMemory("shape") \
+ .HostMemory("seed") \
+ .HostMemory("minval") \
+ .HostMemory("maxval") \
+ .TypeConstraint<TYPE>("dtype"), \
+ StatelessRandomUniformIntOp<DEVICE##Device, TYPE>);
+
+#define REGISTER_CPU(TYPE) REGISTER(CPU, TYPE)
+#define REGISTER_GPU(TYPE) REGISTER(GPU, TYPE)
+#define REGISTER_INT_CPU(TYPE) REGISTER_INT(CPU, TYPE)
+#define REGISTER_INT_GPU(TYPE) REGISTER_INT(GPU, TYPE)
+
+TF_CALL_half(REGISTER_CPU);
+TF_CALL_bfloat16(REGISTER_CPU);
+TF_CALL_float(REGISTER_CPU);
+TF_CALL_double(REGISTER_CPU);
+TF_CALL_int32(REGISTER_INT_CPU);
+TF_CALL_int64(REGISTER_INT_CPU);
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomUniform") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::UniformDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessRandomNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp<GPUDevice, random::NormalDistribution< \
- random::PhiloxRandom, TYPE> >); \
- REGISTER_KERNEL_BUILDER( \
- Name("StatelessTruncatedNormal") \
- .Device(DEVICE_GPU) \
- .HostMemory("shape") \
- .HostMemory("seed") \
- .TypeConstraint<TYPE>("dtype"), \
- StatelessRandomOp< \
- GPUDevice, \
- random::TruncatedNormalDistribution< \
- random::SingleSampleAdapter<random::PhiloxRandom>, TYPE> >);
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+TF_CALL_int32(REGISTER_INT_GPU);
+TF_CALL_int64(REGISTER_INT_GPU);
-TF_CALL_half(REGISTER);
-TF_CALL_float(REGISTER);
-TF_CALL_double(REGISTER);
+#endif // GOOGLE_CUDA
#undef REGISTER
-
-#endif // GOOGLE_CUDA
+#undef REGISTER_INT
+#undef REGISTER_CPU
+#undef REGISTER_GPU
+#undef REGISTER_INT_CPU
+#undef REGISTER_INT_GPU
} // namespace
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index f0575de4d9..3e8a4c5b72 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -149,7 +149,7 @@ class StridedSliceOp : public OpKernel {
// NDIM and T
if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
input_dims == 2 && processing_shape.dims() == 2 &&
- final_shape.dims() == 2) {
+ final_shape.dims() == 2 && new_axis_mask == 0) {
MemCpyFunctor<T> functor;
if (functor.Copy(input, begin, end, result)) {
return;
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
index 3a9803a052..92c73220d8 100644
--- a/tensorflow/core/kernels/string_util.cc
+++ b/tensorflow/core/kernels/string_util.cc
@@ -16,10 +16,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
-namespace {
-inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
-} // namespace
-
namespace tensorflow {
// Sets unit value based on str.
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
index 390cf57702..d40e93ea33 100644
--- a/tensorflow/core/kernels/string_util.h
+++ b/tensorflow/core/kernels/string_util.h
@@ -30,6 +30,9 @@ enum class UnicodeEncoding { UTF8 };
// TODO(edloper): Add support for: UTF32_CHAR, etc.
enum class CharUnit { BYTE, UTF8_CHAR };
+// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char.
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+
// Sets `encoding` based on `str`.
Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
@@ -40,6 +43,47 @@ Status ParseCharUnit(const string& str, CharUnit* unit);
// Result may be incorrect if the input string is not valid UTF-8.
int32 UTF8StrLen(const string& string);
+// Get the next UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset, and
+// should never be `null`. The function return true if successful. However, if
+// the end of the string is reached before the requested characters, then the
+// position will point to the end of string and this function will return false.
+template <typename T>
+bool ForwardNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t size = in.size();
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) {
+ // move forward one utf-8 character
+ do {
+ ++*pos;
+ } while (IsTrailByte(in[*pos]) && *pos < size);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
+// Get the previous UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset with a
+// positive value, relative to the beginning of the string, and should never be
+// `null`. The function return true if successful. However, if the beginning of
+// the string is reached before the requested character, then the position will
+// point to the beginning of the string and this function will return false.
+template <typename T>
+bool BackNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t start = 0;
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) {
+ // move back one utf-8 character
+ do {
+ --*pos;
+ } while (IsTrailByte(in[*pos]) && *pos > start);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 07f1d6e767..93c427039d 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/string_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
@@ -37,7 +38,11 @@ namespace tensorflow {
template <typename T>
class SubstrOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
// Get inputs
@@ -69,11 +74,23 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
StringPiece in(input(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
} else {
@@ -84,11 +101,23 @@ class SubstrOp : public OpKernel {
StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
}
@@ -151,12 +180,24 @@ class SubstrOp : public OpKernel {
StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
- OP_REQUIRES(
- context,
- FastBoundsCheck(std::abs(pos), input_bcast(i).size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i).assign(sub_in.data(), sub_in.size());
}
break;
@@ -205,12 +246,24 @@ class SubstrOp : public OpKernel {
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(
- context, FastBoundsCheck(std::abs(pos), in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for ",
- "string b'", in, "' at index (", i,
- ", ", j, ")"));
- StringPiece sub_in = in.substr(AdjustedPosIndex(pos, in), len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (",
+ i, ", ", j, ")"));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
output(i, j).assign(sub_in.data(), sub_in.size());
}
}
@@ -227,12 +280,73 @@ class SubstrOp : public OpKernel {
private:
// This adjusts the requested position. Note it does not perform any bound
// checks.
- T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
if (pos_requested < 0) {
return s.size() + pos_requested;
}
return pos_requested;
}
+
+ // Return true if successful; otherwise, return false if the `pos` argument
+ // is out of range in the string.
+ static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
+ T* len) {
+ if (*pos >= 0) {
+ return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
+ } else {
+ return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
+ }
+ }
+
+ static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ *char_pos = 0;
+ // Determine byte position of the substring start.
+ if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
+ return false;
+ }
+ // Determine position of the end of the substring.
+ // The length will be capped at the end of the string, and we ignore whether
+ // the string had enough characters to handle it or not.
+ *char_len = *char_pos;
+ ForwardNUTF8CharPositions(in, len, char_len);
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ // This function expects a negative position relative to the end of the
+ // string, but will update the character position to a positive number
+ // relative to the beginning of the string.
+ static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ // Initially treat the length as position of the end of the substring.
+ *char_len = in.size();
+ // This is the number of character to skip from the end of the string to
+ // arrive at the position where the substring should end.
+ T utf8_chars_to_skip = -pos - len;
+ if (utf8_chars_to_skip < 0) {
+ utf8_chars_to_skip = 0;
+ }
+ // Find the byte position where the substring should end using the computed
+ // number of characters to skip.
+ if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
+ return false;
+ }
+ // Next, determine where the substring should begin. The number of chars to
+ // skip is the requested position minus the chars we've previously skipped.
+ *char_pos = *char_len;
+ if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
+ return false;
+ }
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ CharUnit unit_ = CharUnit::BYTE;
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
index 2e07050260..ea6b1ed500 100644
--- a/tensorflow/core/kernels/substr_op_test.cc
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -42,7 +42,7 @@ limitations under the License.
namespace tensorflow {
// Test data from the TensorFlow README.md.
-const char* lines[] = {
+const char* ascii_lines[] = {
"**TensorFlow** is an open source software library for numerical "
"computation using data flow graphs.",
"The graph nodes represent mathematical operations, while the graph edges "
@@ -64,17 +64,76 @@ const char* lines[] = {
"backwards compatibility guarantee like C++, Go, Java, JavaScript and "
"Swift."};
+const char* unicode_lines[] = {
+ "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6"
+ "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6"
+ "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6"
+ "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82",
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c"
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81"
+ "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae"
+ "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89"
+ "\xe3\x80\x82",
+ "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93"
+ "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf"
+ "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2"
+ "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1"
+ "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87"
+ "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a"
+ "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9"
+ "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82",
+ "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88"
+ "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef"
+ "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6"
+ "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5"
+ "\x8c\x85\xe3\x80\x82",
+ "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7"
+ "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5"
+ "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd"
+ "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain"
+ "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c"
+ "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba"
+ "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6"
+ "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6"
+ "\xe3\x80\x82",
+ "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82"
+ "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96"
+ "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4"
+ "\xe3\x80\x82",
+ "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84"
+ "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2"
+ "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6"
+ "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c"
+ "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82",
+};
+
+const char* const kByteUnit = "BYTE";
+const char* const kUTF8Unit = "UTF8_CHAR";
+
Tensor GetTestTensor(int batch) {
- const int sz = TF_ARRAYSIZE(lines);
+ const int sz = TF_ARRAYSIZE(ascii_lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = ascii_lines[i % sz];
+ }
+ return t;
+}
+
+Tensor GetTestUTF8Tensor(int batch) {
+ const int sz = TF_ARRAYSIZE(unicode_lines);
Tensor t(DT_STRING, {batch});
auto s = t.flat<string>();
for (int i = 0; i < batch; ++i) {
- s(i) = lines[i % sz];
+ s(i) = unicode_lines[i % sz];
}
return t;
}
-Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len,
+ const char* const unit) {
Graph* g = new Graph(OpRegistry::Global());
Tensor position(DT_INT32, TensorShape({}));
position.flat<int32>().setConstant(pos);
@@ -85,21 +144,46 @@ Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len) {
.Input(test::graph::Constant(g, input))
.Input(test::graph::Constant(g, position))
.Input(test::graph::Constant(g, length))
+ .Attr("unit", unit)
.Finalize(g, nullptr /* node */));
return g;
}
-void BM_Substr(int iters, int batch_size) {
+void BM_SubstrByte(int iters, int batch_size) {
testing::StopTiming();
testing::ItemsProcessed(static_cast<int64>(iters));
testing::UseRealTime();
Tensor input = GetTestTensor(batch_size);
- Graph* g = SetupSubstrGraph(input, 3, 30);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+void BM_SubstrUTF8(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestUTF8Tensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit);
testing::StartTiming();
test::Benchmark("cpu", g).Run(iters);
}
-BENCHMARK(BM_Substr)->Arg(1)->Arg(8)->Arg(16)->Arg(32)->Arg(64)->Arg(128)->Arg(
- 256);
+BENCHMARK(BM_SubstrByte)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+BENCHMARK(BM_SubstrUTF8)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index 83b83fcdb9..4262a5404b 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -15,14 +15,16 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/util/ptr_util.h"
+
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource) {
+ *maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
- Var* var;
- if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
- core::ScopedUnref scoped_unref(var);
- return var->mu();
+ if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
+ return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
@@ -33,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
-// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once. If
-// do_lock is false, returns immediately. Note that this silently doesn't lock
-// mutexes for invalid variable references; in all usages this is followed by
-// GetInputTensor which will signal a failure.
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// in address order to mitigate deadlock. Returns a structure that, when
+// deleted, will release the acquired mutexes. Safe to pass duplicates - will
+// only lock each distinct mutex once. If do_lock is false, returns
+// immediately. Note that this silently doesn't lock mutexes for invalid
+// variable references; in all usages this is followed by GetInputTensor which
+// will signal a failure.
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
@@ -47,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
break;
}
}
- std::vector<mutex_lock> locks;
if (!do_lock && !any_resource) {
- return locks;
+ return VariableInputLockHolder({}, {});
}
+ std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- mutex* mutex = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
+ if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
@@ -64,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
+ std::unique_ptr<std::vector<mutex_lock>> locks =
+ MakeUnique<std::vector<mutex_lock>>();
+ locks->reserve(acquire_order.size());
+
for (auto input : acquire_order) {
- mutex* mu = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
+ core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
- locks.emplace_back(*mu);
+ locks->emplace_back(*mu);
}
}
- return locks;
+ return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 071cb371a7..9f173a80f7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -23,9 +23,42 @@ limitations under the License.
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
+// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
+//
+// If `input` corresponds to a `DT_RESOURCE`-type variable input,
+// `*maybe_resource` will be updated to contain the underlying resource, and the
+// caller will be responsible for calling `Unref()` on that resource.
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource);
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// Utility structure that releases a sequence of borrowed mutexes when it is
+// deleted.
+struct VariableInputLockHolder {
+ public:
+ VariableInputLockHolder(std::vector<Var*> vars,
+ std::unique_ptr<std::vector<mutex_lock>> locks)
+ : vars_(std::move(vars)), locks_(std::move(locks)) {}
+
+ VariableInputLockHolder(VariableInputLockHolder&& other)
+ : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
+
+ ~VariableInputLockHolder() {
+ // Release the locks before unreffing the Vars, because each lock
+ // is potentially borrowed from a Var in vars_.
+ locks_.reset();
+ for (Var* var : vars_) {
+ var->Unref();
+ }
+ }
+
+ private:
+ std::vector<Var*> vars_;
+ // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
+ // because a `std::vector<mutex_lock>` is not movable on all platforms.
+ std::unique_ptr<std::vector<mutex_lock>> locks_;
+};
+
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 9a07ded17d..acf162deec 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -561,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* resource;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource);
+ core::ScopedUnref scoped_unref(resource);
if (use_exclusive_lock_ && mu != nullptr) {
mutex_lock l1(*mu);
// Don't try to acquire a lock on the second ref as they share the same
@@ -710,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &var);
+ core::ScopedUnref scoped_unref(var);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a3..48e392c070 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
ConjugateTransposeCpuOp);
+#endif // INTEL_MKL && ENABLE_MKL
+
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
-#endif // INTEL_MKL
#if GOOGLE_CUDA
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc
new file mode 100644
index 0000000000..085e397eba
--- /dev/null
+++ b/tensorflow/core/kernels/unicode_script_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "unicode/errorcode.h" // TF:icu
+#include "unicode/uscript.h" // TF:icu
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class UnicodeScriptOp : public OpKernel {
+ public:
+ explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<int32>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<int32>();
+
+ icu::ErrorCode status;
+ for (int i = 0; i < input_flat.size(); i++) {
+ UScriptCode script_code = uscript_getScript(input_flat(i), status);
+ if (status.isSuccess()) {
+ output_flat(i) = script_code;
+ } else {
+ output_flat(i) = -1;
+ status.reset();
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU),
+ UnicodeScriptOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index 3559baa18e..3bdcfc90b8 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -108,7 +108,7 @@ class UniqueOp : public OpKernel {
std::unordered_map<T, TIndex> uniq;
uniq.reserve(2 * N);
- for (int64 i = 0, j = 0; i < N; ++i) {
+ for (Eigen::Index i = 0, j = 0; i < N; ++i) {
auto it = uniq.insert(std::make_pair(Tin(i), j));
idx_vec(i) = it.first->second;
if (it.second) {
@@ -131,19 +131,20 @@ class UniqueOp : public OpKernel {
// General implementation when unique is run over multiple elements.
auto Tin = input.shaped<T, 3>(new_sizes);
- auto hash_fn = [&Tin](const int64& key) {
+ auto hash_fn = [&Tin](const Eigen::Index& key) {
size_t h = 0;
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
h = Hash64Combine(h, hash<T>{}(Tin(i, key, j)));
}
}
return h;
};
- auto equal_to_fn = [&Tin](const int64& lhs, const int64& rhs) {
- for (int64 i = 0; i < Tin.dimension(0); i++) {
- for (int64 j = 0; j < Tin.dimension(2); j++) {
+ auto equal_to_fn = [&Tin](const Eigen::Index& lhs,
+ const Eigen::Index& rhs) {
+ for (Eigen::Index i = 0; i < Tin.dimension(0); i++) {
+ for (Eigen::Index j = 0; j < Tin.dimension(2); j++) {
if (Tin(i, lhs, j) != Tin(i, rhs, j)) {
return false;
}
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 442686c92a..f55562ec99 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -133,6 +133,14 @@ Status TransposeShapeFn(InferenceContext* c) {
} else {
rank = perm->NumElements();
}
+ if (!c->RankKnown(input) && rank < 2) {
+ // A permutation array containing a single element is ambiguous. It could
+ // indicate either a scalar or a 1-dimensional array, both of which the
+ // transpose op returns unchanged.
+ c->set_output(0, input);
+ return Status::OK();
+ }
+
std::vector<DimensionHandle> dims;
dims.resize(rank);
TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
@@ -1531,37 +1539,6 @@ REGISTER_OP("Size")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape);
-namespace {
-
-// This SliceHelper processes the output shape of the `slice`
-// when the tensor of `sizes` is available.
-template <typename T>
-Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
- const Tensor* sizes_value,
- std::vector<DimensionHandle>* dims) {
- auto sizes_vec = sizes_value->vec<T>();
- for (int i = 0; i < sizes_value->NumElements(); ++i) {
- DimensionHandle dim = c->Dim(c->input(0), i);
- if (sizes_vec(i) != -1) {
- auto dim_val = c->Value(dim);
- if (sizes_vec(i) < 0) {
- return errors::InvalidArgument(
- "Out of bounds slicing on dimension ", i, " of length ", dim_val,
- ": sizes vector cannot be < -1, but was ", sizes_vec(i));
- }
-
- dims->emplace_back(c->MakeDim(sizes_vec(i)));
- } else {
- DimensionHandle result;
- TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
- dims->emplace_back(result);
- }
- }
-
- return Status::OK();
-}
-} // namespace
-
// --------------------------------------------------------------------------
REGISTER_OP("Slice")
.Input("input: T")
@@ -1570,83 +1547,22 @@ REGISTER_OP("Slice")
.Output("output: T")
.Attr("T: type")
.Attr("Index: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle input = c->input(0);
- ShapeHandle begin_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
- ShapeHandle sizes_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
-
- // Merge to check compatibility of begin and sizes tensors.
- TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+ .SetShapeFn(shape_inference::SliceShape);
- DimensionHandle ndims = c->Dim(begin_shape, 0);
- if (c->ValueKnown(ndims)) {
- TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
- }
-
- // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
- // values, even though the `begin` value does not represent a shape.
- ShapeHandle begin_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
-
- // We check the tensor value here and will only use
- // `MakeShapeFromShapeTensor` when `sizes_value` is null.
- // The reason is that `sizes`might contain -1, which can't
- // be represented (-1 in the ShapeHandle would mean "unknown".
- const Tensor* sizes_value = c->input_tensor(2);
-
- if (sizes_value != nullptr) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
- std::vector<DimensionHandle> dims;
- // If the begin and sizes tensors are available, then
- // we can be precise about the shape of the output.
- if (sizes_value->dtype() == DT_INT64) {
- TF_RETURN_IF_ERROR(
- SliceHelper<int64>(c, begin_value, sizes_value, &dims));
- } else {
- TF_RETURN_IF_ERROR(
- SliceHelper<int32>(c, begin_value, sizes_value, &dims));
- }
-
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- } else {
- // In case `sizes` is not available (`sizes_value` is null),
- // we could try to use `MakeShapeFromShapeTensor` here.
- // If sizes contain -1, we will simply consider it as `Unknown`.
- // This is less than ideal but still an improvement of shape inference.
- // The following is an example that returns [None, 1, None] with this
- // code path:
- // z = tf.zeros((1, 2, 3))
- // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
- // m.get_shape().as_list()
- ShapeHandle sizes_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
- if (c->RankKnown(sizes_value)) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
- std::vector<DimensionHandle> dims;
- dims.reserve(c->Rank(sizes_value));
- for (int i = 0; i < c->Rank(sizes_value); ++i) {
- dims.emplace_back(c->Dim(sizes_value, i));
- }
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- }
-
- // We might know the rank of the input.
- if (c->RankKnown(input)) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
- return Status::OK();
- } else {
- return shape_inference::UnknownShape(c);
- }
- }
-
- return Status::OK();
- });
+#ifdef INTEL_MKL
+REGISTER_OP("_MklSlice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Input("mkl_input: uint8")
+ .Input("mkl_begin: uint8")
+ .Input("mkl_size: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .SetShapeFn(shape_inference::SliceShape);
+#endif
REGISTER_OP("StridedSlice")
.Input("input: T")
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index 03dab390a7..1c29cd2491 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -975,6 +975,7 @@ TEST(ArrayOpsTest, Transpose_ShapeFn) {
INFER_OK(op, "?;[2]", "[?,?]");
INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
+ INFER_OK(op, "?;[0]", "in0");
// Invalid arguments.
perm = test::AsTensor<int32>({1, 2});
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 86d4c6b421..0753316724 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21532,6 +21532,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -24105,6 +24520,158 @@ op {
}
}
op {
+ name: "FusedBatchNorm"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
+ name: "FusedBatchNormGrad"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormGrad"
input_arg {
name: "y_backprop"
@@ -24168,6 +24735,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -24345,6 +24918,179 @@ op {
}
}
op {
+ name: "FusedBatchNormGradV2"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
+ name: "FusedBatchNormV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormV2"
input_arg {
name: "x"
@@ -24392,6 +25138,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
}
}
@@ -24502,6 +25249,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -26317,6 +27070,52 @@ op {
is_stateful: true
}
op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Igamma"
input_arg {
name: "a"
@@ -29768,6 +30567,52 @@ op {
}
}
op {
+ name: "MapDefun"
+ input_arg {
+ name: "arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+}
+op {
name: "MapIncompleteSize"
output_arg {
name: "size"
@@ -44518,6 +45363,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -58933,6 +59831,14 @@ op {
name: "stats_aggregator"
type: DT_RESOURCE
}
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "counter_prefix"
+ type: DT_STRING
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -69991,6 +70897,62 @@ op {
}
}
op {
+ name: "StatelessRandomNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessRandomUniform"
input_arg {
name: "shape"
@@ -70088,6 +71050,118 @@ op {
}
}
op {
+ name: "StatelessRandomUniform"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -70185,6 +71259,62 @@ op {
}
}
op {
+ name: "StatelessTruncatedNormal"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessWhile"
input_arg {
name: "input"
@@ -70984,6 +72114,48 @@ op {
}
}
op {
+ name: "Substr"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "pos"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "len"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "Sum"
input_arg {
name: "input"
@@ -74573,6 +75745,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -75981,6 +77164,39 @@ op {
is_stateful: true
}
op {
+ name: "While"
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "cond"
+ type: "func"
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "WholeFileReader"
output_arg {
name: "reader_handle"
@@ -76283,6 +77499,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 1ada623cf5..ec22eee874 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -185,6 +185,8 @@ REGISTER_OP("ParseExampleDataset")
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
+ .Input("tag: string")
+ .Input("counter_prefix: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
@@ -756,6 +758,19 @@ REGISTER_OP("DatasetToSingleElement")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
+REGISTER_OP("ReduceDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("components: output_types")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
@@ -888,14 +903,18 @@ REGISTER_OP("ModelDataset")
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
+ .Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
+ .Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ DataTypeVector t_args;
+ TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@@ -903,10 +922,11 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
- for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ for (size_t i = 0; i < t_args.size(); ++i) {
if (c->Rank(c->input(i)) == 0) {
return errors::InvalidArgument(
- "Inputs must have rank at least 1. Input ", i, " has rank of 0");
+ "Arguments must have rank at least 1. Input ", i,
+ " has rank of 0.");
}
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
@@ -914,7 +934,7 @@ REGISTER_OP("MapDefun")
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
- "Inputs must have the same dimension 0.");
+ "Arguments must have the same dimension 0.");
}
}
}
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
new file mode 100644
index 0000000000..f6bd5dce26
--- /dev/null
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -0,0 +1,207 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
+ .Input("selector_input_dataset: variant")
+ .Input("data_input_datasets: N * variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalCSVDataset")
+ .Input("filenames: string")
+ .Input("compression_type: string")
+ .Input("buffer_size: int64")
+ .Input("header: bool")
+ .Input("field_delim: string")
+ .Input("use_quote_delim: bool")
+ .Input("na_value: string")
+ .Input("select_cols: int64")
+ .Input("record_defaults: output_types")
+ .Output("handle: variant")
+ .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `compression_type`, `buffer_size`, `header`, `field_delim`,
+ // `use_quote_delim`, `na_value` must be scalars
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
+ // `select_cols` must be a vector
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
+ // `record_defaults` must be lists of scalars
+ for (size_t i = 8; i < c->num_inputs(); ++i) {
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
+ }
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalUniqueDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIteratorGetDevice")
+ .Input("resource: resource")
+ .Output("device: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResource")
+ .Input("string_arg: string")
+ .Input("target_device: string")
+ .Output("resource: resource")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("f: func")
+ .Attr("buffer_size: int")
+ .Attr("output_types: list(type)")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
+ .Input("function_buffer_resource: resource")
+ .Attr("output_types: list(type)")
+ .Output("output: output_types")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
+ .Input("function_buffer_resource: resource")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalThreadPoolDataset")
+ .Input("input_dataset: variant")
+ .Input("thread_pool: resource")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalThreadPoolHandle")
+ .Output("handle: resource")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Attr("num_threads: int")
+ .Attr("max_intra_op_parallelism: int = 1")
+ .Attr("display_name: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
+ .Input("input_dataset: variant")
+ .Input("transformations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // transformations should be a vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalLMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index bda4a75c5d..22b4b07eff 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -110,8 +110,27 @@ REGISTER_OP("If")
.Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else return unknown shapes.
+ if (output_shapes.empty()) return shape_inference::UnknownShape(c);
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ });
// TODO(drpng): remove this.
REGISTER_OP("_While")
@@ -150,10 +169,29 @@ REGISTER_OP("While")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(i));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else use the input shapes.
+ if (!output_shapes.empty()) {
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ } else {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(i));
+ }
}
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 07f876cb90..55dcc50325 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -549,6 +549,40 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Pow", PowGrad);
+Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
+ {{"xlogygrad"}, "Xdivy", {"x", "y"}},
+ {{"gx"}, "Mul", {"safe_logy", "dz"}},
+ {{"gy"}, "Mul", {"xlogygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
+
+Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
+ {{"y2"}, "Square", {"y"}},
+ {{"negy2"}, "Neg", {"y2"}},
+ {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
+ {{"gx"}, "Mul", {"safe_divy", "dz"}},
+ {{"gy"}, "Mul", {"xdivygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
+
Status MaximumMinimumGradHelper(const string& comparator,
const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 5ee79809ac..9fc6b34147 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -909,6 +909,46 @@ TEST_F(MathGradTest, ComplexPow) {
}
#endif // TENSORFLOW_USE_SYCL
+TEST_F(MathGradTest, Xlogy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : std::log(y); };
+ auto h = [](float x, float y) -> float { return x == 0. ? 0. : x / y; };
+ SymGrad("Xlogy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
+TEST_F(MathGradTest, Xdivy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : 1 / y; };
+ auto h = [](float x, float y) -> float {
+ return x == 0. ? 0. : -x / (y * y);
+ };
+ SymGrad("Xdivy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
TEST_F(MathGradTest, Maximum) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 717263a9b0..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -429,6 +429,20 @@ Returns (x - y)(x - y) element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("Xlogy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
+REGISTER_OP("Xdivy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
#undef BINARY_FEWER
#undef BINARY_MORE
@@ -1423,7 +1437,24 @@ REGISTER_OP("Bincount")
.Attr("T: {int32, int64, float32, float64}")
.Output("bins: T")
.SetShapeFn([](InferenceContext* c) {
- c->set_output(0, c->UnknownShapeOfRank(1));
+ ShapeHandle unused;
+ // The input `size` must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+
+ const Tensor* size_tensor = c->input_tensor(1);
+ if (size_tensor == nullptr) {
+ // Return unknown shape if size is not known.
+ c->set_output(0, c->UnknownShapeOfRank(1));
+ return Status::OK();
+ }
+
+ // Return `[size]` shape if size is known.
+ int32 size_val = size_tensor->scalar<int32>()();
+ if (size_val < 0) {
+ return errors::InvalidArgument("size (", size_val,
+ ") must be non-negative");
+ }
+ c->set_output(0, c->MakeShape({size_val}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index be4c3ed2b6..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -559,4 +559,16 @@ TEST(MathOpsTest, QuantizedAdd_ShapeFn) {
INFER_ERROR("must be rank 0", op, "?;?;?;?;[3];?");
INFER_ERROR("must be rank 0", op, "?;?;?;?;?;[4]");
}
+
+TEST(MathOpsTest, Bincount_ShapeFn) {
+ ShapeInferenceTestOp op("Bincount");
+
+ // size should be scalar.
+ INFER_ERROR("Shape must be rank 0 but is rank 1", op, "?;[1];?");
+
+ INFER_OK(op, "?;?;?", "[?]");
+ INFER_OK(op, "?;[];?", "[?]");
+ INFER_OK(op, "[?];[];?", "[?]");
+ INFER_OK(op, "[?];[];[?]", "[?]");
+}
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 6191a88e5b..a9ca69ad86 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -178,7 +178,7 @@ REGISTER_OP("FusedBatchNorm")
.Output("reserve_space_2: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -196,7 +196,7 @@ REGISTER_OP("FusedBatchNormV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -213,7 +213,7 @@ REGISTER_OP("FusedBatchNormGrad")
.Output("reserve_space_4: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -231,7 +231,7 @@ REGISTER_OP("FusedBatchNormGradV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 3ae4f1a59e..2048ad26ac 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10039,6 +10039,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -11459,6 +11874,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11532,6 +11953,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11616,6 +12043,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11700,6 +12133,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -12737,6 +13176,14 @@ op {
name: "else_branch"
type: "func"
}
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -14883,6 +15330,10 @@ op {
name: "arguments"
type_list_attr: "Targuments"
}
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
output_arg {
name: "output"
type_list_attr: "output_types"
@@ -14894,6 +15345,15 @@ op {
minimum: 1
}
attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
name: "output_types"
type: "list(type)"
has_minimum: true
@@ -22913,6 +23373,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -28171,6 +28684,14 @@ op {
name: "stats_aggregator"
type: DT_RESOURCE
}
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "counter_prefix"
+ type: DT_STRING
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -32525,6 +33046,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -32580,6 +33102,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -32613,6 +33136,62 @@ op {
}
}
op {
+ name: "StatelessRandomUniformInt"
+ input_arg {
+ name: "shape"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "seed"
+ type_attr: "Tseed"
+ }
+ input_arg {
+ name: "minval"
+ type_attr: "dtype"
+ }
+ input_arg {
+ name: "maxval"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "Tseed"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "StatelessTruncatedNormal"
input_arg {
name: "shape"
@@ -32635,6 +33214,7 @@ op {
allowed_values {
list {
type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
}
@@ -33308,6 +33888,19 @@ op {
}
}
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "Sum"
@@ -35693,6 +36286,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -36500,6 +37104,14 @@ op {
name: "body"
type: "func"
}
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -36805,6 +37417,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 26499540f1..adc9cd1486 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeAndType;
@@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ReadVariablesShapeFn(InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector value_dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes));
+ if (n != value_dtypes.size()) {
+ return errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp");
+ }
+ for (int i = 0; i < n; ++i) {
+ ShapeAndType shape_and_type;
+ auto* handle_data = c->input_handle_shapes_and_types(i);
+ if (handle_data == nullptr || handle_data->empty()) {
+ shape_and_type.shape = c->UnknownShape();
+ shape_and_type.dtype = DT_INVALID;
+ } else {
+ shape_and_type = (*handle_data)[0];
+ if (shape_and_type.dtype != value_dtypes[i]) {
+ return errors::InvalidArgument(
+ "Trying to read variable with wrong dtype. "
+ "Expected ",
+ DataTypeString(shape_and_type.dtype), " got ",
+ DataTypeString(value_dtypes[i]));
+ }
+ }
+ c->set_output(i, shape_and_type.shape);
+ }
+ return Status::OK();
+}
+
} // namespace
REGISTER_OP("VarHandleOp")
@@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp")
return Status::OK();
});
+REGISTER_OP("_VarHandlesOp")
+ .Attr("containers: list(string)")
+ .Attr("shared_names: list(string)")
+ .Attr("N: int >= 0")
+ .Attr("dtypes: list(type)")
+ .Attr("shapes: list(shape)")
+ .Output("resources: N * resource")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+ std::vector<PartialTensorShape> shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ if (dtypes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of dtypes (n=", n,
+ ", num dtypes=", dtypes.size(), ")");
+ }
+ if (shapes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of shapes (n=", n,
+ ", num shapes=", shapes.size(), ")");
+ }
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s));
+ c->set_output_handle_shapes_and_types(
+ i, std::vector<ShapeAndType>{{s, dtypes[i]}});
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+REGISTER_OP("_ReadVariablesOp")
+ .Attr("N: int >= 0")
+ .Input("resources: N * resource")
+ .Output("values: dtypes")
+ .Attr("dtypes: list(type)")
+ .SetShapeFn(ReadVariablesShapeFn);
+
Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FunctionDefHelper::Define(
diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc
new file mode 100644
index 0000000000..331e1d0152
--- /dev/null
+++ b/tensorflow/core/ops/stateless_random_grad.cc
@@ -0,0 +1,23 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+
+namespace tensorflow {
+REGISTER_OP_NO_GRADIENT("StatelessRandomUniform");
+REGISTER_OP_NO_GRADIENT("StatelessRandomNormal");
+REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal");
+REGISTER_OP_NO_GRADIENT("StatelessMultinomial");
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/stateless_random_ops.cc b/tensorflow/core/ops/stateless_random_ops.cc
index 742709fb18..f919a21d60 100644
--- a/tensorflow/core/ops/stateless_random_ops.cc
+++ b/tensorflow/core/ops/stateless_random_ops.cc
@@ -19,42 +19,55 @@ limitations under the License.
namespace tensorflow {
using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-static Status StatelessShape(shape_inference::InferenceContext* context) {
+static Status StatelessShape(InferenceContext* c) {
// Check seed shape
ShapeHandle seed;
- TF_RETURN_IF_ERROR(context->WithRank(context->input(1), 1, &seed));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &seed));
DimensionHandle unused;
- TF_RETURN_IF_ERROR(context->WithValue(context->Dim(seed, 0), 2, &unused));
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
// Set output shape
ShapeHandle out;
- TF_RETURN_IF_ERROR(context->MakeShapeFromShapeTensor(0, &out));
- context->set_output(0, out);
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
+ c->set_output(0, out);
return Status::OK();
}
-#define REGISTER_STATELESS_OP(name) \
- REGISTER_OP(name) \
- .Input("shape: T") \
- .Input("seed: Tseed") \
- .Output("output: dtype") \
- .Attr("dtype: {half,float,double} = DT_FLOAT") \
- .Attr("T: {int32, int64} = DT_INT32") \
- .Attr("Tseed: {int32, int64} = DT_INT64") \
+#define REGISTER_STATELESS_OP(name) \
+ REGISTER_OP(name) \
+ .Input("shape: T") \
+ .Input("seed: Tseed") \
+ .Output("output: dtype") \
+ .Attr("dtype: {half,bfloat16,float,double} = DT_FLOAT") \
+ .Attr("T: {int32, int64} = DT_INT32") \
+ .Attr("Tseed: {int32, int64} = DT_INT64") \
.SetShapeFn(StatelessShape)
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomUniform");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessRandomNormal");
-
-// This op is exposed through contrib/stateless only. The interface may change.
REGISTER_STATELESS_OP("StatelessTruncatedNormal");
-// This op is exposed through contrib/stateless only. The interface may change.
+#undef REGISTER_STATELESS_OP
+
+REGISTER_OP("StatelessRandomUniformInt")
+ .Input("shape: T")
+ .Input("seed: Tseed")
+ .Input("minval: dtype")
+ .Input("maxval: dtype")
+ .Output("output: dtype")
+ .Attr("dtype: {int32, int64}")
+ .Attr("T: {int32, int64}")
+ .Attr("Tseed: {int32, int64} = DT_INT64")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ return StatelessShape(c);
+ });
+
REGISTER_OP("StatelessMultinomial")
.Input("logits: T")
.Input("num_samples: int32")
@@ -80,6 +93,4 @@ REGISTER_OP("StatelessMultinomial")
return Status::OK();
});
-#undef REGISTER_STATELESS_OP
-
} // namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index da1d2a6432..94d71a4113 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -223,6 +223,7 @@ REGISTER_OP("Substr")
.Input("len: T")
.Output("output: string")
.Attr("T: {int32, int64}")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle pos_shape = c->input(1);
ShapeHandle len_shape = c->input(2);
@@ -244,4 +245,9 @@ REGISTER_OP("Substr")
return shape_inference::BroadcastBinaryOpShapeFn(c);
});
+REGISTER_OP("UnicodeScript")
+ .Input("input: int32")
+ .Output("output: int32")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
index f41b83ac34..affb68ebbb 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/platform/cloud/curl_http_request.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
namespace tensorflow {
@@ -25,21 +24,14 @@ namespace {
// The URL to retrieve metadata when running in Google Compute Engine.
constexpr char kGceMetadataBaseUrl[] = "http://metadata/computeMetadata/v1/";
-// The default initial delay between retries with exponential backoff.
-constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
} // namespace
ComputeEngineMetadataClient::ComputeEngineMetadataClient(
- std::shared_ptr<HttpRequest::Factory> http_request_factory)
- : ComputeEngineMetadataClient(std::move(http_request_factory),
- kInitialRetryDelayUsec) {}
-
-ComputeEngineMetadataClient::ComputeEngineMetadataClient(
std::shared_ptr<HttpRequest::Factory> http_request_factory,
- int64 initial_retry_delay_usec)
+ const RetryConfig& config)
: http_request_factory_(std::move(http_request_factory)),
- initial_retry_delay_usec_(initial_retry_delay_usec) {}
+ retry_config_(config) {}
Status ComputeEngineMetadataClient::GetMetadata(
const string& path, std::vector<char>* response_buffer) {
@@ -52,8 +44,7 @@ Status ComputeEngineMetadataClient::GetMetadata(
return Status::OK();
};
- return RetryingUtils::CallWithRetries(get_metadata_from_gce,
- initial_retry_delay_usec_);
+ return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
index 534ccf30b2..7f060327da 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/http_request.h"
+#include "tensorflow/core/platform/cloud/retrying_utils.h"
namespace tensorflow {
@@ -31,10 +32,11 @@ namespace tensorflow {
class ComputeEngineMetadataClient {
public:
explicit ComputeEngineMetadataClient(
- std::shared_ptr<HttpRequest::Factory> http_request_factory);
- ComputeEngineMetadataClient(
std::shared_ptr<HttpRequest::Factory> http_request_factory,
- int64 initial_retry_delay_usec);
+ const RetryConfig& config = RetryConfig(
+ 10000, /* init_delay_time_us = 1 ms */
+ 1000000 /* max_delay_time_us = 1 s */
+ ));
virtual ~ComputeEngineMetadataClient() {}
/// \brief Get the metadata value for a given attribute of the metadata
@@ -54,7 +56,7 @@ class ComputeEngineMetadataClient {
private:
std::shared_ptr<HttpRequest::Factory> http_request_factory_;
- const int64 initial_retry_delay_usec_;
+ const RetryConfig retry_config_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineMetadataClient);
};
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
index 4c41ccaa0e..e891b4a5e9 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
@@ -30,7 +30,8 @@ TEST(ComputeEngineMetadataClientTest, GetMetadata) {
std::shared_ptr<HttpRequest::Factory> http_factory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- ComputeEngineMetadataClient client(http_factory, 0);
+ ComputeEngineMetadataClient client(http_factory,
+ RetryConfig(0 /* init_delay_time_us */));
std::vector<char> result;
TF_EXPECT_OK(
@@ -56,7 +57,8 @@ TEST(ComputeEngineMetadataClientTest, RetryOnFailure) {
std::shared_ptr<HttpRequest::Factory> http_factory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- ComputeEngineMetadataClient client(http_factory, 0);
+ ComputeEngineMetadataClient client(http_factory,
+ RetryConfig(0 /* init_delay_time_us */));
std::vector<char> result;
TF_EXPECT_OK(
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
index f7477eca23..476e4f9c1f 100644
--- a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
@@ -34,8 +34,8 @@ TEST_F(ComputeEngineZoneProviderTest, GetZone) {
auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadata_client =
- std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+ auto metadata_client = std::make_shared<ComputeEngineMetadataClient>(
+ httpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
ComputeEngineZoneProvider provider(metadata_client);
@@ -55,8 +55,8 @@ TEST_F(ComputeEngineZoneProviderTest, InvalidZoneString) {
auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadata_client =
- std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+ auto metadata_client = std::make_shared<ComputeEngineMetadataClient>(
+ httpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
ComputeEngineZoneProvider provider(metadata_client);
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 83ea8539ed..c61b68aeeb 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -333,14 +333,14 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem* filesystem,
GcsFileSystem::TimeoutConfig* timeouts,
std::function<void()> file_cache_erase,
- int64 initial_retry_delay_usec)
+ RetryConfig retry_config)
: bucket_(bucket),
object_(object),
filesystem_(filesystem),
timeouts_(timeouts),
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
- initial_retry_delay_usec_(initial_retry_delay_usec) {
+ retry_config_(retry_config) {
// TODO: to make it safer, outfile_ should be constructed from an FD
if (GetTmpFilename(&tmp_content_filename_).ok()) {
outfile_.open(tmp_content_filename_,
@@ -357,14 +357,14 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem* filesystem, const string& tmp_content_filename,
GcsFileSystem::TimeoutConfig* timeouts,
std::function<void()> file_cache_erase,
- int64 initial_retry_delay_usec)
+ RetryConfig retry_config)
: bucket_(bucket),
object_(object),
filesystem_(filesystem),
timeouts_(timeouts),
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
- initial_retry_delay_usec_(initial_retry_delay_usec) {
+ retry_config_(retry_config) {
tmp_content_filename_ = tmp_content_filename;
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
@@ -441,7 +441,7 @@ class GcsWritableFile : public WritableFile {
first_attempt = false;
return UploadToSession(session_uri, already_uploaded);
},
- initial_retry_delay_usec_);
+ retry_config_);
if (upload_status.code() == errors::Code::NOT_FOUND) {
// GCS docs recommend retrying the whole upload. We're relying on the
// RetryingFileSystem to retry the Sync() call.
@@ -586,7 +586,7 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem::TimeoutConfig* timeouts_;
std::function<void()> file_cache_erase_;
bool sync_needed_; // whether there is buffered data that needs to be synced
- int64 initial_retry_delay_usec_;
+ RetryConfig retry_config_;
};
class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
@@ -791,7 +791,7 @@ GcsFileSystem::GcsFileSystem(
std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age,
size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age,
- size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec,
+ size_t matching_paths_cache_max_entries, RetryConfig retry_config,
TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header)
: auth_provider_(std::move(auth_provider)),
@@ -806,7 +806,7 @@ GcsFileSystem::GcsFileSystem(
kCacheNeverExpire, kBucketLocationCacheMaxEntries)),
allowed_locations_(allowed_locations),
timeouts_(timeouts),
- initial_retry_delay_usec_(initial_retry_delay_usec),
+ retry_config_(retry_config),
additional_header_(additional_header) {}
Status GcsFileSystem::NewRandomAccessFile(
@@ -941,7 +941,7 @@ Status GcsFileSystem::NewWritableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(bucket, object, this, &timeouts_,
[this, fname]() { ClearFileCaches(fname); },
- initial_retry_delay_usec_));
+ retry_config_));
return Status::OK();
}
@@ -981,7 +981,7 @@ Status GcsFileSystem::NewAppendableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(
bucket, object, this, old_content_filename, &timeouts_,
- [this, fname]() { ClearFileCaches(fname); }, initial_retry_delay_usec_));
+ [this, fname]() { ClearFileCaches(fname); }, retry_config_));
return Status::OK();
}
@@ -1534,7 +1534,7 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) {
// on the server side, we can't just retry the whole RenameFile operation
// because the source object is already gone.
return RetryingUtils::DeleteWithRetries(
- [this, &src]() { return DeleteFile(src); }, initial_retry_delay_usec_);
+ [this, &src]() { return DeleteFile(src); }, retry_config_);
}
Status GcsFileSystem::IsDirectory(const string& fname) {
@@ -1590,8 +1590,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname,
// and therefore RetryingFileSystem won't pay attention to the failures,
// we need to make sure these failures are properly retried.
const auto& delete_file_status = RetryingUtils::DeleteWithRetries(
- [this, &full_path]() { return DeleteFile(full_path); },
- initial_retry_delay_usec_);
+ [this, &full_path]() { return DeleteFile(full_path); }, retry_config_);
if (!delete_file_status.ok()) {
if (IsDirectory(full_path).ok()) {
// The object is a directory marker.
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 71db707687..d0840a3046 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -93,7 +93,7 @@ class GcsFileSystem : public FileSystem {
uint64 stat_cache_max_age, size_t stat_cache_max_entries,
uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries,
- int64 initial_retry_delay_usec, TimeoutConfig timeouts,
+ RetryConfig retry_config, TimeoutConfig timeouts,
const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header);
@@ -332,7 +332,7 @@ class GcsFileSystem : public FileSystem {
GcsStatsInterface* stats_ = nullptr; // Not owned.
/// The initial delay for exponential backoffs when retrying failed calls.
- const int64 initial_retry_delay_usec_ = 1000000L;
+ RetryConfig retry_config_;
// Additional header material to be transmitted with all GCS requests
std::unique_ptr<std::pair<const string, const string>> additional_header_;
@@ -344,7 +344,8 @@ class GcsFileSystem : public FileSystem {
class RetryingGcsFileSystem : public RetryingFileSystem<GcsFileSystem> {
public:
RetryingGcsFileSystem()
- : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem)) {}
+ : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem),
+ RetryConfig(100000 /* init_delay_time_us */)) {}
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 14376ad339..702802b185 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -24,6 +24,8 @@ namespace tensorflow {
namespace {
static GcsFileSystem::TimeoutConfig kTestTimeoutConfig(5, 1, 10, 20, 30);
+static RetryConfig kTestRetryConfig(0 /* init_delay_time_us */);
+
// Default (empty) constraint config
static std::unordered_set<string>* kAllowedLocationsDefault =
new std::unordered_set<string>();
@@ -62,16 +64,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
"Range: 6-11\n"
"Timeouts: 5 1 20\n",
"6789")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -108,9 +110,9 @@ TEST(GcsFileSystemTest,
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -150,9 +152,9 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) {
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
@@ -191,9 +193,9 @@ TEST(GcsFileSystemTest,
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(tensorflow::errors::FailedPrecondition(
@@ -216,16 +218,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) {
"Range: 3-12\n"
"Timeouts: 5 1 20\n",
"3456789")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -283,7 +285,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -372,7 +374,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -414,17 +416,17 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
"Range: 8-15\n"
"Timeouts: 5 1 20\n",
"89abcdef")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
- 16 /* max bytes */, 3600 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 8 /* block size */, 16 /* max bytes */,
+ 3600 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
// There should only be two HTTP requests issued to GCS even though we iterate
@@ -492,7 +494,7 @@ TEST(GcsFileSystemTest,
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -513,17 +515,17 @@ TEST(GcsFileSystemTest,
TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
- 0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* read ahead bytes */, 0 /* max bytes */,
+ 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -547,16 +549,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) {
"012")});
// Set stat_cache_max_age to 1000s so that StatCache could work.
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 1e3 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 1e3 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Stat the file first so that the file stats are cached.
FileStatistics stat;
@@ -621,7 +623,7 @@ TEST(GcsFileSystemTest, NewWritableFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
8 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -703,16 +705,16 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) {
"Timeouts: 5 1 30\n"
"Put body: t2\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -773,17 +775,17 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) {
"Range: 0-7\n"
"Timeouts: 5 1 20\n",
"01234567")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
- 8 /* max bytes */, 3600 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 8 /* block size */, 8 /* max bytes */,
+ 3600 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Pull the file's first block into the cache. This will trigger the first
// HTTP request to GCS.
std::unique_ptr<RandomAccessFile> rfile;
@@ -867,9 +869,9 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 2 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */,
+ RetryConfig(2 /* .init_delay_time_us */), kTestTimeoutConfig,
+ *kAllowedLocationsDefault, nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -918,16 +920,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
"Timeouts: 5 1 30\n"
"Put body: content1,content2\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -948,16 +950,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1013,7 +1015,7 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 32 /* block size */,
32 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1041,16 +1043,16 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1075,16 +1077,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
"Range: 0-",
content.size() - 1, "\n", "Timeouts: 5 1 20\n"),
content)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
@@ -1096,16 +1098,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1120,16 +1122,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt"));
}
@@ -1150,16 +1152,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subfolder/\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder"));
}
@@ -1176,16 +1178,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"size\": \"100\"}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket1"));
TF_EXPECT_OK(fs.FileExists("gs://bucket1/"));
@@ -1206,16 +1208,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": []}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::NOT_FOUND,
fs.FileExists("gs://bucket/path/file1.txt").code());
@@ -1233,16 +1235,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.FileExists("gs://bucket2/").code());
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1279,7 +1281,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1306,7 +1308,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1322,16 +1324,16 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1350,16 +1352,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1379,16 +1381,16 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) {
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1407,16 +1409,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1432,16 +1434,16 @@ TEST(GcsFileSystemTest, GetChildren_Root) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children));
@@ -1457,16 +1459,16 @@ TEST(GcsFileSystemTest, GetChildren_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1498,16 +1500,16 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) {
" { \"name\": \"path/file4.txt\" },"
" { \"name\": \"path/file5.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1525,16 +1527,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(
@@ -1553,16 +1555,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result));
@@ -1582,16 +1584,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result));
@@ -1608,16 +1610,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) {
"{\"items\": [ "
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result));
@@ -1634,16 +1636,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result));
@@ -1652,16 +1654,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1686,16 +1688,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Repeated calls to fs.GetMatchingPaths on these patterns should not lead to
// any additional HTTP requests to GCS.
@@ -1729,16 +1731,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// This loop should trigger the first HTTP request to GCS.
for (int i = 0; i < 10; i++) {
@@ -1800,7 +1802,7 @@ TEST(GcsFileSystemTest, DeleteFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1821,16 +1823,16 @@ TEST(GcsFileSystemTest, DeleteFile) {
TEST(GcsFileSystemTest, DeleteFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.DeleteFile("gs://bucket/").code());
@@ -1871,7 +1873,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1894,16 +1896,16 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1923,16 +1925,16 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1943,16 +1945,16 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) {
"name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket"));
}
@@ -1965,16 +1967,16 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/file1.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.DeleteDir("gs://bucket/path/").code());
@@ -1988,16 +1990,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
@@ -2006,16 +2008,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
TEST(GcsFileSystemTest, GetFileSize_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -2092,16 +2094,16 @@ TEST(GcsFileSystemTest, RenameFile_Folder) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/"));
}
@@ -2191,7 +2193,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
64 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// Do an initial read of the source and destination files to load their
@@ -2272,7 +2274,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// Do an initial stat of the destination file to load their contents into the
@@ -2332,16 +2334,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(
fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt"));
@@ -2374,16 +2376,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) {
"Post: yes\n"
"Timeouts: 5 1 10\n",
"{\"done\": false}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(
errors::Code::UNIMPLEMENTED,
@@ -2399,16 +2401,16 @@ TEST(GcsFileSystemTest, Stat_Object) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
@@ -2433,16 +2435,16 @@ TEST(GcsFileSystemTest, Stat_Folder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"subfolder/\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat));
@@ -2466,16 +2468,16 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code());
@@ -2487,16 +2489,16 @@ TEST(GcsFileSystemTest, Stat_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat));
@@ -2511,16 +2513,16 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code());
@@ -2556,7 +2558,7 @@ TEST(GcsFileSystemTest, Stat_Cache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -2598,7 +2600,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// There should be a single HTTP request to GCS for fs.Stat in this loop.
@@ -2628,16 +2630,16 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"5\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat));
@@ -2660,16 +2662,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2691,16 +2693,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2722,16 +2724,16 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": [{\"name\": \"subfolder/\"}]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/"));
@@ -2749,16 +2751,16 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/"));
@@ -2770,16 +2772,16 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code());
}
@@ -2812,16 +2814,16 @@ TEST(GcsFileSystemTest, CreateDir_Folder) {
"Timeouts: 5 1 30\n"
"Put body: \n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/"));
@@ -2839,16 +2841,16 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket"));
@@ -2911,16 +2913,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -3004,16 +3006,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) {
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -3039,16 +3041,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
EXPECT_EQ(error::Code::NOT_FOUND,
@@ -3130,7 +3132,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
add_header /* gcs additional header */);
@@ -3199,16 +3201,16 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
"Auth Token: fake_token\n"
"Header Hello: world\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<HttpRequest> request;
TF_EXPECT_OK(fs.CreateHttpRequest(&request));
@@ -3262,16 +3264,16 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
@@ -3289,16 +3291,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) {
"Range: 0-5\n"
"Timeouts: 5 1 20\n",
"012345")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
index 07b88a880f..ec31c5ee8c 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
@@ -93,8 +93,8 @@ TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
oauth_client->return_token = "fake-token";
@@ -129,8 +129,8 @@ TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -178,8 +178,8 @@ TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -206,8 +206,8 @@ TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&empty_requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -228,8 +228,8 @@ TEST_F(GoogleAuthProviderTest, NothingAvailable) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 941ab7ad65..5ce6670dc7 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -34,9 +34,9 @@ template <typename Underlying>
class RetryingFileSystem : public FileSystem {
public:
RetryingFileSystem(std::unique_ptr<Underlying> base_file_system,
- int64 delay_microseconds = 1000000)
+ const RetryConfig& retry_config)
: base_file_system_(std::move(base_file_system)),
- initial_delay_microseconds_(delay_microseconds) {}
+ retry_config_(retry_config) {}
Status NewRandomAccessFile(
const string& filename,
@@ -55,7 +55,7 @@ class RetryingFileSystem : public FileSystem {
Status FileExists(const string& fname) override {
return RetryingUtils::CallWithRetries(
[this, &fname]() { return base_file_system_->FileExists(fname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetChildren(const string& dir, std::vector<string>* result) override {
@@ -63,7 +63,7 @@ class RetryingFileSystem : public FileSystem {
[this, &dir, result]() {
return base_file_system_->GetChildren(dir, result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetMatchingPaths(const string& pattern,
@@ -72,31 +72,31 @@ class RetryingFileSystem : public FileSystem {
[this, &pattern, result]() {
return base_file_system_->GetMatchingPaths(pattern, result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status Stat(const string& fname, FileStatistics* stat) override {
return RetryingUtils::CallWithRetries(
[this, &fname, stat]() { return base_file_system_->Stat(fname, stat); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteFile(const string& fname) override {
return RetryingUtils::DeleteWithRetries(
[this, &fname]() { return base_file_system_->DeleteFile(fname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status CreateDir(const string& dirname) override {
return RetryingUtils::CallWithRetries(
[this, &dirname]() { return base_file_system_->CreateDir(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteDir(const string& dirname) override {
return RetryingUtils::DeleteWithRetries(
[this, &dirname]() { return base_file_system_->DeleteDir(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetFileSize(const string& fname, uint64* file_size) override {
@@ -104,7 +104,7 @@ class RetryingFileSystem : public FileSystem {
[this, &fname, file_size]() {
return base_file_system_->GetFileSize(fname, file_size);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status RenameFile(const string& src, const string& target) override {
@@ -112,13 +112,13 @@ class RetryingFileSystem : public FileSystem {
[this, &src, &target]() {
return base_file_system_->RenameFile(src, target);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status IsDirectory(const string& dirname) override {
return RetryingUtils::CallWithRetries(
[this, &dirname]() { return base_file_system_->IsDirectory(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteRecursively(const string& dirname, int64* undeleted_files,
@@ -128,7 +128,7 @@ class RetryingFileSystem : public FileSystem {
return base_file_system_->DeleteRecursively(dirname, undeleted_files,
undeleted_dirs);
},
- initial_delay_microseconds_);
+ retry_config_);
}
void FlushCaches() override { base_file_system_->FlushCaches(); }
@@ -137,7 +137,7 @@ class RetryingFileSystem : public FileSystem {
private:
std::unique_ptr<Underlying> base_file_system_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
TF_DISALLOW_COPY_AND_ASSIGN(RetryingFileSystem);
};
@@ -147,9 +147,8 @@ namespace retrying_internals {
class RetryingRandomAccessFile : public RandomAccessFile {
public:
RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file,
- int64 delay_microseconds)
- : base_file_(std::move(base_file)),
- initial_delay_microseconds_(delay_microseconds) {}
+ const RetryConfig& retry_config)
+ : base_file_(std::move(base_file)), retry_config_(retry_config) {}
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
@@ -157,20 +156,19 @@ class RetryingRandomAccessFile : public RandomAccessFile {
[this, offset, n, result, scratch]() {
return base_file_->Read(offset, n, result, scratch);
},
- initial_delay_microseconds_);
+ retry_config_);
}
private:
std::unique_ptr<RandomAccessFile> base_file_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
};
class RetryingWritableFile : public WritableFile {
public:
RetryingWritableFile(std::unique_ptr<WritableFile> base_file,
- int64 delay_microseconds)
- : base_file_(std::move(base_file)),
- initial_delay_microseconds_(delay_microseconds) {}
+ const RetryConfig& retry_config)
+ : base_file_(std::move(base_file)), retry_config_(retry_config) {}
~RetryingWritableFile() override {
// Makes sure the retrying version of Close() is called in the destructor.
@@ -179,25 +177,24 @@ class RetryingWritableFile : public WritableFile {
Status Append(StringPiece data) override {
return RetryingUtils::CallWithRetries(
- [this, &data]() { return base_file_->Append(data); },
- initial_delay_microseconds_);
+ [this, &data]() { return base_file_->Append(data); }, retry_config_);
}
Status Close() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Close(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Close(); }, retry_config_);
}
Status Flush() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Flush(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Flush(); }, retry_config_);
}
Status Sync() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Sync(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Sync(); }, retry_config_);
}
private:
std::unique_ptr<WritableFile> base_file_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
};
} // namespace retrying_internals
@@ -210,9 +207,9 @@ Status RetryingFileSystem<Underlying>::NewRandomAccessFile(
[this, &filename, &base_file]() {
return base_file_system_->NewRandomAccessFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingRandomAccessFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -224,9 +221,9 @@ Status RetryingFileSystem<Underlying>::NewWritableFile(
[this, &filename, &base_file]() {
return base_file_system_->NewWritableFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingWritableFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -238,9 +235,9 @@ Status RetryingFileSystem<Underlying>::NewAppendableFile(
[this, &filename, &base_file]() {
return base_file_system_->NewAppendableFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingWritableFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -252,7 +249,7 @@ Status RetryingFileSystem<Underlying>::NewReadOnlyMemoryRegionFromFile(
return base_file_system_->NewReadOnlyMemoryRegionFromFile(filename,
result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index 5910fef1d2..868eea096c 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -184,7 +184,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -211,7 +212,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -235,7 +237,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -265,7 +268,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_NoRetriesForSomeErrors) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -291,7 +295,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_ImmediateSuccess) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -317,7 +322,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -343,7 +349,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry_ViaDestructor) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -368,7 +375,8 @@ TEST(RetryingFileSystemTest, NewAppendableFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped appendable file.
std::unique_ptr<WritableFile> writable_file;
@@ -391,7 +399,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_AllRetriesFailed) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -412,7 +421,8 @@ TEST(RetryingFileSystemTest,
std::make_tuple("NewReadOnlyMemoryRegionFromFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::unique_ptr<ReadOnlyMemoryRegion> result;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile("filename.txt", &result));
@@ -423,7 +433,8 @@ TEST(RetryingFileSystemTest, NewReadOnlyMemoryRegionFromFile_AllRetriesFailed) {
CreateRetriableErrors("NewReadOnlyMemoryRegionFromFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::unique_ptr<ReadOnlyMemoryRegion> result;
const auto& status =
@@ -440,7 +451,8 @@ TEST(RetryingFileSystemTest, GetChildren_SuccessWith2ndTry) {
std::make_tuple("GetChildren", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.GetChildren("gs://path", &result));
@@ -450,7 +462,8 @@ TEST(RetryingFileSystemTest, GetChildren_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetChildren", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.GetChildren("gs://path", &result);
@@ -466,7 +479,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_SuccessWith2ndTry) {
std::make_tuple("GetMatchingPaths", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://path/dir", &result));
@@ -477,7 +491,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_AllRetriesFailed) {
CreateRetriableErrors("GetMatchingPaths", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.GetMatchingPaths("gs://path/dir", &result);
@@ -492,7 +507,8 @@ TEST(RetryingFileSystemTest, DeleteFile_SuccessWith2ndTry) {
std::make_tuple("DeleteFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.DeleteFile("gs://path/file.txt"));
@@ -502,7 +518,8 @@ TEST(RetryingFileSystemTest, DeleteFile_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.DeleteFile("gs://path/file.txt");
@@ -517,7 +534,8 @@ TEST(RetryingFileSystemTest, CreateDir_SuccessWith2ndTry) {
std::make_tuple("CreateDir", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.CreateDir("gs://path/newdir"));
@@ -527,7 +545,8 @@ TEST(RetryingFileSystemTest, CreateDir_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("CreateDir", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.CreateDir("gs://path/newdir");
@@ -542,7 +561,8 @@ TEST(RetryingFileSystemTest, DeleteDir_SuccessWith2ndTry) {
std::make_tuple("DeleteDir", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.DeleteDir("gs://path/dir"));
@@ -552,7 +572,8 @@ TEST(RetryingFileSystemTest, DeleteDir_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteDir", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.DeleteDir("gs://path/dir");
@@ -568,7 +589,8 @@ TEST(RetryingFileSystemTest, GetFileSize_SuccessWith2ndTry) {
std::make_tuple("GetFileSize", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://path/file.txt", &size));
@@ -578,7 +600,8 @@ TEST(RetryingFileSystemTest, GetFileSize_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetFileSize", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
uint64 size;
const auto& status = fs.GetFileSize("gs://path/file.txt", &size);
@@ -593,7 +616,8 @@ TEST(RetryingFileSystemTest, RenameFile_SuccessWith2ndTry) {
std::make_tuple("RenameFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.RenameFile("old_name", "new_name"));
}
@@ -602,7 +626,8 @@ TEST(RetryingFileSystemTest, RenameFile_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("RenameFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.RenameFile("old_name", "new_name");
EXPECT_TRUE(
@@ -616,7 +641,8 @@ TEST(RetryingFileSystemTest, Stat_SuccessWith2ndTry) {
std::make_tuple("Stat", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("file_name", &stat));
@@ -626,7 +652,8 @@ TEST(RetryingFileSystemTest, Stat_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("Stat", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
FileStatistics stat;
const auto& status = fs.Stat("file_name", &stat);
@@ -639,7 +666,8 @@ TEST(RetryingFileSystemTest, FileExists_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("FileExists", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.FileExists("file_name");
EXPECT_TRUE(
@@ -653,7 +681,8 @@ TEST(RetryingFileSystemTest, FileExists_SuccessWith2ndTry) {
std::make_tuple("FileExists", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.FileExists("gs://path/dir"));
}
@@ -665,7 +694,8 @@ TEST(RetryingFileSystemTest, IsDirectory_SuccessWith2ndTry) {
std::make_tuple("IsDirectory", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.IsDirectory("gs://path/dir"));
}
@@ -674,7 +704,8 @@ TEST(RetryingFileSystemTest, IsDirectory_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("IsDirectory", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.IsDirectory("gs://path/dir");
EXPECT_TRUE(
@@ -689,7 +720,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_SuccessWith2ndTry) {
std::make_tuple("DeleteRecursively", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(
@@ -701,7 +733,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) {
CreateRetriableErrors("DeleteRecursively", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
int64 undeleted_files, undeleted_dirs;
const auto& status =
@@ -715,7 +748,8 @@ TEST(RetryingFileSystemTest, FlushCaches) {
ExpectedCalls none;
bool flushed = false;
std::unique_ptr<MockFileSystem> base_fs(new MockFileSystem(none, &flushed));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
fs.FlushCaches();
EXPECT_TRUE(flushed);
}
diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc
index d2df422024..cb0aecdd35 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils.cc
@@ -23,11 +23,6 @@ namespace tensorflow {
namespace {
-// In case of failure, every call will be retried kMaxRetries times.
-constexpr int kMaxRetries = 10;
-// Maximum backoff time in microseconds.
-constexpr int64 kMaximumBackoffMicroseconds = 32000000; // 32 seconds.
-
bool IsRetriable(error::Code code) {
switch (code) {
case error::UNAVAILABLE:
@@ -43,40 +38,41 @@ bool IsRetriable(error::Code code) {
} // namespace
Status RetryingUtils::CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds) {
- return CallWithRetries(f, initial_delay_microseconds, [](int64 micros) {
- return Env::Default()->SleepForMicroseconds(micros);
- });
+ const RetryConfig& config) {
+ return CallWithRetries(
+ f,
+ [](int64 micros) { return Env::Default()->SleepForMicroseconds(micros); },
+ config);
}
Status RetryingUtils::CallWithRetries(
- const std::function<Status()>& f, const int64 initial_delay_microseconds,
- const std::function<void(int64)>& sleep_usec) {
+ const std::function<Status()>& f,
+ const std::function<void(int64)>& sleep_usec, const RetryConfig& config) {
int retries = 0;
while (true) {
auto status = f();
if (!IsRetriable(status.code())) {
return status;
}
- if (retries >= kMaxRetries) {
+ if (retries >= config.max_retries) {
// Return AbortedError, so that it doesn't get retried again somewhere
// at a higher level.
return Status(
error::ABORTED,
strings::StrCat(
- "All ", kMaxRetries,
+ "All ", config.max_retries,
" retry attempts failed. The last failure: ", status.ToString()));
}
int64 delay_micros = 0;
- if (initial_delay_microseconds > 0) {
+ if (config.init_delay_time_us > 0) {
const int64 random_micros = random::New64() % 1000000;
- delay_micros = std::min(initial_delay_microseconds << retries,
- kMaximumBackoffMicroseconds) +
+ delay_micros = std::min(config.init_delay_time_us << retries,
+ config.max_delay_time_us) +
random_micros;
}
LOG(INFO) << "The operation failed and will be automatically retried in "
<< (delay_micros / 1000000.0) << " seconds (attempt "
- << (retries + 1) << " out of " << kMaxRetries
+ << (retries + 1) << " out of " << config.max_retries
<< "), caused by: " << status.ToString();
sleep_usec(delay_micros);
retries++;
@@ -84,8 +80,7 @@ Status RetryingUtils::CallWithRetries(
}
Status RetryingUtils::DeleteWithRetries(
- const std::function<Status()>& delete_func,
- const int64 initial_delay_microseconds) {
+ const std::function<Status()>& delete_func, const RetryConfig& config) {
bool is_retried = false;
return RetryingUtils::CallWithRetries(
[delete_func, &is_retried]() {
@@ -96,7 +91,7 @@ Status RetryingUtils::DeleteWithRetries(
is_retried = true;
return status;
},
- initial_delay_microseconds);
+ config);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h
index 546b8d1c4a..1a7ce1b122 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.h
+++ b/tensorflow/core/platform/cloud/retrying_utils.h
@@ -21,6 +21,26 @@ limitations under the License.
namespace tensorflow {
+// Default time before reporting failure: ~100 seconds.
+struct RetryConfig {
+ RetryConfig(int64 init_delay_time_us = 100 * 1000,
+ int64 max_delay_time_us = 32 * 1000 * 1000,
+ int max_retries = 10) {
+ this->init_delay_time_us = init_delay_time_us;
+ this->max_delay_time_us = max_delay_time_us;
+ this->max_retries = max_retries;
+ }
+
+ // In case of failure, every call will be retried max_retries times.
+ int max_retries;
+
+ // Initial backoff time
+ int64 init_delay_time_us;
+
+ // Maximum backoff time in microseconds.
+ int64 max_delay_time_us;
+};
+
class RetryingUtils {
public:
/// \brief Retries the function in case of failure with exponential backoff.
@@ -31,18 +51,19 @@ class RetryingUtils {
/// retries.
/// If all retries failed, returns the last error status.
static Status CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds);
+ const RetryConfig& config);
+
/// sleep_usec is a function that sleeps for the given number of microseconds.
static Status CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds,
- const std::function<void(int64)>& sleep_usec);
+ const std::function<void(int64)>& sleep_usec,
+ const RetryConfig& config);
/// \brief A retrying wrapper for a function that deletes a resource.
///
/// The function takes care of the scenario when a delete operation
/// returns a failure but succeeds under the hood: if a retry returns
/// NOT_FOUND, the whole operation is considered a success.
static Status DeleteWithRetries(const std::function<Status()>& delete_func,
- const int64 initial_delay_microseconds);
+ const RetryConfig& config);
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc
index 1b6527618a..75fe8a98f4 100644
--- a/tensorflow/core/platform/cloud/retrying_utils_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc
@@ -30,7 +30,8 @@ TEST(RetryingUtilsTest, CallWithRetries_RetryDelays) {
};
std::function<Status()> f = []() { return errors::Unavailable("Failed."); };
- const auto& status = RetryingUtils::CallWithRetries(f, 500000L, sleep);
+ const auto& status = RetryingUtils::CallWithRetries(
+ f, sleep, RetryConfig(500000 /* init_delay_time_us */));
EXPECT_EQ(errors::Code::ABORTED, status.code());
EXPECT_TRUE(str_util::StrContains(
status.error_message(),
@@ -60,8 +61,10 @@ TEST(RetryingUtilsTest, CallWithRetries_NotFoundIsNotRetried) {
results.erase(results.begin());
return result;
};
- EXPECT_EQ(errors::Code::NOT_FOUND,
- RetryingUtils::CallWithRetries(f, 0).code());
+ EXPECT_EQ(
+ errors::Code::NOT_FOUND,
+ RetryingUtils::CallWithRetries(f, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) {
@@ -74,7 +77,8 @@ TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) {
results.erase(results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 1.0, sleep));
+ TF_EXPECT_OK(RetryingUtils::CallWithRetries(
+ f, sleep, RetryConfig(1L /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) {
@@ -86,7 +90,8 @@ TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) {
results.erase(results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 0));
+ TF_EXPECT_OK(RetryingUtils::CallWithRetries(
+ f, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) {
@@ -96,7 +101,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) {
@@ -106,7 +112,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) {
@@ -118,7 +125,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) {
return result;
};
EXPECT_EQ(errors::Code::PERMISSION_DENIED,
- RetryingUtils::DeleteWithRetries(delete_func, 0).code());
+ RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) {
@@ -129,7 +138,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) {
@@ -140,7 +150,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) {
return result;
};
EXPECT_EQ(error::NOT_FOUND,
- RetryingUtils::DeleteWithRetries(delete_func, 0).code());
+ RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
} // namespace
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index bb841aeab7..d884c1aa7c 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -615,11 +615,7 @@ def tf_kernel_tests_linkstatic():
def tf_additional_lib_defines():
"""Additional defines needed to build TF libraries."""
- return select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["TENSORFLOW_USE_JEMALLOC"],
- "//conditions:default": [],
- })
+ return []
def tf_additional_lib_deps():
"""Additional dependencies needed to build TF libraries."""
@@ -631,64 +627,45 @@ def tf_additional_lib_deps():
] + if_static(
["@nsync//:nsync_cpp"],
["@nsync//:nsync_headers"],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- })
+ )
def tf_additional_core_deps():
return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
"//tensorflow/core/platform/s3:s3_file_system",
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
"//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
],
- "//conditions:default": [],
})
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
"//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
"//tensorflow/contrib/cloud/kernels:gcs_config_ops",
],
- "//conditions:default": [],
})
def tf_lib_proto_parsing_deps():
@@ -738,11 +715,7 @@ def tf_additional_binary_deps():
"//tensorflow/stream_executor:cuda_platform",
"//tensorflow/core/platform/default/build_config:cuda",
],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- }) + [
+ ) + [
# TODO(allenl): Split these out into their own shared objects (they are
# here because they are shared between contrib/ op shared objects and
# core).
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 5b237c4736..5732271f15 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -228,6 +228,10 @@ class Env {
/// |suffix|. Returns true if success.
bool CreateUniqueFileName(string* prefix, const string& suffix);
+ /// \brief Return the runfiles directory if running under bazel. Returns
+ /// the directory the executable is located in if not running under bazel.
+ virtual string GetRunfilesDir() = 0;
+
// TODO(jeff,sanjay): Add back thread/thread-pool support if needed.
// TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
// provide a routine to get the absolute time.
@@ -360,6 +364,8 @@ class EnvWrapper : public Env {
return target_->FormatLibraryFileName(name, version);
}
+ string GetRunfilesDir() override { return target_->GetRunfilesDir(); }
+
private:
void GetLocalTempDirectories(std::vector<string>* list) override {
target_->GetLocalTempDirectories(list);
diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc
index 418874d340..af95d8201e 100644
--- a/tensorflow/core/platform/posix/env.cc
+++ b/tensorflow/core/platform/posix/env.cc
@@ -119,6 +119,17 @@ class PosixEnv : public Env {
return tensorflow::internal::FormatLibraryFileName(name, version);
}
+ string GetRunfilesDir() override {
+ string bin_path = this->GetExecutablePath();
+ string runfiles_path = bin_path + ".runfiles/org_tensorflow";
+ Status s = this->IsDirectory(runfiles_path);
+ if (!s.ok()) {
+ return runfiles_path;
+ } else {
+ return bin_path.substr(0, bin_path.find_last_of("/\\"));
+ }
+ }
+
private:
void GetLocalTempDirectories(std::vector<string>* list) override;
};
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index b46b9927cd..acdd7798ea 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef TENSORFLOW_USE_JEMALLOC
-#include "jemalloc/jemalloc.h"
-#endif
-
#include "absl/base/internal/sysinfo.h"
#include "tensorflow/core/platform/cpu_info.h"
@@ -101,11 +97,7 @@ void* AlignedMalloc(size_t size, int minimum_alignment) {
// memory aligned to at least the size of a pointer.
const int required_alignment = sizeof(void*);
if (minimum_alignment < required_alignment) return Malloc(size);
-#ifdef TENSORFLOW_USE_JEMALLOC
- int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
-#else
int err = posix_memalign(&ptr, minimum_alignment, size);
-#endif
if (err != 0) {
return nullptr;
} else {
@@ -116,29 +108,11 @@ void* AlignedMalloc(size_t size, int minimum_alignment) {
void AlignedFree(void* aligned_memory) { Free(aligned_memory); }
-void* Malloc(size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_malloc(size);
-#else
- return malloc(size);
-#endif
-}
+void* Malloc(size_t size) { return malloc(size); }
-void* Realloc(void* ptr, size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_realloc(ptr, size);
-#else
- return realloc(ptr, size);
-#endif
-}
+void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); }
-void Free(void* ptr) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- jemalloc_free(ptr);
-#else
- free(ptr);
-#endif
-}
+void Free(void* ptr) { free(ptr); }
void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
return AlignedMalloc(size, minimum_alignment);
@@ -146,9 +120,7 @@ void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
void NUMAFree(void* ptr, size_t size) { Free(ptr); }
-int NUMAGetMemAffinity(const void* addr) {
- return kNUMANoAffinity;
-}
+int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; }
void MallocExtension_ReleaseToSystem(std::size_t num_bytes) {
// No-op.
diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc
index 68ee3595a2..f26ccd1662 100644
--- a/tensorflow/core/platform/windows/env.cc
+++ b/tensorflow/core/platform/windows/env.cc
@@ -160,6 +160,17 @@ class WindowsEnv : public Env {
return filename;
}
+ string GetRunfilesDir() override {
+ string bin_path = this->GetExecutablePath();
+ string runfiles_path = bin_path + ".runfiles\\org_tensorflow";
+ Status s = this->IsDirectory(runfiles_path);
+ if (!s.ok()) {
+ return runfiles_path;
+ } else {
+ return bin_path.substr(0, bin_path.find_last_of("/\\"));
+ }
+ }
+
private:
void GetLocalTempDirectories(std::vector<string>* list) override;
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 5375f56372..911ea1902f 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef TENSORFLOW_USE_JEMALLOC
-#include "jemalloc/jemalloc.h"
-#endif
-
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -70,55 +66,16 @@ void NUMASetThreadNodeAffinity(int node) {}
int NUMAGetThreadNodeAffinity() { return kNUMANoAffinity; }
void* AlignedMalloc(size_t size, int minimum_alignment) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- void* ptr = NULL;
- // posix_memalign requires that the requested alignment be at least
- // sizeof(void*). In this case, fall back on malloc which should return
- // memory aligned to at least the size of a pointer.
- const int required_alignment = sizeof(void*);
- if (minimum_alignment < required_alignment) return Malloc(size);
- int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
- if (err != 0) {
- return NULL;
- } else {
- return ptr;
- }
-#else
return _aligned_malloc(size, minimum_alignment);
-#endif
}
-void AlignedFree(void* aligned_memory) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- jemalloc_free(aligned_memory);
-#else
- _aligned_free(aligned_memory);
-#endif
-}
+void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); }
-void* Malloc(size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_malloc(size);
-#else
- return malloc(size);
-#endif
-}
+void* Malloc(size_t size) { return malloc(size); }
-void* Realloc(void* ptr, size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_realloc(ptr, size);
-#else
- return realloc(ptr, size);
-#endif
-}
+void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); }
-void Free(void* ptr) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_free(ptr);
-#else
- return free(ptr);
-#endif
-}
+void Free(void* ptr) { return free(ptr); }
void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
return AlignedMalloc(size, minimum_alignment);
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index af034bdd7d..2bf371276e 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -40,7 +40,6 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 85cd02350a..104ab039cb 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -453,6 +453,11 @@ message RunOptions {
// same group_key value (in a distributed computation where tasks
// run disjoint graphs).
int64 collective_graph_key = 1;
+ // If true, then operations (using the inter-op pool) across all
+ // session::run() calls will be centrally scheduled, optimizing for (median
+ // and tail) latency.
+ // Consider using this option for CPU-bound workloads like inference.
+ bool use_run_handler_pool = 2;
};
Experimental experimental = 8;
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index bb8f88336d..8c31468ff5 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -75,8 +75,10 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
- // Force small ops onto the CPU (default is ON).
+ // Force small ops onto the CPU (default is OFF).
Toggle pin_to_host_optimization = 18;
+ // Disable the entire meta optimizer (off by default).
+ bool disable_meta_optimizer = 19;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
@@ -143,8 +145,8 @@ message RewriterConfig {
// not configurable (in contrast to memory optimization passes through the
// meta-optimizer) and act only on manual op annotations.
//
- // Custom registered optimizers will be run after the base optimizers, in
- // the order that they are specified.
+ // Custom optimizers (see custom_optimizers) that are not part of this
+ // schedule will be run after - in the order that they were specified.
repeated string optimizers = 100;
// Message to describe custom graph optimizer and its parameters
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index cf7ffd8149..04aaea4f89 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -2039,8 +2039,8 @@ class MklPrimitiveFactory {
/// Fuction to check whether primitive memory optimization is enabled
static inline bool IsPrimitiveMemOptEnabled() {
bool is_primitive_mem_opt_enabled = true;
- TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
- &is_primitive_mem_opt_enabled));
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
return is_primitive_mem_opt_enabled;
}
@@ -2095,9 +2095,8 @@ static inline memory::format get_desired_format(int channel,
fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = is_2d
- ? memory::format::nChw8c
- : memory::format::ncdhw; //not support avx2 for 3d yet.
+ fmt_desired = is_2d ? memory::format::nChw8c
+ : memory::format::ncdhw; // no avx2 support for 3d yet.
} else {
fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
@@ -2209,7 +2208,8 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
-inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
+ memory::dims strides) {
if (filter_dims.size() != 4 || strides.size() != 2) return false;
return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index c081ceae57..e01058dff6 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() {
}
bool IsMklEnabled() {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
return true;
#else
return false;
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
}
} // end namespace tensorflow
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 648358606c..f40ec9b752 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -64,6 +64,11 @@ cc_library(
tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
+ data = glob(["testdata/**"]),
+ tags = [
+ "nomsan",
+ "notsan",
+ ],
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index ea8a259d1a..2dcb57a1f9 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -64,27 +64,36 @@ namespace {
// Reads "num_elements" string elements from file[offset, offset+size) into the
// length-N "destination". Discards the original content of "destination".
//
-// Checksums the string lengths (as restored uint32, not varint32 bytes) and
-// string bytes, and stores it into "actual_crc32c".
+// Checksums the string lengths (as restored uint32 or uint64, not varint64
+// bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination,
uint32* actual_crc32c) {
if (size == 0) return Status::OK();
CHECK_GT(size, 0);
- // Reads "num_elements" varint32's from "buffered_file".
+ // Reads "num_elements" varint64's from "buffered_file".
TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
- std::vector<uint32> string_lengths(num_elements);
+ std::vector<uint64> string_lengths(num_elements);
for (size_t i = 0; i < num_elements; ++i) {
- TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i]));
+ TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
+ if (string_lengths[i] <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
+ sizeof(uint64));
+ }
}
if (offset + size < buffered_file->Tell()) {
return errors::DataLoss("String lengths longer than expected offset ",
offset + size);
}
- *actual_crc32c =
- crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()),
- sizeof(uint32) * num_elements);
// Reads the length-checksum.
uint32 length_checksum = 0;
@@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) {
- const uint32 string_length = string_lengths[i];
+ const uint64 string_length = string_lengths[i];
string* buffer = &destination[i];
buffer->resize(string_length);
@@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
size_t* bytes_written, uint32* crc32c) {
// On-disk format:
- // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes]
- // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes),
+ // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
+ // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val);
@@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
*crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i];
- DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size()));
- const uint32 elem_size = static_cast<uint32>(elem->size());
-
- core::PutVarint32(&lengths, elem_size);
- *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
- sizeof(uint32));
+ DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
+ const uint64 elem_size = static_cast<uint64>(elem->size());
+
+ core::PutVarint64(&lengths, elem_size);
+ if (elem_size <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
+ *crc32c = crc32c::Extend(*crc32c,
+ reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *crc32c = crc32c::Extend(
+ *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
+ }
}
TF_RETURN_IF_ERROR(out->Append(lengths));
*bytes_written = lengths.size();
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 59c42baa06..9567e4750b 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -39,6 +39,11 @@ string Prefix(const string& prefix) {
return strings::StrCat(testing::TmpDir(), "/", prefix);
}
+string TestdataPrefix(const string& prefix) {
+ return strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/util/tensor_bundle/testdata/", prefix);
+}
+
template <typename T>
Tensor Constant(T v, TensorShape shape) {
Tensor ret(DataTypeToEnum<T>::value, shape);
@@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<qint8>();
}
+TEST(TensorBundleTest, StringTensorsOldFormat) {
+ // Test string tensor bundle made with previous version of code that use
+ // varint32s to store string lengths (we now use varint64s).
+ BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
+ TF_ASSERT_OK(reader.status());
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+
+ Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
+ Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
+ Expect<string>(
+ &reader, "strs",
+ test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
+ Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+}
+
TEST(TensorBundleTest, StringTensors) {
+ constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
+ Tensor long_string_tensor(DT_STRING, TensorShape({1}));
+
{
BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor",
@@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) {
TF_EXPECT_OK(writer.Add(
"strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
+
+ // Requires a 64-bit length.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign(kLongLength, 'd');
+ TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
+
// Mixes in some floats.
TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
TF_ASSERT_OK(writer.Finish());
@@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) {
{
BundleReader reader(Env::Default(), Prefix("foo"));
TF_ASSERT_OK(reader.status());
- EXPECT_EQ(
- AllTensorKeys(&reader),
- std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "long_scalar", "scalar",
+ "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
@@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) {
Expect<string>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
+
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+
+ // We don't use the Expect function so we can re-use the
+ // `long_string_tensor` buffer for reading out long_scalar to keep memory
+ // usage reasonable.
+ EXPECT_TRUE(reader.Contains("long_scalar"));
+ DataType dtype;
+ TensorShape shape;
+ TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
+ EXPECT_EQ(DT_STRING, dtype);
+ EXPECT_EQ(TensorShape({1}), shape);
+
+ // Zero-out the string so that we can be sure the new one is read in.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign("");
+
+ // Read long_scalar and check it contains kLongLength 'd's.
+ TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
+ ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
+ EXPECT_EQ(kLongLength, backing_string->length());
+ for (char c : *backing_string) {
+ // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
+ // compiler optimizations.
+ if (c != 'd') {
+ FAIL() << "long_scalar is not full of 'd's as expected.";
+ break;
+ }
+ }
}
}
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
new file mode 100644
index 0000000000..428d3ef79e
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
@@ -0,0 +1,3 @@
+This tensor bundle was generated from cl/214343133, before string tensor
+lengths were written as varint64s. This is here to check backwards
+compatibility between the new code and old checkpoints.
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
new file mode 100644
index 0000000000..23b488e5fe
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
new file mode 100644
index 0000000000..a22a69e6e1
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
Binary files differ
diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc
index 1e5a9c5712..489999d1e8 100644
--- a/tensorflow/core/util/util.cc
+++ b/tensorflow/core/util/util.cc
@@ -120,4 +120,20 @@ string SliceDebugString(const TensorShape& shape, const int64 flat) {
return result;
}
+#ifdef INTEL_MKL
+bool DisableMKL() {
+ enum MklStatus { MKL_DEFAULT = 0, MKL_ON = 1, MKL_OFF = 2 };
+ static MklStatus status = MKL_DEFAULT;
+ if (status == MKL_DEFAULT) {
+ char* tf_disable_mkl = getenv("TF_DISABLE_MKL");
+ if ((tf_disable_mkl != NULL) && (std::stoi(tf_disable_mkl) == 1)) {
+ VLOG(2) << "TF-MKL: Disabling MKL";
+ status = MKL_OFF;
+ } else {
+ status = MKL_ON;
+ }
+ }
+ return status == MKL_OFF ? true : false;
+}
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h
index 93dfd51ab5..4aa47aa48a 100644
--- a/tensorflow/core/util/util.h
+++ b/tensorflow/core/util/util.h
@@ -56,6 +56,11 @@ string PrintMemory(const char* ptr, size_t n);
// "tensor", "tensor[i]", "tensor[i, j]", etc.
string SliceDebugString(const TensorShape& shape, const int64 flat);
+// disable MKL in runtime
+#ifdef INTEL_MKL
+bool DisableMKL();
+#endif // INTEL_MKL
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_UTIL_H_