aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core')
-rw-r--r--tensorflow/core/BUILD315
-rw-r--r--tensorflow/core/api_def/api_test.cc4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt34
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt22
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt9
-rw-r--r--tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt49
-rw-r--r--tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt3
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Fill.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt11
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt31
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt41
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt30
-rw-r--r--tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt17
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt69
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt112
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt19
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt5
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt29
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt2
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt38
-rw-r--r--tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt30
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Substr.pbtxt16
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt14
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt28
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt20
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt7
-rw-r--r--tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt45
-rw-r--r--tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt23
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Substr.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Tile.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt6
-rw-r--r--tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt6
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.cc150
-rw-r--r--tensorflow/core/common_runtime/base_collective_executor.h20
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.cc23
-rw-r--r--tensorflow/core/common_runtime/bfc_allocator.h22
-rw-r--r--tensorflow/core/common_runtime/broadcaster.cc300
-rw-r--r--tensorflow/core/common_runtime/buf_rendezvous.h6
-rw-r--r--tensorflow/core/common_runtime/collective_executor_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.cc252
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local.h23
-rw-r--r--tensorflow/core/common_runtime/collective_param_resolver_local_test.cc204
-rw-r--r--tensorflow/core/common_runtime/collective_rma_local.h8
-rw-r--r--tensorflow/core/common_runtime/collective_util.cc83
-rw-r--r--tensorflow/core/common_runtime/collective_util.h38
-rw-r--r--tensorflow/core/common_runtime/constant_folding.cc22
-rw-r--r--tensorflow/core/common_runtime/constant_folding.h6
-rw-r--r--tensorflow/core/common_runtime/copy_tensor.cc91
-rw-r--r--tensorflow/core/common_runtime/debugger_state_interface.h6
-rw-r--r--tensorflow/core/common_runtime/device.h16
-rw-r--r--tensorflow/core/common_runtime/device_factory.h6
-rw-r--r--tensorflow/core/common_runtime/device_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/device_resolver_local.h6
-rw-r--r--tensorflow/core/common_runtime/device_set.h6
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc114
-rw-r--r--tensorflow/core/common_runtime/direct_session.h39
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc143
-rw-r--r--tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc19
-rw-r--r--tensorflow/core/common_runtime/dma_helper.h6
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD12
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.cc24
-rw-r--r--tensorflow/core/common_runtime/eager/attr_builder.h19
-rw-r--r--tensorflow/core/common_runtime/eager/context.cc60
-rw-r--r--tensorflow/core/common_runtime/eager/context.h31
-rw-r--r--tensorflow/core/common_runtime/eager/execute.cc109
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.cc16
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device.h11
-rw-r--r--tensorflow/core/common_runtime/eager/kernel_and_device_test.cc4
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.cc17
-rw-r--r--tensorflow/core/common_runtime/eager/tensor_handle.h1
-rw-r--r--tensorflow/core/common_runtime/eval_const_tensor.cc18
-rw-r--r--tensorflow/core/common_runtime/executor.cc253
-rw-r--r--tensorflow/core/common_runtime/executor.h14
-rw-r--r--tensorflow/core/common_runtime/function.cc68
-rw-r--r--tensorflow/core/common_runtime/function.h6
-rw-r--r--tensorflow/core/common_runtime/function_test.cc22
-rw-r--r--tensorflow/core/common_runtime/gpu/cuda_host_allocator.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc50
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h51
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc146
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc15
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h18
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc30
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h28
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc80
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc306
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h36
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device_test.cc19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id.h32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.cc38
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc32
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_utils.h37
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.cc175
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_process_state.h58
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_stream_util.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.h6
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc9
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc68
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h9
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.cc49
-rw-r--r--tensorflow/core/common_runtime/graph_execution_state.h7
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.cc4
-rw-r--r--tensorflow/core/common_runtime/graph_optimizer.h5
-rw-r--r--tensorflow/core/common_runtime/graph_runner.cc4
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc440
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h (renamed from tensorflow/core/common_runtime/broadcaster.h)58
-rw-r--r--tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc (renamed from tensorflow/core/common_runtime/broadcaster_test.cc)239
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.h6
-rw-r--r--tensorflow/core/common_runtime/local_device.h6
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc43
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op_test.cc4
-rw-r--r--tensorflow/core/common_runtime/lower_while_op.cc427
-rw-r--r--tensorflow/core/common_runtime/lower_while_op.h (renamed from tensorflow/core/util/status_util.h)29
-rw-r--r--tensorflow/core/common_runtime/lower_while_op_test.cc249
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator.h186
-rw-r--r--tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc4
-rw-r--r--tensorflow/core/common_runtime/optimization_registry.h11
-rw-r--r--tensorflow/core/common_runtime/parallel_concat_optimizer.cc6
-rw-r--r--tensorflow/core/common_runtime/placer.cc54
-rw-r--r--tensorflow/core/common_runtime/placer.h8
-rw-r--r--tensorflow/core/common_runtime/placer_test.cc49
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.cc46
-rw-r--r--tensorflow/core/common_runtime/pool_allocator.h27
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc70
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h5
-rw-r--r--tensorflow/core/common_runtime/process_state.cc71
-rw-r--r--tensorflow/core/common_runtime/process_state.h15
-rw-r--r--tensorflow/core/common_runtime/process_util.cc31
-rw-r--r--tensorflow/core/common_runtime/renamed_device.h16
-rw-r--r--tensorflow/core/common_runtime/rendezvous_mgr.h6
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.cc38
-rw-r--r--tensorflow/core/common_runtime/rendezvous_util.h3
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.cc419
-rw-r--r--tensorflow/core/common_runtime/ring_reducer.h55
-rw-r--r--tensorflow/core/common_runtime/ring_reducer_test.cc172
-rw-r--r--tensorflow/core/common_runtime/session_factory.h6
-rw-r--r--tensorflow/core/common_runtime/session_ref.cc170
-rw-r--r--tensorflow/core/common_runtime/session_ref.h86
-rw-r--r--tensorflow/core/common_runtime/session_state.cc2
-rw-r--r--tensorflow/core/common_runtime/single_threaded_cpu_device.h1
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.cc182
-rw-r--r--tensorflow/core/common_runtime/step_stats_collector.h149
-rw-r--r--tensorflow/core/common_runtime/sycl/sycl_allocator.h6
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.cc22
-rw-r--r--tensorflow/core/common_runtime/threadpool_device.h1
-rw-r--r--tensorflow/core/common_runtime/visitable_allocator.h79
-rw-r--r--tensorflow/core/debug/debug_callback_registry.h6
-rw-r--r--tensorflow/core/debug/debug_graph_utils.cc4
-rw-r--r--tensorflow/core/debug/debug_graph_utils.h6
-rw-r--r--tensorflow/core/debug/debug_grpc_testlib.h6
-rw-r--r--tensorflow/core/debug/debug_io_utils.cc61
-rw-r--r--tensorflow/core/debug/debug_io_utils.h35
-rw-r--r--tensorflow/core/debug/debug_io_utils_test.cc46
-rw-r--r--tensorflow/core/debug/debug_node_key.h6
-rw-r--r--tensorflow/core/debug/debugger_state_impl.cc3
-rw-r--r--tensorflow/core/debug/debugger_state_impl.h6
-rw-r--r--tensorflow/core/distributed_runtime/BUILD1
-rw-r--r--tensorflow/core/distributed_runtime/graph_mgr.cc13
-rw-r--r--tensorflow/core/distributed_runtime/master.cc51
-rw-r--r--tensorflow/core/distributed_runtime/master_env.h2
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc21
-rw-r--r--tensorflow/core/distributed_runtime/message_wrappers.h2
-rw-r--r--tensorflow/core/distributed_runtime/remote_device.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel.h2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc65
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc2
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h3
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_testlib.h6
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc5
-rw-r--r--tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc2
-rw-r--r--tensorflow/core/distributed_runtime/tensor_coding.h3
-rw-r--r--tensorflow/core/distributed_runtime/test_utils.h14
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache.h2
-rw-r--r--tensorflow/core/distributed_runtime/worker_cache_wrapper.h4
-rw-r--r--tensorflow/core/distributed_runtime/worker_session.cc5
-rw-r--r--tensorflow/core/example/example.proto8
-rw-r--r--tensorflow/core/example/example_parser_configuration.h2
-rw-r--r--tensorflow/core/example/feature_util.h11
-rw-r--r--tensorflow/core/framework/allocator.cc29
-rw-r--r--tensorflow/core/framework/allocator.h39
-rw-r--r--tensorflow/core/framework/allocator_registry.h1
-rw-r--r--tensorflow/core/framework/attr_value_util.h6
-rw-r--r--tensorflow/core/framework/attr_value_util_test.cc1
-rw-r--r--tensorflow/core/framework/bfloat16.h6
-rw-r--r--tensorflow/core/framework/cancellation.cc10
-rw-r--r--tensorflow/core/framework/cancellation.h15
-rw-r--r--tensorflow/core/framework/cancellation_test.cc52
-rw-r--r--tensorflow/core/framework/collective.cc102
-rw-r--r--tensorflow/core/framework/collective.h119
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc113
-rw-r--r--tensorflow/core/framework/common_shape_fns.h9
-rw-r--r--tensorflow/core/framework/control_flow.h6
-rw-r--r--tensorflow/core/framework/dataset.cc85
-rw-r--r--tensorflow/core/framework/dataset.h317
-rw-r--r--tensorflow/core/framework/dataset_stateful_op_whitelist.h33
-rw-r--r--tensorflow/core/framework/device_base.h22
-rw-r--r--tensorflow/core/framework/fake_input.h6
-rw-r--r--tensorflow/core/framework/function.cc34
-rw-r--r--tensorflow/core/framework/function.h31
-rw-r--r--tensorflow/core/framework/function_testlib.cc101
-rw-r--r--tensorflow/core/framework/function_testlib.h15
-rw-r--r--tensorflow/core/framework/graph_def_util.h6
-rw-r--r--tensorflow/core/framework/kernel_def_builder.h6
-rw-r--r--tensorflow/core/framework/log_memory.h6
-rw-r--r--tensorflow/core/framework/lookup_interface.h6
-rw-r--r--tensorflow/core/framework/memory_types.h6
-rw-r--r--tensorflow/core/framework/model.cc420
-rw-r--r--tensorflow/core/framework/model.h408
-rw-r--r--tensorflow/core/framework/node_def_builder.cc17
-rw-r--r--tensorflow/core/framework/node_def_builder.h6
-rw-r--r--tensorflow/core/framework/node_def_util.cc26
-rw-r--r--tensorflow/core/framework/node_def_util.h15
-rw-r--r--tensorflow/core/framework/node_def_util_test.cc42
-rw-r--r--tensorflow/core/framework/numeric_op.h6
-rw-r--r--tensorflow/core/framework/numeric_types.h6
-rw-r--r--tensorflow/core/framework/op.h26
-rw-r--r--tensorflow/core/framework/op_def_builder.cc24
-rw-r--r--tensorflow/core/framework/op_def_builder.h20
-rw-r--r--tensorflow/core/framework/op_def_util.cc9
-rw-r--r--tensorflow/core/framework/op_def_util.h11
-rw-r--r--tensorflow/core/framework/op_gen_lib.cc2
-rw-r--r--tensorflow/core/framework/op_gen_lib.h6
-rw-r--r--tensorflow/core/framework/op_kernel.cc22
-rw-r--r--tensorflow/core/framework/op_kernel.h37
-rw-r--r--tensorflow/core/framework/op_segment.cc8
-rw-r--r--tensorflow/core/framework/op_segment.h4
-rw-r--r--tensorflow/core/framework/queue_interface.h6
-rw-r--r--tensorflow/core/framework/reader_base.h6
-rw-r--r--tensorflow/core/framework/reader_interface.h6
-rw-r--r--tensorflow/core/framework/reader_op_kernel.h6
-rw-r--r--tensorflow/core/framework/register_types.h11
-rw-r--r--tensorflow/core/framework/register_types_traits.h6
-rw-r--r--tensorflow/core/framework/resource_mgr.cc11
-rw-r--r--tensorflow/core/framework/resource_mgr.h133
-rw-r--r--tensorflow/core/framework/resource_op_kernel.h6
-rw-r--r--tensorflow/core/framework/run_handler.cc249
-rw-r--r--tensorflow/core/framework/run_handler.h95
-rw-r--r--tensorflow/core/framework/run_handler_util.cc57
-rw-r--r--tensorflow/core/framework/run_handler_util.h43
-rw-r--r--tensorflow/core/framework/run_handler_util_test.cc93
-rw-r--r--tensorflow/core/framework/selective_registration.h6
-rw-r--r--tensorflow/core/framework/session_state.h6
-rw-r--r--tensorflow/core/framework/shape_inference.cc3
-rw-r--r--tensorflow/core/framework/shape_inference_testutil.h2
-rw-r--r--tensorflow/core/framework/stats_aggregator.h3
-rw-r--r--tensorflow/core/framework/tensor.cc138
-rw-r--r--tensorflow/core/framework/tensor.h26
-rw-r--r--tensorflow/core/framework/tensor_slice.h6
-rw-r--r--tensorflow/core/framework/tensor_test.cc97
-rw-r--r--tensorflow/core/framework/tensor_types.h6
-rw-r--r--tensorflow/core/framework/tensor_util.h7
-rw-r--r--tensorflow/core/framework/tracking_allocator.h6
-rw-r--r--tensorflow/core/framework/type_index.h6
-rw-r--r--tensorflow/core/framework/type_traits.h6
-rw-r--r--tensorflow/core/framework/types.h9
-rw-r--r--tensorflow/core/framework/variant.cc25
-rw-r--r--tensorflow/core/framework/variant.h66
-rw-r--r--tensorflow/core/framework/variant_encode_decode.h38
-rw-r--r--tensorflow/core/framework/variant_op_copy_test.cc6
-rw-r--r--tensorflow/core/framework/variant_op_registry.cc85
-rw-r--r--tensorflow/core/framework/variant_op_registry.h222
-rw-r--r--tensorflow/core/framework/variant_op_registry_test.cc96
-rw-r--r--tensorflow/core/framework/variant_tensor_data.cc22
-rw-r--r--tensorflow/core/framework/variant_tensor_data.h16
-rw-r--r--tensorflow/core/framework/variant_test.cc15
-rw-r--r--tensorflow/core/graph/algorithm.h6
-rw-r--r--tensorflow/core/graph/colors.h6
-rw-r--r--tensorflow/core/graph/control_flow.h6
-rw-r--r--tensorflow/core/graph/costmodel.h6
-rw-r--r--tensorflow/core/graph/default_device.h6
-rw-r--r--tensorflow/core/graph/gradients.cc41
-rw-r--r--tensorflow/core/graph/graph.cc33
-rw-r--r--tensorflow/core/graph/graph.h13
-rw-r--r--tensorflow/core/graph/graph_constructor.cc18
-rw-r--r--tensorflow/core/graph/graph_constructor.h6
-rw-r--r--tensorflow/core/graph/graph_constructor_test.cc14
-rw-r--r--tensorflow/core/graph/graph_def_builder.cc4
-rw-r--r--tensorflow/core/graph/graph_def_builder.h8
-rw-r--r--tensorflow/core/graph/graph_partition.cc2
-rw-r--r--tensorflow/core/graph/graph_partition.h6
-rw-r--r--tensorflow/core/graph/mkl_graph_util.h2
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.cc99
-rw-r--r--tensorflow/core/graph/mkl_layout_pass.h6
-rw-r--r--tensorflow/core/graph/mkl_layout_pass_test.cc24
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass.cc19
-rw-r--r--tensorflow/core/graph/mkl_tfconversion_pass_test.cc4
-rw-r--r--tensorflow/core/graph/node_builder.cc9
-rw-r--r--tensorflow/core/graph/node_builder.h10
-rw-r--r--tensorflow/core/graph/optimizer_cse.h6
-rw-r--r--tensorflow/core/graph/quantize_training.h6
-rw-r--r--tensorflow/core/graph/subgraph.h6
-rw-r--r--tensorflow/core/graph/tensor_id.cc2
-rw-r--r--tensorflow/core/graph/testlib.cc52
-rw-r--r--tensorflow/core/graph/testlib.h17
-rw-r--r--tensorflow/core/graph/types.h6
-rw-r--r--tensorflow/core/graph/while_context.cc2
-rw-r--r--tensorflow/core/graph/while_context.h6
-rw-r--r--tensorflow/core/grappler/clusters/cluster.cc3
-rw-r--r--tensorflow/core/grappler/clusters/single_machine.cc6
-rw-r--r--tensorflow/core/grappler/clusters/utils.cc13
-rw-r--r--tensorflow/core/grappler/clusters/utils.h2
-rw-r--r--tensorflow/core/grappler/clusters/utils_test.cc22
-rw-r--r--tensorflow/core/grappler/clusters/virtual_cluster.cc2
-rw-r--r--tensorflow/core/grappler/costs/analytical_cost_estimator.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_memory.cc1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc251
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h4
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc422
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt7
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator.cc19
-rw-r--r--tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc73
-rw-r--r--tensorflow/core/grappler/costs/utils.cc20
-rw-r--r--tensorflow/core/grappler/costs/utils.h2
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc11
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler_test.cc9
-rw-r--r--tensorflow/core/grappler/graph_analyzer/BUILD139
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node.cc148
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node.h167
-rw-r--r--tensorflow/core/grappler/graph_analyzer/gen_node_test.cc491
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc341
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer.h154
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc569
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc98
-rw-r--r--tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h31
-rw-r--r--tensorflow/core/grappler/graph_analyzer/hash_tools.h47
-rw-r--r--tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc (renamed from tensorflow/core/util/status_util_test.cc)36
-rw-r--r--tensorflow/core/grappler/graph_analyzer/map_tools.h46
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node.cc453
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node.h304
-rw-r--r--tensorflow/core/grappler/graph_analyzer/sig_node_test.cc1235
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph.cc235
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph.h189
-rw-r--r--tensorflow/core/grappler/graph_analyzer/subgraph_test.cc348
-rw-r--r--tensorflow/core/grappler/graph_analyzer/test_tools.cc296
-rw-r--r--tensorflow/core/grappler/graph_analyzer/test_tools.h120
-rw-r--r--tensorflow/core/grappler/graph_view.cc46
-rw-r--r--tensorflow/core/grappler/graph_view.h11
-rw-r--r--tensorflow/core/grappler/graph_view_test.cc97
-rw-r--r--tensorflow/core/grappler/grappler_item.cc1
-rw-r--r--tensorflow/core/grappler/grappler_item.h9
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc9
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.h2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder_test.cc23
-rw-r--r--tensorflow/core/grappler/inputs/utils.cc7
-rw-r--r--tensorflow/core/grappler/inputs/utils.h4
-rw-r--r--tensorflow/core/grappler/op_types.cc147
-rw-r--r--tensorflow/core/grappler/op_types.h5
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD148
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc201
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc131
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc152
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h4
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc310
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD257
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.cc136
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion.h47
-rw-r--r--tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename.cc51
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_rename_test.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.cc176
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils.h108
-rw-r--r--tensorflow/core/grappler/optimizers/data/function_utils_test.cc164
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.cc167
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils.h57
-rw-r--r--tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_test_utils.h36
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.cc127
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h66
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils_test.cc106
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc289
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h55
-rw-r--r--tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc84
-rw-r--r--tensorflow/core/grappler/optimizers/data/latency_all_edges.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc12
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc27
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion.cc42
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_fusion_test.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.cc103
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization.h47
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc85
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc282
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.h (renamed from tensorflow/core/grappler/optimizers/data/function_rename.h)14
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc211
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc43
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc6
-rw-r--r--tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD70
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc49
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc59
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h52
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc47
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h75
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc52
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc612
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h97
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc939
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/debug_stripper_test.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/evaluation_utils.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/evaluation_utils.h1
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc111
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector.h115
-rw-r--r--tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc138
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.cc167
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info.h80
-rw-r--r--tensorflow/core/grappler/optimizers/function_api_info_test.cc160
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc19
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc11
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc116
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h7
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc202
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc363
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h62
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc236
-rw-r--r--tensorflow/core/grappler/optimizers/remapper.cc10
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/shape_optimizer.cc16
-rw-r--r--tensorflow/core/grappler/utils.cc13
-rw-r--r--tensorflow/core/grappler/utils.h108
-rw-r--r--tensorflow/core/grappler/utils/BUILD29
-rw-r--r--tensorflow/core/grappler/utils/functions.cc93
-rw-r--r--tensorflow/core/grappler/utils/functions.h20
-rw-r--r--tensorflow/core/grappler/utils/functions_test.cc31
-rw-r--r--tensorflow/core/grappler/utils/grappler_test.cc9
-rw-r--r--tensorflow/core/grappler/utils/scc.h7
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.cc)2
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes.h (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes.h)6
-rw-r--r--tensorflow/core/grappler/utils/symbolic_shapes_test.cc (renamed from tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc)2
-rw-r--r--tensorflow/core/grappler/utils_test.cc55
-rw-r--r--tensorflow/core/kernels/BUILD386
-rw-r--r--tensorflow/core/kernels/adjust_contrast_op.h6
-rw-r--r--tensorflow/core/kernels/adjust_hue_op.h6
-rw-r--r--tensorflow/core/kernels/adjust_saturation_op.h6
-rw-r--r--tensorflow/core/kernels/aggregate_ops.h6
-rw-r--r--tensorflow/core/kernels/aggregate_ops_cpu.h6
-rw-r--r--tensorflow/core/kernels/argmax_op.h6
-rw-r--r--tensorflow/core/kernels/assign_op.h6
-rw-r--r--tensorflow/core/kernels/avgpooling_op.h6
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_complex.cc10
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h5
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_real.cc9
-rw-r--r--tensorflow/core/kernels/batch_norm_op.h6
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD5
-rw-r--r--tensorflow/core/kernels/betainc_op.h6
-rw-r--r--tensorflow/core/kernels/bias_op.cc13
-rw-r--r--tensorflow/core/kernels/bias_op.h6
-rw-r--r--tensorflow/core/kernels/bincount_op.h6
-rw-r--r--tensorflow/core/kernels/bincount_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/boosted_trees/BUILD16
-rw-r--r--tensorflow/core/kernels/boosted_trees/boosted_trees.proto13
-rw-r--r--tensorflow/core/kernels/boosted_trees/prediction_ops.cc38
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantile_ops.cc453
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/BUILD65
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h96
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h132
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc99
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h330
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc276
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h344
-rw-r--r--tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc223
-rw-r--r--tensorflow/core/kernels/boosted_trees/resources.cc26
-rw-r--r--tensorflow/core/kernels/bounds_check.h6
-rw-r--r--tensorflow/core/kernels/broadcast_to_op.h6
-rw-r--r--tensorflow/core/kernels/bucketize_op.h6
-rw-r--r--tensorflow/core/kernels/candidate_sampler_ops.cc6
-rw-r--r--tensorflow/core/kernels/cast_op.cc8
-rw-r--r--tensorflow/core/kernels/cast_op.h6
-rw-r--r--tensorflow/core/kernels/collective_ops.cc59
-rw-r--r--tensorflow/core/kernels/colorspace_op.h6
-rw-r--r--tensorflow/core/kernels/concat_lib.h6
-rw-r--r--tensorflow/core/kernels/concat_lib_cpu.h5
-rw-r--r--tensorflow/core/kernels/conditional_accumulator.h12
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.cc13
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base.h9
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_base_op.h9
-rw-r--r--tensorflow/core/kernels/conditional_accumulator_op.cc3
-rw-r--r--tensorflow/core/kernels/constant_op.cc43
-rw-r--r--tensorflow/core/kernels/constant_op.h20
-rw-r--r--tensorflow/core/kernels/control_flow_ops.h6
-rw-r--r--tensorflow/core/kernels/conv_2d.h51
-rw-r--r--tensorflow/core/kernels/conv_3d.h49
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc3
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc6
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc13
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc1330
-rw-r--r--tensorflow/core/kernels/conv_ops.cc340
-rw-r--r--tensorflow/core/kernels/conv_ops.h50
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc20
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h6
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu_3.cu.cc81
-rw-r--r--tensorflow/core/kernels/cross_op.h6
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h5
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.h6
-rw-r--r--tensorflow/core/kernels/cwise_op_div.cc3
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_div.cu.cc1
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc (renamed from tensorflow/core/lib/gtl/optional.cc)17
-rw-r--r--tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc26
-rw-r--r--tensorflow/core/kernels/cwise_op_select.cc56
-rw-r--r--tensorflow/core/kernels/cwise_op_xdivy.cc38
-rw-r--r--tensorflow/core/kernels/cwise_op_xlogy.cc41
-rw-r--r--tensorflow/core/kernels/cwise_op_zeta.cc5
-rw-r--r--tensorflow/core/kernels/cwise_ops.h75
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.cc4
-rw-r--r--tensorflow/core/kernels/cwise_ops_common.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_common.cu.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h6
-rw-r--r--tensorflow/core/kernels/cwise_ops_gradients.h6
-rw-r--r--tensorflow/core/kernels/data/BUILD85
-rw-r--r--tensorflow/core/kernels/data/batch_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/cache_dataset_ops.cc18
-rw-r--r--tensorflow/core/kernels/data/captured_function.cc225
-rw-r--r--tensorflow/core/kernels/data/captured_function.h34
-rw-r--r--tensorflow/core/kernels/data/concatenate_dataset_op.cc16
-rw-r--r--tensorflow/core/kernels/data/dataset_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.cc41
-rw-r--r--tensorflow/core/kernels/data/dataset_utils.h14
-rw-r--r--tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD139
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc156
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc860
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc281
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc156
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc141
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc375
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h119
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc218
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc482
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc220
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc224
-rw-r--r--tensorflow/core/kernels/data/filter_by_component_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/filter_dataset_op.cc75
-rw-r--r--tensorflow/core/kernels/data/flat_map_dataset_op.cc29
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.cc72
-rw-r--r--tensorflow/core/kernels/data/generator_dataset_op.h3
-rw-r--r--tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc28
-rw-r--r--tensorflow/core/kernels/data/group_by_window_dataset_op.cc76
-rw-r--r--tensorflow/core/kernels/data/interleave_dataset_op.cc31
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.cc278
-rw-r--r--tensorflow/core/kernels/data/iterator_ops.h9
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc160
-rw-r--r--tensorflow/core/kernels/data/map_dataset_op.cc35
-rw-r--r--tensorflow/core/kernels/data/map_defun_op.cc236
-rw-r--r--tensorflow/core/kernels/data/model_dataset_op.cc183
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc637
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc48
-rw-r--r--tensorflow/core/kernels/data/optional_ops.cc24
-rw-r--r--tensorflow/core/kernels/data/optional_ops.h2
-rw-r--r--tensorflow/core/kernels/data/padded_batch_dataset_op.cc9
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc672
-rw-r--r--tensorflow/core/kernels/data/parallel_map_dataset_op.cc66
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc184
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.h12
-rw-r--r--tensorflow/core/kernels/data/parse_example_dataset_op.cc369
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.cc15
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner.h2
-rw-r--r--tensorflow/core/kernels/data/prefetch_autotuner_test.cc2
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.cc73
-rw-r--r--tensorflow/core/kernels/data/prefetch_dataset_op.h2
-rw-r--r--tensorflow/core/kernels/data/random_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/range_dataset_op.cc11
-rw-r--r--tensorflow/core/kernels/data/reader_dataset_ops.cc16
-rw-r--r--tensorflow/core/kernels/data/repeat_dataset_op.cc46
-rw-r--r--tensorflow/core/kernels/data/scan_dataset_op.cc37
-rw-r--r--tensorflow/core/kernels/data/shuffle_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.cc380
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor.h62
-rw-r--r--tensorflow/core/kernels/data/single_threaded_executor_test.cc332
-rw-r--r--tensorflow/core/kernels/data/skip_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/slide_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/driver_manager.h4
-rw-r--r--tensorflow/core/kernels/data/sql/query_connection.h3
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.cc4
-rw-r--r--tensorflow/core/kernels/data/sql/sqlite_query_connection.h4
-rw-r--r--tensorflow/core/kernels/data/sql_dataset_ops.cc9
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc108
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_ops.cc13
-rw-r--r--tensorflow/core/kernels/data/stats_dataset_ops.cc208
-rw-r--r--tensorflow/core/kernels/data/take_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/tensor_dataset_op.cc25
-rw-r--r--tensorflow/core/kernels/data/tensor_queue_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data/tensor_slice_dataset_op.cc19
-rw-r--r--tensorflow/core/kernels/data/unbatch_dataset_op.cc21
-rw-r--r--tensorflow/core/kernels/data/window_dataset.cc16
-rw-r--r--tensorflow/core/kernels/data/window_dataset.h2
-rw-r--r--tensorflow/core/kernels/data/window_dataset_op.cc223
-rw-r--r--tensorflow/core/kernels/data/writer_ops.cc14
-rw-r--r--tensorflow/core/kernels/data/zip_dataset_op.cc8
-rw-r--r--tensorflow/core/kernels/data_format_ops.h6
-rw-r--r--tensorflow/core/kernels/debug_ops.h19
-rw-r--r--tensorflow/core/kernels/decode_bmp_op.cc7
-rw-r--r--tensorflow/core/kernels/decode_csv_op.cc3
-rw-r--r--tensorflow/core/kernels/dense_update_functor.h6
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc2
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc8
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc4
-rw-r--r--tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h602
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions.h50
-rw-r--r--tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc31
-rw-r--r--tensorflow/core/kernels/eigen_benchmark.h304
-rw-r--r--tensorflow/core/kernels/eigen_benchmark_cpu_test.cc422
-rw-r--r--tensorflow/core/kernels/eigen_cuboid_convolution.h1438
-rw-r--r--tensorflow/core/kernels/eigen_spatial_convolutions.h342
-rw-r--r--tensorflow/core/kernels/eigen_volume_patch.h1
-rw-r--r--tensorflow/core/kernels/example_parsing_ops.cc165
-rw-r--r--tensorflow/core/kernels/extract_image_patches_op.h6
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.cc197
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op.h58
-rw-r--r--tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/fake_quant_ops_functor.h6
-rw-r--r--tensorflow/core/kernels/fill_functor.h6
-rw-r--r--tensorflow/core/kernels/fractional_pool_common.h6
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.h6
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD2
-rw-r--r--tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc45
-rw-r--r--tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc32
-rw-r--r--tensorflow/core/kernels/gather_functor.h7
-rw-r--r--tensorflow/core/kernels/gather_nd_op.h6
-rw-r--r--tensorflow/core/kernels/gather_nd_op_cpu_impl.h21
-rw-r--r--tensorflow/core/kernels/gemm_functors.h5
-rw-r--r--tensorflow/core/kernels/gpu_utils.h3
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transfer_utils.h6
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/graph_transferer.h2
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc2
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h6
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h2
-rw-r--r--tensorflow/core/kernels/hexagon/soc_interface.h6
-rw-r--r--tensorflow/core/kernels/hinge-loss.h6
-rw-r--r--tensorflow/core/kernels/histogram_op.h6
-rw-r--r--tensorflow/core/kernels/histogram_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/host_constant_op.cc78
-rw-r--r--tensorflow/core/kernels/host_constant_op.h42
-rw-r--r--tensorflow/core/kernels/i_remote_fused_graph_executor.h6
-rw-r--r--tensorflow/core/kernels/identity_n_op.h6
-rw-r--r--tensorflow/core/kernels/identity_op.h6
-rw-r--r--tensorflow/core/kernels/image_resizer_state.h6
-rw-r--r--tensorflow/core/kernels/immutable_constant_op.h6
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.cc8
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h63
-rw-r--r--tensorflow/core/kernels/inplace_ops_functor.h6
-rw-r--r--tensorflow/core/kernels/l2loss_op.h6
-rw-r--r--tensorflow/core/kernels/linalg_ops_common.h6
-rw-r--r--tensorflow/core/kernels/list_kernels.cc24
-rw-r--r--tensorflow/core/kernels/list_kernels.cu.cc18
-rw-r--r--tensorflow/core/kernels/list_kernels.h149
-rw-r--r--tensorflow/core/kernels/logging_ops.cc54
-rw-r--r--tensorflow/core/kernels/logging_ops_test.cc22
-rw-r--r--tensorflow/core/kernels/logistic-loss.h8
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.cc4
-rw-r--r--tensorflow/core/kernels/lookup_table_init_op.h6
-rw-r--r--tensorflow/core/kernels/lookup_table_op.cc99
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h15
-rw-r--r--tensorflow/core/kernels/lookup_util.h51
-rw-r--r--tensorflow/core/kernels/loss.h6
-rw-r--r--tensorflow/core/kernels/loss_test.cc174
-rw-r--r--tensorflow/core/kernels/map_stage_op.cc10
-rw-r--r--tensorflow/core/kernels/matmul_op.cc8
-rw-r--r--tensorflow/core/kernels/matmul_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_band_part_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_diag_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_exponential_op.cc1
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.h6
-rw-r--r--tensorflow/core/kernels/matrix_solve_ls_op_impl.h5
-rw-r--r--tensorflow/core/kernels/maxpooling_op.h6
-rw-r--r--tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc4
-rw-r--r--tensorflow/core/kernels/mirror_pad_op.h7
-rw-r--r--tensorflow/core/kernels/mirror_pad_op_cpu_impl.h6
-rw-r--r--tensorflow/core/kernels/mkl_aggregate_ops.cc20
-rw-r--r--tensorflow/core/kernels/mkl_avgpooling_op.cc51
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc193
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc174
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc186
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.h414
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops_test.cc407
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc17
-rw-r--r--tensorflow/core/kernels/mkl_matmul_op.cc6
-rw-r--r--tensorflow/core/kernels/mkl_maxpooling_op.cc59
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.cc157
-rw-r--r--tensorflow/core/kernels/mkl_pooling_ops_common.h132
-rw-r--r--tensorflow/core/kernels/mkl_relu_op.cc607
-rw-r--r--tensorflow/core/kernels/mkl_slice_op.cc358
-rw-r--r--tensorflow/core/kernels/mkl_softmax_op.cc38
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc2
-rw-r--r--tensorflow/core/kernels/multinomial_op.h6
-rw-r--r--tensorflow/core/kernels/neon/depthwiseconv_float.h6
-rw-r--r--tensorflow/core/kernels/no_op.h6
-rw-r--r--tensorflow/core/kernels/non_max_suppression_op.cc132
-rw-r--r--tensorflow/core/kernels/nth_element_op.h6
-rw-r--r--tensorflow/core/kernels/one_hot_op.h6
-rw-r--r--tensorflow/core/kernels/ops_testutil.h6
-rw-r--r--tensorflow/core/kernels/ops_util.h6
-rw-r--r--tensorflow/core/kernels/pad_op.h6
-rw-r--r--tensorflow/core/kernels/padding_fifo_queue.h6
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.cc31
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op.h6
-rw-r--r--tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc94
-rw-r--r--tensorflow/core/kernels/poisson-loss.h109
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d.h6
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d_gpu.h6
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h6
-rw-r--r--tensorflow/core/kernels/priority_queue.h6
-rw-r--r--tensorflow/core/kernels/qr_op_complex128.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_complex64.cc6
-rw-r--r--tensorflow/core/kernels/qr_op_double.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_float.cc12
-rw-r--r--tensorflow/core/kernels/qr_op_impl.h7
-rw-r--r--tensorflow/core/kernels/queue_base.h4
-rw-r--r--tensorflow/core/kernels/queue_ops.cc2
-rw-r--r--tensorflow/core/kernels/random_op.cc10
-rw-r--r--tensorflow/core/kernels/random_op.h6
-rw-r--r--tensorflow/core/kernels/random_poisson_op.h6
-rw-r--r--tensorflow/core/kernels/range_sampler.h6
-rw-r--r--tensorflow/core/kernels/range_sampler_test.cc22
-rw-r--r--tensorflow/core/kernels/record_yielder.h6
-rw-r--r--tensorflow/core/kernels/reduction_gpu_kernels.cu.h17
-rw-r--r--tensorflow/core/kernels/reduction_ops.h6
-rw-r--r--tensorflow/core/kernels/reduction_ops_common.h6
-rw-r--r--tensorflow/core/kernels/reduction_ops_max.cc2
-rw-r--r--tensorflow/core/kernels/regex_full_match_op.cc33
-rw-r--r--tensorflow/core/kernels/regex_replace_op.cc80
-rw-r--r--tensorflow/core/kernels/regex_replace_op_test.cc137
-rw-r--r--tensorflow/core/kernels/relu_op.cc27
-rw-r--r--tensorflow/core/kernels/relu_op.h6
-rw-r--r--tensorflow/core/kernels/relu_op_functor.h6
-rw-r--r--tensorflow/core/kernels/relu_op_gpu.cu.cc35
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc26
-rw-r--r--tensorflow/core/kernels/reshape_op.h6
-rw-r--r--tensorflow/core/kernels/resize_bilinear_op.cc34
-rw-r--r--tensorflow/core/kernels/resize_nearest_neighbor_op.cc75
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc84
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.h10
-rw-r--r--tensorflow/core/kernels/reverse_op.h6
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.cc5
-rw-r--r--tensorflow/core/kernels/reverse_sequence_op.h6
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.cc9
-rw-r--r--tensorflow/core/kernels/save_restore_tensor.h6
-rw-r--r--tensorflow/core/kernels/save_restore_v2_ops.cc4
-rw-r--r--tensorflow/core/kernels/scan_ops.h6
-rw-r--r--tensorflow/core/kernels/scatter_functor.h6
-rw-r--r--tensorflow/core/kernels/scatter_functor_gpu.cu.h6
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc1
-rw-r--r--tensorflow/core/kernels/sdca_internal.cc2
-rw-r--r--tensorflow/core/kernels/sdca_ops.cc3
-rw-r--r--tensorflow/core/kernels/searchsorted_op.cc249
-rw-r--r--tensorflow/core/kernels/searchsorted_op.h52
-rw-r--r--tensorflow/core/kernels/searchsorted_op_gpu.cu.cc126
-rw-r--r--tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h5
-rw-r--r--tensorflow/core/kernels/sendrecv_ops.h6
-rw-r--r--tensorflow/core/kernels/set_kernels.cc14
-rw-r--r--tensorflow/core/kernels/shape_op_test.cc10
-rw-r--r--tensorflow/core/kernels/shape_ops.cc93
-rw-r--r--tensorflow/core/kernels/shape_ops.h14
-rw-r--r--tensorflow/core/kernels/slice_op.cc199
-rw-r--r--tensorflow/core/kernels/slice_op.h6
-rw-r--r--tensorflow/core/kernels/smooth-hinge-loss.h6
-rw-r--r--tensorflow/core/kernels/snapshot_op.h6
-rw-r--r--tensorflow/core/kernels/softmax_op_functor.h6
-rw-r--r--tensorflow/core/kernels/softplus_op.cc11
-rw-r--r--tensorflow/core/kernels/softplus_op.h6
-rw-r--r--tensorflow/core/kernels/softsign_op.cc11
-rw-r--r--tensorflow/core/kernels/softsign_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator.h10
-rw-r--r--tensorflow/core/kernels/sparse_conditional_accumulator_op.cc4
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_softmax_op.cc2
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_add_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h6
-rw-r--r--tensorflow/core/kernels/sparse_xent_op.h6
-rw-r--r--tensorflow/core/kernels/split_lib.h6
-rw-r--r--tensorflow/core/kernels/split_lib_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/split_op.cc7
-rw-r--r--tensorflow/core/kernels/squared-loss.h6
-rw-r--r--tensorflow/core/kernels/stack_ops.cc26
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc41
-rw-r--r--tensorflow/core/kernels/strided_slice_op.h6
-rw-r--r--tensorflow/core/kernels/strided_slice_op_impl.h6
-rw-r--r--tensorflow/core/kernels/string_format_op.cc65
-rw-r--r--tensorflow/core/kernels/string_format_op_test.cc66
-rw-r--r--tensorflow/core/kernels/string_length_op.cc62
-rw-r--r--tensorflow/core/kernels/string_split_op.cc111
-rw-r--r--tensorflow/core/kernels/string_split_op_test.cc129
-rw-r--r--tensorflow/core/kernels/string_strip_op.cc2
-rw-r--r--tensorflow/core/kernels/string_util.cc59
-rw-r--r--tensorflow/core/kernels/string_util.h89
-rw-r--r--tensorflow/core/kernels/substr_op.cc184
-rw-r--r--tensorflow/core/kernels/substr_op_test.cc189
-rw-r--r--tensorflow/core/kernels/svd_op_impl.h5
-rw-r--r--tensorflow/core/kernels/tensor_array.cc3
-rw-r--r--tensorflow/core/kernels/tensor_array.h9
-rw-r--r--tensorflow/core/kernels/tensor_array_ops.cc19
-rw-r--r--tensorflow/core/kernels/tile_functor.h6
-rw-r--r--tensorflow/core/kernels/tile_ops_impl.h6
-rw-r--r--tensorflow/core/kernels/topk_op.h6
-rw-r--r--tensorflow/core/kernels/topk_op_gpu.cu.cc6
-rw-r--r--tensorflow/core/kernels/training_op_helpers.cc44
-rw-r--r--tensorflow/core/kernels/training_op_helpers.h43
-rw-r--r--tensorflow/core/kernels/training_ops.cc60
-rw-r--r--tensorflow/core/kernels/training_ops.h6
-rw-r--r--tensorflow/core/kernels/transpose_op.cc10
-rw-r--r--tensorflow/core/kernels/typed_conditional_accumulator_base.h11
-rw-r--r--tensorflow/core/kernels/unicode_script_op.cc53
-rw-r--r--tensorflow/core/kernels/unravel_index_op.cc10
-rw-r--r--tensorflow/core/kernels/variable_ops.h6
-rw-r--r--tensorflow/core/kernels/where_op.h6
-rw-r--r--tensorflow/core/kernels/where_op_gpu.cu.h13
-rw-r--r--tensorflow/core/kernels/whole_file_read_ops.cc2
-rw-r--r--tensorflow/core/kernels/xent_op.h6
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h4
-rw-r--r--tensorflow/core/lib/core/arena.h6
-rw-r--r--tensorflow/core/lib/core/bits.h6
-rw-r--r--tensorflow/core/lib/core/casts.h6
-rw-r--r--tensorflow/core/lib/core/coding.h6
-rw-r--r--tensorflow/core/lib/core/errors.h24
-rw-r--r--tensorflow/core/lib/core/notification.h6
-rw-r--r--tensorflow/core/lib/core/raw_coding.h6
-rw-r--r--tensorflow/core/lib/core/status.cc2
-rw-r--r--tensorflow/core/lib/core/status.h1
-rw-r--r--tensorflow/core/lib/core/status_test_util.h6
-rw-r--r--tensorflow/core/lib/core/stringpiece.cc54
-rw-r--r--tensorflow/core/lib/core/stringpiece.h125
-rw-r--r--tensorflow/core/lib/core/stringpiece_test.cc4
-rw-r--r--tensorflow/core/lib/core/threadpool.cc49
-rw-r--r--tensorflow/core/lib/core/threadpool.h20
-rw-r--r--tensorflow/core/lib/core/threadpool_test.cc61
-rw-r--r--tensorflow/core/lib/gtl/array_slice.h294
-rw-r--r--tensorflow/core/lib/gtl/array_slice_internal.h269
-rw-r--r--tensorflow/core/lib/gtl/array_slice_test.cc666
-rw-r--r--tensorflow/core/lib/gtl/cleanup.h6
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h671
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector_test.cc898
-rw-r--r--tensorflow/core/lib/gtl/optional.h859
-rw-r--r--tensorflow/core/lib/gtl/optional_test.cc1098
-rw-r--r--tensorflow/core/lib/gtl/priority_queue_util.h6
-rw-r--r--tensorflow/core/lib/hash/crc32c.h6
-rw-r--r--tensorflow/core/lib/hash/hash.h6
-rw-r--r--tensorflow/core/lib/histogram/histogram.h6
-rw-r--r--tensorflow/core/lib/io/block_builder.h1
-rw-r--r--tensorflow/core/lib/io/buffered_inputstream.h6
-rw-r--r--tensorflow/core/lib/io/inputstream_interface.h4
-rw-r--r--tensorflow/core/lib/io/path.cc6
-rw-r--r--tensorflow/core/lib/io/path.h7
-rw-r--r--tensorflow/core/lib/io/path_test.cc6
-rw-r--r--tensorflow/core/lib/io/proto_encode_helper.h6
-rw-r--r--tensorflow/core/lib/io/random_inputstream.h6
-rw-r--r--tensorflow/core/lib/io/record_reader.cc56
-rw-r--r--tensorflow/core/lib/io/record_reader.h39
-rw-r--r--tensorflow/core/lib/io/record_reader_writer_test.cc7
-rw-r--r--tensorflow/core/lib/io/record_writer.cc15
-rw-r--r--tensorflow/core/lib/io/record_writer.h40
-rw-r--r--tensorflow/core/lib/io/recordio_test.cc2
-rw-r--r--tensorflow/core/lib/io/table.h6
-rw-r--r--tensorflow/core/lib/io/table_builder.h6
-rw-r--r--tensorflow/core/lib/io/table_options.h6
-rw-r--r--tensorflow/core/lib/io/table_test.cc8
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.cc2
-rw-r--r--tensorflow/core/lib/io/zlib_outputbuffer.h2
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_handle.h6
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.cc6
-rw-r--r--tensorflow/core/lib/jpeg/jpeg_mem.h6
-rw-r--r--tensorflow/core/lib/math/math_util.h6
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.cc8
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h5
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h5
-rw-r--r--tensorflow/core/lib/png/png_io.h1
-rw-r--r--tensorflow/core/lib/random/distribution_sampler.h6
-rw-r--r--tensorflow/core/lib/random/philox_random.h6
-rw-r--r--tensorflow/core/lib/random/random_distributions.h6
-rw-r--r--tensorflow/core/lib/random/simple_philox.h6
-rw-r--r--tensorflow/core/lib/strings/numbers.h10
-rw-r--r--tensorflow/core/lib/strings/str_util.cc5
-rw-r--r--tensorflow/core/lib/strings/str_util.h8
-rw-r--r--tensorflow/core/lib/strings/strcat.h43
-rw-r--r--tensorflow/core/lib/strings/strcat_test.cc8
-rw-r--r--tensorflow/core/lib/strings/stringprintf.h6
-rw-r--r--tensorflow/core/lib/wav/wav_io.cc5
-rw-r--r--tensorflow/core/ops/array_grad.cc21
-rw-r--r--tensorflow/core/ops/array_ops.cc316
-rw-r--r--tensorflow/core/ops/array_ops_test.cc34
-rw-r--r--tensorflow/core/ops/boosted_trees_ops.cc127
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt2703
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc9
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc11
-rw-r--r--tensorflow/core/ops/data_flow_ops.cc2
-rw-r--r--tensorflow/core/ops/dataset_ops.cc119
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc207
-rw-r--r--tensorflow/core/ops/functional_ops.cc44
-rw-r--r--tensorflow/core/ops/image_ops.cc92
-rw-r--r--tensorflow/core/ops/linalg_ops.cc2
-rw-r--r--tensorflow/core/ops/list_ops.cc51
-rw-r--r--tensorflow/core/ops/logging_ops.cc26
-rw-r--r--tensorflow/core/ops/lookup_ops.cc143
-rw-r--r--tensorflow/core/ops/math_grad.cc47
-rw-r--r--tensorflow/core/ops/math_grad_test.cc112
-rw-r--r--tensorflow/core/ops/math_ops.cc21
-rw-r--r--tensorflow/core/ops/math_ops_test.cc3
-rw-r--r--tensorflow/core/ops/nn_ops.cc201
-rw-r--r--tensorflow/core/ops/ops.pbtxt1774
-rw-r--r--tensorflow/core/ops/parsing_ops.cc100
-rw-r--r--tensorflow/core/ops/parsing_ops_test.cc89
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc72
-rw-r--r--tensorflow/core/ops/sdca_ops.cc2
-rw-r--r--tensorflow/core/ops/stateless_random_grad.cc23
-rw-r--r--tensorflow/core/ops/string_ops.cc53
-rw-r--r--tensorflow/core/platform/abi.cc4
-rw-r--r--tensorflow/core/platform/abi.h9
-rw-r--r--tensorflow/core/platform/cloud/auth_provider.h6
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.cc15
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.h10
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc6
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider.cc2
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc8
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_dns_cache.h6
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc44
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h7
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc1280
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider.h6
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider_test.cc20
-rw-r--r--tensorflow/core/platform/cloud/http_request.h6
-rw-r--r--tensorflow/core/platform/cloud/http_request_fake.h6
-rw-r--r--tensorflow/core/platform/cloud/oauth_client.cc4
-rw-r--r--tensorflow/core/platform/cloud/oauth_client_test.cc6
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system.h69
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc104
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.cc35
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.h29
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils_test.cc32
-rw-r--r--tensorflow/core/platform/context.h6
-rw-r--r--tensorflow/core/platform/cord.h (renamed from tensorflow/core/kernels/warn_about_ints.h)21
-rw-r--r--tensorflow/core/platform/cpu_feature_guard.h6
-rw-r--r--tensorflow/core/platform/cpu_info.h6
-rw-r--r--tensorflow/core/platform/default/build_config.bzl1170
-rw-r--r--tensorflow/core/platform/default/build_config_root.bzl86
-rw-r--r--tensorflow/core/platform/default/cord.h21
-rw-r--r--tensorflow/core/platform/default/device_tracer.cc10
-rw-r--r--tensorflow/core/platform/default/integral_types.h6
-rw-r--r--tensorflow/core/platform/default/logging.h6
-rw-r--r--tensorflow/core/platform/default/mutex.h6
-rw-r--r--tensorflow/core/platform/default/protobuf.h2
-rw-r--r--tensorflow/core/platform/default/protobuf_compiler.h (renamed from tensorflow/core/kernels/warn_about_ints.cc)22
-rw-r--r--tensorflow/core/platform/default/thread_annotations.h6
-rw-r--r--tensorflow/core/platform/default/tracing_impl.h6
-rw-r--r--tensorflow/core/platform/denormal.h6
-rw-r--r--tensorflow/core/platform/dynamic_annotations.h6
-rw-r--r--tensorflow/core/platform/env.cc4
-rw-r--r--tensorflow/core/platform/env.h6
-rw-r--r--tensorflow/core/platform/env_test.cc7
-rw-r--r--tensorflow/core/platform/file_system.cc2
-rw-r--r--tensorflow/core/platform/file_system.h11
-rw-r--r--tensorflow/core/platform/file_system_helper.cc2
-rw-r--r--tensorflow/core/platform/file_system_test.cc2
-rw-r--r--tensorflow/core/platform/hadoop/hadoop_file_system.cc8
-rw-r--r--tensorflow/core/platform/host_info.h6
-rw-r--r--tensorflow/core/platform/init_main.h6
-rw-r--r--tensorflow/core/platform/load_library.h6
-rw-r--r--tensorflow/core/platform/logging.h6
-rw-r--r--tensorflow/core/platform/macros.h6
-rw-r--r--tensorflow/core/platform/mem.h6
-rw-r--r--tensorflow/core/platform/mutex.h6
-rw-r--r--tensorflow/core/platform/net.h6
-rw-r--r--tensorflow/core/platform/png.h6
-rw-r--r--tensorflow/core/platform/posix/env.cc11
-rw-r--r--tensorflow/core/platform/posix/error.h2
-rw-r--r--tensorflow/core/platform/posix/port.cc42
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.cc2
-rw-r--r--tensorflow/core/platform/posix/posix_file_system.h2
-rw-r--r--tensorflow/core/platform/posix/subprocess.h6
-rw-r--r--tensorflow/core/platform/prefetch.h6
-rw-r--r--tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h6
-rw-r--r--tensorflow/core/platform/profile_utils/clock_cycle_profiler.h6
-rw-r--r--tensorflow/core/platform/profile_utils/cpu_utils.h6
-rw-r--r--tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h6
-rw-r--r--tensorflow/core/platform/protobuf.h6
-rw-r--r--tensorflow/core/platform/protobuf_compiler.h25
-rw-r--r--tensorflow/core/platform/protobuf_internal.h6
-rw-r--r--tensorflow/core/platform/s3/s3_file_system.cc6
-rw-r--r--tensorflow/core/platform/setround.h6
-rw-r--r--tensorflow/core/platform/snappy.h6
-rw-r--r--tensorflow/core/platform/stacktrace_handler.h6
-rw-r--r--tensorflow/core/platform/subprocess.h6
-rw-r--r--tensorflow/core/platform/test.h6
-rw-r--r--tensorflow/core/platform/test_benchmark.h6
-rw-r--r--tensorflow/core/platform/thread_annotations.h6
-rw-r--r--tensorflow/core/platform/tracing.h13
-rw-r--r--tensorflow/core/platform/types.h6
-rw-r--r--tensorflow/core/platform/windows/cpu_info.h6
-rw-r--r--tensorflow/core/platform/windows/env.cc11
-rw-r--r--tensorflow/core/platform/windows/integral_types.h6
-rw-r--r--tensorflow/core/platform/windows/port.cc51
-rw-r--r--tensorflow/core/platform/windows/subprocess.h6
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.cc2
-rw-r--r--tensorflow/core/platform/windows/windows_file_system.h2
-rw-r--r--tensorflow/core/profiler/BUILD1
-rw-r--r--tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h2
-rw-r--r--tensorflow/core/profiler/internal/advisor/tfprof_advisor.h6
-rw-r--r--tensorflow/core/profiler/internal/tfprof_code.cc4
-rw-r--r--tensorflow/core/profiler/tfprof_options.h6
-rw-r--r--tensorflow/core/protobuf/config.proto16
-rw-r--r--tensorflow/core/protobuf/debug.proto6
-rw-r--r--tensorflow/core/protobuf/replay_log.proto47
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto8
-rw-r--r--tensorflow/core/public/session.h6
-rw-r--r--tensorflow/core/public/version.h8
-rw-r--r--tensorflow/core/util/activation_mode.h6
-rw-r--r--tensorflow/core/util/bcast.h6
-rw-r--r--tensorflow/core/util/command_line_flags.cc2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_entry.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_scorer.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_beam_search.h19
-rw-r--r--tensorflow/core/util/ctc/ctc_decoder.h2
-rw-r--r--tensorflow/core/util/ctc/ctc_loss_util.h2
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h31
-rw-r--r--tensorflow/core/util/device_name_utils.h6
-rw-r--r--tensorflow/core/util/env_var.cc8
-rw-r--r--tensorflow/core/util/env_var.h5
-rw-r--r--tensorflow/core/util/events_writer.h6
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.cc230
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing.h3
-rw-r--r--tensorflow/core/util/example_proto_fast_parsing_test.cc2
-rw-r--r--tensorflow/core/util/example_proto_helper.cc53
-rw-r--r--tensorflow/core/util/example_proto_helper.h61
-rw-r--r--tensorflow/core/util/guarded_philox_random.h6
-rw-r--r--tensorflow/core/util/mirror_pad_mode.h6
-rw-r--r--tensorflow/core/util/mkl_util.h163
-rw-r--r--tensorflow/core/util/padding.h6
-rw-r--r--tensorflow/core/util/port.cc4
-rw-r--r--tensorflow/core/util/port.h6
-rw-r--r--tensorflow/core/util/saved_tensor_slice_util.h6
-rw-r--r--tensorflow/core/util/sparse/group_iterator.cc10
-rw-r--r--tensorflow/core/util/sparse/group_iterator.h4
-rw-r--r--tensorflow/core/util/sparse/sparse_tensor.h14
-rw-r--r--tensorflow/core/util/strided_slice_op.cc2
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD5
-rw-r--r--tensorflow/core/util/tensor_bundle/naming.h7
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.cc68
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle.h6
-rw-r--r--tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc66
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README3
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001bin0 -> 1080 bytes
-rw-r--r--tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.indexbin0 -> 211 bytes
-rw-r--r--tensorflow/core/util/tensor_format.cc4
-rw-r--r--tensorflow/core/util/tensor_format.h1
-rw-r--r--tensorflow/core/util/tensor_slice_reader.h6
-rw-r--r--tensorflow/core/util/tensor_slice_reader_cache.h6
-rw-r--r--tensorflow/core/util/tensor_slice_writer.h6
-rw-r--r--tensorflow/core/util/util.cc16
-rw-r--r--tensorflow/core/util/util.h11
-rw-r--r--tensorflow/core/util/work_sharder.cc2
-rw-r--r--tensorflow/core/util/work_sharder.h9
1158 files changed, 53985 insertions, 16402 deletions
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 084f128674..900a0e11c4 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -85,11 +85,12 @@ load(
"tf_cc_tests",
"tf_copts",
"tf_cuda_library",
+ "tf_features_nomodules_if_android",
"tf_gen_op_libs",
"tf_generate_proto_text_sources",
"tf_genrule_cmd_append_to_srcs",
"tf_opts_nortti_if_android",
- "tf_features_nomodules_if_android",
+ "transitive_hdrs",
)
load("//tensorflow:tensorflow.bzl", "tf_cc_test_mkl")
load("//tensorflow:tensorflow.bzl", "tf_cc_test_gpu")
@@ -120,6 +121,7 @@ load(
"tf_additional_libdevice_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
+ "tf_additional_proto_compiler_hdrs",
"tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
"tf_additional_test_deps",
@@ -127,6 +129,7 @@ load(
"tf_additional_verbs_lib_defines",
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
+ "tf_lib_proto_compiler_deps",
"tf_lib_proto_parsing_deps",
"tf_nano_proto_library",
"tf_platform_hdrs",
@@ -141,14 +144,17 @@ load(
)
load(
"//tensorflow/core:platform/default/build_config_root.bzl",
+ "if_dynamic_kernels",
"if_static",
"tf_cuda_tests_tags",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda_is_configured")
load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
+ "mkl_deps",
)
exports_files(["ops/ops.pbtxt"])
@@ -165,6 +171,7 @@ COMMON_PROTO_SRCS = [
"example/example.proto",
"example/feature.proto",
"framework/allocation_description.proto",
+ "framework/api_def.proto",
"framework/attr_value.proto",
"framework/cost_graph.proto",
"framework/device_attributes.proto",
@@ -176,7 +183,6 @@ COMMON_PROTO_SRCS = [
"framework/log_memory.proto",
"framework/node_def.proto",
"framework/op_def.proto",
- "framework/api_def.proto",
"framework/reader_base.proto",
"framework/remote_fused_graph_execute_info.proto",
"framework/resource_handle.proto",
@@ -233,7 +239,6 @@ tf_proto_library(
srcs = [],
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":protos_all_proto",
@@ -266,6 +271,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+java_proto_library(
+ name = "example_java_proto",
+ visibility = ["//visibility:public"],
+ deps = [":example_protos"],
+)
+
closure_proto_library(
name = "example_protos_closure",
visibility = ["//visibility:public"],
@@ -296,6 +307,7 @@ filegroup(
name = "platform_base_hdrs",
srcs = [
"platform/byte_order.h",
+ "platform/cord.h",
"platform/env_time.h",
"platform/logging.h",
"platform/macros.h",
@@ -372,6 +384,7 @@ cc_library(
":lib_platform",
":platform_base",
"//tensorflow/core/platform/default/build_config:port",
+ "@com_google_absl//absl/base",
"@snappy",
],
)
@@ -608,10 +621,22 @@ cc_library(
copts = tf_copts(),
deps = tf_lib_proto_parsing_deps() + [
":platform_base",
+ "@com_google_absl//absl/strings",
"@double_conversion//:double-conversion",
],
)
+cc_library(
+ name = "lib_proto_compiler",
+ hdrs = [
+ "platform/protobuf_compiler.h",
+ ] + tf_additional_proto_compiler_hdrs(),
+ copts = tf_copts(),
+ deps = tf_lib_proto_compiler_deps() + [
+ ":lib_proto_parsing",
+ ],
+)
+
# This build rule (along with :lib_internal, :framework, and
# :framework_internal) purposefully omits the definitions of many declared
# symbols, which are included in //tensorflow:libtensorflow_framework.so. Using
@@ -654,8 +679,11 @@ cc_library(
"lib/io/table_builder.h",
"lib/io/table_options.h",
"lib/math/math_util.h",
+ "lib/monitoring/collected_metrics.h",
+ "lib/monitoring/collection_registry.h",
"lib/monitoring/counter.h",
"lib/monitoring/gauge.h",
+ "lib/monitoring/metric_def.h",
"lib/monitoring/sampler.h",
"lib/random/distribution_sampler.h",
"lib/random/philox_random.h",
@@ -676,6 +704,21 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":lib_internal",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:optional",
+ ],
+)
+
+cc_library(
+ name = "feature_util",
+ srcs = ["example/feature_util.cc"],
+ hdrs = ["example/feature_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":core_stringpiece",
+ ":lib_proto_parsing",
+ ":protos_all_cc",
],
)
@@ -683,6 +726,7 @@ cc_library(
name = "abi",
srcs = ["platform/abi.cc"],
hdrs = ["platform/abi.h"],
+ deps = [":platform_base"],
)
cc_library(
@@ -712,10 +756,12 @@ cc_library(
# required to use tf_cc_test, and that rule will change / into _
cc_library(
name = "core_stringpiece",
- srcs = ["lib/core/stringpiece.cc"],
hdrs = ["lib/core/stringpiece.h"],
copts = tf_copts(),
- deps = [":platform_base"],
+ deps = [
+ ":platform_base",
+ "@com_google_absl//absl/strings",
+ ],
)
# Test support library needed for all tests
@@ -735,7 +781,10 @@ cc_library(
"util/reporter.h",
],
copts = tf_copts(),
- linkopts = ["-lm"],
+ linkopts = select({
+ "//tensorflow:windows": [],
+ "//conditions:default": ["-lm"],
+ }),
visibility = ["//visibility:public"],
deps = [
":lib",
@@ -832,7 +881,6 @@ tf_cuda_library(
"util/bcast.h",
"util/cuda_kernel_helper.h",
"util/device_name_utils.h",
- "util/env_var.h",
"util/events_writer.h",
"util/example_proto_fast_parsing.h",
"util/example_proto_helper.h",
@@ -847,7 +895,6 @@ tf_cuda_library(
"util/sparse/sparse_tensor.h",
"util/stat_summarizer.h",
"util/stat_summarizer_options.h",
- "util/status_util.h",
"util/stream_executor_util.h",
"util/strided_slice_op.h",
"util/tensor_format.h",
@@ -860,7 +907,6 @@ tf_cuda_library(
"util/work_sharder.h",
] + select({
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"util/memmapped_file_system.h",
"util/memmapped_file_system_writer.h",
@@ -915,15 +961,6 @@ cc_library(
)
cc_library(
- name = "status_util",
- hdrs = ["util/status_util.h"],
- deps = [
- ":graph",
- ":lib",
- ],
-)
-
-cc_library(
name = "reader_base",
srcs = ["framework/reader_base.cc"],
hdrs = ["framework/reader_base.h"],
@@ -1007,6 +1044,7 @@ tf_gen_op_libs(
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
+ "experimental_dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -1023,7 +1061,6 @@ tf_gen_op_libs(
"random_grad",
"random_ops",
"remote_fused_graph_ops",
- "resource_variable_ops",
"rpc_ops",
"scoped_allocator_ops",
"sdca_ops",
@@ -1034,7 +1071,6 @@ tf_gen_op_libs(
"spectral_ops",
"state_ops",
"stateless_random_ops",
- "string_ops",
"summary_ops",
"training_ops",
],
@@ -1042,6 +1078,13 @@ tf_gen_op_libs(
tf_gen_op_libs(
op_lib_names = [
+ "string_ops",
+ ],
+ deps = ["@com_google_absl//absl/strings"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = [
"array_ops",
],
deps = [":protos_all_cc"],
@@ -1059,6 +1102,14 @@ tf_gen_op_libs(
deps = ["//tensorflow/core/kernels:debug_ops"],
)
+tf_gen_op_libs(
+ is_external = False,
+ op_lib_names = [
+ "resource_variable_ops",
+ ],
+ deps = [":lib"],
+)
+
# And one for all user ops
cc_library(
name = "user_ops_op_lib",
@@ -1124,6 +1175,7 @@ cc_library(
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
+ ":experimental_dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -1190,6 +1242,7 @@ cc_library(
srcs = [
"ops/math_grad.cc",
"ops/random_grad.cc",
+ "ops/stateless_random_grad.cc",
],
linkstatic = 1, # Needed since alwayslink is broken in bazel b/27630669
visibility = ["//visibility:public"],
@@ -1253,8 +1306,8 @@ cc_library(
# This includes implementations of all kernels built into TensorFlow.
cc_library(
- name = "all_kernels",
- visibility = ["//visibility:public"],
+ name = "all_kernels_statically_linked",
+ visibility = ["//visibility:private"],
deps = [
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:audio",
@@ -1297,6 +1350,7 @@ cc_library(
"//tensorflow/core/kernels:rpc_op",
"//tensorflow/core/kernels:scoped_allocator_ops",
"//tensorflow/core/kernels:sdca_ops",
+ "//tensorflow/core/kernels:searchsorted_op",
"//tensorflow/core/kernels:set_kernels",
"//tensorflow/core/kernels:sparse",
"//tensorflow/core/kernels:state",
@@ -1322,7 +1376,9 @@ cc_library(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
+ "//tensorflow/core/kernels:mkl_transpose_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:mkl_aggregate_ops",
]) + if_cuda([
@@ -1331,6 +1387,15 @@ cc_library(
]),
)
+cc_library(
+ name = "all_kernels",
+ visibility = ["//visibility:public"],
+ deps = if_dynamic_kernels(
+ [],
+ otherwise = [":all_kernels_statically_linked"],
+ ),
+)
+
tf_cuda_library(
name = "tensorflow_opensource",
copts = tf_copts(),
@@ -1394,9 +1459,11 @@ cc_library(
":test",
":testlib_ops",
"//tensorflow/cc:scope",
+ "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/kernels:constant_op",
"//tensorflow/core/kernels:ops_testutil",
"//tensorflow/core/kernels:ops_util",
+ "//tensorflow/core/kernels:random_ops",
],
)
@@ -1556,6 +1623,7 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ ":mobile_additional_lib_deps",
":protos_all_cc_impl",
":stats_calculator_portable",
"//third_party/eigen3",
@@ -1566,6 +1634,13 @@ cc_library(
alwayslink = 1,
)
+cc_library(
+ name = "mobile_additional_lib_deps",
+ deps = tf_additional_lib_deps() + [
+ "@com_google_absl//absl/strings",
+ ],
+)
+
# Native library support for iOS applications.
#
# bazel build --config=ios_x86_64 \
@@ -1597,6 +1672,7 @@ cc_library(
copts = tf_copts() + ["-Os"] + ["-std=c++11"],
visibility = ["//visibility:public"],
deps = [
+ ":mobile_additional_lib_deps",
":protos_all_cc_impl",
":stats_calculator_portable",
"//third_party/eigen3",
@@ -1877,6 +1953,13 @@ tf_pyclif_proto_library(
)
tf_pyclif_proto_library(
+ name = "protobuf/config_pyclif",
+ proto_lib = ":protos_all_cc",
+ proto_srcfile = "protobuf/config.proto",
+ visibility = ["//visibility:public"],
+)
+
+tf_pyclif_proto_library(
name = "protobuf/device_properties_pyclif",
proto_lib = ":protos_all_cc",
proto_srcfile = "protobuf/device_properties.proto",
@@ -1993,9 +2076,6 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
"lib/io/zlib_compression_options.h",
"lib/io/zlib_inputstream.h",
"lib/io/zlib_outputbuffer.h",
- "lib/monitoring/collected_metrics.h",
- "lib/monitoring/collection_registry.h",
- "lib/monitoring/metric_def.h",
"lib/monitoring/mobile_counter.h",
"lib/monitoring/mobile_gauge.h",
"lib/monitoring/mobile_sampler.h",
@@ -2018,6 +2098,7 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [
"platform/snappy.h",
"platform/tensor_coding.h",
"platform/tracing.h",
+ "util/env_var.h",
]
# Replicated for lib_internal and lib_internal_impl.
@@ -2036,7 +2117,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//tensorflow:android": [],
"//conditions:default": [
"-ldl",
@@ -2044,7 +2124,9 @@ cc_library(
],
}),
deps = tf_additional_lib_deps() + [
+ "@com_google_absl//absl/strings",
"//third_party/eigen3",
+ "@com_google_absl//absl/base:core_headers",
"//tensorflow/core/platform/default/build_config:platformlib",
] + if_static([":lib_internal_impl"]),
)
@@ -2057,11 +2139,11 @@ cc_library(
"platform/*.cc",
"platform/profile_utils/**/*.cc",
"framework/resource_handle.cc",
+ "util/env_var.cc",
],
exclude = [
"**/*test*",
"framework/variant.cc",
- "lib/core/stringpiece.cc",
"lib/hash/crc32c_accelerate.cc",
"lib/gif/**/*",
"lib/jpeg/**/*",
@@ -2076,7 +2158,6 @@ cc_library(
) + tf_additional_lib_srcs(
exclude = [
"**/*test*",
- "lib/core/stringpiece.cc",
"platform/**/cuda.h",
"platform/**/cuda_libdevice_path.cc",
"platform/**/stream_executor.h",
@@ -2126,7 +2207,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"],
}),
deps = [
@@ -2151,7 +2231,6 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"],
}),
deps = [
@@ -2183,13 +2262,13 @@ cc_library(
linkopts = select({
"//tensorflow:freebsd": [],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": ["-ldl"],
}),
deps = [
":lib",
":lib_internal",
"//tensorflow/core/platform/default/build_config:png",
+ "@com_google_absl//absl/strings",
"@zlib_archive//:zlib",
],
)
@@ -2205,7 +2284,7 @@ cc_library(
"platform/macros.h",
"platform/platform.h",
"platform/types.h",
- ],
+ ] + if_windows(["platform/windows/integral_types.h"]),
copts = tf_copts(),
linkopts = ["-ldl"],
deps = [
@@ -2240,6 +2319,8 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:jpeg",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
],
)
@@ -2248,6 +2329,8 @@ cc_library(
srcs = if_android([
"lib/gif/gif_io.cc",
"platform/gif.h",
+ "lib/strings/strcat.h",
+ "lib/strings/numbers.h",
]),
hdrs = [
"lib/bfloat16/bfloat16.h",
@@ -2269,6 +2352,8 @@ cc_library(
deps = [
"//tensorflow/core/platform/default/build_config:gif",
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/base:core_headers",
+ "@com_google_absl//absl/strings",
],
)
@@ -2296,6 +2381,7 @@ cc_library(
linkopts = ["-ldl"],
deps = [
"//tensorflow/core/platform/default/build_config:logging",
+ "@com_google_absl//absl/strings",
"@png_archive//:png",
],
)
@@ -2305,7 +2391,6 @@ tf_proto_library(
srcs = ERROR_CODES_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
provide_cc_alias = True,
)
@@ -2326,7 +2411,6 @@ tf_proto_library(
srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":error_codes_proto",
@@ -2338,6 +2422,7 @@ tf_generate_proto_text_sources(
srcs = COMMON_PROTO_SRCS,
protodeps = ERROR_CODES_PROTO_SRCS,
srcs_relative_dir = "tensorflow/core/",
+ visibility = ["//visibility:public"],
deps = [
":error_codes_proto_text",
":lib_internal",
@@ -2405,12 +2490,13 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
"framework/resource_var.h",
+ "framework/run_handler.h",
+ "framework/run_handler_util.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
"framework/variant.h",
"util/command_line_flags.h",
- "util/env_var.h",
"util/equal_graph_def.h",
"util/presized_cuckoo_map.h",
"util/tensor_slice_set.h",
@@ -2439,6 +2525,12 @@ tf_cuda_library(
cc_header_only_library(
name = "framework_internal_headers_lib",
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
deps = [
":lib",
":lib_internal",
@@ -2450,6 +2542,7 @@ cc_header_only_library(
cc_header_only_library(
name = "core_cpu_headers_lib",
+ visibility = ["//visibility:public"],
deps = [
":core_cpu_lib",
],
@@ -2475,6 +2568,7 @@ tf_cuda_library(
"**/*test*",
"**/*main.cc",
"example/example_parser_configuration.*",
+ "example/feature_util.cc",
"util/reporter.cc",
"framework/fake_input.*",
"framework/op_gen_lib.*",
@@ -2484,10 +2578,10 @@ tf_cuda_library(
"util/memmapped_file_system_writer.*",
"util/stats_calculator.*",
"util/version_info.cc",
+ "util/env_var.cc",
],
) + select({
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
"//conditions:default": [
"util/memmapped_file_system.cc",
"util/memmapped_file_system_writer.cc",
@@ -2496,14 +2590,15 @@ tf_cuda_library(
hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS,
copts = tf_copts(),
linkopts = select({
- "//tensorflow:freebsd": [],
+ "//tensorflow:freebsd": ["-lm"],
"//tensorflow:windows": [],
- "//tensorflow:windows_msvc": [],
- "//conditions:default": ["-ldl"],
- }) + [
- "-lm",
- ],
+ "//conditions:default": [
+ "-ldl",
+ "-lm",
+ ],
+ }),
deps = [
+ ":feature_util",
":lib",
":lib_internal",
":protos_all_proto_text",
@@ -2517,17 +2612,18 @@ tf_cuda_library(
] + if_static(
extra_deps = ["@protobuf_archive//:protobuf"],
otherwise = ["@protobuf_archive//:protobuf_headers"],
- ) + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
- ),
+ ) + mkl_deps(),
alwayslink = 1,
)
cc_header_only_library(
name = "framework_headers_lib",
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
visibility = ["//visibility:public"],
deps = [
":framework",
@@ -2537,6 +2633,12 @@ cc_header_only_library(
cc_header_only_library(
name = "stream_executor_headers_lib",
+ # Fully depend on external repositories, because identifying the headers
+ # is fragile.
+ extra_deps = [
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/types:span",
+ ],
visibility = ["//visibility:public"],
deps = [
":stream_executor",
@@ -2579,6 +2681,7 @@ tf_cuda_library(
# TODO(josh11b): Is this needed, or can we just use ":protos_all_cc"?
cc_library(
name = "protos_cc",
+ visibility = ["//visibility:public"],
deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
)
@@ -2688,12 +2791,13 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/allocator_retry.h",
"common_runtime/base_collective_executor.h",
"common_runtime/bfc_allocator.h",
- "common_runtime/broadcaster.h",
+ "common_runtime/hierarchical_tree_broadcaster.h",
"common_runtime/buf_rendezvous.h",
"common_runtime/build_graph_options.h",
"common_runtime/collective_executor_mgr.h",
"common_runtime/collective_param_resolver_local.h",
"common_runtime/collective_rma_local.h",
+ "common_runtime/collective_util.h",
"common_runtime/constant_folding.h",
"common_runtime/copy_tensor.h",
"common_runtime/costmodel_manager.h",
@@ -2706,6 +2810,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/graph_optimizer.h",
"common_runtime/local_device.h",
"common_runtime/lower_if_op.h",
+ "common_runtime/lower_while_op.h",
"common_runtime/memory_types.h",
"common_runtime/mkl_cpu_allocator.h",
"common_runtime/optimization_registry.h",
@@ -2724,7 +2829,6 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/stats_publisher_interface.h",
"common_runtime/step_stats_collector.h",
"common_runtime/threadpool_device.h",
- "common_runtime/visitable_allocator.h",
"common_runtime/process_state.h",
"common_runtime/pool_allocator.h",
"graph/gradients.h",
@@ -2738,12 +2842,12 @@ tf_cuda_library(
"common_runtime/allocator_retry.cc",
"common_runtime/base_collective_executor.cc",
"common_runtime/bfc_allocator.cc",
- "common_runtime/broadcaster.cc",
"common_runtime/buf_rendezvous.cc",
"common_runtime/build_graph_options.cc",
"common_runtime/collective_executor_mgr.cc",
"common_runtime/collective_param_resolver_local.cc",
"common_runtime/collective_rma_local.cc",
+ "common_runtime/collective_util.cc",
"common_runtime/constant_folding.cc",
"common_runtime/copy_tensor.cc",
"common_runtime/costmodel_manager.cc",
@@ -2758,8 +2862,10 @@ tf_cuda_library(
"common_runtime/function.cc",
"common_runtime/graph_optimizer.cc",
"common_runtime/graph_runner.cc",
+ "common_runtime/hierarchical_tree_broadcaster.cc",
"common_runtime/local_device.cc",
"common_runtime/lower_if_op.cc",
+ "common_runtime/lower_while_op.cc",
"common_runtime/memory_types.cc",
"common_runtime/mkl_cpu_allocator.cc",
"common_runtime/optimization_registry.cc",
@@ -2803,12 +2909,7 @@ tf_cuda_library(
":protos_all_cc",
"//third_party/eigen3",
"//tensorflow/core/grappler:grappler_item",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
- ),
+ ] + mkl_deps(),
alwayslink = 1,
)
@@ -2848,12 +2949,7 @@ tf_cuda_library(
"//tensorflow/core/grappler/optimizers:meta_optimizer",
"//third_party/eigen3",
"//tensorflow/core/kernels:required",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ],
- ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
+ ] + mkl_deps() + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
alwayslink = 1,
)
@@ -2882,6 +2978,7 @@ tf_cuda_library(
":core_cpu_internal",
":device_tracer",
":framework",
+ ":framework_internal",
":graph",
":lib",
":lib_internal",
@@ -2919,7 +3016,7 @@ tf_cuda_library(
"platform/device_tracer.h",
],
copts = tf_copts(),
- cuda_deps = tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps(),
+ cuda_deps = if_cuda_is_configured(tf_additional_cupti_wrapper_deps() + tf_additional_device_tracer_cuda_deps()),
visibility = ["//visibility:private"],
deps = [
":core_cpu_internal",
@@ -2928,12 +3025,16 @@ tf_cuda_library(
] + tf_additional_device_tracer_deps(),
)
-cc_library(
- name = "session_ref",
- srcs = ["common_runtime/session_ref.cc"],
- hdrs = ["common_runtime/session_ref.h"],
- copts = tf_copts(),
- deps = [":core_cpu_base"],
+tf_proto_library_cc(
+ name = "replay_log_proto",
+ srcs = ["protobuf/replay_log.proto"],
+ cc_api_version = 2,
+ protodeps = [
+ ":master_proto",
+ ] + tf_additional_all_protos(),
+ visibility = [
+ "//tensorflow:internal",
+ ],
)
cc_library(
@@ -3146,7 +3247,10 @@ cc_library(
testonly = 1,
srcs = ["platform/test_main.cc"],
copts = tf_copts(),
- linkopts = ["-lm"],
+ linkopts = select({
+ "//tensorflow:windows": [],
+ "//conditions:default": ["-lm"],
+ }),
visibility = ["//tensorflow:internal"],
deps = [
":lib",
@@ -3191,18 +3295,15 @@ tf_cc_tests(
"lib/core/status_test.cc",
"lib/core/stringpiece_test.cc",
"lib/core/threadpool_test.cc",
- "lib/gtl/array_slice_test.cc",
"lib/gtl/cleanup_test.cc",
"lib/gtl/compactptrset_test.cc",
"lib/gtl/edit_distance_test.cc",
"lib/gtl/flatmap_test.cc",
"lib/gtl/flatset_test.cc",
- "lib/gtl/inlined_vector_test.cc",
"lib/gtl/int_type_test.cc",
"lib/gtl/iterator_range_test.cc",
"lib/gtl/manual_constructor_test.cc",
"lib/gtl/map_util_test.cc",
- "lib/gtl/optional_test.cc",
"lib/gtl/top_n_test.cc",
"lib/hash/crc32c_test.cc",
"lib/hash/hash_test.cc",
@@ -3528,7 +3629,6 @@ tf_cc_tests(
"util/semver_test.cc",
"util/sparse/sparse_tensor_test.cc",
"util/stat_summarizer_test.cc",
- "util/status_util_test.cc",
"util/tensor_format_test.cc",
"util/tensor_slice_reader_test.cc",
"util/tensor_slice_set_test.cc",
@@ -3553,7 +3653,6 @@ tf_cc_tests(
":ops",
":protos_all_cc",
":protos_test_cc",
- ":status_util",
":test",
":test_main",
":testlib",
@@ -3651,10 +3750,10 @@ tf_cc_tests_gpu(
)
tf_cc_tests_gpu(
- name = "broadcaster_test",
- size = "small",
+ name = "hierarchical_tree_broadcaster_test",
+ size = "medium",
srcs = [
- "common_runtime/broadcaster_test.cc",
+ "common_runtime/hierarchical_tree_broadcaster_test.cc",
],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags(),
@@ -3692,6 +3791,7 @@ tf_cc_test_mkl(
":core_cpu_internal",
":framework",
":framework_internal",
+ ":lib",
":test",
":test_main",
":testlib",
@@ -3738,6 +3838,7 @@ tf_cc_test_mkl(
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
+ "//tensorflow/core/kernels:mkl_slice_op",
"//tensorflow/core/kernels:mkl_softmax_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
]),
@@ -3857,11 +3958,7 @@ tf_cuda_only_cc_test(
":test",
":test_main",
"//third_party/eigen3",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- ],
- ),
+ ] + mkl_deps(),
)
tf_cc_test_gpu(
@@ -4029,6 +4126,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "framework_run_handler_util_test",
+ size = "small",
+ srcs = ["framework/run_handler_util_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ],
+)
+
tf_cuda_cc_test(
name = "common_runtime_direct_session_test",
size = "small",
@@ -4050,6 +4160,7 @@ tf_cuda_cc_test(
":testlib",
"//third_party/eigen3",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4091,6 +4202,7 @@ tf_cc_test(
"//tensorflow/cc:cc_ops",
# Link with support for TensorFlow Debugger (tfdbg).
"//tensorflow/core/debug",
+ "//tensorflow/core/kernels:collective_ops",
"//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:dense_update_ops",
@@ -4575,6 +4687,29 @@ tf_cc_tests(
],
)
+tf_cc_tests(
+ name = "common_runtime_lower_while_op_test",
+ size = "small",
+ srcs = ["common_runtime/lower_while_op_test.cc"],
+ deps = [
+ ":all_kernels",
+ ":core_cpu",
+ ":core_cpu_internal",
+ ":direct_session",
+ ":framework",
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ":testlib",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:client_session",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
+ ],
+)
+
# Test data
filegroup(
name = "image_testdata",
@@ -4650,6 +4785,18 @@ cc_library(
] + tf_additional_libdevice_deps(),
)
+transitive_hdrs(
+ name = "headers",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:stream_executor",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets go here (must be at the end).
diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc
index ae03a61ae6..51812caeb2 100644
--- a/tensorflow/core/api_def/api_test.cc
+++ b/tensorflow/core/api_def/api_test.cc
@@ -59,8 +59,8 @@ void GetGoldenApiDefs(Env* env, const string& api_files_dir,
file_contents = PBTxtFromMultiline(file_contents);
ApiDefs api_defs;
- CHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents,
- &api_defs))
+ QCHECK(tensorflow::protobuf::TextFormat::ParseFromString(file_contents,
+ &api_defs))
<< "Failed to load " << file_path;
CHECK_EQ(api_defs.op_size(), 1);
(*name_to_api_def)[api_defs.op(0).graph_op_name()] = api_defs.op(0);
diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
index b90f5473c8..6341eeda32 100644
--- a/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ApplyAdam.pbtxt
@@ -82,7 +82,7 @@ END
}
summary: "Update \'*var\' according to the Adam algorithm."
description: <<END
-$$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
new file mode 100644
index 0000000000..cdaeb5091c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesBucketize.pbtxt
@@ -0,0 +1,34 @@
+op {
+ graph_op_name: "BoostedTreesBucketize"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensor each containing float values for a single feature.
+END
+ }
+ in_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a single
+feature.
+END
+ }
+ out_arg {
+ name: "buckets"
+ description: <<END
+int; List of Rank 2 Tensors each containing the bucketized values for a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features.
+END
+ }
+ summary: "Bucketize each feature based on bucket boundaries."
+ description: <<END
+An op that returns a list of float tensors, where each tensor represents the
+bucketized values for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
new file mode 100644
index 0000000000..20da1295f6
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesCreateQuantileStreamResource.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "BoostedTreesCreateQuantileStreamResource"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; Handle to quantile stream resource.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required approximation error of the stream resource.
+END
+ }
+ in_arg {
+ name: "num_streams"
+ description: <<END
+int; The number of streams managed by the resource that shares the same epsilon.
+END
+ }
+ attr {
+ name: "max_elements"
+ description : <<END
+int; The maximum number of data points that can be fed to the stream.
+END
+ }
+ summary: "Create the Resource for Quantile Streams."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
new file mode 100644
index 0000000000..ca111af312
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesMakeQuantileSummaries.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "BoostedTreesMakeQuantileSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "float_values"
+ description: <<END
+float; List of Rank 2 Tensors each containing values for a single feature.
+END
+ }
+ in_arg {
+ name: "example_weights"
+ description: <<END
+float; Rank 1 Tensor with weights per instance.
+END
+ }
+ in_arg {
+ name: "epsilon"
+ description: <<END
+float; The required maximum approximation error.
+END
+ }
+ out_arg {
+ name: "summaries"
+ description: <<END
+float; List of Rank 2 Tensors each containing the quantile summary (value, weight,
+min_rank, max_rank) of a single feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+int; Inferred from the size of float_values.
+The number of float features.
+END
+ }
+ summary: "Makes the summary of quantiles for the batch."
+ description: <<END
+An op that takes a list of tensors and outputs the quantile summaries for each tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
new file mode 100644
index 0000000000..bbeecbf32b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceAddSummaries.pbtxt
@@ -0,0 +1,22 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "summaries"
+ description: <<END
+string; List of Rank 2 Tensor each containing the summaries for a single feature.
+END
+ }
+ summary: "Add the quantile summaries to each quantile stream resource."
+ description: <<END
+An op that adds a list of quantile summaries to a quantile stream resource. Each
+summary Tensor is rank 2, containing summaries (value, weight, min_rank, max_rank)
+for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
new file mode 100644
index 0000000000..2fd94efa10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceFlush.pbtxt
@@ -0,0 +1,31 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceFlush"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ in_arg {
+ name: "num_buckets",
+ description: <<END
+int; approximate number of buckets unless using generate_quantiles.
+END
+ }
+ attr {
+ name: "generate_quantiles"
+ description: <<END
+bool; If True, the output will be the num_quantiles for each stream where the ith
+entry is the ith quantile of the input with an approximation error of epsilon.
+Duplicate values may be present.
+If False, the output will be the points in the histogram that we got which roughly
+translates to 1/epsilon boundaries and without any duplicates.
+Default to False.
+END
+ }
+ summary: "Flush the summaries for a quantile stream resource."
+ description: <<END
+An op that flushes the summaries for a quantile stream resource.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
new file mode 100644
index 0000000000..206672802f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceGetBucketBoundaries.pbtxt
@@ -0,0 +1,27 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource handle referring to a QuantileStreamResource.
+END
+ }
+ out_arg {
+ name: "bucket_boundaries"
+ description: <<END
+float; List of Rank 1 Tensors each containing the bucket boundaries for a feature.
+END
+ }
+ attr {
+ name: "num_features"
+ description: <<END
+inferred int; number of features to get bucket boundaries for.
+END
+ }
+ summary: "Generate the bucket boundaries for each feature based on accumulated summaries."
+ description: <<END
+An op that returns a list of float tensors for a quantile stream resource. Each
+tensor is Rank 1 containing bucket boundaries for a single feature.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
new file mode 100644
index 0000000000..cb7786c051
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_BoostedTreesQuantileStreamResourceHandleOp.pbtxt
@@ -0,0 +1,5 @@
+op {
+ graph_op_name: "BoostedTreesQuantileStreamResourceHandleOp"
+ visibility: HIDDEN
+ summary: "Creates a handle to a BoostedTreesQuantileStreamResource."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
index e39213cbc7..440800704e 100644
--- a/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DecodeCSV.pbtxt
@@ -11,7 +11,8 @@ END
name: "record_defaults"
description: <<END
One tensor per column of the input record, with either a
-scalar default value for that column or empty if the column is required.
+scalar default value for that column or an empty vector if the column is
+required.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..5604a1a89e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,9 @@
+op {
+ graph_op_name: "DivNoNan"
+ summary: "Returns 0 if the denominator is zero."
+ description: <<END
+
+*NOTE*: `DivNoNan` supports broadcasting. More about broadcasting
+[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt
new file mode 100644
index 0000000000..1658472209
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_EnsureShape.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "EnsureShape"
+ in_arg {
+ name: "input"
+ description: <<END
+A tensor, whose shape is to be validated.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A tensor with the same shape and contents as the input tensor or value.
+END
+ }
+ attr {
+ name: "shape"
+ description: <<END
+The expected (possibly partially specified) shape of the input tensor.
+END
+ }
+ summary: "Ensures that the tensor's shape matches the expected shape."
+ description: <<END
+Raises an error if the input tensor's shape does not match the specified shape.
+Returns the input tensor otherwise.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000000..fa8fc96bb2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalAssertNextDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000000..5fd88e7a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalCSVDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000000..ac1f9719fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "ExperimentalDirectedInterleaveDataset"
+ in_arg {
+ name: "selector_input_dataset"
+ description: <<END
+A dataset of scalar `DT_INT64` elements that determines which of the
+`N` data inputs should produce the next output element.
+END
+ }
+ in_arg {
+ name: "data_input_datasets"
+ description: <<END
+`N` datasets with the same type that will be interleaved according to
+the values of `selector_input_dataset`.
+END
+ }
+ summary: <<END
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
new file mode 100644
index 0000000000..66511eff60
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResource"
+ in_arg {
+ name: "string_arg"
+ description: <<END
+String argument to the function call.
+END
+ }
+ in_arg {
+ name: "target_device"
+ description: <<END
+Target device to execute the function on.
+END
+ }
+ out_arg {
+ name: "resource"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name across
+multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+Function to be executed.
+END
+ }
+ attr {
+ name: "buffer_size"
+ description: <<END
+Size of the buffer.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Creates a resource that fills up a buffer by making function calls.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
new file mode 100644
index 0000000000..bf4b66b22b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceGetNext"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A list of return values.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Gets the next element from a FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
new file mode 100644
index 0000000000..729718ddb3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceReset"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ summary: <<END
+Resets the FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
new file mode 100644
index 0000000000..fe266c111f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIdentityIndexedDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000000..d42546516d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIgnoreErrorsDataset"
+ summary: <<END
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
new file mode 100644
index 0000000000..e285f87e10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetGet"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
new file mode 100644
index 0000000000..60c32473b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetMaterialize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000000..b72b229e9a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIteratorGetDevice"
+ summary: <<END
+Returns the name of the device on which `resource` has been placed.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000000..b38b23a51d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalLMDBDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
new file mode 100644
index 0000000000..9676b9d284
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000000..d73b5bfda3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolDataset"
+ in_arg {
+ name: "thread_pool"
+ description: <<END
+A resource produced by the ThreadPoolHandle op.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000000..48bf93406c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolHandle"
+ out_arg {
+ name: "handle"
+ description: <<END
+A resource that can be consumed by one or more ExperimentalThreadPoolDataset
+ops.
+END
+ }
+ attr {
+ name: "num_threads"
+ description: <<END
+The number of threads in the thread pool.
+END
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ description: <<END
+The maximum degree of parallelism to use within operations that execute on this
+threadpool.
+END
+ }
+ attr {
+ name: "display_name"
+ description: <<END
+A human-readable name for the threads that may be visible in some
+visualizations.
+threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000000..68ed797a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalUniqueDataset"
+ summary: <<END
+Creates a dataset that contains the unique elements of `input_dataset`.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
new file mode 100644
index 0000000000..3c8a455983
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExtractVolumePatches.pbtxt
@@ -0,0 +1,49 @@
+op {
+ graph_op_name: "ExtractVolumePatches"
+ in_arg {
+ name: "input"
+ description: <<END
+5-D Tensor with shape `[batch, in_planes, in_rows, in_cols, depth]`.
+END
+ }
+ out_arg {
+ name: "patches"
+ description: <<END
+5-D Tensor with shape `[batch, out_planes, out_rows, out_cols,
+ksize_planes * ksize_rows * ksize_cols * depth]` containing patches
+with size `ksize_planes x ksize_rows x ksize_cols x depth` vectorized
+in the "depth" dimension. Note `out_planes`, `out_rows` and `out_cols`
+are the dimensions of the output patches.
+END
+ }
+ attr {
+ name: "ksizes"
+ description: <<END
+The size of the sliding window for each dimension of `input`.
+END
+ }
+ attr {
+ name: "strides"
+ description: <<END
+1-D of length 5. How far the centers of two consecutive patches are in
+`input`. Must be: `[1, stride_planes, stride_rows, stride_cols, 1]`.
+END
+ }
+ attr {
+ name: "padding"
+ description: <<END
+The type of padding algorithm to use.
+
+We specify the size-related attributes as:
+
+```python
+ ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
+ strides = [1, stride_planes, strides_rows, strides_cols, 1]
+```
+END
+ }
+ summary: <<END
+Extract `patches` from `input` and put them in the "depth" output
+dimension. 3D extension of `extract_image_patches`.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
deleted file mode 100644
index ffd01ba5cc..0000000000
--- a/tensorflow/core/api_def/base_api/api_def_FeatureStatsDataset.pbtxt
+++ /dev/null
@@ -1,3 +0,0 @@
-op {
- graph_op_name: "FeatureStatsDataset"
-}
diff --git a/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt b/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
index 58262a385c..37d1a9dcbf 100644
--- a/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Fill.pbtxt
@@ -27,5 +27,15 @@ For example:
fill([2, 3], 9) ==> [[9, 9, 9]
[9, 9, 9]]
```
+
+`tf.fill` differs from `tf.constant` in a few ways:
+
+* `tf.fill` only supports scalar contents, whereas `tf.constant` supports
+ Tensor values.
+* `tf.fill` creates an Op in the computation graph that constructs the actual
+ Tensor value at runtime. This is in contrast to `tf.constant` which embeds
+ the entire Tensor into the graph with a `Const` node.
+* Because `tf.fill` evaluates at graph runtime, it supports dynamic shapes
+ based on other runtime Tensors, unlike `tf.constant`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
index a0e42dd02c..9f3f9b276b 100644
--- a/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GatherNd.pbtxt
@@ -123,5 +123,7 @@ Batched indexing into a 3-tensor:
[['a1', 'b1'], ['c1', 'd1']]]
output = [['b0', 'b1'], ['d0', 'c1']]
```
+
+See also `tf.gather` and `tf.batch_gather`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt
index 162ef2b033..c6104da4a6 100644
--- a/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_GatherV2.pbtxt
@@ -54,5 +54,7 @@ params.shape[axis + 1:]` where:
Note that on CPU, if an out of bound index is found, an error is returned.
On GPU, if an out of bound index is found, a 0 is stored in the
corresponding output value.
+
+See also `tf.batch_gather` and `tf.gather_nd`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt b/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt
new file mode 100644
index 0000000000..9d04a01f6f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_HostConst.pbtxt
@@ -0,0 +1,11 @@
+op {
+ graph_op_name: "HostConst"
+ attr {
+ name: "value"
+ description: <<END
+Attr `value` is the tensor to return.
+END
+ }
+ visibility: SKIP
+ summary: "Returns a constant tensor on the host. Only for writing C++ tests."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
index 40d7d371ca..7142a0e3f2 100644
--- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
@@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as:
where
-\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
diff --git a/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
new file mode 100644
index 0000000000..758eeb96f0
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_IsBoostedTreesQuantileStreamResourceInitialized.pbtxt
@@ -0,0 +1,20 @@
+op {
+ graph_op_name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ visibility: HIDDEN
+ in_arg {
+ name: "quantile_stream_resource_handle"
+ description: <<END
+resource; The reference to quantile stream resource handle.
+END
+ }
+ out_arg {
+ name: "is_initialized"
+ description: <<END
+bool; True if the resource is initialized, False otherwise.
+END
+ }
+ summary: "Checks whether a quantile stream has been initialized."
+ description: <<END
+An Op that checks if quantile stream resource is initialized.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
new file mode 100644
index 0000000000..5ce825ae04
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_LowerBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "LowerBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the first scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies lower_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='left')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = LowerBound(sorted_sequence, values)
+
+ result == [[1, 2, 2],
+ [0, 1, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
index 4433693759..d158f4b502 100644
--- a/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MapDefun.pbtxt
@@ -4,16 +4,23 @@ op {
in_arg {
name: "arguments"
description: <<END
- A list of tensors whose types are Targuments, corresponding to the inputs the
- function should be mapped over.
+ A list of tensors whose types are `Targuments`, corresponding to the inputs
+ the function should be mapped over.
+END
+ }
+ in_arg {
+ name: "captured_inputs"
+ description: <<END
+ A list of tensors whose types are `Tcaptured`, corresponding to the captured
+ inputs of the defun.
END
}
out_arg {
name: "output"
description: <<END
- A list of output tensors whose types are output_types and whose dimensions 0
- are the same as the dimensions 0 of the tensors in arguments, and whose
- remaining dimensions correspond to those in output_shapes.
+ A list of output tensors whose types are `output_types` and whose dimensions
+ 0 are the same as the dimensions 0 of the tensors in `arguments`, and whose
+ remaining dimensions correspond to those in `output_shapes`.
END
}
attr {
@@ -21,6 +28,10 @@ END
description: "A list of types."
}
attr {
+ name: "Tcaptured"
+ description: "A list of types."
+ }
+ attr {
name: "output_types"
description: "A list of types."
}
@@ -29,6 +40,6 @@ END
description: "A list of shapes."
}
summary: <<END
- Maps a function on the list of tensors unpacked from inputs on dimension 0.
+ Maps a function on the list of tensors unpacked from arguments on dimension 0.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt b/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
index d7b56aec87..46da1de1c3 100644
--- a/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_MatrixExponential.pbtxt
@@ -1,32 +1,5 @@
op {
graph_op_name: "MatrixExponential"
- in_arg {
- name: "input"
- description: <<END
-Shape is `[..., M, M]`.
-END
- }
- out_arg {
- name: "output"
- description: <<END
-Shape is `[..., M, M]`.
-
-@compatibility(scipy)
-Equivalent to scipy.linalg.expm
-@end_compatibility
-END
- }
- summary: "Computes the matrix exponential of one or more square matrices:"
- description: <<END
-\\(exp(A) = \sum_{n=0}^\infty A^n/n!\\)
-
-The exponential is computed using a combination of the scaling and squaring
-method and the Pade approximation. Details can be founds in:
-Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
-revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
-
-The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-form square matrices. The output is a tensor of the same shape as the input
-containing the exponential for all input submatrices `[..., :, :]`.
-END
+ visibility: SKIP
+ summary: "Deprecated, use python implementation tf.linalg.matrix_exponential."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
new file mode 100644
index 0000000000..171add16d4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ModelDataset.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "ModelDataset"
+ visibility: HIDDEN
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ summary: "Identity transformation that models performance."
+ description: <<END
+Identity transformation that models performance.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
new file mode 100644
index 0000000000..4b0a5d8f65
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIterator.pbtxt
@@ -0,0 +1,43 @@
+op {
+ graph_op_name: "MultiDeviceIterator"
+ out_arg {
+ name: "handle"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "devices"
+ description: <<END
+A list of devices the iterator works across.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name
+across multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Creates a MultiDeviceIterator resource."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
new file mode 100644
index 0000000000..adaacd8ab7
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorFromStringHandle.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "MultiDeviceIteratorFromStringHandle"
+ in_arg {
+ name: "string_handle"
+ description: <<END
+String representing the resource.
+END
+ }
+ out_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Generates a MultiDeviceIterator resource from its provided string handle."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
new file mode 100644
index 0000000000..f9be9188cc
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorGetNextFromShard.pbtxt
@@ -0,0 +1,41 @@
+op {
+ graph_op_name: "MultiDeviceIteratorGetNextFromShard"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ in_arg {
+ name: "shard_num"
+ description: <<END
+Integer representing which shard to fetch data for.
+END
+ }
+ in_arg {
+ name: "incarnation_id"
+ description: <<END
+Which incarnation of the MultiDeviceIterator is running.
+END
+ }
+ out_arg {
+ name: "components"
+ description: <<END
+Result of the get_next on the dataset.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Gets next element for the provided shard number."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
new file mode 100644
index 0000000000..6b54fa1307
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorInit.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "MultiDeviceIteratorInit"
+ in_arg {
+ name: "dataset"
+ description: <<END
+Dataset to be iterated upon.
+END
+ }
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIteratorResource.
+END
+ }
+ in_arg {
+ name: "max_buffer_size"
+ description: <<END
+The maximum size of the host side per device buffer to keep.
+END
+ }
+ out_arg {
+ name: "incarnation_id"
+ description: <<END
+An int64 indicating which incarnation of the MultiDeviceIterator
+is running.
+END
+ }
+ summary: "Initializes the multi device iterator with the given dataset."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
new file mode 100644
index 0000000000..1f1fdf99b4
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_MultiDeviceIteratorToStringHandle.pbtxt
@@ -0,0 +1,17 @@
+op {
+ graph_op_name: "MultiDeviceIteratorToStringHandle"
+ in_arg {
+ name: "multi_device_iterator"
+ description: <<END
+A MultiDeviceIterator resource.
+END
+ }
+ out_arg {
+ name: "string_handle"
+ description: <<END
+A string representing the resource.
+END
+ }
+ summary: "Produces a string handle for the given MultiDeviceIterator."
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
new file mode 100644
index 0000000000..27bc4013c3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParallelInterleaveDatasetV2.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ParallelInterleaveDatasetV2"
+ visibility: HIDDEN
+ attr {
+ name: "f"
+ description: <<END
+A function mapping elements of `input_dataset`, concatenated with
+`other_arguments`, to a Dataset variant that contains elements matching
+`output_types` and `output_shapes`.
+END
+ }
+ summary: "Creates a dataset that applies `f` to the outputs of `input_dataset`."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt
new file mode 100644
index 0000000000..3de2f18fc2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParseExampleDataset.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ParseExampleDataset"
+ in_arg {
+ name: "dense_defaults"
+ description: <<END
+A dict mapping string keys to `Tensor`s.
+The keys of the dict must match the dense_keys of the feature.
+END
+ }
+ attr {
+ name: "sparse_keys"
+ description: <<END
+A list of string keys in the examples features.
+The results for these keys will be returned as `SparseTensor` objects.
+END
+ }
+ attr {
+ name: "dense_keys"
+ description: <<END
+A list of Ndense string Tensors (scalars).
+The keys expected in the Examples features associated with dense values.
+END
+ }
+ attr {
+ name: "sparse_types"
+ description: <<END
+A list of `DTypes` of the same length as `sparse_keys`.
+Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+and `tf.string` (`BytesList`) are supported.
+END
+ }
+ attr {
+ name: "Tdense"
+ description: <<END
+A list of DTypes of the same length as `dense_keys`.
+Only `tf.float32` (`FloatList`), `tf.int64` (`Int64List`),
+and `tf.string` (`BytesList`) are supported.
+
+END
+ }
+ attr {
+ name: "dense_shapes"
+ description: <<END
+List of tuples with the same length as `dense_keys`.
+The shape of the data for each dense feature referenced by `dense_keys`.
+Required for any input tensors identified by `dense_keys`. Must be
+either fully defined, or may contain an unknown first dimension.
+An unknown first dimension means the feature is treated as having
+a variable number of blocks, and the output shape along this dimension
+is considered unknown at graph build time. Padding is applied for
+minibatch elements smaller than the maximum number of blocks for the
+given feature along this dimension.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ attr {
+ name: "output_shapes"
+ description: <<END
+The list of shapes being produced.
+END
+ }
+ summary: "Transforms `input_dataset` containing `Example` protos as vectors of DT_STRING into a dataset of `Tensor` or `SparseTensor` objects representing the parsed features."
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt
new file mode 100644
index 0000000000..b1cb9a696d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ParseSequenceExample.pbtxt
@@ -0,0 +1,112 @@
+op {
+ graph_op_name: "ParseSequenceExample"
+ in_arg {
+ name: "serialized"
+ description: <<END
+A vector containing binary serialized SequenceExample protos.
+END
+ }
+ in_arg {
+ name: "debug_name"
+ description: <<END
+A vector containing the names of the serialized protos.
+May contain, for example, table key (descriptive) name for the
+corresponding serialized proto. This is purely useful for debugging
+purposes, and the presence of values here has no effect on the output.
+May also be an empty vector if no name is available.
+END
+ }
+ in_arg {
+ name: "context_dense_defaults"
+ description: <<END
+A list of Ncontext_dense Tensors (some may be empty).
+context_dense_defaults[j] provides default values
+when the SequenceExample's context map lacks context_dense_key[j].
+If an empty Tensor is provided for context_dense_defaults[j],
+then the Feature context_dense_keys[j] is required.
+The input type is inferred from context_dense_defaults[j], even when it's
+empty. If context_dense_defaults[j] is not empty, its shape must match
+context_dense_shapes[j].
+END
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ description: <<END
+A vector listing the
+FeatureList keys which may be missing from the SequenceExamples. If the
+associated FeatureList is missing, it is treated as empty. By default,
+any FeatureList not listed in this vector must exist in the SequenceExamples.
+END
+ }
+ attr {
+ name: "context_sparse_keys"
+ description: <<END
+A list of Ncontext_sparse string Tensors (scalars).
+The keys expected in the Examples' features associated with context_sparse
+values.
+END
+ }
+ attr {
+ name: "context_dense_keys"
+ description: <<END
+A list of Ncontext_dense string Tensors (scalars).
+The keys expected in the SequenceExamples' context features associated with
+dense values.
+END
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ description: <<END
+A list of Nfeature_list_sparse string Tensors
+(scalars). The keys expected in the FeatureLists associated with sparse
+values.
+END
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ description: <<END
+A list of Nfeature_list_dense string Tensors (scalars).
+The keys expected in the SequenceExamples' feature_lists associated
+with lists of dense values.
+END
+ }
+ attr {
+ name: "context_sparse_types"
+ description: <<END
+A list of Ncontext_sparse types; the data types of data in
+each context Feature given in context_sparse_keys.
+Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+DT_INT64 (Int64List), and DT_STRING (BytesList).
+END
+ }
+ attr {
+ name: "context_dense_shapes"
+ description: <<END
+A list of Ncontext_dense shapes; the shapes of data in
+each context Feature given in context_dense_keys.
+The number of elements in the Feature corresponding to context_dense_key[j]
+must always equal context_dense_shapes[j].NumEntries().
+The shape of context_dense_values[j] will match context_dense_shapes[j].
+END
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ description: <<END
+A list of Nfeature_list_sparse types; the data types
+of data in each FeatureList given in feature_list_sparse_keys.
+Currently the ParseSingleSequenceExample supports DT_FLOAT (FloatList),
+DT_INT64 (Int64List), and DT_STRING (BytesList).
+END
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ description: <<END
+A list of Nfeature_list_dense shapes; the shapes of
+data in each FeatureList given in feature_list_dense_keys.
+The shape of each Feature in the FeatureList corresponding to
+feature_list_dense_key[j] must always equal
+feature_list_dense_shapes[j].NumEntries().
+END
+ }
+ summary: "Transforms a vector of brain.SequenceExample protos (as strings) into typed tensors."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..4cb8955dcb
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,19 @@
+op {
+ graph_op_name: "PrintV2"
+ in_arg {
+ name: "input"
+ description: <<END
+The string scalar to print.
+END
+ }
+ attr {
+ name: "output_stream"
+ description: <<END
+A string specifying the output stream or logging level to print to.
+END
+ }
+ summary: "Prints a string scalar."
+ description: <<END
+Prints a string scalar to the desired output_stream.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
new file mode 100644
index 0000000000..08414b3e68
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ReduceDataset.pbtxt
@@ -0,0 +1,26 @@
+op {
+ visibility: HIDDEN
+ graph_op_name: "ReduceDataset"
+ in_arg {
+ name: "input_dataset"
+ description: <<END
+A variant tensor representing the input dataset.
+END
+ }
+ in_arg {
+ name: "initial_state"
+ description: <<END
+A nested structure of tensors, representing the initial state of the
+transformation.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+A function that maps `(old_state, input_element)` to `new_state`. It must take
+two arguments and return a nested structures of tensors. The structure of
+`new_state` must match the structure of `initial_state`.
+END
+ }
+ summary: "Reduces the input dataset to a singleton using a reduce function."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
index 8cef243aee..30fd97a0d7 100644
--- a/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_RegexFullMatch.pbtxt
@@ -9,7 +9,7 @@ END
in_arg {
name: "pattern"
description: <<END
-A 1-D string tensor of the regular expression to match the input.
+A scalar string tensor containing the regular expression to match the input.
END
}
out_arg {
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt
index ad0aeac004..2dcd136ae3 100644
--- a/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceApplyAdam.pbtxt
@@ -76,7 +76,7 @@ END
}
summary: "Update \'*var\' according to the Adam algorithm."
description: <<END
-$$lr_t := \text{learning_rate} * \sqrt{(1 - beta_2^t) / (1 - beta_1^t)}$$
+$$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
$$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
$$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
$$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
index 1a75e67c0c..e400c7402b 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterNdUpdate.pbtxt
@@ -70,5 +70,7 @@ The resulting update to ref would look like this:
See `tf.scatter_nd` for more details about how to make updates to
slices.
+
+See also `tf.scatter_update` and `tf.batch_scatter_update`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
index 4804908afc..4037dee432 100644
--- a/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_ScatterUpdate.pbtxt
@@ -59,5 +59,7 @@ Requires `updates.shape = indices.shape + ref.shape[1:]` or `updates.shape = []`
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/ScatterUpdate.png" alt>
</div>
+
+See also `tf.batch_scatter_update` and `tf.scatter_nd_update`.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
index 5e2912fcdd..d33a36ce06 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMax.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the maximum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \max_j(data_j)\\) where `max` is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
index a7d85b3f4e..afdc39da96 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMean.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the mean along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \frac{\sum_j data_j}{N}\\) where `mean` is
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
index 74fc598218..026b5b3991 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentMin.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the minimum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \min_j(data_j)\\) where `min` is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
index 4c4363e524..a168eed87f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentProd.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the product along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \prod_j data_j\\) where the product is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
index 583ab3904f..876b860824 100644
--- a/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SegmentSum.pbtxt
@@ -3,7 +3,7 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
+A 1-D tensor whose size is equal to the size of `data`'s
first dimension. Values should be sorted and can be repeated.
END
}
@@ -16,8 +16,9 @@ END
}
summary: "Computes the sum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
\\(output_i = \sum_j data_j\\) where sum is over `j` such
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
index 866e04e97b..138a6366c8 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMean.pbtxt
@@ -21,8 +21,9 @@ END
}
summary: "Computes the mean along sparse segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Like `SegmentMean`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
index af4bc75fa0..b8073d88ac 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentMeanWithNumSegments.pbtxt
@@ -30,7 +30,8 @@ END
Like `SparseSegmentMean`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
index 194bcea726..945bbdcf62 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtN.pbtxt
@@ -23,7 +23,8 @@ END
description: <<END
N is the size of the segment being reduced.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
index 8b502928a5..ff328c8a61 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSqrtNWithNumSegments.pbtxt
@@ -32,7 +32,8 @@ N is the size of the segment being reduced.
Like `SparseSegmentSqrtN`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
index dfd50bf273..a68e14607f 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSum.pbtxt
@@ -21,8 +21,9 @@ END
}
summary: "Computes the sum along sparse segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Like `SegmentSum`, but `segment_ids` can have rank less than `data`'s first
dimension, selecting a subset of dimension 0, specified by `indices`.
diff --git a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
index 3bc16577ff..aa5c1fc8d0 100644
--- a/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_SparseSegmentSumWithNumSegments.pbtxt
@@ -30,8 +30,9 @@ END
Like `SparseSegmentSum`, but allows missing ids in `segment_ids`. If an id is
misisng, the `output` tensor at that position will be zeroed.
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
For example:
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
new file mode 100644
index 0000000000..6d9d9908ca
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexFullMatch.pbtxt
@@ -0,0 +1,29 @@
+op {
+ graph_op_name: "StaticRegexFullMatch"
+ in_arg {
+ name: "input"
+ description: <<END
+A string tensor of the text to be processed.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A bool tensor with the same shape as `input`.
+END
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ summary: "Check if the input matches the regex pattern."
+ description: <<END
+The input is a string tensor of any shape. The pattern is the
+regular expression to be matched with every element of the input tensor.
+The boolean values (True or False) of the output tensor indicate
+if the input matches the regex pattern provided.
+
+The pattern follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt b/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt
new file mode 100644
index 0000000000..e382bcec81
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StaticRegexReplace.pbtxt
@@ -0,0 +1,26 @@
+op {
+ graph_op_name: "StaticRegexReplace"
+ in_arg {
+ name: "input"
+ description: "The text to be processed."
+ }
+ out_arg {
+ name: "output"
+ description: "The text after applying pattern and rewrite."
+ }
+ attr {
+ name: "pattern"
+ description: "The regular expression to match the input."
+ }
+ attr {
+ name: "rewrite"
+ description: "The rewrite to be applied to the matched expresion."
+ }
+ attr {
+ name: "replace_global"
+ description: "If True, the replacement is global, otherwise the replacement\nis done only on the first match."
+ }
+ summary: "Replaces the match of pattern in input with rewrite."
+ description: "It follows the re2 syntax (https://github.com/google/re2/wiki/Syntax)"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
index 8d6fc04847..9a89a4e8e7 100644
--- a/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_StridedSlice.pbtxt
@@ -32,7 +32,7 @@ END
description: <<END
a bitmask where a bit i being 1 means to ignore the begin
value and instead use the largest interval possible. At runtime
-begin[i] will be replaced with `[0, n-1) if `stride[i] > 0` or
+begin[i] will be replaced with `[0, n-1)` if `stride[i] > 0` or
`[-1, n-1]` if `stride[i] < 0`
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..a82dae9e48
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,38 @@
+op {
+ graph_op_name: "StringFormat"
+ in_arg {
+ name: "inputs"
+ description: <<END
+The list of tensors to format into the placeholder string.
+END
+ }
+
+ out_arg {
+ name: "output"
+ description: <<END
+= The resulting string scalar.
+END
+ }
+ attr {
+ name: "template"
+ description: <<END
+A string, the template to format tensor summaries into.
+END
+ }
+ attr {
+ name: "placeholder"
+ description: <<END
+A string, at each placeholder in the template a subsequent tensor summary will be inserted.
+END
+ }
+ attr {
+ name: "summarize"
+ description: <<END
+When formatting the tensor summaries print the first and last summarize entries of each tensor dimension.
+END
+ }
+ summary: "Formats a string template using a list of tensors."
+ description: <<END
+Formats a string template using a list of tensors, pretty-printing tensor summaries.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
new file mode 100644
index 0000000000..7d2fbcd00b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_StringLength.pbtxt
@@ -0,0 +1,30 @@
+op {
+ graph_op_name: "StringLength"
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is counted to compute string length. One of: `"BYTE"` (for
+the number of bytes in each string) or `"UTF8_CHAR"` (for the number of UTF-8
+encoded Unicode code points in each string). Results are undefined
+if `unit=UTF8_CHAR` and the `input` strings do not contain structurally
+valid UTF-8.
+END
+ }
+ in_arg {
+ name: "input"
+ description: <<END
+The string for which to compute the length.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+Integer tensor that has the same shape as `input`. The output contains the
+element-wise string lengths of `input`.
+END
+ }
+ summary: "String lengths of `input`."
+ description: <<END
+Computes the length of each string given in the input tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
index 8fc1e5cba3..fe0fcc9508 100644
--- a/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Substr.pbtxt
@@ -18,6 +18,16 @@ END
Scalar defining the number of characters to include in each substring
END
}
+ attr {
+ name: "unit"
+ description: <<END
+The unit that is used to create the substring. One of: `"BYTE"` (for
+defining position and length by bytes) or `"UTF8_CHAR"` (for the UTF-8
+encoded Unicode code points). The default is `"BYTE"`. Results are undefined if
+`unit=UTF8_CHAR` and the `input` strings do not contain structurally valid
+UTF-8.
+END
+ }
out_arg {
name: "output"
description: <<END
@@ -32,8 +42,10 @@ For each string in the input `Tensor`, creates a substring starting at index
If `len` defines a substring that would extend beyond the length of the input
string, then as many characters as possible are used.
-If `pos` is negative or specifies a character index larger than any of the input
-strings, then an `InvalidArgumentError` is thrown.
+A negative `pos` indicates distance within the string backwards from the end.
+
+If `pos` specifies an index which is out of range for any of the input strings,
+then an `InvalidArgumentError` is thrown.
`pos` and `len` must have the same shape, otherwise a `ValueError` is thrown on
Op creation.
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..3022fccb1e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,12 @@
+op {
+ graph_op_name: "TensorListGather"
+ summary: "Creates a Tensor by indexing into the TensorList."
+ description: <<END
+Each row in the produced Tensor corresponds to the element in the TensorList
+specified by the given index (see `tf.gather`).
+
+input_handle: The input tensor list.
+indices: The indices used to index into the list.
+values: The tensor.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..35194b353e
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,14 @@
+op {
+ graph_op_name: "TensorListScatter"
+ summary: "Creates a TensorList by indexing into a Tensor."
+ description: <<END
+Each member of the TensorList corresponds to one row of the input tensor,
+specified by the given index (see `tf.gather`).
+
+tensor: The input tensor.
+indices: The indices used to index into the list.
+element_shape: The shape of the elements in the list (can be less specified than
+ the shape of the tensor).
+output_handle: The TensorList.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..7898fe8d6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,28 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "UnicodeScript"
+ }
+ in_arg {
+ name: "input"
+ description: <<END
+A Tensor of int32 Unicode code points.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A Tensor of int32 script codes corresponding to each input code point.
+END
+ }
+ summary: <<END
+Determine the script codes of a given tensor of Unicode integer code points.
+END
+ description: <<END
+This operation converts Unicode code points to script codes corresponding to
+each code point. Script codes correspond to International Components for
+Unicode (ICU) UScriptCode values. See http://icu-project.org/apiref/icu4c/uscript_8h.html.
+Returns -1 (USCRIPT_INVALID_CODE) for invalid codepoints. Output shape will
+match input shape.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
index 4ca6780c95..7a60e4387a 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMax.pbtxt
@@ -3,33 +3,36 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
-END
+A tensor whose shape is a prefix of `data.shape`.END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the maximum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the maximum such that:
-\\(output_i = \max_j data_j\\) where max is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \max_{j...} data[j...]\\) where max is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the maximum is empty for a given segment ID `i`, it outputs the smallest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::lowest()`.
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
+
<div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
<img style="width:100%" src="https://www.tensorflow.org/images/UnsortedSegmentMax.png" alt>
</div>
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
index 55ea69b5dd..7e139ddf4d 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentMin.pbtxt
@@ -3,31 +3,35 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the minimum along segments of a tensor."
description: <<END
-Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the minimum such that:
-\\(output_i = \min_j data_j\\) where min is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \min_{j...} data_[j...]\\) where min is over tuples `j...` such
+that `segment_ids[j...] == i`.
If the minimum is empty for a given segment ID `i`, it outputs the largest
possible value for the specific numeric type,
`output[i] = numeric_limits<T>::max()`.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
index 577ff53d60..9c8ea3b620 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentProd.pbtxt
@@ -3,30 +3,34 @@ op {
in_arg {
name: "segment_ids"
description: <<END
-A 1-D tensor whose rank is equal to the rank of `data`'s
-first dimension.
+A tensor whose shape is a prefix of `data.shape`.
END
}
out_arg {
name: "output"
description: <<END
-Has same shape as data, except for dimension 0 which
-has size `num_segments`.
+Has same shape as data, except for the first `segment_ids.rank`
+dimensions, which are replaced with a single dimension which has size
+`num_segments`.
END
}
summary: "Computes the product along segments of a tensor."
description: <<END
-Read @{$math_ops#segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#segmentation)
+for an explanation of segments.
This operator is similar to the unsorted segment sum operator found
[(here)](../../../api_docs/python/math_ops.md#UnsortedSegmentSum).
Instead of computing the sum over segments, it computes the product of all
entries belonging to a segment such that:
-\\(output_i = \prod_j data_j\\) where the product is over `j` such
-that `segment_ids[j] == i`.
+\\(output_i = \prod_{j...} data[j...]\\) where the product is over tuples
+`j...` such that `segment_ids[j...] == i`.
If there is no entry for a given segment ID `i`, it outputs 1.
+
+If the given segment ID `i` is negative, then the corresponding value is
+dropped, and will not be included in the result.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
index 9aeabd030d..7e5d9265c2 100644
--- a/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_UnsortedSegmentSum.pbtxt
@@ -16,11 +16,12 @@ END
}
summary: "Computes the sum along segments of a tensor."
description: <<END
-Read @{$math_ops#Segmentation$the section on segmentation} for an explanation of
-segments.
+Read
+[the section on segmentation](https://tensorflow.org/api_guides/python/math_ops#Segmentation)
+for an explanation of segments.
Computes a tensor such that
-\\(output[i] = sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
+\\(output[i] = \sum_{j...} data[j...]\\) where the sum is over tuples `j...` such
that `segment_ids[j...] == i`. Unlike `SegmentSum`, `segment_ids`
need not be sorted and need not cover all values in the full
range of valid values.
diff --git a/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
new file mode 100644
index 0000000000..0630f6e82a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_UpperBound.pbtxt
@@ -0,0 +1,45 @@
+op {
+ graph_op_name: "UpperBound"
+ visibility: HIDDEN
+ in_arg {
+ name: "sorted_inputs"
+ description: <<END
+2-D Tensor where each row is ordered.
+END
+ }
+ in_arg {
+ name: "values"
+ description: <<END
+2-D Tensor with the same numbers of rows as `sorted_search_values`. Contains
+the values that will be searched for in `sorted_search_values`.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A `Tensor` with the same shape as `values`. It contains the last scalar index
+into the last dimension where values can be inserted without changing the
+ordered property.
+END
+ }
+ summary: "Applies upper_bound(sorted_search_values, values) along each row."
+ description: <<END
+Each set of rows with the same index in (sorted_inputs, values) is treated
+independently. The resulting row is the equivalent of calling
+`np.searchsorted(sorted_inputs, values, side='right')`.
+
+The result is not a global index to the entire
+`Tensor`, but rather just the index in the last dimension.
+
+A 2-D example:
+ sorted_sequence = [[0, 3, 9, 9, 10],
+ [1, 2, 3, 4, 5]]
+ values = [[2, 4, 9],
+ [0, 2, 6]]
+
+ result = UpperBound(sorted_sequence, values)
+
+ result == [[1, 2, 4],
+ [0, 2, 5]]
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
index 1bc3660479..01387b7527 100644
--- a/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_WindowDataset.pbtxt
@@ -2,10 +2,31 @@ op {
visibility: HIDDEN
graph_op_name: "WindowDataset"
in_arg {
- name: "window_size"
+ name: "size"
description: <<END
A scalar representing the number of elements to accumulate in a window.
END
}
+ in_arg {
+ name: "shift"
+ description: <<END
+A scalar representing the steps moving the sliding window forward in one
+iteration. It must be positive.
+END
+ }
+ in_arg {
+ name: "stride"
+ description: <<END
+A scalar representing the stride of the input elements of the sliding window.
+It must be positive.
+END
+ }
+ in_arg {
+ name: "drop_remainder"
+ description: <<END
+A scalar representing whether a window should be dropped in case its size is
+smaller than desired.
+END
+ }
summary: "A dataset that creates window datasets from the input dataset."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..ca107abc6b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xdivy"
+ summary: "Returns 0 if x == 0, and x / y otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..da625f7836
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "Xlogy"
+ summary: "Returns 0 if x == 0, and x * log(y) otherwise, elementwise."
+}
diff --git a/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
new file mode 100644
index 0000000000..3d937c745c
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_defTensorListPushBackBatch.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBackBatch"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
index 9552fc92e3..e395e333bf 100644
--- a/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_BatchToSpaceND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "BatchToSpaceND"
endpoint {
- name: "manip.batch_to_space_nd"
+ name: "batch_to_space_nd"
}
endpoint {
- name: "batch_to_space_nd"
+ name: "manip.batch_to_space_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
new file mode 100644
index 0000000000..1bf3fba3c6
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_DivNoNan.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "DivNoNan"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
new file mode 100644
index 0000000000..44f25b5d93
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EmptyTensorList.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EmptyTensorList"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt
new file mode 100644
index 0000000000..4414d973ac
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_EnsureShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "EnsureShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
deleted file mode 100644
index 7f721f4fb7..0000000000
--- a/tensorflow/core/api_def/python_api/api_def_FeatureStatsDataset.pbtxt
+++ /dev/null
@@ -1,4 +0,0 @@
-op {
- graph_op_name: "FeatureStatsDataset"
- visibility: HIDDEN
-}
diff --git a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
index 71257c8855..598f23bde3 100644
--- a/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_GatherNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "GatherNd"
endpoint {
- name: "manip.gather_nd"
+ name: "gather_nd"
}
endpoint {
- name: "gather_nd"
+ name: "manip.gather_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt
new file mode 100644
index 0000000000..45826b6fdc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseExampleDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ParseExampleDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt b/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt
new file mode 100644
index 0000000000..4a7e75ba0e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ParseSequenceExample.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ParseSequenceExample"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
new file mode 100644
index 0000000000..e22d980424
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PrintV2.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "PrintV2"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
index b17806b338..5020844204 100644
--- a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "RegexReplace"
- endpoint {
- name: "strings.regex_replace"
- }
- endpoint {
- name: "regex_replace"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
index c469665b66..b3d596de7a 100644
--- a/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Reshape.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Reshape"
endpoint {
- name: "manip.reshape"
+ name: "reshape"
}
endpoint {
- name: "reshape"
+ name: "manip.reshape"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
index 77f595927b..51478b7c34 100644
--- a/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ReverseV2.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ReverseV2"
endpoint {
- name: "manip.reverse"
+ name: "reverse"
}
endpoint {
- name: "reverse"
+ name: "manip.reverse"
deprecated: true
}
endpoint {
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
index a65a19b542..85888da45a 100644
--- a/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNd.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "ScatterNd"
endpoint {
- name: "manip.scatter_nd"
+ name: "scatter_nd"
}
endpoint {
- name: "scatter_nd"
+ name: "manip.scatter_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt
new file mode 100644
index 0000000000..c1edef8c9d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterNdSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterNdSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt b/tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt
new file mode 100644
index 0000000000..f1a4cccbc3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_ScatterSub.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ScatterSub"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
index af323a6cf3..146b97f444 100644
--- a/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_SpaceToBatchND.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "SpaceToBatchND"
endpoint {
- name: "manip.space_to_batch_nd"
+ name: "space_to_batch_nd"
}
endpoint {
- name: "space_to_batch_nd"
+ name: "manip.space_to_batch_nd"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
new file mode 100644
index 0000000000..d3c70190dd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessMultinomial"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
new file mode 100644
index 0000000000..e294325fb8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
new file mode 100644
index 0000000000..95d414c54a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomUniform"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
new file mode 100644
index 0000000000..c72bdda94a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessTruncatedNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
new file mode 100644
index 0000000000..8f0b1db45d
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringFormat.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringFormat"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
new file mode 100644
index 0000000000..df012414e3
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StringLength.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StringLength"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
index 4778d7927c..4fb9ee56e9 100644
--- a/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Substr.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "Substr"
- endpoint {
- name: "strings.substr"
- }
- endpoint {
- name: "substr"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
new file mode 100644
index 0000000000..45fc55e71e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListConcatLists.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListConcatLists"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
new file mode 100644
index 0000000000..e1ad713e7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListElementShape.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListElementShape"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
new file mode 100644
index 0000000000..4aaefba3c5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListFromTensor.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListFromTensor"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
new file mode 100644
index 0000000000..aaf607d70e
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGather.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGather"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
new file mode 100644
index 0000000000..3bb5f39cbc
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListGetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListGetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
new file mode 100644
index 0000000000..a04c20bb8a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListLength.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListLength"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
new file mode 100644
index 0000000000..9287162f22
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPopBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPopBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
new file mode 100644
index 0000000000..da2bc11721
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListPushBack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListPushBack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
new file mode 100644
index 0000000000..77e63747d5
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListReserve.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListReserve"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
new file mode 100644
index 0000000000..0015189d7f
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListScatter.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListScatter"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
new file mode 100644
index 0000000000..4999ee7ad9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListSetItem.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListSetItem"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
new file mode 100644
index 0000000000..2dc7b2784b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_TensorListStack.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "TensorListStack"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
index c34061c941..1d8695f1fd 100644
--- a/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_Tile.pbtxt
@@ -1,10 +1,10 @@
op {
graph_op_name: "Tile"
endpoint {
- name: "manip.tile"
+ name: "tile"
}
endpoint {
- name: "tile"
+ name: "manip.tile"
deprecated: true
}
}
diff --git a/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
new file mode 100644
index 0000000000..a884a46143
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_UnicodeScript.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "UnicodeScript"
+ endpoint {
+ name: "strings.unicode_script"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
new file mode 100644
index 0000000000..984442ba2b
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xdivy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xdivy"
+ endpoint {
+ name: "math.xdivy"
+ }
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
new file mode 100644
index 0000000000..b4a5299256
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_Xlogy.pbtxt
@@ -0,0 +1,6 @@
+op {
+ graph_op_name: "Xlogy"
+ endpoint {
+ name: "math.xlogy"
+ }
+}
diff --git a/tensorflow/core/common_runtime/base_collective_executor.cc b/tensorflow/core/common_runtime/base_collective_executor.cc
index 637b43c844..5b01f7fa03 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.cc
+++ b/tensorflow/core/common_runtime/base_collective_executor.cc
@@ -14,13 +14,28 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/base_collective_executor.h"
-#include "tensorflow/core/common_runtime/broadcaster.h"
+#include <algorithm>
+#include <functional>
+#include <utility>
+
#include "tensorflow/core/common_runtime/copy_tensor.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/common_runtime/ring_reducer.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
#define VALUE_IN_DEBUG_STRING false
@@ -83,7 +98,7 @@ class CollectiveAdapterImpl : public CollectiveAdapter {
// If necessary, flatten output.
void Flatten() {
- if (old_shape_.dims() > 1) {
+ if (old_shape_.dims() != 1) {
TensorShape new_shape = TensorShape({old_shape_.num_elements()});
DMAHelper::UnsafeSetShape(&output_, new_shape);
}
@@ -211,104 +226,67 @@ void BaseCollectiveExecutor::ExecuteAsync(OpKernelContext* ctx,
};
Tensor* output = ctx->mutable_output(0);
- string error;
- switch (col_params.instance.type) {
- case REDUCTION_COLLECTIVE: {
- // TODO(tucker): support other reduction algorithms,
- // e.g. tree-reduce, hybrid tree/ring, delegate-to-NCCL, etc.
- const Tensor* input = &ctx->input(0);
- RingReducer* reducer =
- CreateReducer(ctx, CtxParams(ctx), col_params, exec_key, step_id_,
- input, output, &error);
- if (!reducer) {
- done_safe(errors::Internal(error));
- return;
- }
- // Run in an I/O thread, so as not to starve the executor threads.
- // TODO(tucker): Instead of forking every per-device Collective
- // Op off into its own thread, consider queuing them on a
- // fixed-size thread-pool dedicated to running CollectiveOps.
- SchedClosure([reducer, done_safe]() {
- reducer->Run([reducer, done_safe](const Status& s) {
- done_safe(s);
- delete reducer;
- });
- });
- } break;
-
- case BROADCAST_COLLECTIVE: {
- Broadcaster* broadcaster = CreateBroadcaster(
- ctx, CtxParams(ctx), col_params, exec_key, step_id_, output, &error);
- if (!broadcaster) {
- done_safe(errors::Internal(error));
- return;
- }
- // Run in an I/O thread, so as not to starve the executor threads.
- SchedClosure([broadcaster, done_safe]() {
- broadcaster->Run([broadcaster, done_safe](const Status& s) {
- done_safe(s);
- delete broadcaster;
- });
- });
- } break;
-
- default:
- done_safe(errors::Internal("Unimplemented CollectiveType ",
- col_params.instance.type));
+ const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE ||
+ (col_params.instance.type == BROADCAST_COLLECTIVE &&
+ col_params.is_source))
+ ? &ctx->input(0)
+ : nullptr;
+ CollectiveImplementationInterface* col_impl = nullptr;
+ Status status = CreateCollective(col_params, &col_impl);
+ if (!status.ok()) {
+ done_safe(status);
+ DCHECK_EQ(nullptr, col_impl);
+ return;
}
-}
-
-RingReducer* BaseCollectiveExecutor::CreateReducer(
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output, string* error) {
- switch (col_params.instance.data_type) {
- case DT_INT32:
- if (col_params.group.device_type == DEVICE_GPU) {
- *error =
- "Collective Reduce does not support datatype DT_INT32 on "
- "DEVICE_GPU";
- return nullptr;
- }
- TF_FALLTHROUGH_INTENDED;
- case DT_FLOAT:
- case DT_DOUBLE:
- case DT_INT64:
- return new RingReducer(this, dev_mgr_, ctx, params, col_params, exec_key,
- step_id, input, output);
- break;
- default:
- *error = strings::StrCat("Collective Reduce does not support datatype ",
- col_params.instance.data_type);
- return nullptr;
+ CollectiveContext* col_ctx =
+ new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params,
+ exec_key, step_id_, input, output);
+ status = col_impl->InitializeCollectiveContext(col_ctx);
+ if (!status.ok()) {
+ done_safe(status);
+ delete col_ctx;
+ delete col_impl;
+ return;
}
+ // Run in an I/O thread, so as not to starve the executor threads.
+ // TODO(b/80529858): Instead of forking every per-device Collective
+ // Op off into its own thread, consider queuing them on a
+ // fixed-size thread-pool dedicated to running CollectiveOps.
+ SchedClosure([col_impl, col_ctx, done_safe]() {
+ col_impl->Run([col_impl, col_ctx, done_safe](const Status& s) {
+ done_safe(s);
+ delete col_ctx;
+ delete col_impl;
+ });
+ });
}
-Broadcaster* BaseCollectiveExecutor::CreateBroadcaster(
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key, int64 step_id,
- Tensor* output, string* error) {
+Status BaseCollectiveExecutor::CreateCollective(
+ const CollectiveParams& col_params,
+ CollectiveImplementationInterface** col_impl) {
+ *col_impl = nullptr;
+ Status status;
switch (col_params.instance.data_type) {
case DT_INT32:
if (col_params.group.device_type == DEVICE_GPU) {
- *error =
- "Collective Broadcast does not support datatype DT_INT32 on "
- "DEVICE_GPU";
- return nullptr;
+ status = errors::Internal(
+ "CollectiveImplementation does not support datatype DT_INT32 on "
+ "DEVICE_GPU");
}
TF_FALLTHROUGH_INTENDED;
case DT_FLOAT:
case DT_DOUBLE:
case DT_INT64: {
- return new Broadcaster(this, dev_mgr_, ctx, params, col_params, exec_key,
- step_id, output);
- } break;
+ status = CollectiveRegistry::Lookup(
+ col_params.instance.impl_details.collective_name, col_impl);
+ break;
+ }
default:
- *error =
- strings::StrCat("Collective Broadcast does not support datatype ",
- DataTypeString(col_params.instance.data_type));
- return nullptr;
+ status = errors::Internal(
+ "CollectiveImplementation does not support datatype ",
+ col_params.instance.data_type);
}
+ return status;
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/base_collective_executor.h b/tensorflow/core/common_runtime/base_collective_executor.h
index 3af9286264..360ce4db7b 100644
--- a/tensorflow/core/common_runtime/base_collective_executor.h
+++ b/tensorflow/core/common_runtime/base_collective_executor.h
@@ -15,15 +15,17 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_BASE_COLLECTIVE_EXECUTOR_H_
+#include <memory>
#include <string>
+
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class Broadcaster;
+class CollectiveImplementation;
class DeviceMgr;
-class RingReducer;
+class Device;
// Helper interface that aliases regular subfields of a Tensor as separate
// Tensors for in-place update.
@@ -133,18 +135,8 @@ class BaseCollectiveExecutor : public CollectiveExecutor {
std::unique_ptr<PerStepCollectiveRemoteAccess> remote_access_;
private:
- RingReducer* CreateReducer(OpKernelContext* ctx,
- OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output,
- string* error);
-
- Broadcaster* CreateBroadcaster(OpKernelContext* ctx,
- OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- Tensor* output, string* error);
+ Status CreateCollective(const CollectiveParams& col_params,
+ CollectiveImplementationInterface** col_impl);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/bfc_allocator.cc b/tensorflow/core/common_runtime/bfc_allocator.cc
index 3bf0532491..3843ea9e60 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/bfc_allocator.cc
@@ -31,7 +31,7 @@ namespace tensorflow {
BFCAllocator::BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
bool allow_growth, const string& name)
- : suballocator_(sub_allocator),
+ : sub_allocator_(sub_allocator),
name_(name),
free_chunks_list_(kInvalidChunkHandle),
next_allocation_id_(1) {
@@ -72,7 +72,7 @@ BFCAllocator::~BFCAllocator() {
VLOG(2) << "Number of regions allocated: "
<< region_manager_.regions().size();
for (const auto& region : region_manager_.regions()) {
- suballocator_->Free(region.ptr(), region.memory_size());
+ sub_allocator_->Free(region.ptr(), region.memory_size());
}
for (BinNum b = 0; b < kNumBins; b++) {
@@ -108,7 +108,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Try allocating.
size_t bytes = std::min(curr_region_allocation_bytes_, available_bytes);
- void* mem_addr = suballocator_->Alloc(alignment, bytes);
+ void* mem_addr = sub_allocator_->Alloc(alignment, bytes);
if (mem_addr == nullptr && !started_backpedal_) {
// Only backpedal once.
started_backpedal_ = true;
@@ -119,7 +119,7 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
while (mem_addr == nullptr) {
bytes = RoundedBytes(bytes * kBackpedalFactor);
if (bytes < rounded_bytes) break;
- mem_addr = suballocator_->Alloc(alignment, bytes);
+ mem_addr = sub_allocator_->Alloc(alignment, bytes);
}
}
@@ -158,10 +158,6 @@ bool BFCAllocator::Extend(size_t alignment, size_t rounded_bytes) {
// Insert the chunk into the right bin.
InsertFreeChunkIntoBin(h);
- // Invoke visitors on newly allocated region.
- for (const auto& visitor : region_visitors_) {
- visitor(mem_addr, bytes);
- }
return true;
}
@@ -490,15 +486,6 @@ void BFCAllocator::FreeAndMaybeCoalesce(BFCAllocator::ChunkHandle h) {
InsertFreeChunkIntoBin(coalesced_chunk);
}
-void BFCAllocator::AddAllocVisitor(Visitor visitor) {
- VLOG(1) << "AddVisitor";
- mutex_lock l(lock_);
- region_visitors_.push_back(visitor);
- for (const auto& region : region_manager_.regions()) {
- visitor(region.ptr(), region.memory_size());
- }
-}
-
bool BFCAllocator::TracksAllocationSizes() { return true; }
size_t BFCAllocator::RequestedSize(const void* ptr) {
@@ -596,7 +583,7 @@ string BFCAllocator::RenderOccupancy() {
region_offset += region.memory_size();
}
- return std::string(rendered, resolution);
+ return string(rendered, resolution);
}
void BFCAllocator::DumpMemoryLog(size_t num_bytes) {
diff --git a/tensorflow/core/common_runtime/bfc_allocator.h b/tensorflow/core/common_runtime/bfc_allocator.h
index 580e61e2ea..2d74bf2b28 100644
--- a/tensorflow/core/common_runtime/bfc_allocator.h
+++ b/tensorflow/core/common_runtime/bfc_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
#include <array>
#include <memory>
@@ -23,7 +23,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/allocator_retry.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
@@ -42,7 +42,7 @@ namespace tensorflow {
// coalescing. One assumption we make is that the process using this
// allocator owns pretty much all of the memory, and that nearly
// all requests to allocate memory go through this interface.
-class BFCAllocator : public VisitableAllocator {
+class BFCAllocator : public Allocator {
public:
// Takes ownership of sub_allocator.
BFCAllocator(SubAllocator* sub_allocator, size_t total_memory,
@@ -55,11 +55,6 @@ class BFCAllocator : public VisitableAllocator {
const AllocationAttributes& allocation_attr) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
-
- // Does nothing, because memory is never freed.
- void AddFreeVisitor(Visitor visitor) override {}
-
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
@@ -309,7 +304,7 @@ class BFCAllocator : public VisitableAllocator {
};
// Returns 'bytes' rounded up to the next highest kMinAllocationSize.
- size_t RoundedBytes(size_t bytes);
+ static size_t RoundedBytes(size_t bytes);
// Try to add a new memory region that can satisfy an allocation of
// 'rounded_bytes' bytes. Returns true on success and false on
@@ -423,7 +418,7 @@ class BFCAllocator : public VisitableAllocator {
// of the available memory.
bool started_backpedal_ = false;
- std::unique_ptr<SubAllocator> suballocator_;
+ std::unique_ptr<SubAllocator> sub_allocator_;
string name_;
// Structures mutable after construction
@@ -435,9 +430,6 @@ class BFCAllocator : public VisitableAllocator {
// Pointer to head of linked list of free Chunks
ChunkHandle free_chunks_list_ GUARDED_BY(lock_);
- // Called once on each region, ASAP.
- std::vector<Visitor> region_visitors_ GUARDED_BY(lock_);
-
// Counter containing the next unique identifier to assign to a
// newly-created chunk.
int64 next_allocation_id_ GUARDED_BY(lock_);
@@ -451,4 +443,4 @@ class BFCAllocator : public VisitableAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_BFC_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/broadcaster.cc b/tensorflow/core/common_runtime/broadcaster.cc
deleted file mode 100644
index e1c6b21939..0000000000
--- a/tensorflow/core/common_runtime/broadcaster.cc
+++ /dev/null
@@ -1,300 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/broadcaster.h"
-
-#include "tensorflow/core/common_runtime/collective_rma_local.h"
-#include "tensorflow/core/common_runtime/device_mgr.h"
-#include "tensorflow/core/common_runtime/dma_helper.h"
-#include "tensorflow/core/lib/core/notification.h"
-#include "tensorflow/core/platform/env.h"
-
-// Set true for greater intelligibility of debug mode log messages.
-#define READABLE_KEYS false
-
-namespace tensorflow {
-
-namespace {
-// Key to be used for BufRendezvous by Broadcaster.
-string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
- int dst_rank) {
- if (READABLE_KEYS) {
- return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
- "):src(", src_rank, "):dst(", dst_rank, ")");
- } else {
- // TODO(tucker): Try a denser format, e.g. a 64 or 128 bit hash.
- return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
- }
-}
-} // namespace
-
-Broadcaster::Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id, Tensor* output)
- : col_exec_(col_exec),
- dev_mgr_(dev_mgr),
- ctx_(ctx),
- col_params_(col_params),
- exec_key_(exec_key),
- rank_(col_params.subdiv_rank[0]),
- is_source_(col_params.is_source),
- output_(output),
- done_(nullptr),
- device_(nullptr) {}
-
-void Broadcaster::Run(StatusCallback done) {
- // The optimal data transfer choreography is going to very platform dependent.
- // That will be addressed by later improvements here or by platform-specific
- // overrides of collective broadcast. The initial version is simply
- // a binary tree that completely ignores DeviceLocality.
- done_ = std::move(done);
-
- // Get the device for which we're executing and look up its locality.
- status_ = dev_mgr_->LookupDevice(
- col_params_.instance.device_names[col_params_.default_rank], &device_);
- if (!status_.ok()) {
- done_(status_);
- return;
- }
- CHECK(device_);
- device_locality_ = device_->attributes().locality();
-
- RunTree();
-}
-
-// Binary tree parent/child relations are trivial to calculate, i.e.
-// device at rank r is the parent of 2r+1 and 2r+2. The one exception
-// is if the source is not rank 0. We treat that case as though the
-// source is appended to the front of the rank ordering as well as
-// continuing to occupy its current position. Hence we calculate as
-// though each device's rank is actually r+1, then subtract 1 again to
-// get the descendent ranks. If the source is not rank 0 then its
-// descendants include both {0,1} and the descendents of its current
-// position. Where a non-0-rank source is a descendent of another
-// device, no send to it is necessary.
-
-/* static*/
-int Broadcaster::TreeRecvFrom(const CollectiveParams& cp, int subdiv) {
- DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
- int my_rank = cp.subdiv_rank[subdiv];
- if (-1 == my_rank) return -1;
-
- const auto& impl = cp.instance.impl_details;
- DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
- int source_rank = impl.subdiv_source_rank[subdiv];
- if (my_rank == source_rank) return -1;
- if (source_rank == 0) {
- return (my_rank - 1) / 2;
- } else {
- int predecessor_rank = (my_rank / 2) - 1;
- return (predecessor_rank < 0) ? source_rank : predecessor_rank;
- }
-}
-
-/* static */
-void Broadcaster::TreeSendTo(const CollectiveParams& cp, int subdiv,
- std::vector<int>* targets) {
- DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
- int my_rank = cp.subdiv_rank[subdiv];
- if (-1 == my_rank) return;
-
- const auto& impl = cp.instance.impl_details;
- DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
- int source_rank = impl.subdiv_source_rank[subdiv];
-
- int group_size = 0;
- for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
- if (impl.subdiv_permutations[subdiv][i] >= 0) {
- group_size++;
- }
- }
-
- targets->clear();
- int successor_rank = 0;
- if (source_rank == 0) {
- successor_rank = (2 * my_rank) + 1;
- } else {
- successor_rank = (2 * (my_rank + 1));
- }
- DCHECK_NE(successor_rank, my_rank);
- if (cp.is_source && source_rank != 0) {
- // The source sends to rank 0,1 in addition to its positional
- // descendants.
- if (group_size > 1) {
- targets->push_back(0);
- }
- if (group_size > 2 && source_rank != 1) {
- targets->push_back(1);
- }
- }
- for (int i = 0; i < 2; ++i) {
- if (successor_rank < group_size && successor_rank != source_rank) {
- targets->push_back(successor_rank);
- }
- ++successor_rank;
- }
-}
-
-// Executes a hierarchical tree broadcast.
-// Each subdiv is a broadcast between a subset of the devices.
-// If there is only one task, there is one subdiv comprising a broadcast between
-// all devices belonging to the task.
-// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
-// subdiv, one device from each task participates in a binary tree broadcast.
-// Each task receives a copy of the tensor on one device via this broadcast.
-// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
-// corresponds to broadcast between all devices on task i. Thus, each task
-// participates in at most 2 subdivs.
-void Broadcaster::RunTree() {
- int num_subdivs = static_cast<int>(col_params_.subdiv_rank.size());
- // TODO(ayushd): this is easily improved when a node participates in both
- // first and second subdivision. It would first send to its descendents in
- // the first subdiv, then wait until all pending ops are finished before
- // sending to descendents in second subdiv. A better implementation would
- // collapse the two send blocks.
- for (int si = 0; si < num_subdivs; si++) {
- int my_rank = col_params_.subdiv_rank[si];
- // If rank is -1, this device does not participate in this subdiv.
- if (-1 == my_rank) continue;
- int source_rank = col_params_.instance.impl_details.subdiv_source_rank[si];
- if (VLOG_IS_ON(1)) {
- string subdiv_buf;
- for (int r : col_params_.instance.impl_details.subdiv_permutations[si]) {
- strings::StrAppend(&subdiv_buf, r, ",");
- }
- VLOG(1) << "Running Broadcast tree device=" << device_->name()
- << " subdiv=" << si << " perm=" << subdiv_buf
- << " my_rank=" << my_rank << " source_rank=" << source_rank;
- }
-
- mutex mu; // also guards status_ while callbacks are pending
- int pending_count = 0; // GUARDED_BY(mu)
- condition_variable all_done;
-
- if (my_rank >= 0 && my_rank != source_rank) {
- // Begin by receiving the value.
- int recv_from_rank = TreeRecvFrom(col_params_, si);
- Notification note;
- DispatchRecv(si, recv_from_rank, my_rank, output_,
- [this, &mu, &note](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- note.Notify();
- });
- note.WaitForNotification();
- }
-
- // Then forward value to all descendent devices.
- if (my_rank >= 0 && status_.ok()) {
- std::vector<int> send_to_ranks;
- TreeSendTo(col_params_, si, &send_to_ranks);
- for (int i = 0; i < send_to_ranks.size(); ++i) {
- int target_rank = send_to_ranks[i];
- {
- mutex_lock l(mu);
- ++pending_count;
- }
- DispatchSend(si, target_rank, my_rank,
- (is_source_ ? &ctx_->input(0) : output_),
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (pending_count == 0) {
- all_done.notify_all();
- }
- });
- }
- }
-
- // For the original source device, we copy input to output if they are
- // different.
- // If there is only 1 subdiv, we do this in that subdiv. If there is more
- // than 1 subdiv, then the original source device will participate in 2
- // subdivs - the global inter-task broadcast and one local intra-task
- // broadcast. In this case, we perform the copy in the second subdiv for
- // this device.
- if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
- VLOG(2) << "copying input to output for device=" << device_->name()
- << " subdiv=" << si;
- const Tensor* input = &ctx_->input(0);
- if (input != output_ &&
- (DMAHelper::base(input) != DMAHelper::base(output_))) {
- {
- mutex_lock l(mu);
- ++pending_count;
- }
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- CollectiveRemoteAccessLocal::MemCpyAsync(
- op_dev_ctx, op_dev_ctx, device_, device_, ctx_->input_alloc_attr(0),
- ctx_->output_alloc_attr(0), input, output_, 0, /*stream_index*/
- [this, &mu, &pending_count, &all_done](const Status& s) {
- mutex_lock l(mu);
- status_.Update(s);
- --pending_count;
- if (0 == pending_count) {
- all_done.notify_all();
- }
- });
- }
- }
-
- // Then wait for all pending actions to complete.
- {
- mutex_lock l(mu);
- if (pending_count > 0) {
- all_done.wait(l);
- }
- }
- }
- VLOG(2) << "device=" << device_->name() << " return status " << status_;
- done_(status_);
-}
-
-void Broadcaster::DispatchSend(int subdiv, int dst_rank, int src_rank,
- const Tensor* src_tensor,
- const StatusCallback& done) {
- string send_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
- int dst_idx =
- col_params_.instance.impl_details.subdiv_permutations[subdiv][dst_rank];
- VLOG(1) << "DispatchSend " << send_buf_key << " from_device "
- << device_->name() << " to_device "
- << col_params_.instance.device_names[dst_idx] << " subdiv=" << subdiv
- << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
- col_exec_->PostToPeer(col_params_.instance.device_names[dst_idx],
- col_params_.instance.task_names[dst_idx], send_buf_key,
- device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), src_tensor,
- device_locality_, done);
-}
-
-void Broadcaster::DispatchRecv(int subdiv, int src_rank, int dst_rank,
- Tensor* dst_tensor, const StatusCallback& done) {
- string recv_buf_key = BroadcastBufKey(exec_key_, subdiv, src_rank, dst_rank);
- int src_idx =
- col_params_.instance.impl_details.subdiv_permutations[subdiv][src_rank];
- VLOG(1) << "DispatchRecv " << recv_buf_key << " from_device "
- << col_params_.instance.device_names[src_idx] << " to_device "
- << device_->name() << " subdiv=" << subdiv << " src_rank=" << src_rank
- << " src_idx=" << src_idx;
- col_exec_->RecvFromPeer(col_params_.instance.device_names[src_idx],
- col_params_.instance.task_names[src_idx],
- col_params_.task.is_local[src_idx], recv_buf_key,
- device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, 0 /*stream_index*/, done);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/buf_rendezvous.h b/tensorflow/core/common_runtime/buf_rendezvous.h
index 9eb9f060f6..065bbd008b 100644
--- a/tensorflow/core/common_runtime/buf_rendezvous.h
+++ b/tensorflow/core/common_runtime/buf_rendezvous.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
-#define TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
#include <functional>
#include <string>
@@ -100,4 +100,4 @@ class BufRendezvous {
void PurgeTable(const Status& s, HookTable* table);
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BUF_RENDEZVOUS_H_
diff --git a/tensorflow/core/common_runtime/collective_executor_mgr.h b/tensorflow/core/common_runtime/collective_executor_mgr.h
index 9de6ab8968..d53aca85b9 100644
--- a/tensorflow/core/common_runtime/collective_executor_mgr.h
+++ b/tensorflow/core/common_runtime/collective_executor_mgr.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -72,4 +72,4 @@ class CollectiveExecutorMgr : public CollectiveExecutorMgrInterface {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_EXECUTOR_MGR_H_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.cc b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
index 2a14493a67..7cb90de3c7 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.cc
@@ -14,7 +14,20 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/collective_param_resolver_local.h"
+#include <stddef.h>
+#include <algorithm>
+#include <unordered_map>
+#include <utility>
+
#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/cancellation.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -179,7 +192,9 @@ void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
int next_rank = 0;
while (true) {
selected.insert(next_device);
- DevRec* dr = &(*tdm)[next_device];
+ auto next_dev_it = tdm->find(next_device);
+ CHECK(next_dev_it != tdm->end());
+ DevRec* dr = &next_dev_it->second;
dr->local_rank = next_rank;
++next_rank;
if (selected.size() == tdm->size()) {
@@ -193,9 +208,15 @@ void OrderTaskDeviceMap(TaskDeviceMap* tdm) {
parsed_name.id = il.device_id();
string endpoint_device =
DeviceNameUtils::ParsedNameToString(parsed_name);
+ // Skip the device if we've already seen it.
if (selected.find(endpoint_device) != selected.end()) {
continue;
}
+ // Skip the device if it is not participating in this collective
+ // instance.
+ if (tdm->find(endpoint_device) == tdm->end()) {
+ continue;
+ }
if (best_link == nullptr || il.strength() > best_link->strength()) {
best_link = &il;
}
@@ -319,206 +340,6 @@ void SortDevicesAndTasks(CollectiveParams* cp) {
}
} // namespace
-int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task) {
- int num_tasks = static_cast<int>(dev_per_task.size());
- int task_lo = 0;
- int task_hi;
- for (int ti = 0; ti < num_tasks; ti++) {
- task_hi = task_lo + dev_per_task[ti];
- if (task_lo <= device_rank && device_rank < task_hi) return ti;
- task_lo += dev_per_task[ti];
- }
- LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
- << " devices";
- return -1;
-}
-
-void CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- const string& device, int source_rank, const std::vector<int>& dev_per_task,
- CollectiveParams* cp) {
- if (VLOG_IS_ON(1)) {
- string dpt_buf;
- for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
- VLOG(1) << "GenerateBcastSubdivPerms device=" << device
- << " source_rank=" << source_rank << " dev_per_task=" << dpt_buf;
- }
- int num_tasks = cp->group.num_tasks;
- // If there is just 1 task, then execute binary tree broadcast over all
- // devices. Otherwise, the first subdiv is inter-task broadcast, and then
- // there are N more subdivs, where N is #task.
- int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
- int total_num_devices = 0;
- for (int num_dev : dev_per_task) total_num_devices += num_dev;
-
- cp->instance.impl_details.subdiv_permutations.resize(num_subdivs);
- cp->subdiv_rank.reserve(num_subdivs);
- cp->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
-
- // Inter-task subdiv. Pick one device from each task - this is the source
- // device if it belongs to that task, or device 0 for that task. If a device
- // does not participate in the subdiv, set subdiv_rank to -1.
- if (num_tasks > 1) {
- const int sdi = 0;
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int device_count = 0;
- int source_task = GetDeviceTask(source_rank, dev_per_task);
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- bool participate = false;
- if (source_task == ti) {
- // Source device belongs to this task.
- perm.push_back(source_rank);
- participate = cp->instance.device_names[source_rank] == device;
- } else {
- // Source does not belong to this task, choose dev 0.
- perm.push_back(device_count);
- participate = cp->instance.device_names[device_count] == device;
- }
- if (participate) cp->subdiv_rank.push_back(ti);
- device_count += dev_per_task[ti];
- }
- if (cp->subdiv_rank.empty()) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(source_task);
- }
-
- // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
- // source to dev 0 for that task if it does not contain original source, else
- // set to rank of original source. If a device does not participate in the
- // subdiv, set subdiv_rank to -1;
- int abs_di = 0;
- for (int ti = 0; ti < cp->group.num_tasks; ti++) {
- const int sdi = ti + (num_tasks > 1 ? 1 : 0);
- std::vector<int>& perm = cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- bool participate = false;
- int subdiv_source = 0;
- for (int di = 0; di < dev_per_task[ti]; di++) {
- perm.push_back(abs_di);
- if (cp->instance.device_names[abs_di] == device) {
- participate = true;
- cp->subdiv_rank.push_back(di);
- }
- if (abs_di == source_rank) subdiv_source = di;
- abs_di++;
- }
- if (!participate) cp->subdiv_rank.push_back(-1);
- cp->instance.impl_details.subdiv_source_rank.push_back(subdiv_source);
- }
-
- for (int sri = 0; sri < num_subdivs; sri++) {
- CHECK_GE(cp->instance.impl_details.subdiv_source_rank[sri], 0);
- }
-}
-
-// Establish the requested number of subdivision permutations based on the
-// ring order implicit in the device order.
-/*static*/
-void CollectiveParamResolverLocal::GenerateSubdivPerms(const string& device,
- int source_rank,
- CollectiveParams* cp) {
- // Each subdiv permutation is a ring formed by rotating each
- // single-task subsequence of devices by an offset. This makes most
- // sense when each task has the same number of devices but we can't
- // depend on that being the case so we'll compute something that
- // works in any case.
-
- // Start by counting the devices in each task.
- // Precondition: device_names must be sorted so that all devices in
- // the same task are adjacent.
- VLOG(2) << "Sorted task names: "
- << str_util::Join(cp->instance.task_names, ", ");
- std::vector<int> dev_per_task;
- const string* prior_task_name = &cp->instance.task_names[0];
- int dev_count = 1;
- for (int di = 1; di < cp->group.group_size; ++di) {
- if (cp->instance.task_names[di] != *prior_task_name) {
- dev_per_task.push_back(dev_count);
- dev_count = 1;
- prior_task_name = &cp->instance.task_names[di];
- } else {
- ++dev_count;
- }
- }
- dev_per_task.push_back(dev_count);
- CHECK_EQ(cp->group.num_tasks, dev_per_task.size());
-
- CHECK(cp->instance.type == REDUCTION_COLLECTIVE ||
- cp->instance.type == BROADCAST_COLLECTIVE);
- if (cp->instance.type == REDUCTION_COLLECTIVE) {
- // Generate a ring permutation for each requested offset.
- CHECK_GT(cp->instance.impl_details.subdiv_offsets.size(), 0);
- VLOG(2) << "Setting up perms for cp " << cp << " subdiv_permutations "
- << &cp->instance.impl_details.subdiv_permutations;
- cp->instance.impl_details.subdiv_permutations.resize(
- cp->instance.impl_details.subdiv_offsets.size());
- cp->subdiv_rank.resize(cp->instance.impl_details.subdiv_offsets.size(), -1);
- for (int sdi = 0; sdi < cp->instance.impl_details.subdiv_offsets.size();
- ++sdi) {
- std::vector<int>& perm =
- cp->instance.impl_details.subdiv_permutations[sdi];
- CHECK_EQ(perm.size(), 0);
- int offset = cp->instance.impl_details.subdiv_offsets[sdi];
- // A negative subdivision offset is interpreted as follows:
- // 1. Reverse the local device ordering.
- // 2. Begin the subdivision at abs(offset) in the reversed ordering.
- bool reverse = false;
- if (offset < 0) {
- offset = abs(offset);
- reverse = true;
- }
- int prior_dev_count = 0; // sum over prior worker device counts
- for (int ti = 0; ti < cp->group.num_tasks; ++ti) {
- for (int di = 0; di < dev_per_task[ti]; ++di) {
- int di_offset = (di + offset) % dev_per_task[ti];
- int offset_di =
- reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
- // Device index in global subdivision permutation.
- int permuted_di = prior_dev_count + offset_di;
- int rank = static_cast<int>(perm.size());
- perm.push_back(permuted_di);
- if (cp->instance.device_names[permuted_di] == device) {
- CHECK_EQ(permuted_di, cp->default_rank);
- cp->subdiv_rank[sdi] = rank;
- }
- }
- prior_dev_count += dev_per_task[ti];
- }
- CHECK_EQ(cp->group.group_size, perm.size());
- }
- } else if (cp->instance.type == BROADCAST_COLLECTIVE) {
- GenerateBcastSubdivPerms(device, source_rank, dev_per_task, cp);
- }
-
- if (VLOG_IS_ON(1)) {
- // Log the computed ring order for each subdiv.
- string buf;
- for (int sdi = 0;
- sdi < cp->instance.impl_details.subdiv_permutations.size(); ++sdi) {
- buf = strings::StrCat("Subdiv ", sdi, " device order:\n");
- for (int di = 0;
- di < cp->instance.impl_details.subdiv_permutations[sdi].size();
- ++di) {
- int idx = cp->instance.impl_details.subdiv_permutations[sdi][di];
- if (idx >= 0) {
- CHECK_GT(cp->instance.device_names.size(), idx);
- strings::StrAppend(&buf, cp->instance.device_names[idx], "\n");
- }
- }
- strings::StrAppend(&buf, " subdiv_offsets: ");
- for (auto o : cp->instance.impl_details.subdiv_offsets)
- strings::StrAppend(&buf, o, " ");
- strings::StrAppend(&buf, " SubdivRank: ");
- for (auto d : cp->subdiv_rank) strings::StrAppend(&buf, d, " ");
- if (cp->instance.type == BROADCAST_COLLECTIVE) {
- strings::StrAppend(&buf, " subdiv_source_rank: ");
- for (auto src : cp->instance.impl_details.subdiv_source_rank)
- strings::StrAppend(&buf, src, " ");
- }
- VLOG(1) << buf;
- }
- }
-}
-
void CollectiveParamResolverLocal::CompleteTaskIsLocal(const string& task_name,
CollectiveParams* cp) {
cp->task.is_local.resize(cp->group.group_size, false);
@@ -594,6 +415,10 @@ void CollectiveParamResolverLocal::InitInstanceSharedParams(
});
}
+// NOTE(ayushd): The DeviceLocality objects in localities will have LocalLinks
+// to all devices that they are physically connected to and visible to the
+// TensorFlow runtime. This set of devices may be a superset of the devices
+// participating in this instance of collectives.
void CollectiveParamResolverLocal::CompleteDefaultRanking(
const GroupRec* gr, const CollectiveParams* cp, InstanceRec* ir,
const std::vector<DeviceLocality>& localities) {
@@ -697,7 +522,6 @@ void CollectiveParamResolverLocal::CallInitInstanceSharedParams(
InitInstanceSharedParams(
gr, cp, ir,
[this, ir, done](const Status& s) UNLOCK_FUNCTION(ir->out_mu) {
- DCHECK(!ir->out_mu.try_lock());
DCHECK(ir->out_mu_available);
ir->status.Update(s);
ir->out_mu.unlock();
@@ -785,29 +609,39 @@ void CollectiveParamResolverLocal::CompleteInstanceFromInitializedIRec(
// Populate the fields common across task, also default_rank.
SetDefaultRank(device, cp);
CompleteTaskIsLocal(task_name_, cp);
+ // TODO(b/113171733): we need a better way to pick the collective
+ // implementation. The ideal way would depend upon the topology and link
+ // strength before picking a particular implementation.
+ cp->instance.impl_details.collective_name =
+ (cp->instance.type == BROADCAST_COLLECTIVE) ? "HierarchicalTreeBroadcast"
+ : "RingReduce";
+ CollectiveImplementationInterface* col_impl;
+ Status lookup_status = CollectiveRegistry::LookupParamResolverInstance(
+ cp->instance.impl_details.collective_name, &col_impl);
+ if (!lookup_status.ok()) {
+ done(lookup_status);
+ return;
+ }
// If broadcast, may need to wait for source discovery.
if (cp->instance.type == BROADCAST_COLLECTIVE) {
CompleteInstanceSource(ir, cp, is_source,
- [this, ir, device, cp, done](InstanceRec* irec) {
+ [col_impl, ir, device, cp, done](InstanceRec* irec) {
CHECK_EQ(ir, irec);
Status s;
- int source_rank;
{
mutex_lock l(irec->out_mu);
irec->WaitForOutMu(l);
s = irec->status;
- source_rank = irec->source_rank;
+ cp->source_rank = irec->source_rank;
}
if (s.ok()) {
- GenerateSubdivPerms(device, source_rank, cp);
+ s = col_impl->InitializeCollectiveParams(cp);
}
done(s);
});
- return;
} else {
- GenerateSubdivPerms(device, 0, cp);
+ done(col_impl->InitializeCollectiveParams(cp));
}
- done(Status::OK());
}
void CollectiveParamResolverLocal::CompleteInstanceSource(InstanceRec* ir,
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local.h b/tensorflow/core/common_runtime/collective_param_resolver_local.h
index 2e2aa801d9..c5c3497e28 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local.h
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local.h
@@ -12,10 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#include <functional>
+#include <memory>
+#include <set>
#include <string>
+#include <vector>
#include "tensorflow/core/framework/collective.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -79,6 +83,7 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
// Used to complete/verify CollInstance.
struct InstanceRec;
+
typedef std::function<void(InstanceRec*)> IRConsumer;
struct InstanceRec {
// This structure has two mutexes so that a possibly long
@@ -212,18 +217,6 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
void CallbackWithStatus(const InstanceRecCallback& done, InstanceRec* irec)
LOCKS_EXCLUDED(irec->out_mu);
- friend class CollectiveParamResolverLocalTest;
- // Establishes the requested number of subdivision permutations based on the
- // ring order implicit in the device order.
- static void GenerateSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp);
- // Establishes the subdivisions for broadcast op. The first subdiv executes
- // binary tree bcast with one device per task. Each subsequent subdiv
- // executes intra-task binary tree broadcast.
- static void GenerateBcastSubdivPerms(const string& device, int source_rank,
- const std::vector<int>& dev_per_task,
- CollectiveParams* cp);
-
const DeviceMgr* dev_mgr_;
DeviceResolverInterface* dev_resolver_; // Not owned.
string task_name_;
@@ -237,4 +230,4 @@ class CollectiveParamResolverLocal : public ParamResolverInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_PARAM_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
index 9ea23b72d2..9e1e2e8d5b 100644
--- a/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
+++ b/tensorflow/core/common_runtime/collective_param_resolver_local_test.cc
@@ -44,31 +44,6 @@ class CollectiveParamResolverLocalTest : public ::testing::Test {
task_name));
}
- void GenSubdivPerms(const string& device, int source_rank,
- CollectiveParams* cp) {
- CollectiveParamResolverLocal::GenerateSubdivPerms(device, source_rank, cp);
- }
-
- // Calls GenerateBcastSubdivPerms for device at `device_rank`. Checks if the
- // generated subdiv perms, ranks, and source ranks match the expected values.
- void BcastSubdivPerms(
- CollectiveParams* cp, const std::vector<int>& dev_per_task,
- int device_rank, int source_rank,
- const std::vector<std::vector<int>>& expected_subdiv_perms,
- const std::vector<int>& expected_subdiv_rank,
- const std::vector<int>& expected_subdiv_source_rank) {
- cp->subdiv_rank.clear();
- cp->instance.impl_details.subdiv_permutations.clear();
- cp->instance.impl_details.subdiv_source_rank.clear();
- CollectiveParamResolverLocal::GenerateBcastSubdivPerms(
- cp->instance.device_names[device_rank], source_rank, dev_per_task, cp);
- EXPECT_EQ(expected_subdiv_perms,
- cp->instance.impl_details.subdiv_permutations);
- EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
- EXPECT_EQ(expected_subdiv_source_rank,
- cp->instance.impl_details.subdiv_source_rank);
- }
-
std::vector<Device*> devices_;
std::unique_ptr<DeviceMgr> device_mgr_;
std::unique_ptr<DeviceResolverLocal> drl_;
@@ -114,7 +89,6 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsReduction1Task) {
cps[i].instance.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
}
- EXPECT_EQ(cps[i].subdiv_rank[0], i);
EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
EXPECT_FALSE(cps[i].is_source);
EXPECT_EQ(cps[i].default_rank, i);
@@ -161,188 +135,10 @@ TEST_F(CollectiveParamResolverLocalTest, CompleteParamsBroadcast1Task) {
cps[i].instance.device_names[j]);
EXPECT_TRUE(cps[i].task.is_local[j]);
}
- ASSERT_GT(cps[i].subdiv_rank.size(), 0);
- EXPECT_EQ(cps[i].subdiv_rank[0], i);
- ASSERT_GT(cps[i].instance.impl_details.subdiv_source_rank.size(), 0);
- EXPECT_EQ(cps[i].instance.impl_details.subdiv_source_rank[0], 1);
EXPECT_EQ(cps[i].is_source, (i == 1));
EXPECT_EQ(cps[i].default_rank, i);
EXPECT_TRUE(cps[i].instance.same_num_devices_per_task);
}
}
-TEST_F(CollectiveParamResolverLocalTest, GenerateSubdivPerms) {
- static const int kNumDevsPerTask = 8;
- static const int kNumTasks = 3;
- static const int kNumDevs = kNumDevsPerTask * kNumTasks;
- CollectiveParams cp;
- std::vector<string> device_names;
- std::vector<string> task_names;
- cp.group.group_key = 1;
- cp.group.group_size = kNumDevs;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = kNumTasks;
- cp.instance.instance_key = 3;
- cp.instance.type = REDUCTION_COLLECTIVE;
- cp.instance.data_type = DataType(DT_FLOAT);
- cp.instance.shape = TensorShape({5});
- cp.instance.impl_details.subdiv_offsets.push_back(0);
- cp.is_source = false;
- for (int i = 0; i < kNumDevs; ++i) {
- int task_id = i / kNumDevsPerTask;
- int dev_id = i % kNumDevsPerTask;
- string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
- task_names.push_back(task_name);
- string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
- device_names.push_back(device_name);
- cp.instance.task_names.push_back(task_name);
- cp.instance.device_names.push_back(device_name);
- }
-
- int test_rank = 0;
- cp.default_rank = test_rank;
- cp.instance.impl_details.subdiv_offsets = {0, 4};
- GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
- std::vector<int> expected_0 = {0, 1, 2, 3, 4, 5, 6, 7,
- 8, 9, 10, 11, 12, 13, 14, 15,
- 16, 17, 18, 19, 20, 21, 22, 23};
- std::vector<int> expected_1 = {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
- 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19};
- for (int i = 0; i < kNumDevs; ++i) {
- EXPECT_EQ(expected_0[i],
- cp.instance.impl_details.subdiv_permutations[0][i]);
- EXPECT_EQ(expected_1[i],
- cp.instance.impl_details.subdiv_permutations[1][i]);
- }
- EXPECT_EQ(0, cp.subdiv_rank[0]);
- EXPECT_EQ(4, cp.subdiv_rank[1]);
-
- test_rank = 3;
- cp.default_rank = test_rank;
- cp.instance.impl_details.subdiv_offsets = {3, -3};
- cp.instance.impl_details.subdiv_permutations.clear();
- GenSubdivPerms(cp.instance.device_names[test_rank], 0, &cp);
- expected_0 = {3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
- 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18};
- expected_1 = {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
- 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21};
- for (int i = 0; i < kNumDevs; ++i) {
- EXPECT_EQ(expected_0[i],
- cp.instance.impl_details.subdiv_permutations[0][i]);
- EXPECT_EQ(expected_1[i],
- cp.instance.impl_details.subdiv_permutations[1][i]);
- }
- EXPECT_EQ(0, cp.subdiv_rank[0]);
- EXPECT_EQ(1, cp.subdiv_rank[1]);
-}
-
-TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms1Task8GPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 1;
- cp.instance.type = BROADCAST_COLLECTIVE;
- for (int i = 0; i < 8; i++) {
- string dev_name =
- strings::StrCat("/job:worker/replica:0/task:0/device:GPU:", i);
- cp.instance.device_names.push_back(dev_name);
- }
- std::vector<int> dev_per_task = {8};
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
- {0});
-
- // source 2 device 2
- BcastSubdivPerms(&cp, dev_per_task, 2, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2},
- {2});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0},
- {2});
-}
-
-TEST_F(CollectiveParamResolverLocalTest, GenerateBcastSubdivPerms4Tasks8GPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 4;
- cp.instance.type = BROADCAST_COLLECTIVE;
- for (int ti = 0; ti < cp.group.num_tasks; ti++) {
- for (int di = 0; di < 8; di++) {
- string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
- "/device:GPU:", di);
- cp.instance.device_names.push_back(dev_name);
- }
- }
- std::vector<int> dev_per_task = {8, 8, 8, 8};
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0,
- {{0, 8, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2,
- {{2, 8, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
-
- // source 9 device 9
- BcastSubdivPerms(&cp, dev_per_task, 9, 9,
- {{0, 9, 16, 24},
- {0, 1, 2, 3, 4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13, 14, 15},
- {16, 17, 18, 19, 20, 21, 22, 23},
- {24, 25, 26, 27, 28, 29, 30, 31}},
- {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
-}
-
-TEST_F(CollectiveParamResolverLocalTest,
- GenerateBcastSubdivPerms4TasksVariableGPU) {
- CollectiveParams cp;
- cp.group.device_type = DeviceType("GPU");
- cp.group.num_tasks = 4;
- std::vector<int> dev_per_task = {4, 4, 6, 8};
- for (int ti = 0; ti < cp.group.num_tasks; ti++) {
- for (int di = 0; di < dev_per_task[ti]; di++) {
- string dev_name = strings::StrCat("/job:worker/replica:0/task:", ti,
- "/device:GPU:", di);
- cp.instance.device_names.push_back(dev_name);
- }
- }
-
- // source 0 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 0,
- {{0, 4, 8, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
-
- // source 2 device 0
- BcastSubdivPerms(&cp, dev_per_task, 0, 2,
- {{2, 4, 8, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
-
- // source 9 device 5
- BcastSubdivPerms(&cp, dev_per_task, 5, 9,
- {{0, 4, 9, 14},
- {0, 1, 2, 3},
- {4, 5, 6, 7},
- {8, 9, 10, 11, 12, 13},
- {14, 15, 16, 17, 18, 19, 20, 21}},
- {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_rma_local.h b/tensorflow/core/common_runtime/collective_rma_local.h
index dbb2e67c7d..2188087957 100644
--- a/tensorflow/core/common_runtime/collective_rma_local.h
+++ b/tensorflow/core/common_runtime/collective_rma_local.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
-#define TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
#include "tensorflow/core/common_runtime/buf_rendezvous.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/collective.h"
@@ -34,7 +34,7 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
virtual ~CollectiveRemoteAccessLocal() {}
- void StartAbort(const Status& s);
+ void StartAbort(const Status& s) override;
void RecvFromPeer(const string& peer_device, const string& peer_task,
bool peer_is_local, const string& key, Device* to_device,
@@ -89,4 +89,4 @@ class CollectiveRemoteAccessLocal : public PerStepCollectiveRemoteAccess {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_ACCESS_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_RMA_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/collective_util.cc b/tensorflow/core/common_runtime/collective_util.cc
new file mode 100644
index 0000000000..195521a078
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_util.cc
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/collective_util.h"
+
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace collective_util {
+
+/*static*/
+Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
+ const string& device_name, Device** device,
+ DeviceLocality* device_locality) {
+ if (!dev_mgr) {
+ return errors::Internal("Required non-null dev_mgr ", dev_mgr,
+ " for InitializeDeviceAndLocality");
+ }
+
+ Status status = dev_mgr->LookupDevice(device_name, device);
+ if (status.ok()) {
+ CHECK(*device);
+ *device_locality = (*device)->attributes().locality();
+ } else {
+ LOG(ERROR) << "Failed to find device " << device_name;
+ for (auto d : dev_mgr->ListDevices()) {
+ LOG(ERROR) << "Available devices " << d->name();
+ }
+ }
+ return status;
+}
+
+/*static*/
+string SubdivPermDebugString(const CollectiveParams& col_params) {
+ const auto& subdiv_perms =
+ col_params.instance.impl_details.subdiv_permutations;
+ string buf;
+ for (int sdi = 0; sdi < subdiv_perms.size(); ++sdi) {
+ strings::StrAppend(&buf, "Subdiv ", sdi, " device order:\n");
+ for (int di = 0; di < subdiv_perms[sdi].size(); ++di) {
+ int idx = subdiv_perms[sdi][di];
+ if (idx >= 0) {
+ CHECK_GT(col_params.instance.device_names.size(), idx);
+ strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n");
+ }
+ }
+ strings::StrAppend(&buf, " subdiv_offsets: ");
+ for (auto o : col_params.instance.impl_details.subdiv_offsets)
+ strings::StrAppend(&buf, o, " ");
+ strings::StrAppend(&buf, " SubdivRank: ");
+ for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " ");
+ if (col_params.instance.type == BROADCAST_COLLECTIVE) {
+ strings::StrAppend(&buf, " subdiv_source_rank: ");
+ for (auto src : col_params.instance.impl_details.subdiv_source_rank)
+ strings::StrAppend(&buf, src, " ");
+ }
+ strings::StrAppend(&buf, "\n");
+ }
+ return buf;
+}
+
+} // namespace collective_util
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/collective_util.h b/tensorflow/core/common_runtime/collective_util.h
new file mode 100644
index 0000000000..ebb5731bec
--- /dev/null
+++ b/tensorflow/core/common_runtime/collective_util.h
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/framework/collective.h"
+#include "tensorflow/core/framework/device_attributes.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace collective_util {
+
+Status InitializeDeviceAndLocality(const DeviceMgr* dev_mgr,
+ const string& device_name, Device** device,
+ DeviceLocality* device_locality);
+string SubdivPermDebugString(const CollectiveParams& col_params);
+
+} // namespace collective_util
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_COLLECTIVE_UTIL_H_
diff --git a/tensorflow/core/common_runtime/constant_folding.cc b/tensorflow/core/common_runtime/constant_folding.cc
index b5a51d2526..e81e61b633 100644
--- a/tensorflow/core/common_runtime/constant_folding.cc
+++ b/tensorflow/core/common_runtime/constant_folding.cc
@@ -37,6 +37,8 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/denormal.h"
+#include "tensorflow/core/platform/setround.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
@@ -59,6 +61,7 @@ bool ReadPartialShapesFromShapeMap(
shape_map,
std::vector<PartialTensorShape>* input_shapes) {
CHECK(shape_map != nullptr);
+ input_shapes->resize(n->num_inputs());
for (const Edge* in : n->in_edges()) {
// Don't need to check if incoming control edges have known shapes.
if (in->IsControlEdge()) continue;
@@ -69,7 +72,9 @@ bool ReadPartialShapesFromShapeMap(
}
const auto& known_shape = known_shape_iter->second;
CHECK_GT(known_shape.size(), in->src_output()) << known_shape_iter->first;
- input_shapes->push_back(known_shape[in->src_output()]);
+ DCHECK_GE(in->dst_input(), 0);
+ DCHECK_LT(in->dst_input(), input_shapes->size());
+ (*input_shapes)[in->dst_input()] = known_shape[in->src_output()];
}
return true;
}
@@ -465,19 +470,19 @@ bool ReplaceTensorWithConstant(
const ConstantFoldNameGenerator& generate_new_name) {
// Be conservative when replacing a tensor with a constant, when not
// running on CPU.
- // 1) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
+ // 1) Do not replace another constant.
+ // 2) If the destination tensor is not an int32 tensor, and has HOST_MEMORY
// constraint, do not replace it.
- // 2) If the destination tensor is an int32 tensor, but has DEVICE_MEMORY
+ // 3) If the destination tensor is an int32 tensor, and has DEVICE_MEMORY
// constraint, do not replace it.
- // 3) If the constant op created does not have a kernel implementation
- // for the device, do not use it.
// 4) If the size of the constant in bytes is too large (>
// max_constant_in_bytes), do not replace it. This prevents the size of the
// Graph from growing too large.
+ // 5) If the constant op created does not have a kernel implementation
+ // for the device, do not use it.
// TODO(keveman): Consider adding a new constant op that has a kernel
// implementation for all types, but with HostMemory constraint on it's
// output.
- // 5) Do not replace another constant.
if (tensor.first->IsConstant()) {
return false;
}
@@ -553,6 +558,11 @@ bool ReplaceTensorWithConstant(
Status ConstantFold(const ConstantFoldingOptions& opts,
FunctionLibraryRuntime* function_library, Env* env,
Device* partition_device, Graph* graph, bool* was_mutated) {
+ // TensorFlow flushes denormals to zero and rounds to nearest, so we do
+ // the same here.
+ port::ScopedFlushDenormal flush;
+ port::ScopedSetRound round(FE_TONEAREST);
+
DumpGraph("Before", graph);
ConstantFoldNameGenerator generate_new_name = opts.generate_new_name;
if (generate_new_name == nullptr) {
diff --git a/tensorflow/core/common_runtime/constant_folding.h b/tensorflow/core/common_runtime/constant_folding.h
index 84598880bb..a9a84f761b 100644
--- a/tensorflow/core/common_runtime/constant_folding.h
+++ b/tensorflow/core/common_runtime/constant_folding.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
-#define TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/function.h"
@@ -66,4 +66,4 @@ Status ConstantFold(const ConstantFoldingOptions& opts,
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_CONSTANT_FOLDING_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_CONSTANT_FOLDING_H_
diff --git a/tensorflow/core/common_runtime/copy_tensor.cc b/tensorflow/core/common_runtime/copy_tensor.cc
index f8cb854b52..6e2eb66b94 100644
--- a/tensorflow/core/common_runtime/copy_tensor.cc
+++ b/tensorflow/core/common_runtime/copy_tensor.cc
@@ -61,26 +61,33 @@ void CopyHostToDevice(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [dst, recv_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Host->Device Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [dst, recv_dev_context, out_allocator, status_cb, cpu_allocator,
+ edge_name](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
- wrapped_done_);
+ CopyHostToDevice(&from, cpu_allocator, out_allocator, edge_name,
+ dst, to, recv_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Host->Device Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ recv_dev_context->CopyCPUTensorToDevice(&from, dst, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -119,26 +126,33 @@ void CopyDeviceToHost(const Tensor* input, Allocator* cpu_allocator,
status_cb->Unref();
};
auto copier = std::bind(
- [edge_name, src, send_dev_context, out_allocator, status_cb](
- StatusCallback wrapped_done_,
- // Begin unbound arguments
- const Tensor& from, Tensor* to) {
- if (!DMAHelper::CanUseDMA(&from)) {
- Status err = errors::InvalidArgument(
- "During Variant Device->Host Copy: "
- "non-DMA-copy attempted of tensor type: ",
- DataTypeString(from.dtype()));
- status_cb->UpdateStatus(err);
- return err;
- }
- if (status_cb->ok()) {
+ [edge_name, src, send_dev_context, out_allocator, status_cb,
+ cpu_allocator](StatusCallback wrapped_done_,
+ // Begin unbound arguments
+ const Tensor& from, Tensor* to) {
+ if (from.dtype() == DT_VARIANT) {
status_cb->Ref();
- *to = Tensor(out_allocator, from.dtype(), from.shape());
- send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
- wrapped_done_);
+ CopyDeviceToHost(&from, cpu_allocator, out_allocator, edge_name,
+ src, to, send_dev_context, wrapped_done_);
return Status::OK();
} else {
- return status_cb->status();
+ if (!DMAHelper::CanUseDMA(&from)) {
+ Status err = errors::InvalidArgument(
+ "During Variant Device->Host Copy: "
+ "non-DMA-copy attempted of tensor type: ",
+ DataTypeString(from.dtype()));
+ status_cb->UpdateStatus(err);
+ return err;
+ }
+ if (status_cb->ok()) {
+ status_cb->Ref();
+ *to = Tensor(out_allocator, from.dtype(), from.shape());
+ send_dev_context->CopyDeviceTensorToCPU(&from, edge_name, src, to,
+ wrapped_done_);
+ return Status::OK();
+ } else {
+ return status_cb->status();
+ }
}
},
std::move(wrapped_done), std::placeholders::_1, std::placeholders::_2);
@@ -347,7 +361,12 @@ namespace {
static Status WrappedTensorDeviceCopy(
const Tensor& from, Tensor* to,
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy) {
- if (DMAHelper::CanUseDMA(&from)) {
+ if (from.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ } else if (DMAHelper::CanUseDMA(&from)) {
TF_RETURN_IF_ERROR(copy(from, to));
} else {
*to = from;
@@ -358,7 +377,7 @@ static Status WrappedTensorDeviceCopy(
#define REGISTER_WRAPPED_TENSOR_COPY(DIRECTION) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- Tensor, DIRECTION, "tensorflow::Tensor", WrappedTensorDeviceCopy)
+ Tensor, DIRECTION, WrappedTensorDeviceCopy)
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_WRAPPED_TENSOR_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
diff --git a/tensorflow/core/common_runtime/debugger_state_interface.h b/tensorflow/core/common_runtime/debugger_state_interface.h
index e0fa983373..797a0ade53 100644
--- a/tensorflow/core/common_runtime/debugger_state_interface.h
+++ b/tensorflow/core/common_runtime/debugger_state_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
#include <memory>
@@ -117,4 +117,4 @@ class DebugGraphDecoratorRegistry {
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEBUGGER_STATE_INTERFACE_H_
diff --git a/tensorflow/core/common_runtime/device.h b/tensorflow/core/common_runtime/device.h
index b537666492..2ef1547cd9 100644
--- a/tensorflow/core/common_runtime/device.h
+++ b/tensorflow/core/common_runtime/device.h
@@ -26,8 +26,8 @@ limitations under the License.
// * Task numbers are within the specified replica, so there are as
// many "task zeros" as replicas.
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
#include <memory>
#include <string>
@@ -101,11 +101,21 @@ class Device : public DeviceBase {
}
}
+ // If true, and tracing is enabled, the `tracing::ScopedAnnotation()` tracing
+ // mechanism will be used instead of `tracing::ScopedActivity()`. Some devices
+ // may override this method to use annotations, which enable child activities
+ // (such as GPU kernel launches) to be related to the OpKernel invocation.
+ virtual bool TraceUsingAnnotations() const { return false; }
+
// Blocks until all operations queued on the device at the time of
// the call have completed. Returns any error pending on the device
// at completion.
virtual Status Sync() = 0;
+ // Override this to return true for devices that require a Sync() call before
+ // session completion.
+ virtual bool RequiresSyncOnCompletion() const { return false; }
+
// Optionally modify the device's GraphDef before execution.
//
// This method should be considered experimental and is supplied to enable
@@ -183,4 +193,4 @@ class Device : public DeviceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/device_factory.h b/tensorflow/core/common_runtime/device_factory.h
index 10eb62afa8..db50226fe8 100644
--- a/tensorflow/core/common_runtime/device_factory.h
+++ b/tensorflow/core/common_runtime/device_factory.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
#include <string>
#include <vector>
@@ -126,4 +126,4 @@ class Registrar {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_FACTORY_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/device_mgr.h b/tensorflow/core/common_runtime/device_mgr.h
index cd93f76324..c1ff10d9b5 100644
--- a/tensorflow/core/common_runtime/device_mgr.h
+++ b/tensorflow/core/common_runtime/device_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
#include <string>
#include <unordered_map>
@@ -77,4 +77,4 @@ class DeviceMgr {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_MGR_H_
diff --git a/tensorflow/core/common_runtime/device_resolver_local.h b/tensorflow/core/common_runtime/device_resolver_local.h
index 098eccdf84..bb6ff2efa0 100644
--- a/tensorflow/core/common_runtime/device_resolver_local.h
+++ b/tensorflow/core/common_runtime/device_resolver_local.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
#include <string>
@@ -45,4 +45,4 @@ class DeviceResolverLocal : public DeviceResolverInterface {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_RESOLVER_LOCAL_H_
diff --git a/tensorflow/core/common_runtime/device_set.h b/tensorflow/core/common_runtime/device_set.h
index 4cd56e583c..c384d46e97 100644
--- a/tensorflow/core/common_runtime/device_set.h
+++ b/tensorflow/core/common_runtime/device_set.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
-#define TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
#include <memory>
#include <unordered_map>
@@ -86,4 +86,4 @@ class DeviceSet {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DEVICE_SET_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DEVICE_SET_H_
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0695278c0d..458e133b68 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/run_handler.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
#endif // __ANDROID__
}
+static RunHandlerPool* GetOrCreateRunHandlerPool(
+ const SessionOptions& options) {
+ static RunHandlerPool* pool =
+ new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
+ return pool;
+}
+
+bool DirectSession::ShouldUseRunHandlerPool() const {
+ if (options_.config.session_inter_op_thread_pool_size() > 0 ||
+ options_.config.use_per_session_threads()) {
+ return false;
+ }
+ return true;
+}
+
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
@@ -363,7 +379,7 @@ Status DirectSession::MaybeInitializeExecutionState(
Status DirectSession::Create(const GraphDef& graph) {
TF_RETURN_IF_ERROR(init_error_);
if (graph.node_size() > 0) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (graph_created_) {
return errors::AlreadyExists(
"A Graph has already been created for this session.");
@@ -375,7 +391,7 @@ Status DirectSession::Create(const GraphDef& graph) {
Status DirectSession::Extend(const GraphDef& graph) {
TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
return ExtendLocked(graph);
}
@@ -451,8 +467,22 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
RunState run_state(step_id, &devices_);
run_state.rendez = new IntraProcessRendezvous(device_mgr_.get());
#ifndef __ANDROID__
- // Set up for collectives if the RunOption declares a key.
- if (run_options.experimental().collective_graph_key() > 0) {
+ // Set up for collectives if ExecutorsAndKeys declares a key.
+ if (executors_and_keys->collective_graph_key !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ if (run_options.experimental().collective_graph_key() !=
+ BuildGraphOptions::kNoCollectiveGraphKey) {
+ // If a collective_graph_key was specified in run_options, ensure that it
+ // matches what came out of GraphExecutionState::BuildGraph().
+ if (run_options.experimental().collective_graph_key() !=
+ executors_and_keys->collective_graph_key) {
+ return errors::Internal(
+ "collective_graph_key in RunOptions ",
+ run_options.experimental().collective_graph_key(),
+ " should match collective_graph_key from optimized graph ",
+ executors_and_keys->collective_graph_key);
+ }
+ }
if (!collective_executor_mgr_) {
std::unique_ptr<DeviceResolverInterface> drl(
new DeviceResolverLocal(device_mgr_.get()));
@@ -568,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- Executor::Args::Runner default_runner = [this,
- pool](Executor::Args::Closure c) {
- SchedClosure(pool, std::move(c));
- };
+ std::unique_ptr<RunHandler> handler;
+ if (ShouldUseRunHandlerPool() &&
+ run_options.experimental().use_run_handler_pool()) {
+ // Non-null only when a global inter-op pool is used.
+ VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
+ handler = GetOrCreateRunHandlerPool(options_)->Get();
+ }
+ auto* handler_ptr = handler.get();
+
+ Executor::Args::Runner default_runner = nullptr;
+
+ if (pool == nullptr) {
+ default_runner = [](Executor::Args::Closure c) { c(); };
+ } else if (handler_ptr != nullptr) {
+ default_runner = [handler_ptr](Executor::Args::Closure c) {
+ handler_ptr->ScheduleInterOpClosure(std::move(c));
+ };
+ } else {
+ default_runner = [this, pool](Executor::Args::Closure c) {
+ SchedClosure(pool, std::move(c));
+ };
+ }
+
for (const auto& item : executors_and_keys->items) {
- // TODO(zhengxq): support partial run.
- // TODO(zhengxq): if the device picks its own threadpool, we need to assign
+ // TODO(azaks): support partial run.
+ // TODO(azaks): if the device picks its own threadpool, we need to assign
// less threads to the main compute pool by default.
thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
+ // TODO(crk): Investigate usage of RunHandlerPool when using device specific
+ // thread pool(s).
if (!device_thread_pool) {
args.runner = default_runner;
} else {
@@ -602,7 +653,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
if (tracer) {
TF_RETURN_IF_ERROR(tracer->Stop());
- TF_RETURN_IF_ERROR(tracer->Collect(args.stats_collector));
+ TF_RETURN_IF_ERROR(tracer->Collect(run_state.collector.get()));
}
{
@@ -618,8 +669,8 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
&session_state_));
}
- if (args.stats_collector) {
- args.stats_collector->Finalize();
+ if (run_state.collector) {
+ run_state.collector->Finalize();
}
// Build and return the cost model as instructed.
@@ -634,7 +685,7 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
mutex_lock l(executor_lock_);
- args.stats_collector->BuildCostModel(&cost_model_manager_, device_to_graph);
+ run_state.collector->BuildCostModel(&cost_model_manager_, device_to_graph);
// annotate stats onto cost graph.
CostGraphDef* cost_graph = run_metadata->mutable_cost_graph();
@@ -678,10 +729,16 @@ Status DirectSession::Run(const RunOptions& run_options,
// Check if we already have an executor for these arguments.
ExecutorsAndKeys* executors_and_keys;
RunStateArgs run_state_args(run_options.debug_options());
+ run_state_args.collective_graph_key =
+ run_options.experimental().collective_graph_key();
TF_RETURN_IF_ERROR(GetOrCreateExecutors(input_tensor_names, output_names,
target_nodes, &executors_and_keys,
&run_state_args));
+ {
+ mutex_lock l(collective_graph_key_lock_);
+ collective_graph_key_ = executors_and_keys->collective_graph_key;
+ }
// Configure a call frame for the step, which we use to feed and
// fetch values to and from the executors.
@@ -1116,6 +1173,8 @@ Status DirectSession::CreateExecutors(
BuildGraphOptions options;
options.callable_options = callable_options;
options.use_function_convention = !run_state_args->is_partial_run;
+ options.collective_graph_key =
+ callable_options.run_options().experimental().collective_graph_key();
std::unique_ptr<FunctionInfo> func_info(new FunctionInfo);
std::unique_ptr<ExecutorsAndKeys> ek(new ExecutorsAndKeys);
@@ -1123,9 +1182,9 @@ Status DirectSession::CreateExecutors(
ek->callable_options = callable_options;
std::unordered_map<string, std::unique_ptr<Graph>> graphs;
- TF_RETURN_IF_ERROR(CreateGraphs(options, &graphs, &func_info->flib_def,
- run_state_args, &ek->input_types,
- &ek->output_types));
+ TF_RETURN_IF_ERROR(CreateGraphs(
+ options, &graphs, &func_info->flib_def, run_state_args, &ek->input_types,
+ &ek->output_types, &ek->collective_graph_key));
if (run_state_args->is_partial_run) {
ek->graph = std::move(run_state_args->graph);
@@ -1150,7 +1209,7 @@ Status DirectSession::CreateExecutors(
int graph_def_version;
{
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
graph_def_version =
execution_state_->original_graph_def().versions().producer();
}
@@ -1180,14 +1239,11 @@ Status DirectSession::CreateExecutors(
auto opseg = device->op_segment();
params.create_kernel = [this, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -1200,13 +1256,11 @@ Status DirectSession::CreateExecutors(
create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string()))
delete kernel;
- }
};
- optimizer.Optimize(lib, options_.env, device, &iter->second,
+ optimizer.Optimize(lib, options_.env, device, &partition_graph,
/*shape_map=*/nullptr);
// TensorFlow Debugger (tfdbg) inserts debug nodes in the graph.
@@ -1353,6 +1407,9 @@ Status DirectSession::GetOrCreateExecutors(
}
*callable_options.mutable_run_options()->mutable_debug_options() =
run_state_args->debug_options;
+ callable_options.mutable_run_options()
+ ->mutable_experimental()
+ ->set_collective_graph_key(run_state_args->collective_graph_key);
std::unique_ptr<ExecutorsAndKeys> ek;
std::unique_ptr<FunctionInfo> func_info;
TF_RETURN_IF_ERROR(
@@ -1379,8 +1436,8 @@ Status DirectSession::CreateGraphs(
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types) {
- mutex_lock l(graph_def_lock_);
+ DataTypeVector* output_types, int64* collective_graph_key) {
+ mutex_lock l(graph_state_lock_);
std::unique_ptr<ClientGraph> client_graph;
std::unique_ptr<GraphExecutionState> temp_exec_state_holder;
@@ -1403,6 +1460,7 @@ Status DirectSession::CreateGraphs(
TF_RETURN_IF_ERROR(
execution_state->BuildGraph(subgraph_options, &client_graph));
}
+ *collective_graph_key = client_graph->collective_graph_key;
if (subgraph_options.callable_options.feed_size() !=
client_graph->feed_types.size()) {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 72a2be4816..3a168bbe3f 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
-#define TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
#include <atomic>
#include <memory>
@@ -117,6 +117,9 @@ class DirectSession : public Session {
::tensorflow::Status ReleaseCallable(CallableHandle handle) override;
private:
+ // For access to collective_graph_key_.
+ friend class DirectSessionCollectiveTest;
+
// We create one executor and its dependent library runtime for
// every partition.
struct PerPartitionExecutorsAndLib {
@@ -150,6 +153,8 @@ class DirectSession : public Session {
DataTypeVector output_types;
CallableOptions callable_options;
+
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// A FunctionInfo object is created for every unique set of feeds/fetches.
@@ -203,13 +208,14 @@ class DirectSession : public Session {
string handle;
std::unique_ptr<Graph> graph;
const DebugOptions& debug_options;
+ int64 collective_graph_key = BuildGraphOptions::kNoCollectiveGraphKey;
};
// Initializes the base execution state given the 'graph',
// if not already initialized.
Status MaybeInitializeExecutionState(const GraphDef& graph,
bool* out_already_initialized)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
// Retrieves an already existing set of executors to run 'inputs' and
// 'outputs', or creates and caches them for future use.
@@ -234,15 +240,18 @@ class DirectSession : public Session {
std::unordered_map<string, std::unique_ptr<Graph>>* outputs,
std::unique_ptr<FunctionLibraryDefinition>* flib_def,
RunStateArgs* run_state_args, DataTypeVector* input_types,
- DataTypeVector* output_types);
+ DataTypeVector* output_types, int64* collective_graph_key);
::tensorflow::Status RunInternal(int64 step_id, const RunOptions& run_options,
CallFrameInterface* call_frame,
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata);
+ // Returns whether inter-op execution uses a global pool.
+ bool ShouldUseRunHandlerPool() const;
+
::tensorflow::Status ExtendLocked(const GraphDef& graph)
- EXCLUSIVE_LOCKS_REQUIRED(graph_def_lock_);
+ EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
::tensorflow::Status ResourceHandleToInputTensor(
const Tensor& resource_tensor, Tensor* retrieved_tensor);
@@ -283,7 +292,7 @@ class DirectSession : public Session {
}
::tensorflow::Status CheckGraphCreated(const char* method) {
- mutex_lock l(graph_def_lock_);
+ mutex_lock l(graph_state_lock_);
if (!graph_created_) {
return errors::InvalidArgument(
"Session was not created with a graph before ", method, "!");
@@ -307,10 +316,8 @@ class DirectSession : public Session {
DeviceSet device_set_;
string session_handle_;
- bool graph_created_ GUARDED_BY(graph_def_lock_) = false;
-
- mutex graph_def_lock_;
- GraphDef graph_def_ GUARDED_BY(graph_def_lock_);
+ mutex graph_state_lock_;
+ bool graph_created_ GUARDED_BY(graph_state_lock_) = false;
// The thread-pools to use for running ops, with a bool indicating if the pool
// is owned.
@@ -361,11 +368,11 @@ class DirectSession : public Session {
// nodes can not be moved to a different device. Maps node names to
// device names.
std::unordered_map<string, string> stateful_placements_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// Execution_state; used when placing the entire graph.
std::unique_ptr<GraphExecutionState> execution_state_
- GUARDED_BY(graph_def_lock_);
+ GUARDED_BY(graph_state_lock_);
// The function library, before any rewrites or optimizations have been
// performed. In particular, CreateGraphs() may need to modify the function
@@ -380,7 +387,7 @@ class DirectSession : public Session {
std::atomic<int64> edge_name_counter_ = {0};
std::atomic<int64> handle_name_counter_ = {0};
- // For generating step ids that are unique across all sessions.
+ // For generating step ids that are unique across this sessions.
static std::atomic_int_fast64_t step_id_counter_;
// Global timeout for all blocking operations in this session.
@@ -389,7 +396,9 @@ class DirectSession : public Session {
// Manages all the cost models for the graphs executed in this session.
CostModelManager cost_model_manager_;
- Executor::Args::NodeOutputsCallback node_outputs_callback_ = nullptr;
+ // For testing collective graph key generation.
+ mutex collective_graph_key_lock_;
+ int64 collective_graph_key_ GUARDED_BY(collective_graph_key_lock_) = -1;
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
@@ -399,4 +408,4 @@ class DirectSession : public Session {
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DIRECT_SESSION_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DIRECT_SESSION_H_
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 4b51b20bb1..a6440c55ad 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
+TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
+ Initialize({3, 2, -1, 0});
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+
+ // Prepares RunOptions and RunMetadata
+ RunOptions run_options;
+ run_options.mutable_experimental()->set_use_run_handler_pool(true);
+
+ Status s = session->Run(run_options, inputs, output_names, target_nodes,
+ &outputs, nullptr);
+ TF_ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initialized and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
@@ -1255,7 +1283,7 @@ TEST(DirectSessionTest, RunHandleTest) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
@@ -1308,7 +1336,7 @@ TEST(DirectSessionTest, RunHandleTest_Callable) {
ASSERT_TRUE(s.ok());
ASSERT_EQ(1, outputs.size());
- ResourceHandle resource_handle = outputs[0].scalar<ResourceHandle>()();
+ const ResourceHandle& resource_handle = outputs[0].scalar<ResourceHandle>()();
Tensor string_handle(DT_STRING, {});
string_handle.flat<string>().setConstant(resource_handle.name());
@@ -2218,4 +2246,115 @@ BENCHMARK(BM_FeedFetch)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
BENCHMARK(BM_FeedFetchCallable)->Arg(1)->Arg(2)->Arg(5)->Arg(10);
} // namespace
+
+class DirectSessionCollectiveTest : public ::testing::Test {
+ public:
+ // Creates a graph with CollectiveOps inside functions and runs it. Returns
+ // the generated collective_graph_key.
+ Status RunGraphWithCollectiveFunctions(bool add_unused_function,
+ int64* collective_graph_key) {
+ GraphDef g = CreateGraph(add_unused_function);
+ const Tensor t1 =
+ test::AsTensor<float>({0.1, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1});
+ const Tensor t2 =
+ test::AsTensor<float>({0.3, 1.3, 2.3, 3.3, 4.3, 5.3, 6.3, 7.3});
+ auto session = CreateSession();
+ TF_RETURN_IF_ERROR(session->Create(g));
+ std::vector<Tensor> outputs;
+ TF_RETURN_IF_ERROR(
+ session->Run({{"input0:0", t1}, {"input1:0", t2}}, {},
+ {"collective_call0:0", "collective_call1:0"}, &outputs));
+ DirectSession* direct_session = static_cast<DirectSession*>(session.get());
+ {
+ mutex_lock l(direct_session->collective_graph_key_lock_);
+ *collective_graph_key = direct_session->collective_graph_key_;
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Creates a function with name `function_name` and a single CollectiveReduce
+ // node with instance key set as `instance_key`.
+ FunctionDef CollectiveFunction(const string& function_name,
+ int instance_key) {
+ return FunctionDefHelper::Define(
+ // Function name
+ function_name,
+ // In def
+ {"arg:float"},
+ // Out def
+ {"reduce:float"},
+ // Attr def
+ {},
+ // Node def
+ {{
+ {"reduce"},
+ "CollectiveReduce",
+ {"arg"},
+ {{"group_size", 2},
+ {"group_key", 1},
+ {"instance_key", instance_key},
+ {"subdiv_offsets", gtl::ArraySlice<int32>({0})},
+ {"merge_op", "Add"},
+ {"final_op", "Div"},
+ {"T", DT_FLOAT}},
+ }});
+ }
+
+ NodeDef Input(int id) {
+ AttrValue dtype_attr;
+ SetAttrValue(DT_FLOAT, &dtype_attr);
+ NodeDef input;
+ input.set_name(strings::StrCat("input", id));
+ input.set_op("Placeholder");
+ input.mutable_attr()->insert({"dtype", dtype_attr});
+ return input;
+ }
+
+ NodeDef CollectiveCall(const string& op, const string& input, int cpu_id) {
+ NodeDef collective_call;
+ collective_call.set_name(strings::StrCat("collective_call", cpu_id));
+ collective_call.set_op(op);
+ collective_call.add_input(input);
+ collective_call.set_device(
+ strings::StrCat("/job:localhost/replica:0/task:0/device:CPU:", cpu_id));
+ return collective_call;
+ }
+
+ // Creates a GraphDef that adds two CollectiveFunctions, one each on CPU0 and
+ // CPU1, with instance_key 1, and appropriate placeholder inputs. If
+ // `add_unused_function` is true, adds another CollectiveFunction with
+ // instance_key 2 that is not invoked in the graph.
+ GraphDef CreateGraph(bool add_unused_function) {
+ GraphDef g;
+ FunctionDef collective_function =
+ CollectiveFunction("CollectiveFunction1", 1);
+ FunctionDefLibrary* lib = g.mutable_library();
+ *lib->add_function() = collective_function;
+ if (add_unused_function) {
+ FunctionDef unused_function =
+ CollectiveFunction("CollectiveFunction2", 2);
+ *lib->add_function() = unused_function;
+ }
+
+ *g.add_node() = Input(0);
+ *g.add_node() = Input(1);
+ // CollectiveReduce on CPU0 with instance_key 1.
+ *g.add_node() = CollectiveCall("CollectiveFunction1", "input0", 0);
+ // CollectiveReduce on CPU1 with instance_key 1.
+ *g.add_node() = CollectiveCall("CollectiveFunction1", "input1", 1);
+
+ return g;
+ }
+};
+
+TEST_F(DirectSessionCollectiveTest,
+ TestCollectiveGraphKeyUsesOnlyCalledFunctions) {
+ int64 key1;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(false, &key1));
+ int64 key2;
+ TF_ASSERT_OK(RunGraphWithCollectiveFunctions(true, &key2));
+ ASSERT_EQ(key1, key2);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
index 0b096a14a3..2c63b8704e 100644
--- a/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_with_tracking_alloc_test.cc
@@ -77,6 +77,9 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
options.config.mutable_graph_options()
->mutable_rewrite_options()
->set_min_graph_nodes(-1);
+ options.config.mutable_graph_options()
+ ->mutable_rewrite_options()
+ ->set_pin_to_host_optimization(RewriterConfig::OFF);
std::unique_ptr<Session> session(NewSession(options));
TF_ASSERT_OK(session->Create(def));
std::vector<std::pair<string, Tensor>> inputs;
@@ -105,7 +108,7 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
EXPECT_EQ(2, shape.dim(0).size());
EXPECT_EQ(1, shape.dim(1).size());
if (node->name() == y->name()) {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// if MKL is used, it goes through various additional
// graph rewrite pass. In TF, everytime a graph pass
// happens, "constant" nodes are allocated
@@ -114,16 +117,16 @@ TEST(DirectSessionWithTrackingAllocTest, CostModelTest) {
// which increments the value of AllocationId.
// Thus AllocationId becomes more than TF if MKL
// is used. Now IDs for MKL are 8 more than TF.
- EXPECT_EQ(29, cm->AllocationId(node, 0));
-#else
EXPECT_EQ(21, cm->AllocationId(node, 0));
-#endif
- } else {
-#ifdef INTEL_MKL
- EXPECT_EQ(30, cm->AllocationId(node, 0));
#else
+ EXPECT_EQ(13, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
+ } else {
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
EXPECT_EQ(22, cm->AllocationId(node, 0));
-#endif
+#else
+ EXPECT_EQ(14, cm->AllocationId(node, 0));
+#endif // INTEL_MKL && ENABLE_MKL
}
}
EXPECT_LE(0, cm->MaxExecutionTime(node));
diff --git a/tensorflow/core/common_runtime/dma_helper.h b/tensorflow/core/common_runtime/dma_helper.h
index cdfce1f366..4a76cff1e3 100644
--- a/tensorflow/core/common_runtime/dma_helper.h
+++ b/tensorflow/core/common_runtime/dma_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_DMA_HELPER_H_
-#define TENSORFLOW_COMMON_RUNTIME_DMA_HELPER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_
#include "tensorflow/core/framework/tensor.h"
@@ -35,4 +35,4 @@ class DMAHelper {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_DMA_HELPER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_DMA_HELPER_H_
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index 7f28f3b793..7b74c67c85 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -147,10 +147,11 @@ tf_cuda_library(
"kernel_and_device.h",
],
visibility = ["//tensorflow:internal"],
- deps = select({
+ deps = [
+ "@farmhash_archive//:farmhash",
+ ] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
- "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
"//tensorflow/core:core_cpu_lib",
@@ -219,11 +220,13 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":kernel_and_device",
- "//tensorflow/c:c_api",
+ "@farmhash_archive//:farmhash",
+ # Only the TF_AttrType enum is required, so pull in just the C headers.
+ # TODO(b/113535673): Break this dependency and avoid the C header completely.
+ "//tensorflow/c:c_api_headers",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
- "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
@@ -249,6 +252,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc
index 92307d78f2..5c8369de87 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.cc
+++ b/tensorflow/core/common_runtime/eager/attr_builder.cc
@@ -103,7 +103,6 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out) {
return *this; \
}
-DEFINE_SET_ATTR(StringPiece, string_attrs_);
DEFINE_SET_ATTR(float, float_attrs_);
DEFINE_SET_ATTR(int, int_attrs_);
DEFINE_SET_ATTR(bool, bool_attrs_);
@@ -119,9 +118,6 @@ AttrBuilder& AttrBuilder::NumInputs(int n) {
void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
bool include_those_in_node_def) const {
- for (const auto& p : string_attrs_) {
- SetInAttrValueMap(m, p.first, p.second);
- }
for (const auto& p : int_attrs_) {
SetInAttrValueMap(m, p.first, p.second);
}
@@ -140,6 +136,22 @@ void AttrBuilder::FillAttrValueMap(AttrValueMap* m,
m->insert(*it);
}
}
+ // For any attr-value pairs that exist in the op def (from op registry) but
+ // not `m`, fill them into `m`, so that we can run a TFE_Op without having to
+ // specify all the default attr values (e.g. for matmul, the `transpose_a`
+ // attr defaults to false).
+ const OpDef* op_def = nullptr;
+ Status s = OpDefForOp(op_name_.c_str(), &op_def);
+ // This is expected, if this op is a custom function, and is therefore not
+ // present in the op registry.
+ if (!s.ok()) return;
+
+ DCHECK(op_def);
+ for (const auto& attr_def : op_def->attr()) {
+ if (attr_def.has_default_value() && !m->count(attr_def.name())) {
+ SetInAttrValueMap(m, attr_def.name(), attr_def.default_value());
+ }
+ }
}
const NodeDef& AttrBuilder::BuildNodeDef() {
@@ -211,10 +223,6 @@ tensorflow::Fprint128 AttrBuilder::CacheKey(const string& device) const {
// not been called.
if (node_def_finalized_) return f;
}
- for (const auto& p : string_attrs_) {
- CombineUnordered(
- CacheKeyHelper(p.first, tensorflow::Fingerprint128(p.second)), &f);
- }
for (const auto& p : int_attrs_) {
CombineUnordered(CacheKeyHelper(p.first, static_cast<uint64>(p.second)),
&f);
diff --git a/tensorflow/core/common_runtime/eager/attr_builder.h b/tensorflow/core/common_runtime/eager/attr_builder.h
index 929b1b8296..c114ea4ba0 100644
--- a/tensorflow/core/common_runtime/eager/attr_builder.h
+++ b/tensorflow/core/common_runtime/eager/attr_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_C_EAGER_RUNTIME_H_
-#define TENSORFLOW_C_EAGER_RUNTIME_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
// Support for eager execution of TensorFlow kernels.
@@ -110,6 +110,12 @@ class AttrBuilder {
using AttrVec = tensorflow::gtl::InlinedVector<std::pair<StringPiece, T>, 2>;
void MayBeInitializeNodeDef();
+ // Fill `m` with the attr-value pairs set via AttrBuilder::Set() so far, as
+ // well as any default attr-value pairs from the associated op_def, if there
+ // is one.
+ //
+ // If `include_those_in_node_def` is true, also include any attr-value pairs
+ // from `node_def_`.
void FillAttrValueMap(AttrValueMap* m, bool include_those_in_node_def) const;
template <class T>
@@ -122,16 +128,15 @@ class AttrBuilder {
AttrValue attr_value;
if (found == nullptr) {
SetAttrValue(value, &attr_value);
- m->insert(AttrValueMap::value_type(attr_name.ToString(), attr_value));
+ m->insert(AttrValueMap::value_type(string(attr_name), attr_value));
} else {
// TODO(ashankar): Do what is done in
// NodeDefBuilder::CheckInconsistency(attr_name, *found, attr_value);
SetAttrValue(std::forward<T>(value), &attr_value);
- (*m)[attr_name.ToString()] = attr_value;
+ (*m)[string(attr_name)] = attr_value;
}
}
- AttrVec<StringPiece> string_attrs_;
AttrVec<int> int_attrs_;
AttrVec<float> float_attrs_;
AttrVec<bool> bool_attrs_;
@@ -143,8 +148,6 @@ class AttrBuilder {
}; // namespace tensorflow
template <>
-AttrBuilder& AttrBuilder::Set(StringPiece attr_name, StringPiece&& value);
-template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, int&& value);
template <>
AttrBuilder& AttrBuilder::Set(StringPiece attr_name, float&& value);
@@ -157,4 +160,4 @@ AttrBuilder& AttrBuilder::Set(StringPiece attr_name,
} // namespace tensorflow
-#endif // TENSORFLOW_C_EAGER_RUNTIME_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EAGER_ATTR_BUILDER_H_
diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc
index 5bdd547c7f..f23cefb33d 100644
--- a/tensorflow/core/common_runtime/eager/context.cc
+++ b/tensorflow/core/common_runtime/eager/context.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eager/context.h"
+#include "tensorflow/core/common_runtime/device_set.h"
#include "tensorflow/core/common_runtime/process_util.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/blocking_counter.h"
@@ -25,40 +26,63 @@ namespace {
bool ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val) {
bool val;
- if (ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
+ if (tensorflow::ReadBoolFromEnvVar(env_var_name, default_val, &val).ok()) {
return val;
}
return default_val;
}
+std::unique_ptr<thread::ThreadPool> EagerThreadPool(
+ const SessionOptions& opts) {
+ SessionOptions opts_copy(opts);
+ if (opts_copy.config.inter_op_parallelism_threads() == 0) {
+ // Eager defaults to a single thread when no threads are specified.
+ opts_copy.config.set_inter_op_parallelism_threads(1);
+ }
+
+ return std::unique_ptr<thread::ThreadPool>(
+ NewThreadPoolFromSessionOptions(opts_copy));
+}
+
} // namespace
EagerContext::EagerContext(const SessionOptions& opts,
ContextDevicePlacementPolicy default_policy,
- bool async, std::unique_ptr<DeviceMgr> device_mgr,
+ bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
Rendezvous* rendezvous)
+ : EagerContext(opts, default_policy, async, device_mgr.release(),
+ /*device_mgr_owned*/ true, rendezvous) {}
+
+EagerContext::EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy,
+ bool async, const DeviceMgr* device_mgr,
+ bool device_mgr_owned, Rendezvous* rendezvous)
: policy_(default_policy),
- local_device_manager_(std::move(device_mgr)),
- local_unowned_device_manager_(nullptr),
- devices_(local_device_manager_->ListDevices()),
+ devices_(device_mgr->ListDevices()),
rendezvous_(rendezvous),
- thread_pool_(NewThreadPoolFromSessionOptions(opts)),
+ thread_pool_(EagerThreadPool(opts)),
pflr_(new ProcessFunctionLibraryRuntime(
- local_device_manager_.get(), opts.env, TF_GRAPH_DEF_VERSION,
- &func_lib_def_, {}, thread_pool_.get())),
+ device_mgr, opts.env, TF_GRAPH_DEF_VERSION, &func_lib_def_, {},
+ thread_pool_.get())),
log_device_placement_(opts.config.log_device_placement()),
num_active_steps_(0),
async_default_(async),
+ log_memory_(LogMemory::IsEnabled()),
env_(opts.env),
- use_send_tensor_rpc_(false) {
- InitDeviceMapAndAsync();
- if (opts.config.inter_op_parallelism_threads() > 0) {
- runner_ = [this](std::function<void()> closure) {
- this->thread_pool_->Schedule(closure);
- };
+ use_send_tensor_rpc_(false),
+ pin_small_ops_to_cpu_(ReadBoolFromEnvVar(
+ "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING", true)) {
+ if (device_mgr_owned) {
+ local_device_manager_.reset(device_mgr);
+ local_unowned_device_manager_ = nullptr;
} else {
- runner_ = [](std::function<void()> closure) { closure(); };
+ local_unowned_device_manager_ = device_mgr;
}
+ InitDeviceMapAndAsync();
+ runner_ = [this](std::function<void()> closure) {
+ this->thread_pool_->Schedule(std::move(closure));
+ };
}
void EagerContext::InitDeviceMapAndAsync() {
@@ -78,6 +102,12 @@ void EagerContext::InitDeviceMapAndAsync() {
}
}
}
+
+ DeviceSet ds;
+ for (Device* d : devices_) {
+ ds.AddDevice(d);
+ }
+ prioritized_device_type_list_ = ds.PrioritizedDeviceTypeList();
}
bool EagerContext::Async() const {
diff --git a/tensorflow/core/common_runtime/eager/context.h b/tensorflow/core/common_runtime/eager/context.h
index 21c5bdf8e9..15eeaa8066 100644
--- a/tensorflow/core/common_runtime/eager/context.h
+++ b/tensorflow/core/common_runtime/eager/context.h
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/eager/eager_client.h"
#include "tensorflow/core/distributed_runtime/server_lib.h"
#endif
+#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/rendezvous.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
@@ -65,10 +66,17 @@ enum ContextDevicePlacementPolicy {
class EagerContext {
public:
- explicit EagerContext(const SessionOptions& opts,
- ContextDevicePlacementPolicy default_policy, bool async,
- std::unique_ptr<DeviceMgr> device_mgr,
- Rendezvous* rendezvous);
+ // TODO: remove this constructor once we migrate all callers to the next one.
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ std::unique_ptr<const DeviceMgr> device_mgr,
+ Rendezvous* rendezvous);
+
+ EagerContext(const SessionOptions& opts,
+ ContextDevicePlacementPolicy default_policy, bool async,
+ const DeviceMgr* device_mgr, bool device_mgr_owned,
+ Rendezvous* rendezvous);
+
~EagerContext();
// Returns the function library runtime for the given device.
@@ -93,6 +101,9 @@ class EagerContext {
// TODO(apassos) make this return a constant reference
std::vector<Device*>* devices() { return &devices_; }
+ const std::vector<DeviceType>& prioritized_device_type_list() {
+ return prioritized_device_type_list_;
+ }
// Clears the kernel caches.
void ClearCaches();
@@ -131,6 +142,7 @@ class EagerContext {
void AddKernelToCache(Fprint128 cache_key, KernelAndDevice* kernel);
bool LogDevicePlacement() { return log_device_placement_; }
+ bool LogMemory() { return log_memory_; }
Rendezvous* GetRendezvous() { return rendezvous_; }
@@ -190,6 +202,7 @@ class EagerContext {
// EagerService.SendTensor RPC. If false, _Send/_Recv ops should be used
// instead (which in-turn use WorkerService.RecvTensor RPCs).
bool UseSendTensorRPC() { return use_send_tensor_rpc_; }
+ bool PinSmallOpsToCPU() { return pin_small_ops_to_cpu_; }
private:
void InitDeviceMapAndAsync();
@@ -204,11 +217,13 @@ class EagerContext {
thread_local_policies_ GUARDED_BY(policy_map_mu_);
// Only one of the below is set.
- std::unique_ptr<DeviceMgr> local_device_manager_;
- DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<const DeviceMgr> local_device_manager_;
+ const DeviceMgr* local_unowned_device_manager_;
+ std::unique_ptr<DeviceMgr> remote_device_manager_;
// Devices owned by device_manager
std::vector<Device*> devices_;
+ std::vector<DeviceType> prioritized_device_type_list_;
// All devices are not owned.
gtl::FlatMap<string, Device*, StringPieceHasher> devices_map_;
Rendezvous* rendezvous_;
@@ -249,11 +264,12 @@ class EagerContext {
std::unordered_map<std::thread::id, bool> thread_local_async_
GUARDED_BY(async_map_mu_);
+ const bool log_memory_;
+
Env* const env_;
#ifndef __ANDROID__
void CloseRemoteContexts();
- std::unique_ptr<DeviceMgr> remote_device_manager_;
// The server_ is not const since we release it when the context is destroyed.
// Therefore the server_ object is not marked as const (even though it should
@@ -278,6 +294,7 @@ class EagerContext {
#endif
bool use_send_tensor_rpc_;
+ const bool pin_small_ops_to_cpu_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/execute.cc b/tensorflow/core/common_runtime/eager/execute.cc
index 46065f399c..a52f933d75 100644
--- a/tensorflow/core/common_runtime/eager/execute.cc
+++ b/tensorflow/core/common_runtime/eager/execute.cc
@@ -192,17 +192,14 @@ Status ValidateInputTypeAndPlacement(EagerContext* ctx, Device* op_device,
}
Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
- DeviceSet ds;
- for (Device* d : *ctx->devices()) {
- ds.AddDevice(d);
- }
DeviceTypeVector final_devices;
- auto status = SupportedDeviceTypesForNode(ds.PrioritizedDeviceTypeList(),
- ndef, &final_devices);
- if (!status.ok()) return status;
+ TF_RETURN_IF_ERROR(SupportedDeviceTypesForNode(
+ ctx->prioritized_device_type_list(), ndef, &final_devices));
if (final_devices.empty()) {
- return errors::Internal("Could not find valid device for node ",
- ndef.DebugString());
+ return errors::Internal(
+ "Could not find valid device for node.\nNode: ", SummarizeNodeDef(ndef),
+ "\nAll kernels registered for op ", ndef.op(), " :\n",
+ KernelsRegisteredForOp(ndef.op()));
}
for (Device* d : *ctx->devices()) {
if (d->device_type() == final_devices[0].type_string()) {
@@ -211,7 +208,7 @@ Status SelectDevice(const NodeDef& ndef, EagerContext* ctx, Device** device) {
}
}
return errors::Unknown("Could not find a device for node ",
- ndef.DebugString());
+ SummarizeNodeDef(ndef));
}
Status GetOutputDTypes(EagerOperation* op, DataTypeVector* output_dtypes) {
@@ -254,26 +251,6 @@ Status EagerLocalExecute(EagerOperation* op,
EagerContext* ctx = op->EagerContext();
auto status = ctx->GetStatus();
if (!status.ok()) return status;
- // Ensure all resource-touching ops run in the device the resource is,
- // regardless of anything else that has been specified. This is identical to
- // the graph mode behavior.
- for (int i = 0; i < op->Inputs().size(); ++i) {
- Device* input_op_device = nullptr;
- status = op->Inputs()[i]->OpDevice(&input_op_device);
- if (!status.ok()) return status;
- VLOG(2) << "for op " << op->Name() << " input " << i << " "
- << DataTypeString(op->Inputs()[i]->dtype) << " "
- << (input_op_device == nullptr ? "cpu" : input_op_device->name())
- << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
- if (op->Inputs()[i]->dtype == DT_RESOURCE &&
- (input_op_device != op->Device() || input_op_device == nullptr)) {
- Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
- VLOG(1) << "Changing device of operation " << op->Name() << " to "
- << d->name() << " because input #" << i
- << " is a resource in this device.";
- op->SetDevice(d);
- }
- }
Device* device = op->Device();
Fprint128 cache_key = op->MutableAttrs()->CacheKey(
@@ -299,7 +276,7 @@ Status EagerLocalExecute(EagerOperation* op,
LOG(INFO) << "Executing op " << ndef.op() << " in device "
<< device->name();
}
- kernel = new KernelAndDevice(ctx->GetRendezvous());
+ kernel = new KernelAndDevice(ctx->GetRendezvous(), ctx->LogMemory());
auto* flr = ctx->func_lib(device);
if (flr == nullptr) {
@@ -602,11 +579,81 @@ Status EagerRemoteExecute(EagerOperation* op, TensorHandle** retvals,
return Status::OK();
#endif
}
+
+// The Op device may be updated if:
+// - A resource touching input is specified: all resource-touching ops run in
+// the device the resource is, regardless of anything else that has been
+// specified. This is identical to the graph mode behavior.
+//
+// - All op inputs are on the CPU, small (<64 elements) and integers
+// (int32/int64). This can be disabled by setting the environment variable
+// "TF_EAGER_ENABLE_SMALL_TENSOR_CPU_PINNING" to "0" or "false".
+Status MaybeUpdateOpDevice(EagerOperation* op) {
+ EagerContext* ctx = op->EagerContext();
+ bool device_set_for_resource_variable = false;
+ bool all_inputs_eligible_for_cpu_pinning = ctx->PinSmallOpsToCPU();
+
+ for (int i = 0; i < op->Inputs().size(); ++i) {
+ Device* input_op_device = nullptr;
+ TF_RETURN_IF_ERROR(op->Inputs()[i]->OpDevice(&input_op_device));
+ VLOG(2) << "for op " << op->Name() << " input " << i << " "
+ << DataTypeString(op->Inputs()[i]->dtype) << " "
+ << (input_op_device == nullptr ? "cpu" : input_op_device->name())
+ << " " << (op->Device() == nullptr ? "cpu" : op->Device()->name());
+ if (op->Inputs()[i]->dtype == DT_RESOURCE &&
+ (input_op_device != op->Device() || input_op_device == nullptr)) {
+ Device* d = input_op_device == nullptr ? ctx->HostCPU() : input_op_device;
+ VLOG(1) << "Changing device of operation " << op->Name() << " to "
+ << d->name() << " because input #" << i
+ << " is a resource in this device.";
+ op->SetDevice(d);
+
+ device_set_for_resource_variable = true;
+ all_inputs_eligible_for_cpu_pinning = false;
+ } else if (all_inputs_eligible_for_cpu_pinning) {
+ TensorHandle* handle = op->Inputs()[i];
+
+ // Input is on CPU.
+ if (input_op_device != nullptr && input_op_device != ctx->HostCPU()) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ if (handle->dtype != DataType::DT_INT32 &&
+ handle->dtype != DataType::DT_INT64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ continue;
+ }
+
+ int64 num_elements;
+ TF_RETURN_IF_ERROR(handle->NumElements(&num_elements));
+ if (num_elements > 64) {
+ all_inputs_eligible_for_cpu_pinning = false;
+ }
+ }
+ }
+
+ // Ops without inputs are usually ops that generate a tensor in some way and
+ // usually require being present on whatever device they are scheduled on
+ // - for e.g. VarHandleOp or _Recv).
+ // TODO(nareshmodi): Is it possible there is no int32/int64 CPU kernel for
+ // an op, but there is a GPU kernel?
+ if (!op->Inputs().empty() && all_inputs_eligible_for_cpu_pinning) {
+ VLOG(1) << "Forcing op " << op->Name()
+ << " to be on the CPU since all input tensors have an "
+ "int32/int64 dtype, and are small (less than 64 elements).";
+ op->SetDevice(ctx->HostCPU());
+ }
+
+ return Status::OK();
+}
} // namespace
Status EagerExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
+ TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
+
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.cc b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
index 3d61ff4dc2..83d8425477 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.cc
@@ -32,21 +32,6 @@ limitations under the License.
namespace tensorflow {
// static
-Status KernelAndDevice::InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out) {
- OpKernel* k = nullptr;
- Status s = CreateOpKernel(device->device_type().c_str(), device,
- device->GetAllocator(AllocatorAttributes()),
- nullptr, ndef, TF_GRAPH_DEF_VERSION, &k);
- out->device_ = device;
- out->kernel_.reset(k);
- out->flib_ = nullptr;
- out->runner_ = nullptr;
- out->default_runner_ = [](std::function<void()> f) { f(); };
- return s;
-}
-
-// static
Status KernelAndDevice::Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out) {
@@ -95,6 +80,7 @@ Status KernelAndDevice::Run(ScopedStepContainer* step_container,
params.slice_reader_cache = &slice_reader_cache_;
params.rendezvous = rendez_;
params.cancellation_manager = &cm_;
+ params.log_memory = log_memory_;
if (stats != nullptr) {
params.track_allocations = true;
}
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device.h b/tensorflow/core/common_runtime/eager/kernel_and_device.h
index 0ef419cbaa..04151a1171 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device.h
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device.h
@@ -52,12 +52,12 @@ class KernelAndDevice {
static Status Init(const NodeDef& ndef, FunctionLibraryRuntime* flib,
std::function<void(std::function<void()>)>* runner,
KernelAndDevice* out);
- // TODO(ashankar): Remove this
- static Status InitOp(Device* device, const NodeDef& ndef,
- KernelAndDevice* out);
- KernelAndDevice(tensorflow::Rendezvous* rendez)
- : device_(nullptr), flib_(nullptr), rendez_(rendez) {}
+ KernelAndDevice(tensorflow::Rendezvous* rendez, bool log_memory)
+ : device_(nullptr),
+ flib_(nullptr),
+ rendez_(rendez),
+ log_memory_(log_memory) {}
// TODO(ashankar): Handle list-valued inputs.
Status Run(std::vector<Tensor>* inputs, std::vector<Tensor>* outputs,
@@ -87,6 +87,7 @@ class KernelAndDevice {
DataTypeVector output_dtypes_;
std::function<void(std::function<void()>)>* runner_;
std::function<void(std::function<void()>)> default_runner_;
+ const bool log_memory_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
index 6abe98f53c..da280b2317 100644
--- a/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
+++ b/tensorflow/core/common_runtime/eager/kernel_and_device_test.cc
@@ -104,7 +104,7 @@ void BM_KernelAndDeviceInit(int iters) {
.NumInputs(2)
.BuildNodeDef());
TestEnv env;
- KernelAndDevice k(nullptr);
+ KernelAndDevice k(nullptr, false);
tensorflow::testing::StartTiming();
for (int i = 0; i < iters; ++i) {
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
@@ -127,7 +127,7 @@ void BM_KernelAndDeviceRun(int iters) {
.NumInputs(inputs.size())
.BuildNodeDef());
TestEnv env;
- KernelAndDevice kernel(nullptr);
+ KernelAndDevice kernel(nullptr, false);
TF_CHECK_OK(KernelAndDevice::Init(ndef, env.function_library_runtime(),
nullptr, &kernel));
tensorflow::testing::StartTiming();
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc
index 85b0b79bce..d58724cbfa 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.cc
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc
@@ -125,7 +125,6 @@ Status TensorHandle::Shape(tensorflow::TensorShape* shape) {
Status TensorHandle::NumDims(int* num_dims) {
if (IsRemote()) {
TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
- CHECK(remote_shape_ != nullptr);
*num_dims = remote_shape_->dims();
} else {
TF_RETURN_IF_ERROR(WaitReady());
@@ -153,6 +152,21 @@ Status TensorHandle::Dim(int dim_index, int64* dim) {
return Status::OK();
}
+Status TensorHandle::NumElements(int64* num_elements) {
+ if (IsRemote()) {
+ TF_RETURN_IF_ERROR(WaitForNode(remote_shape_node_id_, false));
+ *num_elements = remote_shape_->num_elements();
+ } else {
+ TF_RETURN_IF_ERROR(WaitReady());
+ DCHECK(IsReady());
+ DCHECK(num_elements != nullptr);
+
+ *num_elements = tensor_.NumElements();
+ }
+
+ return Status::OK();
+}
+
Status TensorHandle::RemoteAddress(int64* op_id, int32* output_num) {
if (!IsRemote()) {
return errors::FailedPrecondition(
@@ -193,7 +207,6 @@ Status TensorHandle::CopyToDevice(EagerContext* ctx, tensorflow::Device* dstd,
// has device type XLA_CPU, and the other CPU.
const bool both_on_cpu = src_cpu && dst_cpu;
if (is_same_device || both_on_cpu) {
- dstd = dst_cpu ? nullptr : dstd;
*output = new tensorflow::TensorHandle(*src, dstd, dstd, ctx);
return tensorflow::Status::OK();
}
diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h
index 1bc9c6531a..e55f1a0338 100644
--- a/tensorflow/core/common_runtime/eager/tensor_handle.h
+++ b/tensorflow/core/common_runtime/eager/tensor_handle.h
@@ -113,6 +113,7 @@ class TensorHandle : public core::RefCounted {
Status NumDims(int* num_dims);
Status Dim(int dim_index, int64* dim);
+ Status NumElements(int64* num_elements);
// Return the op_id and output num if the handle refers to a remote tensor.
Status RemoteAddress(int64* op_id, int32* output_num);
diff --git a/tensorflow/core/common_runtime/eval_const_tensor.cc b/tensorflow/core/common_runtime/eval_const_tensor.cc
index c1542f1f57..87749da7af 100644
--- a/tensorflow/core/common_runtime/eval_const_tensor.cc
+++ b/tensorflow/core/common_runtime/eval_const_tensor.cc
@@ -113,6 +113,13 @@ Status TryToInferTensorOutputFromInputShapes(const Edge& edge,
return Status::OK();
}
+// Returns true if 'node' has a registered CPU kernel.
+bool HasCpuKernel(const Node& node) {
+ return FindKernelDef(DeviceType(DEVICE_CPU), node.def(), /*def=*/nullptr,
+ /*kernel_class_name=*/nullptr)
+ .ok();
+}
+
// Extracts the subgraph ending at 'target_node' that is statically computable
// and inserts into 'out_graph'. If statically computable, 'is_constant_graph'
// will be set to true.
@@ -136,6 +143,12 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ // Since constant-folding runs on the CPU, do not attempt to constant-fold
+ // operators that have no CPU kernel.
+ if (!HasCpuKernel(target_node)) {
+ return Status::OK();
+ }
+
// TODO(skyewm): should more of the filtering applied in input nodes below be
// applied to target_node here?
@@ -201,6 +214,11 @@ Status ExtractConstantSubgraph(
return Status::OK();
}
+ if (!HasCpuKernel(*current_node)) {
+ *is_constant_graph = false;
+ return Status::OK();
+ }
+
// If there is nothing more to recurse down, see if
// the generator node is a constant.
if (current_node->num_inputs() == 0) {
diff --git a/tensorflow/core/common_runtime/executor.cc b/tensorflow/core/common_runtime/executor.cc
index c2fac4c2c8..2c48084cab 100644
--- a/tensorflow/core/common_runtime/executor.cc
+++ b/tensorflow/core/common_runtime/executor.cc
@@ -72,141 +72,49 @@ bool IsInitializationOp(const Node* node) {
return node->op_def().allows_uninitialized_input();
}
-// Sets the timeline_label field of *node_stats, using data from *node.
-// Returns true iff the node is a transfer node.
-// TODO(tucker): merge with the DetailText function in session.cc
-// in a common location.
-bool SetTimelineLabel(const Node* node, NodeExecStatsWrapper* stats) {
- bool is_transfer_node = false;
- if (!stats) {
- return is_transfer_node;
- }
- string memory;
- for (auto& all : stats->stats()->memory()) {
- int64 tot = all.total_bytes();
- if (tot >= 0.1 * 1048576.0) {
- int64 peak = all.peak_bytes();
- if (peak > 0) {
- memory =
- strings::StrCat(memory, "[", all.allocator_name(),
- strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
- peak / 1048576.0));
- } else {
- memory = strings::StrCat(memory, "[", all.allocator_name(),
- strings::Printf(" %.1fMB] ", tot / 1048576.0));
- }
- }
- }
- const AttrSlice attrs = node->attrs();
- string text;
- if (IsSend(node)) {
- string tensor_name;
- TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
- string recv_device;
- TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
- "(", tensor_name, " @", recv_device);
- is_transfer_node = true;
- } else if (IsRecv(node)) {
- string tensor_name;
- TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
- string send_device;
- TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
- text = strings::StrCat(memory, node->name(), " = ", node->type_string(),
- "(", tensor_name, " @", send_device);
- is_transfer_node = true;
- } else {
- text =
- strings::StrCat(memory, node->name(), " = ", node->type_string(), "(",
- str_util::Join(node->requested_inputs(), ", "), ")");
- }
- stats->stats()->set_timeline_label(text);
- return is_transfer_node;
-}
-
// Helper routines for collecting step stats.
namespace nodestats {
-inline int64 NowInUsec() { return Env::Default()->NowMicros(); }
inline int64 NowInNsec() { return Env::Default()->NowNanos(); }
-void SetScheduled(NodeExecStatsWrapper* stats, int64 nanos) {
+void SetScheduled(NodeExecStatsInterface* stats, int64 micros) {
if (!stats) return;
- stats->stats()->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
- stats->stats()->set_scheduled_nanos(nanos);
+ stats->SetScheduled(micros * EnvTime::kMicrosToNanos);
}
-void SetAllStart(NodeExecStatsWrapper* stats) {
+void SetAllStart(NodeExecStatsInterface* stats) {
if (!stats) return;
- int64 now_nanos = NowInNsec();
- stats->stats()->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
- stats->stats()->set_all_start_nanos(now_nanos);
+ stats->RecordExecutorStarted();
}
-void SetOpStart(NodeExecStatsWrapper* stats) {
+void SetOpStart(NodeExecStatsInterface* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_op_start_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordComputeStarted();
}
-void SetOpEnd(NodeExecStatsWrapper* stats) {
+void SetOpEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_op_end_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordComputeEnded();
}
-void SetAllEnd(NodeExecStatsWrapper* stats) {
+void SetAllEnd(NodeExecStatsInterface* stats) {
if (!stats) return;
- NodeExecStats* nt = stats->stats();
- DCHECK_NE(nt->all_start_micros(), 0);
- DCHECK_NE(nt->all_start_nanos(), 0);
- int64 now_nanos = NowInNsec();
- nt->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
- nt->all_start_micros());
- nt->set_all_end_rel_nanos(now_nanos - nt->all_start_nanos());
+ stats->RecordExecutorEnded();
}
-void SetOutput(NodeExecStatsWrapper* stats, int slot, const Tensor* v) {
+void SetOutput(NodeExecStatsInterface* stats, int slot, const Tensor* v) {
if (!stats) return;
- DCHECK(v);
- NodeOutput* no = stats->stats()->add_output();
- no->set_slot(slot);
- v->FillDescription(no->mutable_tensor_description());
+ stats->SetOutput(slot, v);
}
-void SetMemory(NodeExecStatsWrapper* stats, OpKernelContext* ctx) {
+void SetMemory(NodeExecStatsInterface* stats, OpKernelContext* ctx) {
if (!stats) return;
-
- for (const auto& allocator_pair : ctx->wrapped_allocators()) {
- stats->AddAllocation(allocator_pair.first, allocator_pair.second);
- }
- auto* ms = stats->stats()->mutable_memory_stats();
- ms->set_temp_memory_size(ctx->temp_memory_allocated());
- for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
- ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
- }
- ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+ stats->SetMemory(ctx);
}
-void SetReferencedTensors(NodeExecStatsWrapper* stats,
+void SetReferencedTensors(NodeExecStatsInterface* stats,
const TensorReferenceVector& tensors) {
if (!stats) return;
- // be careful not to increment the reference count on any tensor
- // while recording the information
- for (size_t i = 0; i < tensors.size(); ++i) {
- AllocationDescription* description =
- stats->stats()->add_referenced_tensor();
- tensors.at(i).FillDescription(description);
- }
+ stats->SetReferencedTensors(tensors);
}
} // namespace nodestats
@@ -235,6 +143,8 @@ struct NodeItem {
bool kernel_is_async : 1; // True iff kernel->AsAsync() != nullptr
bool is_merge : 1; // True iff IsMerge(node)
bool is_enter : 1; // True iff IsEnter(node)
+ bool is_constant_enter : 1; // True iff IsEnter(node) and
+ // node->GetAttr("is_constant") == true.
bool is_exit : 1; // True iff IsExit(node)
bool is_control_trigger : 1; // True iff IsControlTrigger(node)
bool is_sink : 1; // True iff IsSink(node)
@@ -718,6 +628,14 @@ Status ExecutorImpl::Initialize() {
item->kernel_is_async = (item->kernel->AsAsync() != nullptr);
item->is_merge = IsMerge(n);
item->is_enter = IsEnter(n);
+ if (item->is_enter) {
+ bool is_constant_enter;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(n->attrs(), "is_constant", &is_constant_enter));
+ item->is_constant_enter = is_constant_enter;
+ } else {
+ item->is_constant_enter = false;
+ }
item->is_exit = IsExit(n);
item->is_control_trigger = IsControlTrigger(n);
item->is_sink = IsSink(n);
@@ -1319,7 +1237,10 @@ class ExecutorState {
TensorStore* tensor_store_;
// Step-local container.
ScopedStepContainer* step_container_;
- StepStatsCollector* stats_collector_;
+ StepStatsCollectorInterface* const stats_collector_;
+ const tracing::TraceCollector* const trace_collector_;
+ const tracing::EventCollector* const event_collector_;
+
// QUESTION: Make it a checkpoint::TensorSliceReaderCacheWrapper
// instead of a pointer? (avoids having to delete).
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache_;
@@ -1328,6 +1249,7 @@ class ExecutorState {
CancellationManager* cancellation_manager_;
Executor::Args::Runner runner_;
bool sync_on_finish_;
+ const bool trace_using_annotations_;
// Owned.
@@ -1384,7 +1306,7 @@ class ExecutorState {
// After item->kernel computation is done, processes its outputs.
Status ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
- EntryVector* outputs, NodeExecStatsWrapper* stats);
+ EntryVector* outputs, NodeExecStatsInterface* stats);
// After processing the outputs, propagates the outputs to their dsts.
// Contents of *outputs are left in an indeterminate state after
@@ -1395,7 +1317,7 @@ class ExecutorState {
// "node" just finishes. Takes ownership of "stats". Returns true if
// execution has completed.
bool NodeDone(const Status& s, const Node* node, const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready);
// Schedule all the expensive nodes in 'ready', and put all the inexpensive
@@ -1442,12 +1364,16 @@ ExecutorState::ExecutorState(const Executor::Args& args, ExecutorImpl* impl)
tensor_store_(args.tensor_store),
step_container_(args.step_container),
stats_collector_(args.stats_collector),
+ trace_collector_(tracing::GetTraceCollector()),
+ event_collector_(
+ tracing::GetEventCollector(tracing::EventCategory::kCompute)),
slice_reader_cache_(new checkpoint::TensorSliceReaderCacheWrapper),
call_frame_(args.call_frame),
impl_(impl),
cancellation_manager_(args.cancellation_manager),
runner_(args.runner),
sync_on_finish_(args.sync_on_finish),
+ trace_using_annotations_(impl->params_.device->TraceUsingAnnotations()),
num_outstanding_ops_(0) {
// We start the entire execution in iteration 0 of the root frame
// so let us create the root frame and the state for iteration 0.
@@ -1565,6 +1491,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
const Status fill_status =
device->FillContextMap(graph, &device_context_map_);
if (!fill_status.ok()) {
+ delete this;
done(fill_status);
return;
}
@@ -1575,6 +1502,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
ready.push_back(TaggedNode{n, root_frame_, 0, false});
}
if (ready.empty()) {
+ delete this;
done(Status::OK());
} else {
num_outstanding_ops_ = ready.size();
@@ -1594,7 +1522,7 @@ void ExecutorState::RunAsync(Executor::DoneCallback done) {
struct ExecutorState::AsyncState {
AsyncState(const OpKernelContext::Params& p, const TaggedNode& _tagged_node,
const NodeItem* _item, Entry* _first_input,
- NodeExecStatsWrapper* _stats)
+ NodeExecStatsInterface* _stats)
: saved_inputs(*p.inputs),
saved_input_device_contexts(*p.input_device_contexts),
saved_input_alloc_attrs(*p.input_alloc_attrs),
@@ -1619,7 +1547,7 @@ struct ExecutorState::AsyncState {
const NodeItem* item;
Entry* first_input;
OpKernelContext ctx;
- NodeExecStatsWrapper* stats;
+ NodeExecStatsInterface* stats;
private:
OpKernelContext::Params* ParamsButClearingEigenGPUDevice(
@@ -1631,6 +1559,32 @@ struct ExecutorState::AsyncState {
}
};
+// Returns true if `item` might be traced by the given trace and event
+// collectors. Returns false only if `item` definitely will not be traced.
+bool MightTrace(const NodeItem& item,
+ const tracing::TraceCollector* trace_collector,
+ const tracing::EventCollector* event_collector,
+ bool using_annotations) {
+ // Tracing will only be enabled if either `event_collector` is non null,
+ // or `trace_collector` is non-null and enabled for this particular kernel.
+ // Although `tracing::ScopedActivity`,
+ // `tracing::ScopedAnnotation`, and `tracing::ScopedRegion` check subsets of
+ // these properties internally in their constructors, the cost of passing the
+ // necessary arguments to them can be significant, so we avoid constructing
+ // them in the common case (when we know they will not be used).
+ if (event_collector != nullptr) {
+ return true;
+ }
+ if (trace_collector) {
+ if (using_annotations) {
+ return trace_collector->IsEnabledForAnnotations();
+ } else {
+ return trace_collector->IsEnabledForActivities(item.kernel_is_expensive);
+ }
+ }
+ return false;
+}
+
void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
const GraphView& gview = impl_->gview_;
TaggedNodeSeq ready;
@@ -1664,7 +1618,8 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
params.stats_collector = stats_collector_;
Status s;
- NodeExecStatsWrapper* stats = nullptr;
+ NodeExecStatsInterface* stats = nullptr;
+
EntryVector outputs;
bool completed = false;
inline_ready.push_back(tagged_node);
@@ -1694,15 +1649,14 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (stats_collector_ && !tagged_node.is_dead) {
// track allocations if and only if we are collecting statistics
params.track_allocations = true;
- stats = new NodeExecStatsWrapper;
- stats->stats()->set_node_name(node->name());
+ stats = stats_collector_->CreateNodeExecStats(node);
nodestats::SetScheduled(stats, scheduled_nsec);
nodestats::SetAllStart(stats);
}
if (vlog_) {
VLOG(1) << "Process node: " << id << " step " << params.step_id << " "
- << SummarizeNode(*node) << " is dead: " << tagged_node.is_dead
+ << SummarizeNode(*node) << (tagged_node.is_dead ? " is dead" : "")
<< " device: " << device->name();
}
@@ -1753,7 +1707,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
auto done = [this, state]() {
Device* device = impl_->params_.device;
- NodeExecStatsWrapper* stats = state->stats; // Shorthand
+ NodeExecStatsInterface* stats = state->stats; // Shorthand
Entry* first_input = state->first_input; // Shorthand
nodestats::SetOpEnd(stats);
@@ -1764,7 +1718,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
VLOG(2) << "Async kernel done: " << state->item->node->id()
<< " step " << step_id_ << " "
<< SummarizeNode(*state->item->node)
- << " is dead: " << state->tagged_node.is_dead
+ << (state->tagged_node.is_dead ? " is dead" : "")
<< " device: " << device->name();
}
@@ -1802,7 +1756,32 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
// Synchronous computes.
OpKernelContext ctx(&params, item.num_outputs);
nodestats::SetOpStart(stats);
- device->Compute(CHECK_NOTNULL(op_kernel), &ctx);
+
+ if (TF_PREDICT_FALSE(MightTrace(item, trace_collector_,
+ event_collector_,
+ trace_using_annotations_))) {
+ const string& op_name = op_kernel->name();
+ tracing::ScopedRegion region(tracing::EventCategory::kCompute,
+ op_name);
+ if (trace_using_annotations_) {
+ // The OpKernel may create child activities (such as GPU kernel
+ // launches), so use a `ScopedAnnotation` to relate these activities
+ // in the trace.
+ tracing::ScopedAnnotation activity(op_name,
+ op_kernel->type_string());
+ device->Compute(op_kernel, &ctx);
+ } else {
+ // Use the cheaper `ScopedActivity` to trace just the OpKernel
+ // execution.
+ tracing::ScopedActivity activity(op_name, op_kernel->type_string(),
+ item.kernel_is_expensive);
+ device->Compute(op_kernel, &ctx);
+ }
+ } else {
+ // In the common case, avoid creating any tracing objects.
+ device->Compute(op_kernel, &ctx);
+ }
+
nodestats::SetOpEnd(stats);
s = ProcessOutputs(item, &ctx, &outputs, stats);
if (s.ok() && impl_->device_record_tensor_accesses_) {
@@ -1818,7 +1797,7 @@ void ExecutorState::Process(TaggedNode tagged_node, int64 scheduled_nsec) {
if (vlog_) {
VLOG(2) << "Synchronous kernel done: " << id << " step "
<< params.step_id << " " << SummarizeNode(*node)
- << " is dead: " << tagged_node.is_dead
+ << (tagged_node.is_dead ? " is dead: " : "")
<< " device: " << device->name();
}
@@ -1944,7 +1923,7 @@ Status ExecutorState::PrepareInputs(const NodeItem& item, Entry* first_input,
Status ExecutorState::ProcessOutputs(const NodeItem& item, OpKernelContext* ctx,
EntryVector* outputs,
- NodeExecStatsWrapper* stats) {
+ NodeExecStatsInterface* stats) {
const Node* node = item.node;
DCHECK_EQ(0, outputs->size());
outputs->resize(item.num_outputs);
@@ -2079,15 +2058,12 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
is_frame_done = input_frame->DecrementOutstandingOpsLocked(
&impl_->gview_, input_iter, ready);
} else if (item->is_enter) {
- bool is_constant;
- const Status s = GetNodeAttr(node->attrs(), "is_constant", &is_constant);
- DCHECK(s.ok()) << s;
FindOrCreateChildFrame(input_frame, input_iter, node, &output_frame);
output_iter = 0;
{
const NodeItem* item = impl_->gview_.node(node->id());
mutex_lock l(output_frame->mu);
- if (is_constant) {
+ if (item->is_constant_enter) {
// Propagate to all active iterations if this is a loop invariant.
output_frame->AddLoopInv(item, (*outputs)[0], ready);
} else {
@@ -2162,15 +2138,15 @@ void ExecutorState::PropagateOutputs(const TaggedNode& tagged_node,
bool ExecutorState::NodeDone(const Status& s, const Node* node,
const TaggedNodeSeq& ready,
- NodeExecStatsWrapper* stats,
+ NodeExecStatsInterface* stats,
TaggedNodeReadyQueue* inline_ready) {
nodestats::SetAllEnd(stats);
- if (stats_collector_ != nullptr && !SetTimelineLabel(node, stats)) {
- // Only record non-transfer nodes.
- // Transfers 'stats' ownership to 'stats_collector_'.
- stats_collector_->Save(impl_->params_.device->name(), stats);
- } else if (stats) {
- delete stats;
+ if (stats) {
+ if (stats_collector_) {
+ stats->Done(impl_->params_.device->name());
+ } else {
+ delete stats;
+ }
}
bool abort_run = false;
@@ -2392,13 +2368,15 @@ void ExecutorState::Finish() {
auto done_cb = std::move(done_cb_);
auto runner = std::move(runner_);
mu_.unlock();
- if (sync_on_finish_ && status.ok()) {
+ Device* device = impl_->params_.device;
+ if ((sync_on_finish_ && status.ok()) || device->RequiresSyncOnCompletion()) {
// Block until the device has finished all queued operations. For
// devices like GPUs that continue to execute Ops after their Compute
// methods have completed, this ensures that control is not returned to
// the user until the step (and its side-effects) has actually completed.
- status = impl_->params_.device->Sync();
+ status.Update(device->Sync());
}
+
delete this;
CHECK(done_cb != nullptr);
runner([=]() { done_cb(status); });
@@ -2502,8 +2480,7 @@ void ExecutorState::DeleteFrame(FrameState* frame, TaggedNodeSeq* ready) {
}
if (dst_ready) {
if (IsControlTrigger(dst_node)) dst_dead = false;
- ready->push_back(
- TaggedNode(dst_node, parent_frame, parent_iter, dst_dead));
+ ready->emplace_back(dst_node, parent_frame, parent_iter, dst_dead);
parent_iter_state->outstanding_ops++;
}
}
@@ -2627,7 +2604,7 @@ void ExecutorState::FrameState::ActivateNodes(const NodeItem* item,
// Add dst to the ready queue if it's ready
if (dst_ready) {
if (dst_item->is_control_trigger) dst_dead = false;
- ready->push_back(TaggedNode(dst_item->node, this, iter, dst_dead));
+ ready->emplace_back(dst_item->node, this, iter, dst_dead);
iter_state->outstanding_ops++;
}
}
diff --git a/tensorflow/core/common_runtime/executor.h b/tensorflow/core/common_runtime/executor.h
index cd01b43aea..34bf73972f 100644
--- a/tensorflow/core/common_runtime/executor.h
+++ b/tensorflow/core/common_runtime/executor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/rendezvous.h"
@@ -83,7 +83,7 @@ class Executor {
struct Args {
int64 step_id = 0;
Rendezvous* rendezvous = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
CallFrameInterface* call_frame = nullptr;
CancellationManager* cancellation_manager = nullptr;
SessionState* session_state = nullptr;
@@ -97,12 +97,6 @@ class Executor {
typedef std::function<void()> Closure;
typedef std::function<void(Closure)> Runner;
Runner runner = nullptr;
-
- // A callback that is invoked each time a node has finished executing.
- typedef std::function<Status(const string& node_name, const int output_slot,
- const Tensor* tensor, const bool is_ref,
- OpKernelContext* ctx)>
- NodeOutputsCallback;
};
typedef std::function<void(const Status&)> DoneCallback;
virtual void RunAsync(const Args& args, DoneCallback done) = 0;
@@ -235,4 +229,4 @@ void DeleteNonCachedKernel(OpKernel* kernel);
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_EXECUTOR_H_
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index 54bbe84b57..472865ca43 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -310,6 +310,7 @@ class CallOp : public AsyncOpKernel {
opts.step_container = ctx->step_container();
opts.stats_collector = ctx->stats_collector();
opts.runner = ctx->runner();
+ opts.collective_executor = ctx->collective_executor();
std::vector<Tensor> args;
args.reserve(ctx->num_inputs());
for (int i = 0; i < ctx->num_inputs(); ++i) {
@@ -346,9 +347,10 @@ const FunctionBody* FunctionLibraryRuntimeImpl::GetFunctionBody(Handle h) {
return nullptr;
}
- mutex_lock l(mu_);
- CHECK_EQ(1, items_.count(local_handle));
- return items_[local_handle]->func_graph;
+ tf_shared_lock l(mu_);
+ auto iter = items_.find(local_handle);
+ CHECK(iter != items_.end());
+ return iter->second->func_graph;
}
Status FunctionLibraryRuntimeImpl::CreateKernel(const NodeDef& ndef,
@@ -412,9 +414,8 @@ Status FunctionLibraryRuntimeImpl::CreateKernel(
device_type, device_, device_->GetAllocator(AllocatorAttributes()), &ndef,
&fbody->fdef.signature(), this, fbody->arg_types, input_memory_types,
fbody->ret_types, output_memory_types, graph_def_version_, &s);
- *kernel = new CallOp(handle, &construction);
- if (!s.ok()) {
- delete *kernel;
+ if (s.ok()) {
+ *kernel = new CallOp(handle, &construction);
}
return s;
}
@@ -555,6 +556,12 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
next_handle_++;
}
}
+
+ if (options.create_kernels_eagerly) {
+ Item* item;
+ TF_RETURN_IF_ERROR(GetOrCreateItem(*handle, &item));
+ }
+
return Status::OK();
}
@@ -607,11 +614,14 @@ void PruneFunctionBody(Graph* g) {
std::unordered_set<const Node*> nodes;
for (auto n : g->nodes()) {
// NOTE(mrry): "_Retval" nodes are stateful, and so will be added
- // to the seed set of `nodes`.
+ // to the seed set of `nodes`. "_Arg" nodes are also stateful, but we
+ // specifically exclude them as seeds, to avoid unconditionally executing
+ // unused argument nodes (e.g. in a function like `lambda x, y: y`).
// TODO(mrry): Investigate whether the `n->IsControlFlow()` test is
// still needed. It would be preferable to prune entire loops and/or
// conditionals if they are not used in the graph.
- if (n->IsControlFlow() || n->op_def().is_stateful()) {
+ if (n->IsControlFlow() ||
+ (n->op_def().is_stateful() && n->type_string() != kArgOp)) {
nodes.insert(n);
}
}
@@ -627,7 +637,7 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
const FunctionLibraryDefinition* lib_def;
string executor_type;
{
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
fbody = (*item)->func_graph;
lib_def = (*item)->overlay_lib;
executor_type = (*item)->executor_type;
@@ -676,12 +686,13 @@ Status FunctionLibraryRuntimeImpl::CreateItem(Handle handle, Item** item) {
Status FunctionLibraryRuntimeImpl::GetOrCreateItem(Handle handle, Item** item) {
LocalHandle local_handle = parent_->GetHandleOnDevice(device_name_, handle);
{
- mutex_lock l(mu_);
- if (items_.count(local_handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = items_.find(local_handle);
+ if (iter == items_.end()) {
return errors::NotFound("Function handle ", handle,
" is not valid. Likely an internal error.");
}
- *item = items_[local_handle].get();
+ *item = iter->second.get();
if ((*item)->exec != nullptr) {
return Status::OK();
}
@@ -916,29 +927,18 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
}
DCHECK(run_opts.runner != nullptr);
- Executor::Args* exec_args = new Executor::Args;
+ Executor::Args exec_args;
// Inherit the step_id from the caller.
- exec_args->step_id = run_opts.step_id;
- exec_args->rendezvous = run_opts.rendezvous;
- exec_args->stats_collector = run_opts.stats_collector;
- exec_args->cancellation_manager = run_opts.cancellation_manager;
- exec_args->collective_executor = run_opts.collective_executor;
- exec_args->step_container = run_opts.step_container;
- exec_args->runner = *run_opts.runner;
- exec_args->call_frame = frame;
-
- item->exec->RunAsync(
- // Executor args
- *exec_args,
- // Done callback.
- std::bind(
- [item, frame, exec_args](DoneCallback done,
- // Start unbound arguments.
- const Status& status) {
- delete exec_args;
- done(status);
- },
- std::move(done), std::placeholders::_1));
+ exec_args.step_id = run_opts.step_id;
+ exec_args.rendezvous = run_opts.rendezvous;
+ exec_args.stats_collector = run_opts.stats_collector;
+ exec_args.cancellation_manager = run_opts.cancellation_manager;
+ exec_args.collective_executor = run_opts.collective_executor;
+ exec_args.step_container = run_opts.step_container;
+ exec_args.runner = *run_opts.runner;
+ exec_args.call_frame = frame;
+
+ item->exec->RunAsync(exec_args, std::move(done));
}
bool FunctionLibraryRuntimeImpl::IsStateful(const string& func) {
diff --git a/tensorflow/core/common_runtime/function.h b/tensorflow/core/common_runtime/function.h
index a274f1ef51..eeca66f5d0 100644
--- a/tensorflow/core/common_runtime/function.h
+++ b/tensorflow/core/common_runtime/function.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
-#define TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
#include <functional>
#include <memory>
@@ -170,4 +170,4 @@ Status FunctionDefToBodyHelper(
FunctionBody** fbody);
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_FUNCTION_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_FUNCTION_H_
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 120f480198..7bab9be9a6 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -802,9 +802,9 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
// Name
"SquareAndAddOneWithStatefulNodes",
// Args
- {"x: int32"},
+ {"x: int32", "y: float32"},
// Return values
- {"y: int32"},
+ {"z: int32"},
// Attrs
{},
// Nodes
@@ -822,12 +822,13 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
"RandomUniform",
{"shape"},
{{"T", T}, {"dtype", DT_FLOAT}}},
- // y = Add<T>(a, o)
- {{"y"}, "Add", {"a", "o"}, {{"T", T}}}});
+ // z = Add<T>(a, o)
+ {{"z"}, "Add", {"a", "o"}, {{"T", T}}}});
Init({stateful_func});
auto x = test::AsTensor<int32>({1, 2, 3, 4});
- Tensor y;
+ auto y = test::AsTensor<float>({1.0, 2.0, 3.0, 4.0});
+ Tensor z;
FunctionLibraryRuntime::Handle handle;
TF_CHECK_OK(
@@ -837,18 +838,19 @@ TEST_F(FunctionLibraryRuntimeTest, PruneBody) {
StepStatsCollector stats_collector(&stats);
FunctionLibraryRuntime::Options opts;
opts.stats_collector = &stats_collector;
- TF_CHECK_OK(Run(flr0_, handle, opts, {x}, {&y}));
+ TF_CHECK_OK(Run(flr0_, handle, opts, {x, y}, {&z}));
TF_CHECK_OK(flr0_->ReleaseHandle(handle));
TF_CHECK_OK(InstantiateAndRun(flr0_, "SquareAndAddOneWithStatefulNodes", {},
- {x}, {&y}));
- test::ExpectTensorEqual<int>(y, test::AsTensor<int32>({2, 5, 10, 17}));
+ {x, y}, {&z}));
+ test::ExpectTensorEqual<int>(z, test::AsTensor<int32>({2, 5, 10, 17}));
stats_collector.FinalizeAndSwap(&stats);
- // Note that we do not expect the nodes named "x1", "x2", or "x3" to execute.
+ // Note that we do not expect the nodes named "y", "x1", "x2", or "x3" to
+ // execute.
std::set<string> expected_node_names(
- {"_SOURCE", "shape", "x", "o", "a", "keep_me", "y", "y_RetVal"});
+ {"_SOURCE", "shape", "x", "o", "a", "keep_me", "z", "z_RetVal"});
std::set<string> executed_node_names;
for (const auto& node_stats : stats.dev_stats()[0].node_stats()) {
executed_node_names.insert(node_stats.node_name());
diff --git a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
index 636cd43575..6bd29ef775 100644
--- a/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/cuda_host_allocator.h
@@ -26,8 +26,12 @@ namespace tensorflow {
class CUDAHostAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
- : stream_exec_(stream_exec) {
+ explicit CUDAHostAllocator(se::StreamExecutor* stream_exec, int numa_node,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ numa_node_(numa_node) {
CHECK(stream_exec_ != nullptr);
}
~CUDAHostAllocator() override {}
@@ -39,19 +43,23 @@ class CUDAHostAllocator : public SubAllocator {
if (ptr == nullptr) {
LOG(WARNING) << "could not allocate pinned host memory of size: "
<< num_bytes;
+ return ptr;
}
+ VisitAlloc(ptr, numa_node_, num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, numa_node_, num_bytes);
stream_exec_->HostMemoryDeallocate(ptr);
}
}
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const int numa_node_;
TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
index 2d4c8d0201..42021e51f3 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc
@@ -22,18 +22,48 @@ limitations under the License.
namespace tensorflow {
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const string& name)
- : GPUBFCAllocator(cuda_gpu_id, total_memory, GPUOptions(), name) {}
+bool GPUBFCAllocator::GetAllowGrowthValue(const GPUOptions& gpu_options) {
+ const char* force_allow_growth_string =
+ std::getenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ if (force_allow_growth_string == nullptr) {
+ return gpu_options.allow_growth();
+ }
+
+ if (strcmp("false", force_allow_growth_string) == 0) {
+ if (gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return false;
+ } else if (strcmp("true", force_allow_growth_string) == 0) {
+ if (!gpu_options.allow_growth()) {
+ LOG(WARNING)
+ << "Overriding allow_growth setting because the"
+ << " TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original"
+ << " config value was " << gpu_options.allow_growth() << ".";
+ }
+ return true;
+ }
+
+ LOG(ERROR)
+ << "The TF_FORCE_GPU_ALLOW_GROWTH environment variable is set but could"
+ << " not be parsed: \"" << force_allow_growth_string << "\". Valid"
+ << " values are \"true\" or \"false\". Using original config value"
+ << " of " << gpu_options.allow_growth() << ".";
+ return gpu_options.allow_growth();
+}
+
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+ size_t total_memory, const string& name)
+ : GPUBFCAllocator(sub_allocator, total_memory, GPUOptions(), name) {}
-GPUBFCAllocator::GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
+GPUBFCAllocator::GPUBFCAllocator(GPUMemAllocator* sub_allocator,
+ size_t total_memory,
const GPUOptions& gpu_options,
const string& name)
- : BFCAllocator(
- new GPUMemAllocator(
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie(),
- gpu_options.per_process_gpu_memory_fraction() > 1.0 ||
- gpu_options.experimental().use_unified_memory()),
- total_memory, gpu_options.allow_growth(), name) {}
+ : BFCAllocator(sub_allocator, total_memory,
+ GPUBFCAllocator::GetAllowGrowthValue(gpu_options), name) {}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index a3e0d0734f..d4c9cee89a 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
#include <memory>
#include <string>
@@ -31,28 +31,20 @@ limitations under the License.
namespace tensorflow {
-// A GPU memory allocator that implements a 'best-fit with coalescing'
-// algorithm.
-class GPUBFCAllocator : public BFCAllocator {
- public:
- // 'cuda_gpu_id' refers to the ID of the GPU device within
- // the process and must reference a valid ID in the process.
- GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const string& name);
- GPUBFCAllocator(CudaGpuId cuda_gpu_id, size_t total_memory,
- const GPUOptions& gpu_options, const string& name);
- virtual ~GPUBFCAllocator() {}
-
- TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
-};
-
// Suballocator for GPU memory.
class GPUMemAllocator : public SubAllocator {
public:
+ // 'platform_gpu_id' refers to the ID of the GPU device within
+ // the process and must reference a valid ID in the process.
// Note: stream_exec cannot be null.
explicit GPUMemAllocator(se::StreamExecutor* stream_exec,
- bool use_unified_memory)
- : stream_exec_(stream_exec), use_unified_memory_(use_unified_memory) {
+ PlatformGpuId gpu_id, bool use_unified_memory,
+ const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors),
+ stream_exec_(stream_exec),
+ gpu_id_(gpu_id),
+ use_unified_memory_(use_unified_memory) {
CHECK(stream_exec_ != nullptr);
}
~GPUMemAllocator() override {}
@@ -65,12 +57,14 @@ class GPUMemAllocator : public SubAllocator {
} else {
ptr = stream_exec_->AllocateArray<char>(num_bytes).opaque();
}
+ VisitAlloc(ptr, gpu_id_.value(), num_bytes);
}
return ptr;
}
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
+ VisitFree(ptr, gpu_id_.value(), num_bytes);
if (use_unified_memory_) {
stream_exec_->UnifiedMemoryDeallocate(ptr);
} else {
@@ -82,11 +76,28 @@ class GPUMemAllocator : public SubAllocator {
private:
se::StreamExecutor* stream_exec_; // not owned, non-null
+ const PlatformGpuId gpu_id_;
const bool use_unified_memory_ = false;
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
+// A GPU memory allocator that implements a 'best-fit with coalescing'
+// algorithm.
+class GPUBFCAllocator : public BFCAllocator {
+ public:
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const string& name);
+ GPUBFCAllocator(GPUMemAllocator* sub_allocator, size_t total_memory,
+ const GPUOptions& gpu_options, const string& name);
+ ~GPUBFCAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GPUBFCAllocator);
+
+ private:
+ static bool GetAllowGrowthValue(const GPUOptions& gpu_options);
+};
+
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_BFC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
index 67caeb3495..60e82ed13b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator_test.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_id_utils.h"
#include "tensorflow/core/common_runtime/gpu/gpu_init.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -46,7 +47,11 @@ static void CheckStats(Allocator* a, int64 num_allocs, int64 bytes_in_use,
}
TEST(GPUBFCAllocatorTest, NoDups) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
// Allocate a lot of raw pointers
@@ -75,7 +80,11 @@ TEST(GPUBFCAllocatorTest, NoDups) {
}
TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Allocate 256 raw pointers of sizes between 100 bytes and about
// a meg
random::PhiloxRandom philox(123, 17);
@@ -133,7 +142,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocations) {
}
TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
CheckStats(&a, 0, 0, 0, 0);
float* first_ptr = a.Allocate<float>(1024);
@@ -168,18 +181,30 @@ TEST(GPUBFCAllocatorTest, ExerciseCoalescing) {
}
TEST(GPUBFCAllocatorTest, AllocateZeroBufSize) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* ptr = a.Allocate<float>(0);
EXPECT_EQ(nullptr, ptr);
}
TEST(GPUBFCAllocatorTest, TracksSizes) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
@@ -187,8 +212,12 @@ TEST(GPUBFCAllocatorTest, AllocatedVsRequested) {
}
TEST(GPUBFCAllocatorTest, TestCustomMemoryLimit) {
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
// Configure a 1MiB byte limit
- GPUBFCAllocator a(CudaGpuId(0), 1 << 20, "GPU_0_bfc");
+ GPUBFCAllocator a(sub_allocator, 1 << 20, "GPU_0_bfc");
float* first_ptr = a.Allocate<float>(1 << 6);
float* second_ptr = a.Allocate<float>(1 << 20);
@@ -203,7 +232,11 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
options.set_allow_growth(true);
// Max of 2GiB, but starts out small.
- GPUBFCAllocator a(CudaGpuId(0), 1LL << 31, options, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1LL << 31, "GPU_0_bfc");
// Allocate 10 raw pointers of sizes between 100 bytes and about
// 64 megs.
@@ -264,8 +297,15 @@ TEST(GPUBFCAllocatorTest, AllocationsAndDeallocationsWithGrowth) {
}
TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
- GPUBFCAllocator a(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
- GPUBFCAllocator b(CudaGpuId(0), 1UL << 60, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1UL << 60, "GPU_0_bfc");
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator b(sub_allocator, 1UL << 60, "GPU_0_bfc");
void* amem = a.AllocateRaw(1, 1);
void* bmem = b.AllocateRaw(1, 1 << 30);
a.DeallocateRaw(amem);
@@ -273,7 +313,11 @@ TEST(GPUBFCAllocatorTest, DISABLED_AllocatorReceivesZeroMemory) {
}
static void BM_Allocation(int iters) {
- GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<size_t> sizes = {256, 4096, 16384, 524288,
512, 1048576, 10485760, 104857600,
@@ -289,7 +333,11 @@ static void BM_Allocation(int iters) {
BENCHMARK(BM_Allocation);
static void BM_AllocationThreaded(int iters, int num_threads) {
- GPUBFCAllocator a(CudaGpuId(0), 1uLL << 33, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1uLL << 33, "GPU_0_bfc");
thread::ThreadPool pool(Env::Default(), "test", num_threads);
std::atomic_int_fast32_t count(iters);
mutex done_lock;
@@ -325,7 +373,11 @@ BENCHMARK(BM_AllocationThreaded)->Arg(1)->Arg(4)->Arg(16);
// A more complex benchmark that defers deallocation of an object for
// "delay" allocations.
static void BM_AllocationDelayed(int iters, int delay) {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
// Exercise a few different allocation sizes
std::vector<int> sizes = {256, 4096, 16384, 4096, 512, 1024, 1024};
int size_index = 0;
@@ -358,12 +410,18 @@ BENCHMARK(BM_AllocationDelayed)->Arg(1)->Arg(10)->Arg(100)->Arg(1000);
class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
protected:
+ void SetUp() override { CHECK_EQ(unsetenv("TF_FORCE_GPU_ALLOW_GROWTH"), 0); }
+
// The following test methods are called from tests. The reason for this is
// that this class is a friend class to BFCAllocator, but tests are not, so
// only methods inside this class can access private members of BFCAllocator.
void TestBinDebugInfo() {
- GPUBFCAllocator a(CudaGpuId(0), 1 << 30, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 << 30, "GPU_0_bfc");
std::vector<void*> initial_ptrs;
std::vector<size_t> initial_ptrs_allocated_sizes;
@@ -441,7 +499,11 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
}
void TestLog2FloorNonZeroSlow() {
- GPUBFCAllocator a(CudaGpuId(0), 1 /* total_memory */, "GPU_0_bfc");
+ PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator a(sub_allocator, 1 /* total_memory */, "GPU_0_bfc");
EXPECT_EQ(-1, a.Log2FloorNonZeroSlow(0));
EXPECT_EQ(0, a.Log2FloorNonZeroSlow(1));
EXPECT_EQ(1, a.Log2FloorNonZeroSlow(2));
@@ -450,6 +512,56 @@ class GPUBFCAllocatorPrivateMethodsTest : public ::testing::Test {
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1024));
EXPECT_EQ(10, a.Log2FloorNonZeroSlow(1025));
}
+
+ void TestForceAllowGrowth() {
+ PlatformGpuId platform_gpu_id(0);
+ GPUOptions options;
+ // Unset flag value uses provided option.
+ unsetenv("TF_FORCE_GPU_ALLOW_GROWTH");
+ options.set_allow_growth(true);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unset_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_0_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unset_flag_allocator.curr_region_allocation_bytes_);
+
+ // Unparseable flag value uses provided option.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "unparseable", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator unparsable_flag_allocator(sub_allocator, 1LL << 31, options,
+ "GPU_1_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ unparsable_flag_allocator.curr_region_allocation_bytes_);
+
+ // Max of 2GiB total memory. Env variable set forces allow_growth, which
+ // does an initial allocation of 1MiB.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "true", 1);
+ options.set_allow_growth(false);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_2_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(size_t{1048576}),
+ force_allow_growth_allocator.curr_region_allocation_bytes_);
+
+ // If env variable forces allow_growth disabled, all available memory is
+ // allocated.
+ setenv("TF_FORCE_GPU_ALLOW_GROWTH", "false", 1);
+ options.set_allow_growth(true);
+ sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUBFCAllocator force_no_allow_growth_allocator(sub_allocator, 1LL << 31,
+ options, "GPU_3_bfc");
+ EXPECT_EQ(GPUBFCAllocator::RoundedBytes(1LL << 31),
+ force_no_allow_growth_allocator.curr_region_allocation_bytes_);
+ }
};
TEST_F(GPUBFCAllocatorPrivateMethodsTest, BinDebugInfo) { TestBinDebugInfo(); }
@@ -458,6 +570,10 @@ TEST_F(GPUBFCAllocatorPrivateMethodsTest, Log2FloorNonZeroSlow) {
TestLog2FloorNonZeroSlow();
}
+TEST_F(GPUBFCAllocatorPrivateMethodsTest, ForceAllowGrowth) {
+ TestForceAllowGrowth();
+}
+
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 934a57a5fb..d85ca8892f 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -27,10 +27,11 @@ limitations under the License.
namespace tensorflow {
-GPUcudaMallocAllocator::GPUcudaMallocAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUcudaMallocAllocator::GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; }
@@ -60,14 +61,6 @@ void GPUcudaMallocAllocator::DeallocateRaw(void* ptr) {
#endif // GOOGLE_CUDA
}
-void GPUcudaMallocAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUcudaMallocAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUcudaMallocAllocator::TracksAllocationSizes() { return false; }
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 5043fac797..8df3724bc4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDA_MALLOC_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDA_MALLOC_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
#include <memory>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -29,20 +29,18 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUcudaMallocAllocator : public VisitableAllocator {
+class GPUcudaMallocAllocator : public Allocator {
public:
- explicit GPUcudaMallocAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUcudaMallocAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUcudaMallocAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
@@ -51,4 +49,4 @@ class GPUcudaMallocAllocator : public VisitableAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_CUDAMALLOC_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index e4c834b30d..989ddbe4af 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -73,10 +73,11 @@ void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) {
// -----------------------------------------------------------------------------
// GPUDebugAllocator
// -----------------------------------------------------------------------------
-GPUDebugAllocator::GPUDebugAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUDebugAllocator::GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUDebugAllocator::~GPUDebugAllocator() { delete base_allocator_; }
@@ -111,14 +112,6 @@ void GPUDebugAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUDebugAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUDebugAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
bool GPUDebugAllocator::TracksAllocationSizes() { return true; }
size_t GPUDebugAllocator::RequestedSize(const void* ptr) {
@@ -158,10 +151,11 @@ bool GPUDebugAllocator::CheckFooter(void* ptr) {
// -----------------------------------------------------------------------------
// GPUNanResetAllocator
// -----------------------------------------------------------------------------
-GPUNanResetAllocator::GPUNanResetAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id)
+GPUNanResetAllocator::GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id)
: base_allocator_(allocator) {
- stream_exec_ = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ stream_exec_ =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
}
GPUNanResetAllocator::~GPUNanResetAllocator() { delete base_allocator_; }
@@ -200,14 +194,6 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
base_allocator_->DeallocateRaw(ptr);
}
-void GPUNanResetAllocator::AddAllocVisitor(Visitor visitor) {
- return base_allocator_->AddAllocVisitor(visitor);
-}
-
-void GPUNanResetAllocator::AddFreeVisitor(Visitor visitor) {
- return base_allocator_->AddFreeVisitor(visitor);
-}
-
size_t GPUNanResetAllocator::RequestedSize(const void* ptr) {
return base_allocator_->RequestedSize(ptr);
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index c49ec2a566..17757a106c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -13,15 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
#include <memory>
#include <string>
#include <unordered_map>
#include "tensorflow/core/common_runtime/gpu/gpu_id.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
@@ -31,16 +31,14 @@ namespace tensorflow {
// An allocator that wraps a GPU allocator and adds debugging
// functionality that verifies that users do not write outside their
// allocated memory.
-class GPUDebugAllocator : public VisitableAllocator {
+class GPUDebugAllocator : public Allocator {
public:
- explicit GPUDebugAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUDebugAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUDebugAllocator() override;
string Name() override { return "gpu_debug"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
bool TracksAllocationSizes() override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
@@ -53,7 +51,7 @@ class GPUDebugAllocator : public VisitableAllocator {
bool CheckFooter(void* ptr);
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
@@ -63,23 +61,21 @@ class GPUDebugAllocator : public VisitableAllocator {
// An allocator that wraps a GPU allocator and resets the memory on
// allocation and free to 'NaN', helping to identify cases where the
// user forgets to initialize the memory.
-class GPUNanResetAllocator : public VisitableAllocator {
+class GPUNanResetAllocator : public Allocator {
public:
- explicit GPUNanResetAllocator(VisitableAllocator* allocator,
- CudaGpuId cuda_gpu_id);
+ explicit GPUNanResetAllocator(Allocator* allocator,
+ PlatformGpuId platform_gpu_id);
~GPUNanResetAllocator() override;
string Name() override { return "gpu_nan_reset"; }
void* AllocateRaw(size_t alignment, size_t num_bytes) override;
void DeallocateRaw(void* ptr) override;
- void AddAllocVisitor(Visitor visitor) override;
- void AddFreeVisitor(Visitor visitor) override;
size_t RequestedSize(const void* ptr) override;
size_t AllocatedSize(const void* ptr) override;
void GetStats(AllocatorStats* stats) override;
void ClearStats() override;
private:
- VisitableAllocator* base_allocator_ = nullptr; // owned
+ Allocator* base_allocator_ = nullptr; // owned
se::StreamExecutor* stream_exec_; // Not owned.
@@ -88,4 +84,4 @@ class GPUNanResetAllocator : public VisitableAllocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_DEBUG_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index 236a0afa0b..aca08a7e33 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -34,10 +34,14 @@ namespace tensorflow {
namespace {
TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
for (int s : {8}) {
std::vector<int64> cpu_array(s);
@@ -58,11 +62,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
for (int s : {8, 211}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -91,11 +98,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
for (int s : {8, 22}) {
EXPECT_DEATH(
{
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
auto stream_exec =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
@@ -121,10 +131,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
}
TEST(GPUDebugAllocatorTest, ResetToNan) {
- const CudaGpuId cuda_gpu_id(0);
- GPUNanResetAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUNanResetAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -161,13 +175,17 @@ TEST(GPUDebugAllocatorTest, ResetToNan) {
}
TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
// NaN reset must be the outer-most allocator.
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
- auto stream_exec = GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id),
+ platform_gpu_id);
+ auto stream_exec =
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
std::vector<float> cpu_array(1024);
std::vector<float> cpu_array_result(1024);
@@ -204,18 +222,24 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
}
TEST(GPUDebugAllocatorTest, TracksSizes) {
- const CudaGpuId cuda_gpu_id(0);
- GPUDebugAllocator a(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
+ GPUDebugAllocator a(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id);
EXPECT_EQ(true, a.TracksAllocationSizes());
}
TEST(GPUDebugAllocatorTest, AllocatedVsRequested) {
- const CudaGpuId cuda_gpu_id(0);
+ const PlatformGpuId platform_gpu_id(0);
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id, false /*use_unified_memory*/, {}, {});
GPUNanResetAllocator a(
- new GPUDebugAllocator(new GPUBFCAllocator(cuda_gpu_id, 1 << 30, ""),
- cuda_gpu_id),
- cuda_gpu_id);
+ new GPUDebugAllocator(new GPUBFCAllocator(sub_allocator, 1 << 30, ""),
+ platform_gpu_id),
+ platform_gpu_id);
float* t1 = a.Allocate<float>(1);
EXPECT_EQ(4, a.RequestedSize(t1));
EXPECT_EQ(256, a.AllocatedSize(t1));
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 3292ef2f62..d8ebdeff5d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -41,7 +41,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/common_runtime/local_device.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -105,9 +104,9 @@ class EigenCudaStreamDevice : public ::Eigen::StreamInterface {
reinterpret_cast<unsigned int*>(scratch + Eigen::kCudaScratchSize);
stream_ = cuda_stream;
allocator_ = alloc;
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- device_prop_ = &Eigen::m_deviceProperties[cuda_gpu_id.value()];
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ device_prop_ = &Eigen::m_deviceProperties[platform_gpu_id.value()];
}
const cudaStream_t& stream() const override { return *stream_; }
@@ -285,6 +284,38 @@ BaseGPUDevice::~BaseGPUDevice() {
for (auto ctx : device_contexts_) ctx->Unref();
}
+// This should be idempotent if already initialized.
+Status BaseGPUDevice::InitScratchBuffers() {
+ mutex_lock l(scratch_init_mutex_);
+ if (scratch_.size() < max_streams_) {
+ for (int i = 0; i < max_streams_; i++) {
+ DCHECK(streams_[i]);
+ if (scratch_.size() > i && scratch_[i]) continue;
+ size_t scratch_buffer_size =
+ Eigen::kCudaScratchSize + sizeof(unsigned int);
+ void* scratch_buffer = gpu_allocator_->AllocateRaw(
+ Allocator::kAllocatorAlignment, scratch_buffer_size);
+ if (scratch_buffer == nullptr) {
+ return errors::FailedPrecondition(
+ "Failed to allocate scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ se::DeviceMemory<char> mem(
+ se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
+
+ bool ok = executor_->SynchronousMemZero(
+ &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
+ if (!ok) {
+ return errors::FailedPrecondition(
+ "Failed to memcopy into scratch buffer for device ",
+ tf_gpu_id_.value());
+ }
+ scratch_.push_back(static_cast<char*>(scratch_buffer));
+ }
+ }
+ return Status::OK();
+}
+
Status BaseGPUDevice::Init(const SessionOptions& options) {
auto executor_status = GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id_);
if (!executor_status.status().ok()) {
@@ -303,27 +334,6 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
for (int i = 0; i < max_streams_; i++) {
streams_.push_back(StreamGroupFactory::Global().GetOrCreate(
tf_gpu_id_, i, executor_, options.config.gpu_options()));
-
- size_t scratch_buffer_size = Eigen::kCudaScratchSize + sizeof(unsigned int);
- void* scratch_buffer = gpu_allocator_->AllocateRaw(
- Allocator::kAllocatorAlignment, scratch_buffer_size);
- if (scratch_buffer == nullptr) {
- return errors::FailedPrecondition(
- "Failed to allocate scratch buffer for device ", tf_gpu_id_.value());
- }
- scratch_.push_back(static_cast<char*>(scratch_buffer));
-
- se::DeviceMemory<char> mem(
- se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
-
- bool ok = executor_->SynchronousMemZero(
- &mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
- if (!ok) {
- return errors::FailedPrecondition(
- "Failed to memcopy into scratch buffer for device ",
- tf_gpu_id_.value());
- }
-
device_contexts_.push_back(new GPUDeviceContext(
i, streams_.back()->compute, streams_.back()->host_to_device,
streams_.back()->device_to_host, streams_.back()->device_to_device));
@@ -332,9 +342,10 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
gpu_device_info_->stream = streams_[0]->compute;
gpu_device_info_->default_context = device_contexts_[0];
gpu_device_info_->event_mgr = em_.get();
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- gpu_device_info_->gpu_id = cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ gpu_device_info_->gpu_id = platform_gpu_id.value();
set_tensorflow_gpu_device_info(gpu_device_info_);
// Whether and how the GPU device uses its own threadpool.
@@ -423,9 +434,6 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph,
}
void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
- tracing::ScopedRegion region(tracing::EventCategory::kCompute,
- op_kernel->name());
-
// NOTE(tucker): We need to discriminate between Eigen GPU
// operations and all others. If an operation is Eigen
// implemented (or otherwise tries to launch a cuda kernel
@@ -439,8 +447,6 @@ void BaseGPUDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
context->SetStatus(errors::Internal(
"Invalid synchronous 'Compute' on GPU for '_Recv' op"));
} else {
- tracing::ScopedAnnotation annotation(op_kernel->name(),
- op_kernel->type_string());
ComputeHelper(op_kernel, context);
}
}
@@ -690,9 +696,9 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
Eigen::GpuDevice device_;
};
-// Parse 'visible_device_list' into a list of CUDA GPU ids.
+// Parse 'visible_device_list' into a list of platform GPU ids.
Status ParseVisibleDeviceList(const string& visible_device_list,
- std::vector<CudaGpuId>* visible_gpu_order) {
+ std::vector<PlatformGpuId>* visible_gpu_order) {
visible_gpu_order->clear();
se::Platform* gpu_manager = GPUMachineManager();
@@ -707,26 +713,28 @@ Status ParseVisibleDeviceList(const string& visible_device_list,
} else {
const std::vector<string> order_str =
str_util::Split(visible_device_list, ',');
- for (const string& cuda_gpu_id_str : order_str) {
- int32 cuda_gpu_id;
- if (!strings::safe_strto32(cuda_gpu_id_str, &cuda_gpu_id)) {
+ for (const string& platform_gpu_id_str : order_str) {
+ int32 platform_gpu_id;
+ if (!strings::safe_strto32(platform_gpu_id_str, &platform_gpu_id)) {
return errors::InvalidArgument(
"Could not parse entry in 'visible_device_list': '",
- cuda_gpu_id_str, "'. visible_device_list = ", visible_device_list);
+ platform_gpu_id_str, "'. visible_device_list = ",
+ visible_device_list);
}
- if (cuda_gpu_id < 0 || cuda_gpu_id >= gpu_manager->VisibleDeviceCount()) {
+ if (platform_gpu_id < 0 ||
+ platform_gpu_id >= gpu_manager->VisibleDeviceCount()) {
return errors::InvalidArgument(
- "'visible_device_list' listed an invalid GPU id '", cuda_gpu_id,
+ "'visible_device_list' listed an invalid GPU id '", platform_gpu_id,
"' but visible device count is ",
gpu_manager->VisibleDeviceCount());
}
- visible_gpu_order->push_back(CudaGpuId(cuda_gpu_id));
+ visible_gpu_order->push_back(PlatformGpuId(platform_gpu_id));
}
}
// Validate no repeats.
- std::set<CudaGpuId> visible_device_set(visible_gpu_order->begin(),
- visible_gpu_order->end());
+ std::set<PlatformGpuId> visible_device_set(visible_gpu_order->begin(),
+ visible_gpu_order->end());
if (visible_device_set.size() != visible_gpu_order->size()) {
return errors::InvalidArgument(
"visible_device_list contained a duplicate entry: ",
@@ -737,8 +745,8 @@ Status ParseVisibleDeviceList(const string& visible_device_list,
Status VerifyVirtualDeviceSettings(
const size_t num_gpus_to_use, const GPUOptions& gpu_options,
- const std::vector<CudaGpuId>& visible_gpu_order,
- const std::vector<CudaGpuId>& valid_cuda_gpu_ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& valid_platform_gpu_ids) {
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
CHECK(!virtual_devices.empty());
if (gpu_options.per_process_gpu_memory_fraction() > 0) {
@@ -760,11 +768,11 @@ Status VerifyVirtualDeviceSettings(
" #GPUs in visible_device_list: ", visible_gpu_order.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
- if (valid_cuda_gpu_ids.size() != virtual_devices.size()) {
+ if (valid_platform_gpu_ids.size() != virtual_devices.size()) {
return errors::Unknown(
"The number of valid GPUs doesn't match the number of elements in "
"the virtual_devices list.",
- " #valid GPUs: ", valid_cuda_gpu_ids.size(),
+ " #valid GPUs: ", valid_platform_gpu_ids.size(),
" virtual_devices.size(): ", virtual_devices.size());
}
return Status::OK();
@@ -806,18 +814,18 @@ int64 MinSystemMemory(int64 available_memory) {
}
// Get the memory limit for the virtual device being created on GPU with
-// 'cuda_gpu_id', when that virtual device is the only virtual device being
+// 'platform_gpu_id', when that virtual device is the only virtual device being
// created on that GPU.
Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
- CudaGpuId cuda_gpu_id,
+ PlatformGpuId platform_gpu_id,
int64* memory_limit) {
int64 total_memory = 0;
int64 available_memory = 0;
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) {
return errors::Unknown("Failed to query available memory for GPU ",
- cuda_gpu_id.value());
+ platform_gpu_id.value());
}
int64 allocated_memory = 0;
@@ -867,10 +875,11 @@ PerOpGpuDevice* BaseGPUDevice::MakeGpuDevice() {
return new ConcretePerOpGpuDevice();
}
-void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
- PerOpGpuDevice* device,
- DeviceContext* dc,
- Allocator* allocator) {
+Status BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
+ PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) {
+ TF_RETURN_IF_ERROR(InitScratchBuffers());
if (dc) {
const GPUDeviceContext* gpu_dc = static_cast<GPUDeviceContext*>(dc);
const int stream_id = gpu_dc->stream_id();
@@ -881,6 +890,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
} else {
ReinitializeDevice(context, device, 0, allocator);
}
+ return Status::OK();
}
Allocator* BaseGPUDevice::GetScopedAllocator(AllocatorAttributes attr,
@@ -916,17 +926,22 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
num_gpus_to_use = iter->second;
}
const auto& gpu_options = options.config.gpu_options();
- std::vector<CudaGpuId> visible_gpu_order;
- TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
- &visible_gpu_order));
-
- std::vector<CudaGpuId> valid_cuda_gpu_ids;
- TF_RETURN_IF_ERROR(GetValidDeviceIds(visible_gpu_order, &valid_cuda_gpu_ids));
- if (num_gpus_to_use > valid_cuda_gpu_ids.size()) {
- num_gpus_to_use = valid_cuda_gpu_ids.size();
- }
+ std::vector<PlatformGpuId> visible_gpu_order;
+ std::vector<PlatformGpuId> valid_platform_gpu_ids;
// If we aren't going to use any GPUs, don't initialize them.
- if (num_gpus_to_use > 0 && !valid_cuda_gpu_ids.empty()) {
+ // We don't want to call ParseVisibleDeviceList if num_gpus_to_use is 0,
+ // because it treats an empty gpu_options.visible_device_list as 'all GPUs are
+ // visible'.
+ if (num_gpus_to_use > 0) {
+ TF_RETURN_IF_ERROR(ParseVisibleDeviceList(gpu_options.visible_device_list(),
+ &visible_gpu_order));
+ TF_RETURN_IF_ERROR(
+ GetValidDeviceIds(visible_gpu_order, &valid_platform_gpu_ids));
+ }
+ if (num_gpus_to_use > valid_platform_gpu_ids.size()) {
+ num_gpus_to_use = valid_platform_gpu_ids.size();
+ }
+ if (!valid_platform_gpu_ids.empty()) {
// Save the original device.
int original_device = 0;
cudaError_t err = cudaGetDevice(&original_device);
@@ -936,17 +951,18 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
}
// Force to implicitly initialize CUDA runtime on each valid GPU before
// CreateGPUDevice().
- for (CudaGpuId cuda_gpu_id : valid_cuda_gpu_ids) {
- err = cudaSetDevice(cuda_gpu_id.value());
+ for (PlatformGpuId platform_gpu_id : valid_platform_gpu_ids) {
+ err = cudaSetDevice(platform_gpu_id.value());
if (err != cudaSuccess) {
- return errors::Internal("cudaSetDevice() on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("cudaSetDevice() on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
err = cudaFree(nullptr);
if (err != cudaSuccess) {
- return errors::Internal(
- "CUDA runtime implicit initialization on GPU:", cuda_gpu_id.value(),
- " failed. Status: ", cudaGetErrorString(err));
+ return errors::Internal("CUDA runtime implicit initialization on GPU:",
+ platform_gpu_id.value(), " failed. Status: ",
+ cudaGetErrorString(err));
}
}
// Reset to the original device.
@@ -972,10 +988,10 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
LOG(INFO) << line_buf;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
line_buf = strings::StrCat(visible_gpu_order[i].value(), ": ");
- CudaGpuId cuda_id_i = visible_gpu_order[i];
+ PlatformGpuId gpu_id_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- CudaGpuId cuda_id_j = visible_gpu_order[j];
- if (im.directed_links.find({cuda_id_i, cuda_id_j}) !=
+ PlatformGpuId gpu_id_j = visible_gpu_order[j];
+ if (im.directed_links.find({gpu_id_i, gpu_id_j}) !=
im.directed_links.end()) {
line_buf.append("Y ");
} else {
@@ -988,22 +1004,23 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
const auto& virtual_devices = gpu_options.experimental().virtual_devices();
if (!virtual_devices.empty()) {
- TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(
- num_gpus_to_use, gpu_options, visible_gpu_order, valid_cuda_gpu_ids));
+ TF_RETURN_IF_ERROR(VerifyVirtualDeviceSettings(num_gpus_to_use, gpu_options,
+ visible_gpu_order,
+ valid_platform_gpu_ids));
// We've verified that num_gpus_to_use >= virtual_devices.size().
num_gpus_to_use = virtual_devices.size();
CHECK(gpu_options.visible_device_list().empty() ||
- valid_cuda_gpu_ids == visible_gpu_order);
+ valid_platform_gpu_ids == visible_gpu_order);
}
int next_tf_gpu_id = 0;
std::vector<int64> memory_limit_bytes;
for (int i = 0; i < num_gpus_to_use; ++i) {
- const CudaGpuId cuda_gpu_id = valid_cuda_gpu_ids[i];
+ const PlatformGpuId platform_gpu_id = valid_platform_gpu_ids[i];
if (virtual_devices.empty() ||
virtual_devices.Get(i).memory_limit_mb_size() == 0) {
int64 single_virtual_device_memory_limit = 0;
TF_RETURN_IF_ERROR(SingleVirtualDeviceMemoryLimit(
- gpu_options, cuda_gpu_id, &single_virtual_device_memory_limit));
+ gpu_options, platform_gpu_id, &single_virtual_device_memory_limit));
memory_limit_bytes.push_back(single_virtual_device_memory_limit);
} else {
const auto& memory_limit_mb = virtual_devices.Get(i).memory_limit_mb();
@@ -1016,7 +1033,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
TfGpuId tf_gpu_id(next_tf_gpu_id);
++next_tf_gpu_id;
TF_RETURN_IF_ERROR(
- GpuIdManager::InsertTfCudaGpuIdPair(tf_gpu_id, cuda_gpu_id));
+ GpuIdManager::InsertTfPlatformGpuIdPair(tf_gpu_id, platform_gpu_id));
}
}
const int num_tf_gpus = next_tf_gpu_id;
@@ -1041,7 +1058,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
return Status::OK();
}
-static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
+static string GetShortDeviceDescription(PlatformGpuId platform_gpu_id,
const se::DeviceDescription& desc) {
int cc_major;
int cc_minor;
@@ -1050,9 +1067,8 @@ static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
cc_minor = 0;
}
// LINT.IfChange
- return strings::StrCat("device: ", cuda_gpu_id.value(),
- ", name: ", desc.name(),
- ", pci bus id: ", desc.pci_bus_id(),
+ return strings::StrCat("device: ", platform_gpu_id.value(), ", name: ",
+ desc.name(), ", pci bus id: ", desc.pci_bus_id(),
", compute capability: ", cc_major, ".", cc_minor);
// LINT.ThenChange(//tensorflow/python/platform/test.py)
}
@@ -1067,12 +1083,13 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
const string device_name =
strings::StrCat(name_prefix, "/device:GPU:", tf_gpu_id.value());
GpuIdUtil::CheckValidTfGpuId(tf_gpu_id);
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
int numa_node = dev_locality.numa_node();
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
GPUProcessState* process_state = GPUProcessState::singleton();
Allocator* gpu_allocator = process_state->GetGPUAllocator(
@@ -1093,11 +1110,11 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
// TODO(laigd): report error if memory_limit doesn't match stats.bytes_limit.
BaseGPUDevice* gpu_device = CreateGPUDevice(
options, device_name, static_cast<Bytes>(stats.bytes_limit), dev_locality,
- tf_gpu_id, GetShortDeviceDescription(cuda_gpu_id, desc), gpu_allocator,
- ProcessState::singleton()->GetCPUAllocator(numa_node));
+ tf_gpu_id, GetShortDeviceDescription(platform_gpu_id, desc),
+ gpu_allocator, ProcessState::singleton()->GetCPUAllocator(numa_node));
LOG(INFO) << "Created TensorFlow device (" << device_name << " with "
<< (stats.bytes_limit >> 20) << " MB memory) -> physical GPU ("
- << GetShortDeviceDescription(cuda_gpu_id, desc) << ")";
+ << GetShortDeviceDescription(platform_gpu_id, desc) << ")";
TF_RETURN_IF_ERROR(gpu_device->Init(options));
devices->push_back(gpu_device);
@@ -1105,18 +1122,21 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
}
namespace {
-std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>>
+std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>>
GetPeerAccessMap(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
- std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map(
- new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>);
- for (CudaGpuId cuda_gpu_i : visible_gpu_order) {
- for (CudaGpuId cuda_gpu_j : visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
+ std::unique_ptr<std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>> map(
+ new std::map<std::pair<PlatformGpuId, PlatformGpuId>, bool>);
+ for (PlatformGpuId platform_gpu_i : visible_gpu_order) {
+ for (PlatformGpuId platform_gpu_j : visible_gpu_order) {
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
- (*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to);
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
+ (*map)[{platform_gpu_i, platform_gpu_j}] =
+ from->CanEnablePeerAccessTo(to);
}
}
@@ -1126,19 +1146,19 @@ GetPeerAccessMap(se::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager,
- std::vector<InterconnectMap>* maps) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ se::Platform* gpu_manager, std::vector<InterconnectMap>* maps) {
// The default interconnect map is obtained from the StreamExecutor.
auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
maps->resize(1);
InterconnectMap& imap = maps->at(0);
imap.name = "StreamExecutor";
imap.strength = InterconnectMap::kStreamExecutorStrength;
- for (CudaGpuId cuda_id_i : visible_gpu_order) {
- for (CudaGpuId cuda_id_j : visible_gpu_order) {
- if (cuda_id_i == cuda_id_j) continue;
- if ((*access_map)[{cuda_id_i, cuda_id_j}]) {
- imap.directed_links.insert({cuda_id_i, cuda_id_j});
+ for (PlatformGpuId gpu_id_i : visible_gpu_order) {
+ for (PlatformGpuId gpu_id_j : visible_gpu_order) {
+ if (gpu_id_i == gpu_id_j) continue;
+ if ((*access_map)[{gpu_id_i, gpu_id_j}]) {
+ imap.directed_links.insert({gpu_id_i, gpu_id_j});
}
}
}
@@ -1153,13 +1173,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
all_tf_gpu_ids.push_back(TfGpuId(i));
}
for (TfGpuId tf_gpu_id : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
// Get GPU bus_id from its reported NUMA affinity. Because GPUs are
// virtualized in some environments, we can't just use the GPU id.
// NUMA locales are indexed from 0, buses are indexed from 1.
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie();
const se::DeviceDescription& desc = se->GetDeviceDescription();
int numa_node = desc.numa_node();
if (numa_node < 0) {
@@ -1169,7 +1190,8 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// may run into trouble later with data transfer operations. The
// trouble may manifest as slower than expected performance, or
// outright failures.
- LOG(INFO) << "Could not identify NUMA node of CUDA gpu id " << cuda_gpu_id
+ LOG(INFO) << "Could not identify NUMA node of platform GPU id "
+ << platform_gpu_id
<< ", defaulting to 0. Your kernel may not have been built "
<< "with NUMA support.";
numa_node = 0;
@@ -1182,10 +1204,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
LocalLinks* links = dev_locality.mutable_links();
for (const InterconnectMap& imap : interconnects) {
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (imap.directed_links.find({cuda_gpu_id, cuda_gpu_dst}) !=
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (imap.directed_links.find({platform_gpu_id, platform_gpu_dst}) !=
imap.directed_links.end()) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
@@ -1199,10 +1221,10 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// add high strength links to the others.
for (TfGpuId tf_gpu_dst : all_tf_gpu_ids) {
if (tf_gpu_id == tf_gpu_dst) continue;
- CudaGpuId cuda_gpu_dst;
+ PlatformGpuId platform_gpu_dst;
TF_RETURN_IF_ERROR(
- GpuIdManager::TfToCudaGpuId(tf_gpu_dst, &cuda_gpu_dst));
- if (cuda_gpu_id == cuda_gpu_dst) {
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_dst, &platform_gpu_dst));
+ if (platform_gpu_id == platform_gpu_dst) {
InterconnectLink* ilink = links->add_link();
ilink->set_device_id(tf_gpu_dst.value());
ilink->set_type("SAME_DEVICE");
@@ -1211,9 +1233,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
}
(*localities)[tf_gpu_id] = dev_locality;
- VLOG(1) << "GPUDevice CudaGpuId " << cuda_gpu_id << " TfGpuId " << tf_gpu_id
- << " on bus " << dev_locality.bus_id() << " numa: " << numa_node
- << " pci: " << desc.pci_bus_id()
+ VLOG(1) << "GPUDevice PlatformGpuId " << platform_gpu_id << " TfGpuId "
+ << tf_gpu_id << " on bus " << dev_locality.bus_id()
+ << " numa: " << numa_node << " pci: " << desc.pci_bus_id()
<< " DeviceLocality: " << dev_locality.DebugString();
}
return Status::OK();
@@ -1221,14 +1243,14 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
static int GetDefaultMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
static const int kDefaultMinGPUMultiprocessorCount = 8;
// Find the highest multi-processor count across all visible GPUs.
int max_count = -1;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_order[i]);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_order[i]);
if (!exec_status.ok()) {
continue;
}
@@ -1247,7 +1269,7 @@ static int GetDefaultMinGPUMultiprocessorCount(
static int GetMinGPUMultiprocessorCount(
se::Platform* gpu_manager,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
if (tf_min_gpu_core_count == nullptr ||
@@ -1325,18 +1347,20 @@ std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
}
Status EnablePeerAccess(se::Platform* platform,
- const std::vector<CudaGpuId>& visible_gpu_order) {
+ const std::vector<PlatformGpuId>& visible_gpu_order) {
int possible_peer_count = 0;
int enabled_peer_count = 0;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_i = visible_gpu_order[i];
+ const PlatformGpuId platform_gpu_i = visible_gpu_order[i];
for (int j = 0; j < visible_gpu_order.size(); ++j) {
- const CudaGpuId cuda_gpu_j = visible_gpu_order[j];
+ const PlatformGpuId platform_gpu_j = visible_gpu_order[j];
// We have already validated that ExecutorForDevice() calls return OK.
se::StreamExecutor* from =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_i)
+ .ValueOrDie();
se::StreamExecutor* to =
- GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(platform, platform_gpu_j)
+ .ValueOrDie();
if (from->CanEnablePeerAccessTo(to)) {
++possible_peer_count;
@@ -1344,7 +1368,8 @@ Status EnablePeerAccess(se::Platform* platform,
if (!status.ok()) {
LOG(WARNING)
<< "Unable to enable peer access between device ordinals "
- << cuda_gpu_i << " and " << cuda_gpu_j << ", status: " << status;
+ << platform_gpu_i << " and " << platform_gpu_j
+ << ", status: " << status;
} else {
++enabled_peer_count;
}
@@ -1367,22 +1392,23 @@ Status EnablePeerAccess(se::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetValidDeviceIds(
- const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids) {
+ const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids) {
se::Platform* gpu_manager = GPUMachineManager();
bool new_gpu_found = false;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId cuda_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
- // Only perform this once per visible cuda gpu id.
- if (visible_gpu_initialized_[cuda_gpu_id.value()]) {
+ // Only perform this once per visible platform gpu id.
+ if (visible_gpu_initialized_[visible_gpu_id.value()]) {
continue;
}
- visible_gpu_initialized_[cuda_gpu_id.value()] = true;
+ visible_gpu_initialized_[visible_gpu_id.value()] = true;
new_gpu_found = true;
- auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id);
+ auto executor =
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!executor.ok()) {
return executor.status();
}
@@ -1430,9 +1456,9 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
// Filter out devices that don't have the right capability or power.
for (int i = 0; i < visible_gpu_order.size(); ++i) {
- const CudaGpuId visible_gpu_id = visible_gpu_order[i];
+ const PlatformGpuId visible_gpu_id = visible_gpu_order[i];
auto exec_status =
- GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, visible_gpu_id);
+ GpuIdUtil::ExecutorForPlatformGpuId(gpu_manager, visible_gpu_id);
if (!exec_status.ok()) {
LOG(INFO) << "Ignoring visible gpu device " << visible_gpu_id
<< " whose executor is in invalid state: "
@@ -1481,7 +1507,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
if (!ids->empty()) {
std::vector<int> raw_ids(ids->size());
std::transform(ids->begin(), ids->end(), raw_ids.begin(),
- [](CudaGpuId id) -> int { return id.value(); });
+ [](PlatformGpuId id) -> int { return id.value(); });
LOG(INFO) << "Adding visible gpu devices: "
<< str_util::Join(raw_ids, ", ");
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index 56d03d7a8c..674e8384d5 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -65,6 +65,11 @@ class BaseGPUDevice : public LocalDevice {
// completes.
bool RequiresRecordingAccessedTensors() const override;
+ // GPU kernel execution requires us to use `tracing::ScopedAnnotation()`
+ // rather than `tracing::ScopedActivity()`, in order to relate asynchronously
+ // launched GPU kernels to the OpKernel.
+ bool TraceUsingAnnotations() const { return true; }
+
void ConsumeListOfAccessedTensors(
DeviceContext* device_context,
const TensorReferenceVector& tensor_refs) override;
@@ -86,15 +91,16 @@ class BaseGPUDevice : public LocalDevice {
// The caller owns the returned device.
PerOpGpuDevice* MakeGpuDevice() override;
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override;
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override;
- // Returns the CUDA GPU id of this device within the native driver system;
+ // Returns the platform GPU id of this device within the native driver system;
// e.g., for CUDA this is the ordinal of the GPU within the system.
int gpu_id() const {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id_, &cuda_gpu_id));
- return cuda_gpu_id.value();
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id_, &platform_gpu_id));
+ return platform_gpu_id.value();
}
// The executor that provides control for the device; e.g., for CUDA this
@@ -125,6 +131,7 @@ class BaseGPUDevice : public LocalDevice {
class StreamGroupFactory;
gtl::InlinedVector<StreamGroup*, 4> streams_;
+ mutex scratch_init_mutex_;
gtl::InlinedVector<char*, 4> scratch_;
std::vector<GPUDeviceContext*> device_contexts_;
GpuDeviceInfo* gpu_device_info_ = nullptr;
@@ -135,6 +142,9 @@ class BaseGPUDevice : public LocalDevice {
std::unique_ptr<EventMgr> em_;
std::unique_ptr<thread::ThreadPool> thread_pool_;
+ // Initialize scractch buffers used by Eigen.
+ Status InitScratchBuffers();
+
void ReinitializeDevice(OpKernelContext* context, PerOpGpuDevice* device,
int stream_id, Allocator* allocator);
@@ -168,14 +178,14 @@ class BaseGPUDeviceFactory : public DeviceFactory {
int32 strength;
static const int kSameDeviceStrength;
static const int kStreamExecutorStrength;
- std::set<std::pair<CudaGpuId, CudaGpuId>> directed_links;
+ std::set<std::pair<PlatformGpuId, PlatformGpuId>> directed_links;
};
protected:
// Populates *maps with interconnect maps for all local direct access
// pathways between GPUs.
virtual Status GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order,
+ const std::vector<PlatformGpuId>& visible_gpu_order,
se::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
struct TfGpuIdHash {
@@ -207,16 +217,16 @@ class BaseGPUDeviceFactory : public DeviceFactory {
Allocator* gpu_allocator,
Allocator* cpu_allocator) = 0;
- // Returns into 'ids' the list of valid CUDA GPU ids, in the order that
+ // Returns into 'ids' the list of valid platform GPU ids, in the order that
// they should map to TF GPU ids "/device:GPU:0", "/device:GPU:1", etc,
// based upon 'visible_gpu_order' which was generated by parsing
// GPUOptions::visible_device_list which is a comma-separated list of CUDA GPU
// ids.
- Status GetValidDeviceIds(const std::vector<CudaGpuId>& visible_gpu_order,
- std::vector<CudaGpuId>* ids);
+ Status GetValidDeviceIds(const std::vector<PlatformGpuId>& visible_gpu_order,
+ std::vector<PlatformGpuId>* ids);
- // visible_gpu_initialized_[cuda_gpu_id] is true if visible GPU cuda_gpu_id
- // has been initialized by the process.
+ // visible_gpu_initialized_[platform_gpu_id] is true if visible GPU
+ // platform_gpu_id has been initialized by the process.
std::unordered_map<int, bool> visible_gpu_initialized_;
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
index daf59f0560..36294094e9 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device_test.cc
@@ -30,18 +30,21 @@ namespace tensorflow {
namespace {
const char* kDeviceNamePrefix = "/job:localhost/replica:0/task:0";
-int64 GetTotalGPUMemory(CudaGpuId gpu_id) {
+int64 GetTotalGPUMemory(PlatformGpuId gpu_id) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
return total_memory;
}
-Status GetComputeCapability(CudaGpuId gpu_id, int* cc_major, int* cc_minor) {
+Status GetComputeCapability(PlatformGpuId gpu_id, int* cc_major,
+ int* cc_minor) {
se::StreamExecutor* se =
- GpuIdUtil::ExecutorForCudaGpuId(GPUMachineManager(), gpu_id).ValueOrDie();
+ GpuIdUtil::ExecutorForPlatformGpuId(GPUMachineManager(), gpu_id)
+ .ValueOrDie();
if (!se->GetDeviceDescription().cuda_compute_capability(cc_major, cc_minor)) {
*cc_major = 0;
*cc_minor = 0;
@@ -223,7 +226,7 @@ TEST_F(GPUDeviceTest, MultipleVirtualDevices) {
// error.
TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(CudaGpuId(0), &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(PlatformGpuId(0), &cc_major, &cc_minor));
// Exit early while running on Pascal or later GPUs.
if (cc_major >= 6) {
return;
@@ -244,10 +247,10 @@ TEST_F(GPUDeviceTest, UnifiedMemoryUnavailableOnPrePascalGpus) {
// more memory than what is available on the device.
TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
static constexpr double kGpuMemoryFraction = 1.2;
- static constexpr CudaGpuId kCudaGpuId(0);
+ static constexpr PlatformGpuId kPlatformGpuId(0);
int cc_major, cc_minor;
- TF_ASSERT_OK(GetComputeCapability(kCudaGpuId, &cc_major, &cc_minor));
+ TF_ASSERT_OK(GetComputeCapability(kPlatformGpuId, &cc_major, &cc_minor));
// Exit early if running on pre-Pascal GPUs.
if (cc_major < 6) {
LOG(INFO)
@@ -262,7 +265,7 @@ TEST_F(GPUDeviceTest, UnifiedMemoryAllocation) {
ASSERT_EQ(1, devices.size());
int64 memory_limit = devices[0]->attributes().memory_limit();
- ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kCudaGpuId) *
+ ASSERT_EQ(memory_limit, static_cast<int64>(GetTotalGPUMemory(kPlatformGpuId) *
kGpuMemoryFraction));
AllocatorAttributes allocator_attributes = AllocatorAttributes();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index f0a109cc10..2d406b676e 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
#include <deque>
#include <vector>
@@ -203,4 +203,4 @@ class EventMgr {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_EVENT_MGR_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id.h b/tensorflow/core/common_runtime/gpu/gpu_id.h
index 2a6caea296..f0d9022821 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -25,10 +25,10 @@ namespace tensorflow {
// physical machine, it can be filtered by CUDA environment variable
// CUDA_VISIBLE_DEVICES. Note that this id is not visible to Tensorflow, but
// result after filtering by CUDA_VISIBLE_DEVICES is visible to TF and is
-// called CUDA GPU id as below. See
+// called platform GPU id as below. See
// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#env-vars
// for more details.
-// - CUDA GPU id (also called *visible* GPU id in
+// - *platform* GPU id (also called *visible* GPU id in
// third_party/tensorflow/core/protobuf/config.proto): this is the id that is
// visible to Tensorflow after filtering by CUDA_VISIBLE_DEVICES, and is
// generated by the CUDA GPU driver. It starts from 0 and is used for CUDA API
@@ -39,14 +39,14 @@ namespace tensorflow {
// field of the device name "/device:GPU:<id>", and is also the identifier of
// a BaseGPUDevice. Note that the configuration allows us to create multiple
// BaseGPUDevice per GPU hardware in order to use multi CUDA streams on the
-// hardware, so the mapping between TF GPU id and CUDA GPU id is not a 1:1
+// hardware, so the mapping between TF GPU id and platform GPU id is not a 1:1
// mapping, see the example below.
//
// For example, assuming that in the machine we have GPU device with index 0, 1,
// 2 and 3 (physical GPU id). Setting "CUDA_VISIBLE_DEVICES=1,2,3" will create
-// the following mapping between CUDA GPU id and physical GPU id:
+// the following mapping between platform GPU id and physical GPU id:
//
-// CUDA GPU id -> physical GPU id
+// platform GPU id -> physical GPU id
// 0 -> 1
// 1 -> 2
// 2 -> 3
@@ -56,32 +56,32 @@ namespace tensorflow {
//
// Assuming we configure the Session to create one BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 0
//
-// Note that CUDA GPU id 1 is filtered out by GPUOptions::visible_device_list,
-// so it won't be used by the TF process.
+// Note that platform GPU id 1 is filtered out by
+// GPUOptions::visible_device_list, so it won't be used by the TF process.
//
// On the other hand, if we configure it to create 2 BaseGPUDevice per GPU
// hardware, then setting GPUOptions::visible_device_list to "2,0" will create
-// the following mappting between TF GPU id and CUDA GPU id:
+// the following mappting between TF GPU id and platform GPU id:
//
-// TF GPU id -> CUDA GPU ID
+// TF GPU id -> platform GPU ID
// 0 (i.e. /device:GPU:0) -> 2
// 1 (i.e. /device:GPU:1) -> 2
// 2 (i.e. /device:GPU:2) -> 0
// 3 (i.e. /device:GPU:3) -> 0
//
-// We create strong-typed integer classes for both TF GPU id and CUDA GPU id to
-// minimize programming errors and improve code readability. Except for the
+// We create strong-typed integer classes for both TF GPU id and platform GPU id
+// to minimize programming errors and improve code readability. Except for the
// StreamExecutor interface (as we don't change its API), whenever we need a
-// TF GPU id (or CUDA GPU id) we should use TfGpuId (or CudaGpuId) instead of a
-// raw integer.
+// TF GPU id (or platform GPU id) we should use TfGpuId (or PlatformGpuId)
+// instead of a raw integer.
TF_LIB_GTL_DEFINE_INT_TYPE(TfGpuId, int32);
-TF_LIB_GTL_DEFINE_INT_TYPE(CudaGpuId, int32);
+TF_LIB_GTL_DEFINE_INT_TYPE(PlatformGpuId, int32);
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
index b5099dc8ef..2b40730119 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.cc
@@ -26,26 +26,27 @@ limitations under the License.
namespace tensorflow {
namespace {
-// Manages the map between TfGpuId and CUDA GPU id.
-class TfToCudaGpuIdMap {
+// Manages the map between TfGpuId and platform GPU id.
+class TfToPlatformGpuIdMap {
public:
- static TfToCudaGpuIdMap* singleton() {
- static auto* id_map = new TfToCudaGpuIdMap;
+ static TfToPlatformGpuIdMap* singleton() {
+ static auto* id_map = new TfToPlatformGpuIdMap;
return id_map;
}
- Status Insert(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id) LOCKS_EXCLUDED(mu_) {
+ Status Insert(TfGpuId tf_gpu_id, PlatformGpuId platform_gpu_id)
+ LOCKS_EXCLUDED(mu_) {
std::pair<IdMapType::iterator, bool> result;
{
mutex_lock lock(mu_);
- result = id_map_.insert({tf_gpu_id.value(), cuda_gpu_id.value()});
+ result = id_map_.insert({tf_gpu_id.value(), platform_gpu_id.value()});
}
- if (!result.second && cuda_gpu_id.value() != result.first->second) {
+ if (!result.second && platform_gpu_id.value() != result.first->second) {
return errors::AlreadyExists(
"TensorFlow device (GPU:", tf_gpu_id.value(),
") is being mapped to "
"multiple CUDA devices (",
- cuda_gpu_id.value(), " now, and ", result.first->second,
+ platform_gpu_id.value(), " now, and ", result.first->second,
" previously), which is not supported. "
"This may be the result of providing different GPU configurations "
"(ConfigProto.gpu_options, for example different visible_device_list)"
@@ -56,17 +57,17 @@ class TfToCudaGpuIdMap {
return Status::OK();
}
- bool Find(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) const
+ bool Find(TfGpuId tf_gpu_id, PlatformGpuId* platform_gpu_id) const
LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
auto result = id_map_.find(tf_gpu_id.value());
if (result == id_map_.end()) return false;
- *cuda_gpu_id = result->second;
+ *platform_gpu_id = result->second;
return true;
}
private:
- TfToCudaGpuIdMap() = default;
+ TfToPlatformGpuIdMap() = default;
void TestOnlyReset() LOCKS_EXCLUDED(mu_) {
mutex_lock lock(mu_);
@@ -78,17 +79,18 @@ class TfToCudaGpuIdMap {
IdMapType id_map_ GUARDED_BY(mu_);
friend class ::tensorflow::GpuIdManager;
- TF_DISALLOW_COPY_AND_ASSIGN(TfToCudaGpuIdMap);
+ TF_DISALLOW_COPY_AND_ASSIGN(TfToPlatformGpuIdMap);
};
} // namespace
-Status GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id,
- CudaGpuId cuda_gpu_id) {
- return TfToCudaGpuIdMap::singleton()->Insert(tf_gpu_id, cuda_gpu_id);
+Status GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id) {
+ return TfToPlatformGpuIdMap::singleton()->Insert(tf_gpu_id, platform_gpu_id);
}
-Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
- if (TfToCudaGpuIdMap::singleton()->Find(tf_gpu_id, cuda_gpu_id)) {
+Status GpuIdManager::TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id) {
+ if (TfToPlatformGpuIdMap::singleton()->Find(tf_gpu_id, platform_gpu_id)) {
return Status::OK();
}
return errors::NotFound("TensorFlow device GPU:", tf_gpu_id.value(),
@@ -96,7 +98,7 @@ Status GpuIdManager::TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id) {
}
void GpuIdManager::TestOnlyReset() {
- TfToCudaGpuIdMap::singleton()->TestOnlyReset();
+ TfToPlatformGpuIdMap::singleton()->TestOnlyReset();
}
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
index 491d92ccdd..62df4310c4 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager.h
@@ -21,15 +21,17 @@ limitations under the License.
namespace tensorflow {
-// Class that maintains a map from TfGpuId to CudaGpuId, and manages the
+// Class that maintains a map from TfGpuId to PlatformGpuId, and manages the
// translation between them.
class GpuIdManager {
public:
- // Adds a mapping from tf_gpu_id to cuda_gpu_id.
- static Status InsertTfCudaGpuIdPair(TfGpuId tf_gpu_id, CudaGpuId cuda_gpu_id);
+ // Adds a mapping from tf_gpu_id to platform_gpu_id.
+ static Status InsertTfPlatformGpuIdPair(TfGpuId tf_gpu_id,
+ PlatformGpuId platform_gpu_id);
- // Gets the cuda_gpu_id associated with tf_gpu_id. Returns OK if found.
- static Status TfToCudaGpuId(TfGpuId tf_gpu_id, CudaGpuId* cuda_gpu_id);
+ // Gets the platform_gpu_id associated with tf_gpu_id. Returns OK if found.
+ static Status TfToPlatformGpuId(TfGpuId tf_gpu_id,
+ PlatformGpuId* platform_gpu_id);
// Clears the map. Used in unit tests only.
static void TestOnlyReset();
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
index a663ec7051..8bf3c6a308 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_manager_test.cc
@@ -22,38 +22,38 @@ limitations under the License.
namespace tensorflow {
namespace {
-CudaGpuId TfToCudaGpuId(TfGpuId tf) {
- CudaGpuId cuda;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf, &cuda));
- return cuda;
+PlatformGpuId TfToPlatformGpuId(TfGpuId tf) {
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf, &platform_gpu_id));
+ return platform_gpu_id;
}
TEST(GpuIdManagerTest, Basics) {
TfGpuId key_0(0);
- CudaGpuId value_0(0);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ PlatformGpuId value_0(0);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Multiple calls to map the same value is ok.
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_0, value_0));
- EXPECT_EQ(value_0, TfToCudaGpuId(key_0));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_0, value_0));
+ EXPECT_EQ(value_0, TfToPlatformGpuId(key_0));
// Map a different TfGpuId to a different value.
TfGpuId key_1(3);
- CudaGpuId value_1(2);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_1, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_1));
+ PlatformGpuId value_1(2);
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_1, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_1));
// Mapping a different TfGpuId to the same value is ok.
TfGpuId key_2(10);
- TF_ASSERT_OK(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_1));
- EXPECT_EQ(value_1, TfToCudaGpuId(key_2));
+ TF_ASSERT_OK(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_1));
+ EXPECT_EQ(value_1, TfToPlatformGpuId(key_2));
// Mapping the same TfGpuId to a different value.
- ASSERT_FALSE(GpuIdManager::InsertTfCudaGpuIdPair(key_2, value_0).ok());
+ ASSERT_FALSE(GpuIdManager::InsertTfPlatformGpuIdPair(key_2, value_0).ok());
// Getting a nonexistent mapping.
- ASSERT_FALSE(GpuIdManager::TfToCudaGpuId(TfGpuId(100), &value_0).ok());
+ ASSERT_FALSE(GpuIdManager::TfToPlatformGpuId(TfGpuId(100), &value_0).ok());
}
} // namespace
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index b9c66b3328..b1f10fb1dc 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -24,34 +24,37 @@ limitations under the License.
namespace tensorflow {
-// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
+// Utility methods for translation between Tensorflow GPU ids and platform GPU
+// ids.
class GpuIdUtil {
public:
// Convenient methods for getting the associated executor given a TfGpuId or
- // CudaGpuId.
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
- return gpu_manager->ExecutorForDevice(cuda_gpu_id.value());
+ // PlatformGpuId.
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ se::Platform* gpu_manager, PlatformGpuId platform_gpu_id) {
+ return gpu_manager->ExecutorForDevice(platform_gpu_id.value());
}
- static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
- CudaGpuId cuda_gpu_id) {
- return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id);
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForPlatformGpuId(
+ PlatformGpuId platform_gpu_id) {
+ return ExecutorForPlatformGpuId(GPUMachineManager(), platform_gpu_id);
}
static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_RETURN_IF_ERROR(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- return ExecutorForCudaGpuId(cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ TF_RETURN_IF_ERROR(
+ GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ return ExecutorForPlatformGpuId(platform_gpu_id);
}
- // Verify that the cuda_gpu_id associated with a TfGpuId is legitimate.
+ // Verify that the platform_gpu_id associated with a TfGpuId is legitimate.
static void CheckValidTfGpuId(TfGpuId tf_gpu_id) {
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
const int visible_device_count = GPUMachineManager()->VisibleDeviceCount();
- CHECK_LT(cuda_gpu_id.value(), visible_device_count)
- << "cuda_gpu_id is outside discovered device range."
- << " TF GPU id: " << tf_gpu_id << " CUDA GPU id: " << cuda_gpu_id
+ CHECK_LT(platform_gpu_id.value(), visible_device_count)
+ << "platform_gpu_id is outside discovered device range."
+ << " TF GPU id: " << tf_gpu_id
+ << " platform GPU id: " << platform_gpu_id
<< " visible device count: " << visible_device_count;
}
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.h b/tensorflow/core/common_runtime/gpu/gpu_init.h
index bfd7a77f83..4e1f06ac83 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_init.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
#include "tensorflow/core/lib/core/status.h"
@@ -36,4 +36,4 @@ stream_executor::Platform* GPUMachineManager();
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_INIT_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_INIT_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
index b18688174d..3e95374fda 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.cc
@@ -76,12 +76,16 @@ GPUProcessState::GPUProcessState() : gpu_device_enabled_(false) {
// This function is defined for debugging problems with the allocators.
GPUProcessState::~GPUProcessState() {
CHECK_EQ(this, instance_);
- for (auto p : gpu_allocators_) {
- delete p;
- }
instance_ = nullptr;
}
+int GPUProcessState::BusIdForGPU(TfGpuId tf_gpu_id) {
+ // Return the NUMA node associated with the GPU's StreamExecutor.
+ se::StreamExecutor* se =
+ GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
+ return se->GetDeviceDescription().numa_node();
+}
+
Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
TfGpuId tf_gpu_id,
size_t total_bytes) {
@@ -93,64 +97,63 @@ Allocator* GPUProcessState::GetGPUAllocator(const GPUOptions& options,
if (tf_gpu_id.value() >= static_cast<int64>(gpu_allocators_.size())) {
gpu_allocators_.resize(tf_gpu_id.value() + 1);
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- gpu_al_.resize(tf_gpu_id.value() + 1);
}
- if (gpu_allocators_[tf_gpu_id.value()] == nullptr) {
- VisitableAllocator* gpu_allocator;
-
+ AllocatorParts& allocator_parts = gpu_allocators_[tf_gpu_id.value()];
+ if (allocator_parts.allocator.get() == nullptr) {
// Validate allocator types.
if (!allocator_type.empty() && allocator_type != "BFC") {
LOG(ERROR) << "Invalid allocator type: " << allocator_type;
return nullptr;
}
- CudaGpuId cuda_gpu_id;
- TF_CHECK_OK(GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id));
- gpu_allocator =
- new GPUBFCAllocator(cuda_gpu_id, total_bytes, options,
+ PlatformGpuId platform_gpu_id;
+ TF_CHECK_OK(GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id));
+ int bus_id = BusIdForGPU(tf_gpu_id);
+ while (bus_id >= gpu_visitors_.size()) {
+ gpu_visitors_.push_back({});
+ }
+ GPUMemAllocator* sub_allocator = new GPUMemAllocator(
+ GpuIdUtil::ExecutorForPlatformGpuId(platform_gpu_id).ValueOrDie(),
+ platform_gpu_id,
+ (options.per_process_gpu_memory_fraction() > 1.0 ||
+ options.experimental().use_unified_memory()),
+ gpu_visitors_[bus_id], {});
+ Allocator* gpu_allocator =
+ new GPUBFCAllocator(sub_allocator, total_bytes, options,
strings::StrCat("GPU_", tf_gpu_id.value(), "_bfc"));
// If true, checks for memory overwrites by writing
// distinctive patterns on both ends of allocated memory.
if (useCudaMemoryGuardAllocator()) {
- gpu_allocator = new GPUDebugAllocator(gpu_allocator, cuda_gpu_id);
- gpu_allocator = new GPUNanResetAllocator(gpu_allocator, cuda_gpu_id);
+ gpu_allocator = new GPUDebugAllocator(gpu_allocator, platform_gpu_id);
+ gpu_allocator = new GPUNanResetAllocator(gpu_allocator, platform_gpu_id);
} else if (useCudaMallocAllocator()) {
// If true, passes all allocation requests through to cudaMalloc
// useful for doing memory debugging with tools like cuda-memcheck
// **WARNING** probably will not work in a multi-gpu scenario
- gpu_allocator = new GPUcudaMallocAllocator(gpu_allocator, cuda_gpu_id);
- }
- gpu_allocators_[tf_gpu_id.value()] = gpu_allocator;
-
- // If there are any pending AllocVisitors for this bus, add
- // them now.
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
- int bus_id = se->GetDeviceDescription().numa_node();
- if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) {
- for (const auto& v : gpu_visitors_[bus_id]) {
- gpu_allocator->AddAllocVisitor(v);
- }
+ gpu_allocator =
+ new GPUcudaMallocAllocator(gpu_allocator, platform_gpu_id);
}
+
+ Allocator* recording_allocator = nullptr;
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::GPU;
- md.dev_index = cuda_gpu_id.value();
+ md.dev_index = platform_gpu_id.value();
md.gpu_registered = false;
md.nic_registered = true;
- if (static_cast<int64>(gpu_al_.size()) <= tf_gpu_id.value()) {
- gpu_al_.resize(tf_gpu_id.value() + 1);
- }
- gpu_al_[tf_gpu_id.value()] = new internal::RecordingAllocator(
+ recording_allocator = new internal::RecordingAllocator(
&process_state_->mem_desc_map_, gpu_allocator, md, &mu_);
}
+ allocator_parts = {std::unique_ptr<Allocator>(gpu_allocator), sub_allocator,
+ std::unique_ptr<Allocator>(recording_allocator)};
+ }
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return allocator_parts.recording_allocator.get();
+ } else {
+ return allocator_parts.allocator.get();
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return gpu_al_[tf_gpu_id.value()];
- return gpu_allocators_[tf_gpu_id.value()];
#else
LOG(FATAL) << "GPUAllocator unavailable. Not compiled with --config=cuda.";
return nullptr;
@@ -172,11 +175,12 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
tf_shared_lock lock(mu_);
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types &&
- static_cast<int>(cuda_al_.size()) > 0) {
- return cuda_al_[0];
+ !cuda_host_allocators_.empty() &&
+ cuda_host_allocators_[0].recording_allocator != nullptr) {
+ return cuda_host_allocators_[0].recording_allocator.get();
}
if (static_cast<int>(cuda_host_allocators_.size()) > numa_node) {
- return cuda_host_allocators_[0];
+ return cuda_host_allocators_[0].allocator.get();
}
}
@@ -190,7 +194,7 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
// it knows is valid.
se::StreamExecutor* se = nullptr;
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
- if (gpu_allocators_[i] != nullptr) {
+ if (gpu_allocators_[i].allocator != nullptr) {
se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
break;
}
@@ -199,6 +203,15 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
CHECK_NE(nullptr, se);
while (static_cast<int>(cuda_host_allocators_.size()) <= numa_node) {
+ while (cuda_host_alloc_visitors_.size() <= numa_node) {
+ cuda_host_alloc_visitors_.push_back({});
+ }
+ while (cuda_host_free_visitors_.size() <= numa_node) {
+ cuda_host_free_visitors_.push_back({});
+ }
+ SubAllocator* sub_allocator = new CUDAHostAllocator(
+ se, numa_node, cuda_host_alloc_visitors_[numa_node],
+ cuda_host_free_visitors_[numa_node]);
// TODO(zheng-xq): evaluate whether 64GB by default is the best choice.
int64 cuda_host_mem_limit_in_mb = -1;
Status status = ReadInt64FromEnvVar("TF_CUDA_HOST_MEM_LIMIT_IN_MB",
@@ -208,62 +221,92 @@ Allocator* GPUProcessState::GetCUDAHostAllocator(int numa_node) {
LOG(ERROR) << "GetCUDAHostAllocator: " << status.error_message();
}
int64 cuda_host_mem_limit = cuda_host_mem_limit_in_mb * (1LL << 20);
- VisitableAllocator* allocator =
- new BFCAllocator(new CUDAHostAllocator(se), cuda_host_mem_limit,
+ Allocator* allocator =
+ new BFCAllocator(sub_allocator, cuda_host_mem_limit,
true /*allow_growth*/, "cuda_host_bfc" /*name*/);
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
- cuda_host_allocators_.push_back(allocator);
+ cuda_host_allocators_.push_back({std::unique_ptr<Allocator>(allocator),
+ sub_allocator,
+ std::unique_ptr<Allocator>(nullptr)});
+ AllocatorParts& allocator_parts = cuda_host_allocators_.back();
if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
ProcessState::MemDesc md;
md.loc = ProcessState::MemDesc::CPU;
md.dev_index = 0;
md.gpu_registered = true;
md.nic_registered = false;
- cuda_al_.push_back(new internal::RecordingAllocator(
- &process_state_->mem_desc_map_, cuda_host_allocators_.back(), md,
- &mu_));
+ allocator_parts.recording_allocator.reset(
+ new internal::RecordingAllocator(&process_state_->mem_desc_map_,
+ allocator_parts.allocator.get(), md,
+ &mu_));
}
}
- if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types)
- return cuda_al_[0];
- return cuda_host_allocators_[0];
+ if (process_state_->ProcessState::FLAGS_brain_gpu_record_mem_types) {
+ return cuda_host_allocators_[0].recording_allocator.get();
+ } else {
+ return cuda_host_allocators_[0].allocator.get();
+ }
}
void GPUProcessState::AddGPUAllocVisitor(int bus_id,
- const AllocVisitor& visitor) {
- CHECK(process_state_);
+ const SubAllocator::Visitor& visitor) {
#if GOOGLE_CUDA
mutex_lock lock(mu_);
- for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
- se::StreamExecutor* se =
- GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
- if (gpu_allocators_[i] &&
- (se->GetDeviceDescription().numa_node() + 1) == bus_id) {
- gpu_allocators_[i]->AddAllocVisitor(visitor);
- }
- }
+ CHECK(gpu_allocators_.empty()) // Crash OK
+ << "AddGPUAllocVisitor must be called before "
+ "first call to GetGPUAllocator.";
while (bus_id >= static_cast<int64>(gpu_visitors_.size())) {
- gpu_visitors_.push_back(std::vector<AllocVisitor>());
+ gpu_visitors_.push_back(std::vector<SubAllocator::Visitor>());
}
gpu_visitors_[bus_id].push_back(visitor);
#endif // GOOGLE_CUDA
}
+void GPUProcessState::AddCUDAHostAllocVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostAllocVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_alloc_visitors_.size())) {
+ cuda_host_alloc_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_alloc_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
+void GPUProcessState::AddCUDAHostFreeVisitor(
+ int numa_node, const SubAllocator::Visitor& visitor) {
+#if GOOGLE_CUDA
+ mutex_lock lock(mu_);
+ CHECK(cuda_host_allocators_.empty()) // Crash OK
+ << "AddCUDAHostFreeVisitor must be called before "
+ "first call to GetCUDAHostAllocator.";
+ while (numa_node >= static_cast<int64>(cuda_host_free_visitors_.size())) {
+ cuda_host_free_visitors_.push_back(std::vector<SubAllocator::Visitor>());
+ }
+ cuda_host_free_visitors_[numa_node].push_back(visitor);
+#endif // GOOGLE_CUDA
+}
+
void GPUProcessState::TestOnlyReset() {
- process_state_->ProcessState::TestOnlyReset();
+ if (process_state_) {
+ process_state_->ProcessState::TestOnlyReset();
+ }
{
mutex_lock lock(mu_);
gpu_device_enabled_ = false;
+ gpu_allocators_.clear();
gpu_visitors_.clear();
- gtl::STLDeleteElements(&gpu_allocators_);
- gtl::STLDeleteElements(&cuda_host_allocators_);
- gtl::STLDeleteElements(&gpu_al_);
- gtl::STLDeleteElements(&cuda_al_);
+ cuda_host_allocators_.clear();
+ cuda_host_alloc_visitors_.clear();
+ cuda_host_free_visitors_.clear();
}
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_process_state.h b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
index cb41c3c6bd..43e9a31660 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_process_state.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_process_state.h
@@ -32,7 +32,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state when GPUs are present.
@@ -72,18 +71,30 @@ class GPUProcessState {
virtual Allocator* GetCUDAHostAllocator(int numa_node);
- // Registers a function to be called once on every new Region
- // allocated by every GPURegionAllocator proximate to the specified
- // bus. The AllocVisitor is provided with a memory pointer and the
- // size of the area it identifies. The pointer is not guaranteed to
- // be valid after the call terminates. The intention is for this
- // interface to be used for network device memory registration.
- // "bus_id" is platform-specific. On many platforms it
- // should be 0. On machines with multiple PCIe buses, it should be
- // the index of one of the PCIe buses. If the bus_id is invalid,
- // results are undefined.
- typedef std::function<void(void*, size_t)> AllocVisitor;
- virtual void AddGPUAllocVisitor(int bus_id, const AllocVisitor& visitor);
+ // Registers a Visitor to be invoked on new chunks of memory allocated by the
+ // SubAllocator of every GPU proximate to the specified bus. The AllocVisitor
+ // is provided with a memory pointer, a GPU id, and the size of the area it
+ // identifies. The pointer is not guaranteed to be valid after the call
+ // terminates. The intention is for this interface to be used for network
+ // device memory registration. "bus_id" is platform-specific. On many
+ // platforms it should be 0. On machines with multiple PCIe buses, it should
+ // be the index of one of the PCIe buses (maybe the NUMA node at which the
+ // PCIe is rooted). If the bus_id is invalid, results are undefined.
+ virtual void AddGPUAllocVisitor(int bus_id,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on new chunks of memory allocated by
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostAllocVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Registers a Visitor to be invoked on each chunk handed back for freeing to
+ // the SubAllocator of the CUDAHostAllocator for the given numa_node.
+ virtual void AddCUDAHostFreeVisitor(int numa_node,
+ const SubAllocator::Visitor& visitor);
+
+ // Returns bus_id for the given GPU id.
+ virtual int BusIdForGPU(TfGpuId tf_gpu_id);
protected:
GPUProcessState();
@@ -103,16 +114,21 @@ class GPUProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> gpu_allocators_ GUARDED_BY(mu_);
- std::vector<std::vector<AllocVisitor>> gpu_visitors_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_host_allocators_ GUARDED_BY(mu_);
+ struct AllocatorParts {
+ std::unique_ptr<Allocator> allocator;
+ SubAllocator* sub_allocator; // owned by allocator
+ std::unique_ptr<Allocator> recording_allocator;
+ };
+ std::vector<AllocatorParts> gpu_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> gpu_visitors_ GUARDED_BY(mu_);
- virtual ~GPUProcessState();
+ std::vector<AllocatorParts> cuda_host_allocators_ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_alloc_visitors_
+ GUARDED_BY(mu_);
+ std::vector<std::vector<SubAllocator::Visitor>> cuda_host_free_visitors_
+ GUARDED_BY(mu_);
- // Optional RecordingAllocators that wrap the corresponding
- // Allocators for runtime attribute use analysis.
- std::vector<Allocator*> gpu_al_ GUARDED_BY(mu_);
- std::vector<Allocator*> cuda_al_ GUARDED_BY(mu_);
+ virtual ~GPUProcessState();
friend class GPUDeviceTest;
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
index 771c158267..c61ada96ef 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_stream_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
#include <unordered_map>
@@ -42,4 +42,4 @@ Status AssignStreams(const Graph* graph, const AssignStreamsOpts& opts,
} // namespace gpu_stream_util
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_STREAM_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h
index 57687a8364..8ac3febb01 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
@@ -108,4 +108,4 @@ class GPUUtil {
};
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_GPU_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_GPU_UTIL_H_
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
index ea1b04feeb..4bc88ffc8c 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
#include "tensorflow/core/common_runtime/gpu_device_context.h"
#include "tensorflow/core/framework/tensor.h"
@@ -36,4 +37,12 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor,
GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done);
}
+Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) {
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ device->tensorflow_gpu_device_info();
+ gpu_info->event_mgr->ThenExecute(stream, func);
+ return Status::OK();
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index 583bff2c07..6b2f6547b0 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -31,7 +31,8 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(nullptr, pool.AllocateRaw(4 /*alignment*/, 0 /*num_bytes*/));
@@ -49,7 +50,8 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ(0, pool.get_from_pool_count());
@@ -82,7 +84,8 @@ TEST(PoolAllocatorTest, Alignment) {
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
for (int i = 0; i < 16; ++i) {
size_t alignment = 1 << i;
@@ -97,8 +100,8 @@ TEST(PoolAllocatorTest, Alignment) {
TEST(PoolAllocatorTest, AutoResize) {
PoolAllocator pool(2 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(0 /*numa_node*/), new NoopRounder,
- "pool");
+ new BasicCPUAllocator(0 /*numa_node*/, {}, {}),
+ new NoopRounder, "pool");
// Alloc/dealloc 10 sizes just a few times, confirming pool size
// stays at 2.
@@ -123,14 +126,32 @@ TEST(PoolAllocatorTest, AutoResize) {
}
TEST(PoolAllocatorTest, CudaHostAllocator) {
+ int alloc_count = 0;
+ int64 alloc_size = 0;
+ SubAllocator::Visitor alloc_visitor =
+ [&alloc_count, &alloc_size](void* ptr, int numa_node, int64 size) {
+ ++alloc_count;
+ alloc_size += size;
+ };
+ int free_count = 0;
+ int64 free_size = 0;
+ SubAllocator::Visitor free_visitor =
+ [&free_count, &free_size](void* ptr, int numa_node, int64 size) {
+ ++free_count;
+ free_size += size;
+ };
se::Platform* platform =
se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
- PoolAllocator pool(
- 2 /*pool_size_limit*/, false /*auto_resize*/,
- new CUDAHostAllocator(
- platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
- new NoopRounder, "pool");
+ CUDAHostAllocator* sub_allocator = new CUDAHostAllocator(
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
+ .ValueOrDie(),
+ 0 /*numa_node*/, {alloc_visitor}, {free_visitor});
+ PoolAllocator pool(2 /*pool_size_limit*/, false /*auto_resize*/,
+ sub_allocator, new NoopRounder, "pool");
+ EXPECT_EQ(0, alloc_count);
+ EXPECT_EQ(0, alloc_size);
+ EXPECT_EQ(0, free_count);
+ EXPECT_EQ(0, free_size);
// Repeatedly Get a 16-byte value, confirming that there's only
// one real allocation.
@@ -138,6 +159,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(1, pool.allocated_count());
EXPECT_NE(nullptr, p1_16);
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ // Each suballocation includes a 16B ChunkPrefix.
+ static const int kChunkPrefixSize = 16;
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
pool.DeallocateRaw(p1_16);
// Pool contents {16}
EXPECT_EQ(1, pool.put_count());
@@ -148,6 +173,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
pool.DeallocateRaw(p2_16); // Put it back.
// Pool contents {16}
EXPECT_EQ(2, pool.put_count());
+ EXPECT_EQ(1, alloc_count); // Underlying suballoc of 16 bytes
+ EXPECT_EQ(16 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// Get two more values of different sizes.
void* p3_4 = pool.AllocateRaw(4, 4);
@@ -160,6 +188,9 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
void* p4_2 = pool.AllocateRaw(4, 2); // Get a third size buffer.
EXPECT_NE(nullptr, p4_2);
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(0, free_count);
// The pool is full: when we put back p4_2, the 16-byte buffer
// should be evicted since it was least recently inserted.
@@ -167,6 +198,10 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
// Pool contents {2, 4}
EXPECT_EQ(4, pool.put_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
// Re-getting and putting size 2 or 4 should not alter pool size or
// num-evicted.
@@ -180,12 +215,20 @@ TEST(PoolAllocatorTest, CudaHostAllocator) {
EXPECT_EQ(6, pool.put_count());
EXPECT_EQ(3, pool.allocated_count());
EXPECT_EQ(1, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(1, free_count);
+ EXPECT_EQ(16 + (free_count * kChunkPrefixSize), free_size);
pool.Clear();
EXPECT_EQ(0, pool.get_from_pool_count());
EXPECT_EQ(0, pool.put_count());
EXPECT_EQ(0, pool.allocated_count());
EXPECT_EQ(0, pool.evicted_count());
+ EXPECT_EQ(3, alloc_count);
+ EXPECT_EQ(16 + 4 + 2 + (alloc_count * kChunkPrefixSize), alloc_size);
+ EXPECT_EQ(3, free_count);
+ EXPECT_EQ(16 + 4 + 2 + (free_count * kChunkPrefixSize), free_size);
}
TEST(PoolAllocatorTest, Pow2Rounder) {
@@ -206,7 +249,8 @@ TEST(PoolAllocatorTest, Name) {
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
- .ValueOrDie()),
+ .ValueOrDie(),
+ 0 /*numa_node*/, {}, {}),
new NoopRounder, "pool");
EXPECT_EQ("pool", pool.Name());
}
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
index d697d878dc..3603808152 100644
--- a/tensorflow/core/common_runtime/gpu_device_context.h
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
-#define TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_base.h"
@@ -60,6 +60,9 @@ class GPUDeviceContext : public DeviceContext {
void MaintainLifetimeOnStream(const Tensor* t,
se::Stream* stream) const override {}
+ Status ThenExecute(Device* device, se::Stream* stream,
+ std::function<void()> func) override;
+
private:
int stream_id_;
// The default primary stream to use for this context.
@@ -75,4 +78,4 @@ class GPUDeviceContext : public DeviceContext {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_GPU_DEVICE_CONTEXT_H_
diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc
index c23b7d3699..4475fa979e 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.cc
+++ b/tensorflow/core/common_runtime/graph_execution_state.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/graph_execution_state.h"
#include <memory>
+#include <set>
#include <string>
#include <unordered_set>
#include <utility>
@@ -560,6 +561,10 @@ Status GraphExecutionState::OptimizeGraph(
grappler::GrapplerItem item;
item.id = "tf_graph";
graph_->ToGraphDef(&item.graph);
+ // TODO(b/114748242): Add a unit test to test this bug fix.
+ if (flib_def_) {
+ *item.graph.mutable_library() = flib_def_->ToProto();
+ }
item.fetch.insert(item.fetch.end(),
options.callable_options.fetch().begin(),
@@ -581,7 +586,7 @@ Status GraphExecutionState::OptimizeGraph(
if (id.second != 0) {
return errors::InvalidArgument("Unsupported feed: ", feed);
}
- feeds.insert(id.first.ToString());
+ feeds.emplace(id.first);
}
for (const TensorConnection& tensor_connection :
options.callable_options.tensor_connection()) {
@@ -590,7 +595,7 @@ Status GraphExecutionState::OptimizeGraph(
return errors::InvalidArgument("Unsupported feed: ",
tensor_connection.to_tensor());
}
- feeds.insert(id.first.ToString());
+ feeds.emplace(id.first);
}
for (const NodeDef& node : original_graph_def_.node()) {
if (feeds.find(node.name()) == feeds.end()) {
@@ -727,12 +732,50 @@ Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
+ int64 collective_graph_key = options.collective_graph_key;
+ if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
+ // BuildGraphOptions does not specify a collective_graph_key. Check all
+ // nodes in the Graph and FunctionLibraryDefinition for collective ops and
+ // if found, initialize a collective_graph_key as a hash of the ordered set
+ // of instance keys.
+ std::set<int32> instance_key_set;
+ for (Node* node : optimized_graph->nodes()) {
+ if (node->IsCollective()) {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(node->attrs(), "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ } else {
+ const FunctionDef* fdef = optimized_flib->Find(node->def().op());
+ if (fdef != nullptr) {
+ for (const NodeDef& ndef : fdef->node_def()) {
+ if (ndef.op() == "CollectiveReduce" ||
+ ndef.op() == "CollectiveBcastSend" ||
+ ndef.op() == "CollectiveBcastRecv") {
+ int32 instance_key;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(ndef, "instance_key", &instance_key));
+ instance_key_set.emplace(instance_key);
+ }
+ }
+ }
+ }
+ }
+ if (!instance_key_set.empty()) {
+ uint64 hash = 0x8774aa605c729c72ULL;
+ for (int32 instance_key : instance_key_set) {
+ hash = Hash64Combine(instance_key, hash);
+ }
+ collective_graph_key = hash;
+ }
+ }
+
// Copy the extracted graph in order to make its node ids dense,
// since the local CostModel used to record its stats is sized by
// the largest node id.
std::unique_ptr<ClientGraph> dense_copy(
new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
- rewrite_metadata.fetch_types));
+ rewrite_metadata.fetch_types, collective_graph_key));
CopyGraph(*optimized_graph, &dense_copy->graph);
// TODO(vrv): We should check invariants of the graph here.
diff --git a/tensorflow/core/common_runtime/graph_execution_state.h b/tensorflow/core/common_runtime/graph_execution_state.h
index d44a24c87b..9cabe478a6 100644
--- a/tensorflow/core/common_runtime/graph_execution_state.h
+++ b/tensorflow/core/common_runtime/graph_execution_state.h
@@ -50,17 +50,20 @@ struct GraphExecutionStateOptions {
// BuildGraphOptions.
struct ClientGraph {
explicit ClientGraph(std::unique_ptr<FunctionLibraryDefinition> flib,
- DataTypeVector feed_types, DataTypeVector fetch_types)
+ DataTypeVector feed_types, DataTypeVector fetch_types,
+ int64 collective_graph_key)
: flib_def(std::move(flib)),
graph(flib_def.get()),
feed_types(std::move(feed_types)),
- fetch_types(std::move(fetch_types)) {}
+ fetch_types(std::move(fetch_types)),
+ collective_graph_key(collective_graph_key) {}
// Each client-graph gets its own function library since optimization passes
// post rewrite for execution might want to introduce new functions.
std::unique_ptr<FunctionLibraryDefinition> flib_def;
Graph graph;
DataTypeVector feed_types;
DataTypeVector fetch_types;
+ int64 collective_graph_key;
};
// GraphExecutionState is responsible for generating an
diff --git a/tensorflow/core/common_runtime/graph_optimizer.cc b/tensorflow/core/common_runtime/graph_optimizer.cc
index 96ecfb41d4..37a979a8f1 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.cc
+++ b/tensorflow/core/common_runtime/graph_optimizer.cc
@@ -38,7 +38,8 @@ void GraphOptimizer::Optimize(
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn) {
+ const std::function<bool(const Node*)>& cse_consider_fn,
+ const std::function<bool(const Node*)>& cf_consider_fn) {
Graph* g = graph->get();
DumpGraph("Initial", g);
@@ -62,6 +63,7 @@ void GraphOptimizer::Optimize(
if (opts_.do_constant_folding()) {
ConstantFoldingOptions cf_opts;
cf_opts.shape_map = shape_map;
+ cf_opts.consider = cf_consider_fn;
if (opts_.max_folded_constant_in_bytes() > 0) {
cf_opts.max_constant_size_in_bytes =
opts_.max_folded_constant_in_bytes();
diff --git a/tensorflow/core/common_runtime/graph_optimizer.h b/tensorflow/core/common_runtime/graph_optimizer.h
index 80246281cd..789cc56942 100644
--- a/tensorflow/core/common_runtime/graph_optimizer.h
+++ b/tensorflow/core/common_runtime/graph_optimizer.h
@@ -45,12 +45,15 @@ class GraphOptimizer {
//
// If cse_consider_fn is not null then only nodes for which cse_consider_fn
// returns true will be considered for CSE.
+ // If cf_consider_fn is not null then only nodes for which cf_consider_fn
+ // returns true will be considered for CF.
void Optimize(
FunctionLibraryRuntime* runtime, Env* env, Device* device,
std::unique_ptr<Graph>* graph,
const std::unordered_map<string, std::vector<PartialTensorShape>>*
shape_map,
- const std::function<bool(const Node*)>& cse_consider_fn = nullptr);
+ const std::function<bool(const Node*)>& cse_consider_fn = nullptr,
+ const std::function<bool(const Node*)>& cf_consider_fn = nullptr);
const OptimizerOptions& options() { return opts_; }
diff --git a/tensorflow/core/common_runtime/graph_runner.cc b/tensorflow/core/common_runtime/graph_runner.cc
index 0a1797fa19..f9aef3af70 100644
--- a/tensorflow/core/common_runtime/graph_runner.cc
+++ b/tensorflow/core/common_runtime/graph_runner.cc
@@ -56,7 +56,7 @@ class SimpleRendezvous : public Rendezvous {
}
mutex_lock l(mu_);
- string edge_name = std::string(parsed.edge_name);
+ string edge_name(parsed.edge_name);
if (table_.count(edge_name) > 0) {
return errors::Internal("Send of an already sent tensor");
}
@@ -69,7 +69,7 @@ class SimpleRendezvous : public Rendezvous {
Tensor tensor;
Status status = Status::OK();
{
- string key = std::string(parsed.edge_name);
+ string key(parsed.edge_name);
mutex_lock l(mu_);
if (table_.count(key) <= 0) {
status = errors::Internal("Did not find key ", key);
diff --git a/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
new file mode 100644
index 0000000000..eae34997d9
--- /dev/null
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.cc
@@ -0,0 +1,440 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <utility>
+
+#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/collective_util.h"
+#include "tensorflow/core/common_runtime/device_mgr.h"
+#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+
+// Set true for greater intelligibility of debug mode log messages.
+#define READABLE_KEYS false
+
+namespace tensorflow {
+
+namespace {
+// Key to be used for BufRendezvous by Broadcaster.
+string BroadcastBufKey(const string& exec_key, int subdiv, int src_rank,
+ int dst_rank) {
+ if (READABLE_KEYS) {
+ return strings::StrCat("broadcast(", exec_key, "):subdiv(", subdiv,
+ "):src(", src_rank, "):dst(", dst_rank, ")");
+ } else {
+ // TODO(b/78352018): Try a denser format, e.g. a 64 or 128 bit hash.
+ return strings::StrCat(exec_key, ":", subdiv, ":", src_rank, ":", dst_rank);
+ }
+}
+} // namespace
+
+HierarchicalTreeBroadcaster::HierarchicalTreeBroadcaster()
+ : col_ctx_(nullptr),
+ col_params_(nullptr),
+ done_(nullptr),
+ is_source_(false) {}
+
+int HierarchicalTreeBroadcaster::GetDeviceTask(
+ int device_rank, const std::vector<int>& dev_per_task) {
+ int num_tasks = static_cast<int>(dev_per_task.size());
+ int task_lo = 0;
+ int task_hi;
+ for (int ti = 0; ti < num_tasks; ti++) {
+ task_hi = task_lo + dev_per_task[ti];
+ if (task_lo <= device_rank && device_rank < task_hi) return ti;
+ task_lo = task_hi;
+ }
+ LOG(FATAL) << "Unexpected device rank " << device_rank << " for " << task_hi
+ << " devices";
+ return -1;
+}
+
+Status HierarchicalTreeBroadcaster::InitializeCollectiveParams(
+ CollectiveParams* col_params) {
+ CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE);
+ CHECK_EQ(col_params->instance.impl_details.collective_name,
+ "HierarchicalTreeBroadcast");
+ const string& device_name =
+ col_params->instance.device_names[col_params->default_rank];
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(col_params->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &col_params->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < col_params->group.group_size; ++di) {
+ if (col_params->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &col_params->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
+
+ if (VLOG_IS_ON(2)) {
+ string dpt_buf;
+ for (int dpt : dev_per_task) strings::StrAppend(&dpt_buf, dpt, ";");
+ VLOG(2) << "HierarchicalTreeBroadcaster::InitializeCollectiveParams device="
+ << device_name << " source_rank=" << col_params->source_rank
+ << " dev_per_task=" << dpt_buf;
+ }
+ int num_tasks = col_params->group.num_tasks;
+ // If there is just 1 task, then execute binary tree broadcast over all
+ // devices. Otherwise, the first subdiv is inter-task broadcast, and then
+ // there are N more subdivs, where N is #task.
+ int num_subdivs = num_tasks + (num_tasks > 1 ? 1 : 0);
+ int total_num_devices = 0;
+ for (int num_dev : dev_per_task) total_num_devices += num_dev;
+
+ col_params->instance.impl_details.subdiv_permutations.resize(num_subdivs);
+ col_params->subdiv_rank.reserve(num_subdivs);
+ col_params->instance.impl_details.subdiv_source_rank.reserve(num_subdivs);
+
+ // Inter-task subdiv. Pick one device from each task - this is the source
+ // device if it belongs to that task, or device 0 for that task. If a device
+ // does not participate in the subdiv, set subdiv_rank to -1.
+ if (num_tasks > 1) {
+ const int sdi = 0;
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int device_count = 0;
+ int source_task = GetDeviceTask(col_params->source_rank, dev_per_task);
+ for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
+ bool participate = false;
+ if (source_task == ti) {
+ // Source device belongs to this task.
+ perm.push_back(col_params->source_rank);
+ participate =
+ col_params->instance.device_names[col_params->source_rank] ==
+ device_name;
+ } else {
+ // Source does not belong to this task, choose dev 0.
+ perm.push_back(device_count);
+ participate =
+ col_params->instance.device_names[device_count] == device_name;
+ }
+ if (participate) col_params->subdiv_rank.push_back(ti);
+ device_count += dev_per_task[ti];
+ }
+ if (col_params->subdiv_rank.empty()) col_params->subdiv_rank.push_back(-1);
+ col_params->instance.impl_details.subdiv_source_rank.push_back(source_task);
+ }
+
+ // Intra-task subdivs. Pick all devices in task ti for subdiv sdi. Set
+ // source to dev 0 for that task if it does not contain original source, else
+ // set to rank of original source. If a device does not participate in
+ // the subdiv, set subdiv_rank to -1;
+ int abs_di = 0;
+ for (int ti = 0; ti < col_params->group.num_tasks; ti++) {
+ const int sdi = ti + (num_tasks > 1 ? 1 : 0);
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ bool participate = false;
+ int subdiv_source = 0;
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ perm.push_back(abs_di);
+ if (col_params->instance.device_names[abs_di] == device_name) {
+ participate = true;
+ col_params->subdiv_rank.push_back(di);
+ }
+ if (abs_di == col_params->source_rank) subdiv_source = di;
+ abs_di++;
+ }
+ if (!participate) col_params->subdiv_rank.push_back(-1);
+ col_params->instance.impl_details.subdiv_source_rank.push_back(
+ subdiv_source);
+ }
+
+ for (int sri = 0; sri < num_subdivs; sri++) {
+ CHECK_GE(col_params->instance.impl_details.subdiv_source_rank[sri], 0);
+ }
+
+ VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
+ return Status::OK();
+}
+
+Status HierarchicalTreeBroadcaster::InitializeCollectiveContext(
+ CollectiveContext* col_ctx) {
+ CHECK(col_ctx->dev_mgr);
+ col_ctx_ = col_ctx;
+ col_params_ = &col_ctx->col_params;
+ return collective_util::InitializeDeviceAndLocality(
+ col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+ &col_ctx->device_locality);
+}
+
+void HierarchicalTreeBroadcaster::Run(StatusCallback done) {
+ CHECK(col_ctx_);
+ CHECK(col_params_);
+ done_ = std::move(done);
+ is_source_ = col_params_->is_source;
+ RunTree();
+}
+
+// Binary tree parent/child relations are trivial to calculate, i.e.
+// device at rank r is the parent of 2r+1 and 2r+2. The one exception
+// is if the source is not rank 0. We treat that case as though the
+// source is appended to the front of the rank ordering as well as
+// continuing to occupy its current position. Hence we calculate as
+// though each device's rank is actually r+1, then subtract 1 again to
+// get the descendent ranks. If the source is not rank 0 then its
+// descendants include both {0,1} and the descendents of its current
+// position. Where a non-0-rank source is a descendent of another
+// device, no send to it is necessary.
+
+/* static*/
+int HierarchicalTreeBroadcaster::TreeRecvFrom(const CollectiveParams& cp,
+ int subdiv) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return -1;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+ if (my_rank == source_rank) return -1;
+ if (source_rank == 0) {
+ return (my_rank - 1) / 2;
+ } else {
+ int predecessor_rank = (my_rank / 2) - 1;
+ return (predecessor_rank < 0) ? source_rank : predecessor_rank;
+ }
+}
+
+/* static */
+void HierarchicalTreeBroadcaster::TreeSendTo(const CollectiveParams& cp,
+ int subdiv,
+ std::vector<int>* targets) {
+ DCHECK_LT(subdiv, static_cast<int>(cp.subdiv_rank.size()));
+ int my_rank = cp.subdiv_rank[subdiv];
+ if (-1 == my_rank) return;
+
+ const auto& impl = cp.instance.impl_details;
+ DCHECK_LT(subdiv, static_cast<int>(impl.subdiv_source_rank.size()));
+ int source_rank = impl.subdiv_source_rank[subdiv];
+
+ int group_size = 0;
+ for (int i = 0; i < impl.subdiv_permutations[subdiv].size(); i++) {
+ if (impl.subdiv_permutations[subdiv][i] >= 0) {
+ group_size++;
+ }
+ }
+
+ targets->clear();
+ int successor_rank = 0;
+ if (source_rank == 0) {
+ successor_rank = (2 * my_rank) + 1;
+ } else {
+ successor_rank = (2 * (my_rank + 1));
+ }
+ DCHECK_NE(successor_rank, my_rank);
+ if (cp.is_source && source_rank != 0) {
+ // The source sends to rank 0,1 in addition to its positional
+ // descendants.
+ if (group_size > 1) {
+ targets->push_back(0);
+ }
+ if (group_size > 2 && source_rank != 1) {
+ targets->push_back(1);
+ }
+ }
+ for (int i = 0; i < 2; ++i) {
+ if (successor_rank < group_size && successor_rank != source_rank) {
+ targets->push_back(successor_rank);
+ }
+ ++successor_rank;
+ }
+}
+
+// Executes a hierarchical tree broadcast.
+// Each subdiv is a broadcast between a subset of the devices.
+// If there is only one task, there is one subdiv comprising a broadcast between
+// all devices belonging to the task.
+// If there are n tasks, n>1, then there are n+1 subdivs. In the first (global)
+// subdiv, one device from each task participates in a binary tree broadcast.
+// Each task receives a copy of the tensor on one device via this broadcast.
+// Subsequent subdivs correspond to intra-task broadcasts. Subdiv i+1
+// corresponds to broadcast between all devices on task i. Thus, each task
+// participates in at most 2 subdivs.
+void HierarchicalTreeBroadcaster::RunTree() {
+ int num_subdivs = static_cast<int>(col_params_->subdiv_rank.size());
+ // TODO(b/78352018): this is easily improved when a node participates in both
+ // first and second subdivision. It would first send to its descendents in
+ // the first subdiv, then wait until all pending ops are finished before
+ // sending to descendents in second subdiv. A better implementation would
+ // collapse the two send blocks.
+ for (int si = 0; si < num_subdivs; si++) {
+ int my_rank = col_params_->subdiv_rank[si];
+ // If rank is -1, this device does not participate in this subdiv.
+ if (-1 == my_rank) continue;
+ int source_rank = col_params_->instance.impl_details.subdiv_source_rank[si];
+ if (VLOG_IS_ON(1)) {
+ string subdiv_buf;
+ for (int r : col_params_->instance.impl_details.subdiv_permutations[si]) {
+ strings::StrAppend(&subdiv_buf, r, ",");
+ }
+ VLOG(1) << "Running Broadcast tree device=" << col_ctx_->device_name
+ << " subdiv=" << si << " perm=" << subdiv_buf
+ << " my_rank=" << my_rank << " source_rank=" << source_rank;
+ }
+
+ mutex mu; // also guards status_ while callbacks are pending
+ int pending_count = 0; // GUARDED_BY(mu)
+ condition_variable all_done;
+
+ if (my_rank >= 0 && my_rank != source_rank) {
+ // Begin by receiving the value.
+ int recv_from_rank = TreeRecvFrom(*col_params_, si);
+ Notification note;
+ DispatchRecv(si, recv_from_rank, my_rank, col_ctx_->output,
+ [this, &mu, &note](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ note.Notify();
+ });
+ note.WaitForNotification();
+ }
+
+ // Then forward value to all descendent devices.
+ if (my_rank >= 0 && status_.ok()) {
+ std::vector<int> send_to_ranks;
+ TreeSendTo(*col_params_, si, &send_to_ranks);
+ for (int i = 0; i < send_to_ranks.size(); ++i) {
+ int target_rank = send_to_ranks[i];
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DispatchSend(si, target_rank, my_rank,
+ (is_source_ ? col_ctx_->input : col_ctx_->output),
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (pending_count == 0) {
+ all_done.notify_all();
+ }
+ });
+ }
+ }
+
+ // For the original source device, we copy input to output if they are
+ // different.
+ // If there is only 1 subdiv, we do this in that subdiv. If there is more
+ // than 1 subdiv, then the original source device will participate in 2
+ // subdivs - the global inter-task broadcast and one local intra-task
+ // broadcast. In this case, we perform the copy in the second subdiv for
+ // this device.
+ if (status_.ok() && is_source_ && (1 == num_subdivs || 0 != si)) {
+ VLOG(2) << "copying input to output for device=" << col_ctx_->device_name
+ << " subdiv=" << si;
+ if (col_ctx_->input != col_ctx_->output &&
+ (DMAHelper::base(col_ctx_->input) !=
+ DMAHelper::base(col_ctx_->output))) {
+ {
+ mutex_lock l(mu);
+ ++pending_count;
+ }
+ DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
+ CollectiveRemoteAccessLocal::MemCpyAsync(
+ op_dev_ctx, op_dev_ctx, col_ctx_->device, col_ctx_->device,
+ col_ctx_->op_ctx->input_alloc_attr(0),
+ col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
+ col_ctx_->output, 0, /*stream_index*/
+ [this, &mu, &pending_count, &all_done](const Status& s) {
+ mutex_lock l(mu);
+ status_.Update(s);
+ --pending_count;
+ if (0 == pending_count) {
+ all_done.notify_all();
+ }
+ });
+ }
+ }
+
+ // Then wait for all pending actions to complete.
+ {
+ mutex_lock l(mu);
+ if (pending_count > 0) {
+ all_done.wait(l);
+ }
+ }
+ }
+ VLOG(2) << "device=" << col_ctx_->device_name << " return status " << status_;
+ done_(status_);
+}
+
+void HierarchicalTreeBroadcaster::DispatchSend(int subdiv, int dst_rank,
+ int src_rank,
+ const Tensor* src_tensor,
+ const StatusCallback& done) {
+ string send_buf_key =
+ BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
+ int dst_idx =
+ col_params_->instance.impl_details.subdiv_permutations[subdiv][dst_rank];
+ VLOG(3) << "DispatchSend " << send_buf_key << " from_device "
+ << col_ctx_->device_name << " to_device "
+ << col_params_->instance.device_names[dst_idx] << " subdiv=" << subdiv
+ << " dst_rank=" << dst_rank << " dst_idx=" << dst_idx;
+ col_ctx_->col_exec->PostToPeer(col_params_->instance.device_names[dst_idx],
+ col_params_->instance.task_names[dst_idx],
+ send_buf_key, col_ctx_->device,
+ col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0),
+ src_tensor, col_ctx_->device_locality, done);
+}
+
+void HierarchicalTreeBroadcaster::DispatchRecv(int subdiv, int src_rank,
+ int dst_rank, Tensor* dst_tensor,
+ const StatusCallback& done) {
+ string recv_buf_key =
+ BroadcastBufKey(col_ctx_->exec_key, subdiv, src_rank, dst_rank);
+ int src_idx =
+ col_params_->instance.impl_details.subdiv_permutations[subdiv][src_rank];
+ VLOG(3) << "DispatchRecv " << recv_buf_key << " from_device "
+ << col_params_->instance.device_names[src_idx] << " to_device "
+ << col_ctx_->device_name << " subdiv=" << subdiv
+ << " src_rank=" << src_rank << " src_idx=" << src_idx;
+ col_ctx_->col_exec->RecvFromPeer(
+ col_params_->instance.device_names[src_idx],
+ col_params_->instance.task_names[src_idx],
+ col_params_->task.is_local[src_idx], recv_buf_key, col_ctx_->device,
+ col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
+ col_ctx_->device_locality, 0 /*stream_index*/, done);
+}
+
+REGISTER_COLLECTIVE(HierarchicalTreeBroadcast, HierarchicalTreeBroadcaster);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/broadcaster.h b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
index 799228b161..ceb9baad30 100644
--- a/tensorflow/core/common_runtime/broadcaster.h
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h
@@ -12,25 +12,40 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
#include <vector>
+
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-// Tree-algorithm implementation of collective broadcast.
-class Broadcaster {
+// Hierarchical tree-algorithm implementation of collective broadcast.
+class HierarchicalTreeBroadcaster : public CollectiveImplementationInterface {
public:
- Broadcaster(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, Tensor* output);
+ HierarchicalTreeBroadcaster();
+ ~HierarchicalTreeBroadcaster() override = default;
+
+ // Establishes the subdiv permutations needed for a hierarchical broadcast.
+ // If all devices are local, establishes a single subdiv comprising all
+ // devices. If any devices are on a different task, establishes n+1 subdivs
+ // for n tasks.
+ // The first subdiv comprises one device per task which gets the tensor on
+ // each task. Subdiv i+1 corresponds to a task-local tree-broadcast for task
+ // i.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the hierarchical tree broadcast.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
// Returns the rank of the device from which this device should receive
// its value, -1 if no value should be received.
@@ -42,32 +57,29 @@ class Broadcaster {
std::vector<int>* targets);
private:
+ // Get the task to which the device at `device_rank` belongs.
+ int GetDeviceTask(int device_rank, const std::vector<int>& dev_per_task);
+
// Sends `src_tensor` asynchronously from this device to device at `dst_rank`
// in `subdiv`. Calls `done` upon completion.
void DispatchSend(int subdiv, int dst_rank, int src_rank,
const Tensor* src_tensor, const StatusCallback& done);
+
// Receives a tensor into the memory buffer owned by `dst_tensor` at this
// device from device at `src_rank` in `subdiv`. Calls `done` upon
// completion.
void DispatchRecv(int subdiv, int src_rank, int dst_rank, Tensor* dst_tensor,
const StatusCallback& done);
+
// Executes the hierarchical broadcast defined by this op.
void RunTree();
- Status status_;
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const int rank_;
- const bool is_source_;
- Tensor* output_; // Not owned
- std::unique_ptr<CollectiveAdapter> ca_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
StatusCallback done_;
- Device* device_; // The device for which this instance labors
- DeviceLocality device_locality_;
+ Status status_;
+ bool is_source_;
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_BROADCASTER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_HIERARCHICAL_TREE_BROADCASTER_H_
diff --git a/tensorflow/core/common_runtime/broadcaster_test.cc b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
index 3960fc6c97..da0e359cf8 100644
--- a/tensorflow/core/common_runtime/broadcaster_test.cc
+++ b/tensorflow/core/common_runtime/hierarchical_tree_broadcaster_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/common_runtime/broadcaster.h"
+#include "tensorflow/core/common_runtime/hierarchical_tree_broadcaster.h"
#include <algorithm>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
@@ -41,7 +41,7 @@ static int64 kStepId = 123;
// The test harness won't allow a mixture of fixture and non-fixture
// tests in one file, so this is a trival fixture for tests that don't
-// need the heavy-weight BroadcasterTest fixture.
+// need the heavy-weight HierarchicalTreeBroadcasterTest fixture.
class TrivialTest : public ::testing::Test {
protected:
TrivialTest() {}
@@ -53,23 +53,23 @@ class TrivialTest : public ::testing::Test {
// R = tested rank
// RF = receive-from rank
// ST = send_to rank vector
-#define DEF_TL_TEST(D, S, R, RF, ST) \
- TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
- CollectiveParams cp; \
- cp.group.group_size = D; \
- cp.instance.impl_details.subdiv_source_rank = {S}; \
- cp.instance.impl_details.subdiv_permutations.push_back( \
- std::vector<int>(D, 0)); \
- cp.subdiv_rank = {R}; \
- cp.is_source = (S == R); \
- EXPECT_EQ(RF, Broadcaster::TreeRecvFrom(cp, 0)); \
- std::vector<int> expected = ST; \
- std::vector<int> send_to; \
- Broadcaster::TreeSendTo(cp, 0, &send_to); \
- ASSERT_EQ(expected.size(), send_to.size()); \
- for (int i = 0; i < expected.size(); ++i) { \
- EXPECT_EQ(expected[i], send_to[i]); \
- } \
+#define DEF_TL_TEST(D, S, R, RF, ST) \
+ TEST_F(TrivialTest, TreeLinks_##D##Devs_##S##Source_##R##Rank) { \
+ CollectiveParams cp; \
+ cp.group.group_size = D; \
+ cp.instance.impl_details.subdiv_source_rank = {S}; \
+ cp.instance.impl_details.subdiv_permutations.push_back( \
+ std::vector<int>(D, 0)); \
+ cp.subdiv_rank = {R}; \
+ cp.is_source = (S == R); \
+ EXPECT_EQ(RF, HierarchicalTreeBroadcaster::TreeRecvFrom(cp, 0)); \
+ std::vector<int> expected = ST; \
+ std::vector<int> send_to; \
+ HierarchicalTreeBroadcaster::TreeSendTo(cp, 0, &send_to); \
+ ASSERT_EQ(expected.size(), send_to.size()); \
+ for (int i = 0; i < expected.size(); ++i) { \
+ EXPECT_EQ(expected[i], send_to[i]); \
+ } \
}
#define V(...) std::vector<int>({__VA_ARGS__})
@@ -130,7 +130,7 @@ DEF_TL_TEST(8, 7, 7, -1, V(0, 1))
// Wraps CollectiveRemoteAccessLocal with the ability to return an
// error status to the N'th action.
-// TODO(tucker): factor out of this file and ring_reducer_test.cc
+// TODO(b/113171733): factor out of this file and ring_reducer_test.cc
// into a single common source.
class FailTestRMA : public CollectiveRemoteAccessLocal {
public:
@@ -187,31 +187,32 @@ class FailTestRMA : public CollectiveRemoteAccessLocal {
int fail_after_ GUARDED_BY(mu_);
};
-class BroadcasterTest : public ::testing::Test {
+class HierarchicalTreeBroadcasterTest : public ::testing::Test {
protected:
- BroadcasterTest() : device_type_(DEVICE_CPU) {}
+ HierarchicalTreeBroadcasterTest() : device_type_(DEVICE_CPU) {}
- ~BroadcasterTest() override {
+ ~HierarchicalTreeBroadcasterTest() override {
stop_ = true;
- for (auto i : instances_) {
- delete i;
- }
+ for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
}
- void SetUp() override {
-#if GOOGLE_CUDA
+#ifdef GOOGLE_CUDA
+ void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
SessionOptions options;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &gpu_devices_);
CHECK(s.ok());
-#endif
}
+#endif
void Init(int num_workers, int num_devices_per_worker, DataType dtype,
const DeviceType& device_type, int fail_after) {
+#ifdef GOOGLE_CUDA
+ InitGPUDevices();
+#endif
VLOG(2) << "num_workers=" << num_workers
<< " num_devices_per_worker=" << num_devices_per_worker;
int total_num_devices = num_workers * num_devices_per_worker;
@@ -400,8 +401,6 @@ class BroadcasterTest : public ::testing::Test {
return GetKernel(node_def, device_type, device);
}
- void BuildColParams() {}
-
template <typename T>
void RunTest(DataType dtype, const DeviceType& device_type, int num_workers,
int num_devices, int tensor_len, int fail_after,
@@ -511,10 +510,47 @@ class BroadcasterTest : public ::testing::Test {
}
}
+ void RunSubdivPermsTest(
+ CollectiveParams* cp,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank,
+ const std::vector<int>& expected_subdiv_source_rank) {
+ col_exec_ = nullptr;
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->subdiv_rank.clear();
+ cp->instance.impl_details.subdiv_source_rank.clear();
+ // Create a stub broadcaster only for testing param initialization.
+ HierarchicalTreeBroadcaster broadcaster;
+ TF_CHECK_OK(broadcaster.InitializeCollectiveParams(cp));
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ EXPECT_EQ(expected_subdiv_source_rank,
+ cp->instance.impl_details.subdiv_source_rank);
+ }
+
+ void PrepColParamsForSubdivPermsTest(CollectiveParams* cp, int num_tasks,
+ int num_gpus) {
+ cp->group.device_type = DeviceType("GPU");
+ cp->group.num_tasks = num_tasks;
+ cp->group.group_size = num_tasks * num_gpus;
+ cp->instance.type = BROADCAST_COLLECTIVE;
+ cp->instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+ for (int ti = 0; ti < num_tasks; ti++) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
+ for (int di = 0; di < num_gpus; di++) {
+ string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
+ cp->instance.task_names.push_back(task_name);
+ cp->instance.device_names.push_back(dev_name);
+ }
+ }
+ }
+
class DeviceInstance {
public:
DeviceInstance(int rank, const string& dev_name,
- const DeviceType& device_type, BroadcasterTest* parent)
+ const DeviceType& device_type,
+ HierarchicalTreeBroadcasterTest* parent)
: parent_(parent),
dev_name_(dev_name),
device_type_(device_type),
@@ -636,21 +672,20 @@ class BroadcasterTest : public ::testing::Test {
ctx.allocate_output(0, tensor_.shape(), &output_tensor_ptr));
}
CHECK_EQ(output_tensor_ptr, ctx.mutable_output(0));
+ const Tensor* input_tensor_ptr =
+ col_params_.is_source ? &tensor_ : nullptr;
// Prepare a Broadcaster instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
- Broadcaster broadcaster(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
- &op_params, col_params_, exec_key, kStepId,
- output_tensor_ptr);
-
- // Start execution in a threadpool then wait for completion.
- Notification notification;
- broadcaster.Run([this, &notification](Status s) {
- status_ = s;
- notification.Notify();
- });
- notification.WaitForNotification();
+ HierarchicalTreeBroadcaster broadcaster;
+ CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+ &ctx, &op_params, col_params_, exec_key,
+ kStepId, input_tensor_ptr, output_tensor_ptr);
+ TF_CHECK_OK(broadcaster.InitializeCollectiveContext(&col_ctx));
+
+ // Run the broadcast.
+ broadcaster.Run([this](Status s) { status_ = s; });
if (status_.ok()) {
CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
}
@@ -658,15 +693,13 @@ class BroadcasterTest : public ::testing::Test {
dev_ctx->Unref();
}
- BroadcasterTest* parent_;
+ HierarchicalTreeBroadcasterTest* parent_;
string dev_name_;
DeviceType device_type_ = DEVICE_CPU;
int rank_;
Tensor tensor_;
Device* device_;
CollectiveParams col_params_;
- std::unique_ptr<CollectiveAdapter> ca_;
- std::unique_ptr<OpKernelContext> ctx_;
Status status_;
}; // class DeviceInstance
@@ -688,6 +721,118 @@ class BroadcasterTest : public ::testing::Test {
int failure_count_ GUARDED_BY(mu_) = 0;
};
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams1Task8GPU) {
+ CollectiveParams cp;
+ PrepColParamsForSubdivPermsTest(&cp, 1, 8);
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {0});
+
+ // source 2 device 2
+ cp.source_rank = 2;
+ cp.default_rank = 2;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {2}, {2});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7}}, {0}, {2});
+}
+
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4Tasks8GPU) {
+ CollectiveParams cp;
+ PrepColParamsForSubdivPermsTest(&cp, 4, 8);
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{0, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{2, 8, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 9
+ cp.source_rank = 9;
+ cp.default_rank = 9;
+ RunSubdivPermsTest(&cp,
+ {{0, 9, 16, 24},
+ {0, 1, 2, 3, 4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13, 14, 15},
+ {16, 17, 18, 19, 20, 21, 22, 23},
+ {24, 25, 26, 27, 28, 29, 30, 31}},
+ {1, -1, 1, -1, -1}, {1, 0, 1, 0, 0});
+}
+
+TEST_F(HierarchicalTreeBroadcasterTest, InitializeParams4TasksVariableGPU) {
+ CollectiveParams cp;
+ int num_tasks = 4;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = num_tasks;
+ cp.group.group_size = 0;
+ cp.instance.type = BROADCAST_COLLECTIVE;
+ cp.instance.impl_details.collective_name = "HierarchicalTreeBroadcast";
+ std::vector<int> dev_per_task = {4, 4, 6, 8};
+ for (int ti = 0; ti < cp.group.num_tasks; ti++) {
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", ti);
+ for (int di = 0; di < dev_per_task[ti]; di++) {
+ string dev_name = strings::StrCat(task_name, "/device:GPU:", di);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(dev_name);
+ cp.group.group_size++;
+ }
+ }
+
+ // source 0 device 0
+ cp.source_rank = 0;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{0, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {0, 0, -1, -1, -1}, {0, 0, 0, 0, 0});
+
+ // source 2 device 0
+ cp.source_rank = 2;
+ cp.default_rank = 0;
+ RunSubdivPermsTest(&cp,
+ {{2, 4, 8, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, 0, -1, -1, -1}, {0, 2, 0, 0, 0});
+
+ // source 9 device 5
+ cp.source_rank = 9;
+ cp.default_rank = 5;
+ RunSubdivPermsTest(&cp,
+ {{0, 4, 9, 14},
+ {0, 1, 2, 3},
+ {4, 5, 6, 7},
+ {8, 9, 10, 11, 12, 13},
+ {14, 15, 16, 17, 18, 19, 20, 21}},
+ {-1, -1, 1, -1, -1}, {2, 0, 0, 1, 0});
+}
+
+// TODO(b/113171733): change to use TEST_P.
// Tests of full broadcast algorithm, with different device and
// data types.
// B = data element type
@@ -697,7 +842,7 @@ class BroadcasterTest : public ::testing::Test {
// L = tensor length
// A = abort after count
#define DEF_TEST(B, T, W, D, L, A, F) \
- TEST_F(BroadcasterTest, \
+ TEST_F(HierarchicalTreeBroadcasterTest, \
DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Len##L##_Abt##A##_Fw##F) { \
DataType dtype = DT_##B; \
switch (dtype) { \
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
index 995a15a299..555b43f655 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
-#define TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
#include <string>
#include <vector>
@@ -65,4 +65,4 @@ class Benchmark {
} // end namespace test
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_KERNEL_BENCHMARK_TESTLIB_H_
diff --git a/tensorflow/core/common_runtime/local_device.h b/tensorflow/core/common_runtime/local_device.h
index 84a4f66db4..226f121bf3 100644
--- a/tensorflow/core/common_runtime/local_device.h
+++ b/tensorflow/core/common_runtime/local_device.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
-#define TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/framework/device_attributes.pb.h"
@@ -54,4 +54,4 @@ class LocalDevice : public Device {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_LOCAL_DEVICE_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOCAL_DEVICE_H_
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index dfce7c23e7..a02084f223 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -38,11 +38,12 @@ class CondBuilder {
public:
enum Branch { kElseBranch = 0, kThenBranch = 1 };
- // Create a CondBuilder to create the lowering of If op. that has then and
+ // Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions named `then_fn_name` and `else_fn_name` respectively in the
- // given graph.
+ // `graph`. The functions should be available in `flib`.
CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph);
+ const string& else_fn_name, const FunctionLibraryDefinition& flib,
+ Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
@@ -89,6 +90,7 @@ class CondBuilder {
Node* then_call_node_;
Node* else_call_node_;
Graph* graph_;
+ const FunctionLibraryDefinition& flib_;
string name_;
NodeBuilder then_call_builder_;
@@ -96,9 +98,11 @@ class CondBuilder {
};
CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph)
+ const string& else_fn_name,
+ const FunctionLibraryDefinition& flib, Graph* graph)
: if_op_(if_op),
graph_(graph),
+ flib_(flib),
name_(if_op->name()),
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
@@ -193,15 +197,15 @@ Status CondBuilder::AddOutputs() {
return Status::OK();
}
-Status InlineCallInGraph(Node* n, Graph* g) {
- const auto& lib = g->flib_def();
- const FunctionDef* fdef = lib.Find(n->type_string());
+Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
+ Graph* g) {
+ const FunctionDef* fdef = flib.Find(n->type_string());
CHECK(fdef != nullptr);
FunctionBody* fbody;
TF_RETURN_IF_ERROR(
- FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
- [&lib](const string& op, const OpDef** sig) {
- return lib.LookUpOpDef(op, sig);
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &flib,
+ [&flib](const string& op, const OpDef** sig) {
+ return flib.LookUpOpDef(op, sig);
},
&fbody));
// TODO(jpienaar): Improve this interface to make the need to delete it
@@ -219,8 +223,8 @@ Status CondBuilder::BuildLoweredIfOutput() {
}
Status CondBuilder::InlineCallNodes() {
- TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
- TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_));
return Status::OK();
}
@@ -240,6 +244,12 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
return errors::Internal("Lowering If op requires a graph to be available.");
}
+ FunctionLibraryDefinition* flib = options.flib_def;
+ if (flib == nullptr) {
+ return errors::Internal(
+ "Lowering If op requires a FunctionLibraryDefinition to be available.");
+ }
+
// Match all the nodes that need to be rewritten.
gtl::InlinedVector<Node*, 2> matches;
for (Node* n : g->op_nodes()) {
@@ -251,12 +261,14 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
}
}
for (Node* n : matches) {
- TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g));
}
return Status::OK();
}
-Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
+Status LowerIfOpPass::RewriteNode(Node* n,
+ const FunctionLibraryDefinition& flib,
+ Graph* g) {
const AttrValue* then_attr = n->attrs().Find("then_branch");
if (then_attr == nullptr) {
return errors::InvalidArgument("Then branch function missing");
@@ -266,7 +278,8 @@ Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
return errors::InvalidArgument("Else branch function missing");
}
- CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
+ CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
+ g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
index a9ef39ae5c..5ab1123e3f 100644
--- a/tensorflow/core/common_runtime/lower_if_op.h
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -29,8 +29,9 @@ class LowerIfOpPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
private:
- // Rewrite the given If node `n` in graph `g` to use the switch-merge form.
- Status RewriteNode(Node* n, Graph* g);
+ // Rewrite the given If node `n` in graph `g` to use the switch-merge
+ // form. `flib` should contain the branch functions referenced by `n`.
+ Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index 319a617b32..044a355d06 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -36,9 +36,7 @@ namespace tensorflow {
namespace {
Status Rewrite(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
-
+ FunctionLibraryDefinition flib_def((*graph)->flib_def());
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = &flib_def;
diff --git a/tensorflow/core/common_runtime/lower_while_op.cc b/tensorflow/core/common_runtime/lower_while_op.cc
new file mode 100644
index 0000000000..1f5da133e9
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_while_op.cc
@@ -0,0 +1,427 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+
+namespace tensorflow {
+
+namespace {
+
+using NodeOut = NodeBuilder::NodeOut;
+
+// Helper to convert a functional While op to its lowered form.
+//
+// Example:
+//
+// Input graph:
+//
+// loop_var -> WhileOp<cond_func, body_func> -> consumer
+//
+// Output graph(top to down flow):
+//
+// loop_var
+// |
+// Enter
+// |
+// inlined_cond_func ---<--- Merge -----<----- NextIteration
+// | | |
+// V V ^
+// | | |
+// LoopCond ------>-------- Switch ---->---- inlined_body_func
+// |
+// Exit
+// |
+// consumer
+class LowerWhileHelper {
+ public:
+ static Status Run(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph) {
+ LowerWhileHelper helper(while_op, cond_fn_name, body_fn_name, graph);
+ return helper.RunInternal();
+ }
+
+ private:
+ // Create a LowerWhileHelper to create the lowering of While op that has cond
+ // and body functions named `cond_fn_name` and `body_fn_name` respectively in
+ // the given graph.
+ LowerWhileHelper(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph);
+
+ Status RunInternal();
+
+ // Creates an Enter node for each `while_op_` input and adds them to
+ // `enter_nodes_`. If the `while_op_` has an incoming control edge from a
+ // `src` node we add a control edge from `src` to each Enter node.
+ Status CreateEnterNodes();
+
+ // Creates a Merge node for each Enter node and adds to `merge_nodes_`.
+ // Initially now both inputs of a Merge node are the Enter node. Input at
+ // index 1 is later updated to the output of NextIteration node in
+ // `UpdateMergeNodes`.
+ Status CreateMergeNodes();
+
+ // Creates the call node for cond func and stores in `cond_call_node_`.
+ // This gets inlined later in `InlineCallNodes`.
+ Status CreateCondFuncCallNode();
+
+ // Creates a Switch node for each loop var and adds to `switch_nodes_`.
+ // Output at index 1(true) of a Switch node is fed into the loop body.
+ // Output at index 0(false) of a Switch node is fed into the Exit nodes.
+ Status CreateSwitchNodes();
+
+ // Creates the call node for body func and stores in `body_call_node_`.
+ // This gets inlined later in `InlineCallNodes`.
+ Status CreateBodyFuncCallNode();
+
+ // Creates an Exit node for each loop var and adds to `exit_nodes_`. These
+ // are fed into the consumers of the `while_op_`.
+ Status CreateExitNodes();
+
+ // Creates an NextIteration node for each loop var and adds to
+ // `next_iteration_nodes_`.
+ Status CreateNextIterationNodes();
+
+ // Updates input at index 1 of each merge node created in `CreateMergeNodes`
+ // to use the output of NextIteration node created in
+ // `CreateNextIterationNodes` instead.
+ Status UpdateMergeNodes();
+
+ // Updates consumers of the original `while_op_` to instead use the outputs
+ // from the exit nodes in `exit_nodes_`. Also updates any outgoing control
+ // edges to depend on `lowered_while_output_` instead.
+ Status UpdateConsumers();
+
+ // Inlines the cond and body functions.
+ Status InlineCallNodes();
+
+ // Returns unique name containing the name of the While op being rewritten
+ // (name_), infix and a suffix to ensure it is unique within the graph.
+ string NewName(const string& infix);
+
+ // The original While op.
+ Node* while_op_;
+ // The call node for the cond branch. This gets inlined.
+ Node* cond_call_node_;
+ // The LoopCond node specifying the loop termination condition.
+ Node* loop_cond_node_;
+ // The call node for the body branch. This gets inlined.
+ Node* body_call_node_;
+ // The IdentityN node with the same outputs as the original While op.
+ Node* lowered_while_output_;
+ Graph* graph_;
+ // Name of the `while_op_`.
+ string name_;
+
+ NodeBuilder cond_call_builder_;
+ NodeBuilder body_call_builder_;
+
+ std::vector<Node*> enter_nodes_;
+ std::vector<Node*> merge_nodes_;
+ std::vector<Node*> switch_nodes_;
+ std::vector<Node*> exit_nodes_;
+ std::vector<Node*> next_iterations_nodes_;
+
+ size_t num_loop_inputs_;
+};
+
+LowerWhileHelper::LowerWhileHelper(Node* while_op, const string& cond_fn_name,
+ const string& body_fn_name, Graph* graph)
+ : while_op_(while_op),
+ graph_(graph),
+ name_(while_op->name()),
+ cond_call_builder_(NewName("cond"), cond_fn_name, graph->op_registry()),
+ body_call_builder_(NewName("body"), body_fn_name, graph->op_registry()),
+ num_loop_inputs_(while_op_->num_inputs()) {
+ // We intentionally `resize` instead of `reserve` space in `enter_nodes_`
+ // because we need to set it's elements out of order in `CreateEnterNodes`.
+ enter_nodes_.resize(num_loop_inputs_);
+ merge_nodes_.reserve(num_loop_inputs_);
+ switch_nodes_.reserve(num_loop_inputs_);
+ exit_nodes_.reserve(num_loop_inputs_);
+ next_iterations_nodes_.reserve(num_loop_inputs_);
+}
+
+Status LowerWhileHelper::RunInternal() {
+ TF_RETURN_IF_ERROR(CreateEnterNodes());
+ TF_RETURN_IF_ERROR(CreateMergeNodes());
+ TF_RETURN_IF_ERROR(CreateCondFuncCallNode());
+ TF_RETURN_IF_ERROR(CreateSwitchNodes());
+ TF_RETURN_IF_ERROR(CreateBodyFuncCallNode());
+ TF_RETURN_IF_ERROR(CreateExitNodes());
+ TF_RETURN_IF_ERROR(CreateNextIterationNodes());
+ TF_RETURN_IF_ERROR(UpdateMergeNodes());
+ TF_RETURN_IF_ERROR(UpdateConsumers());
+ TF_RETURN_IF_ERROR(InlineCallNodes());
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateEnterNodes() {
+ // Note: `Node::input_edge` runs in O(num_inputs) so we use
+ // `Node::input_edges` instead so that below loop runs in O(num_inputs) time
+ // and not O(num_inputs^2).
+ std::vector<const Edge*> edges;
+ TF_RETURN_IF_ERROR(while_op_->input_edges(&edges));
+ for (const Edge* edge : edges) {
+ Node* enter_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("enter"), "Enter", graph_->op_registry())
+ .Input(NodeOut(edge->src(), edge->src_output()))
+ .Attr("frame_name", name_)
+ .Finalize(graph_, &enter_node));
+ enter_nodes_[edge->dst_input()] = enter_node;
+ }
+ // Create a NoOp node that takes incoming control inputs of the original While
+ // op as control inputs and use it as a control input for all Enter nodes.
+ std::vector<Node*> control_inputs;
+ for (const Edge* e : while_op_->in_edges()) {
+ if (e->IsControlEdge()) {
+ control_inputs.push_back(e->src());
+ }
+ }
+ if (!control_inputs.empty()) {
+ Node* incoming_control_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("LoopControlInputs"), "NoOp", graph_->op_registry())
+ .ControlInputs(control_inputs)
+ .Finalize(graph_, &incoming_control_node));
+ for (Node* n : enter_nodes_) {
+ graph_->AddControlEdge(incoming_control_node, n);
+ }
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateMergeNodes() {
+ for (Node* enter_node : enter_nodes_) {
+ Node* merge_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("merge"), "Merge", graph_->op_registry())
+ .Input({NodeOut(enter_node, 0), NodeOut(enter_node, 0)})
+ .Finalize(graph_, &merge_node));
+ merge_nodes_.emplace_back(merge_node);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateCondFuncCallNode() {
+ for (Node* merge_node : merge_nodes_) {
+ cond_call_builder_.Input(NodeOut(merge_node, 0));
+ }
+ TF_RETURN_IF_ERROR(cond_call_builder_.Finalize(graph_, &cond_call_node_));
+ // Add a control edge to make sure the Const nodes in the cond function
+ // are in the same frame as the rest of the function, otherwise
+ // `BuildControlFlowInfo` throws an error.
+ graph_->AddControlEdge(merge_nodes_[0], cond_call_node_);
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("LoopCond"), "LoopCond", graph_->op_registry())
+ .Input(NodeOut(cond_call_node_, 0))
+ .Finalize(graph_, &loop_cond_node_));
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateSwitchNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ string op_name;
+ {
+ const Node* input_node;
+ TF_RETURN_IF_ERROR(while_op_->input_node(i, &input_node));
+ op_name = strings::StrCat(input_node->name(), "_switch");
+ }
+ Node* switch_node;
+ string op_type = "Switch";
+ if (IsRefType(merge_nodes_[i]->output_type(0))) {
+ op_type = "RefSwitch";
+ }
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName(op_name), op_type, graph_->op_registry())
+ .Input(NodeOut(merge_nodes_[i], 0))
+ .Input(NodeOut(loop_cond_node_, 0))
+ .Finalize(graph_, &switch_node));
+ switch_nodes_.emplace_back(switch_node);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateBodyFuncCallNode() {
+ for (Node* switch_node : switch_nodes_) {
+ body_call_builder_.Input(NodeOut(switch_node, 1));
+ }
+ TF_RETURN_IF_ERROR(body_call_builder_.Finalize(graph_, &body_call_node_));
+ // Add a control edge to make sure the Const nodes in the body function
+ // are in the same frame as the rest of the function, otherwise
+ // `BuildControlFlowInfo` throws an error.
+ // TODO(srbs): The choice of input at index 0 seems arbitrary(is it?) however
+ // this is how tf.while_loop does it. Can this affect performance if the 0th
+ // node is not the first one to be ready? Can we speed that case up using some
+ // sort of multi-input Merge?
+ Node* body_control_node_;
+ string op_type = "Identity";
+ if (IsRefType(switch_nodes_[0]->output_type(1))) {
+ op_type = "RefIdentity";
+ }
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("loop_body_control"), op_type, graph_->op_registry())
+ .Input(NodeOut(switch_nodes_[0], 1))
+ .Finalize(graph_, &body_control_node_));
+ graph_->AddControlEdge(body_control_node_, body_call_node_);
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateExitNodes() {
+ std::vector<NodeOut> outputs;
+ outputs.reserve(num_loop_inputs_);
+ for (Node* switch_node : switch_nodes_) {
+ Node* exit_node;
+ TF_RETURN_IF_ERROR(
+ NodeBuilder(NewName("exit"), "Exit", graph_->op_registry())
+ .Input(NodeOut(switch_node, 0))
+ .Finalize(graph_, &exit_node));
+ exit_nodes_.emplace_back(exit_node);
+ outputs.emplace_back(NodeOut(exit_node, 0));
+ }
+
+ // Add an IdentityN node that has the same outputs and same name as the
+ // original functional While op. This is used for
+ // 1. Rewiring the control edges with the original while op as src.
+ // 2. Fetching the output of the While node by name in calls to sess.run.
+ NodeBuilder ib(name_, "IdentityN");
+ ib.Input(outputs);
+ TF_RETURN_IF_ERROR(ib.Finalize(graph_, &lowered_while_output_));
+ return Status::OK();
+}
+
+Status LowerWhileHelper::CreateNextIterationNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ Node* next_iteration;
+ TF_RETURN_IF_ERROR(NodeBuilder(NewName("next_iteration"), "NextIteration",
+ graph_->op_registry())
+ .Input(NodeOut(body_call_node_, i))
+ .Finalize(graph_, &next_iteration));
+ next_iterations_nodes_.emplace_back(next_iteration);
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::UpdateMergeNodes() {
+ for (int i = 0; i < num_loop_inputs_; i++) {
+ TF_RETURN_IF_ERROR(
+ graph_->UpdateEdge(next_iterations_nodes_[i], 0, merge_nodes_[i], 1));
+ }
+ return Status::OK();
+}
+
+Status LowerWhileHelper::UpdateConsumers() {
+ for (const Edge* e : while_op_->out_edges()) {
+ if (e->IsControlEdge()) {
+ graph_->AddControlEdge(lowered_while_output_, e->dst());
+ } else {
+ // Feed the outputs directly from the exit nodes so that downstream ops
+ // can start before all the outputs have been computed.
+ graph_->AddEdge(exit_nodes_[e->src_output()], 0, e->dst(),
+ e->dst_input());
+ }
+ }
+ return Status::OK();
+}
+
+string LowerWhileHelper::NewName(const string& infix) {
+ return graph_->NewName(strings::StrCat(name_, "/", infix));
+}
+
+Status InlineCallInGraph(Node* n, Graph* g) {
+ const auto& lib = g->flib_def();
+ const FunctionDef* fdef = lib.Find(n->type_string());
+ CHECK(fdef != nullptr);
+ FunctionBody* fbody;
+ TF_RETURN_IF_ERROR(
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
+ [&lib](const string& op, const OpDef** sig) {
+ return lib.LookUpOpDef(op, sig);
+ },
+ &fbody));
+ // TODO(jpienaar): Improve this interface to make the need to delete it
+ // explicit.
+ InlineFunctionBody(g->flib_def(), g, n, fbody, false);
+ delete fbody;
+ return Status::OK();
+}
+
+Status LowerWhileHelper::InlineCallNodes() {
+ TF_RETURN_IF_ERROR(InlineCallInGraph(cond_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(body_call_node_, graph_));
+ return Status::OK();
+}
+
+} // namespace
+
+Status LowerWhileOpPass::Run(const GraphOptimizationPassOptions& options) {
+ if (options.partition_graphs != nullptr) {
+ return errors::Internal(
+ "Lowering While op should happen before partitioning.");
+ }
+ if (options.graph == nullptr) {
+ return Status::OK();
+ }
+
+ Graph* g = options.graph->get();
+ if (g == nullptr) {
+ return errors::Internal(
+ "Lowering While op requires a graph to be available.");
+ }
+
+ // Match all the nodes that need to be rewritten.
+ gtl::InlinedVector<Node*, 2> matches;
+ for (Node* n : g->op_nodes()) {
+ if (n->type_string() == "While") {
+ // Only rewrite if the While op is marked as needing to be lowered.
+ bool match;
+ Status s = GetNodeAttr(n->attrs(),
+ LowerIfOpPass::kLowerUsingSwitchMergeAttr, &match);
+ if (s.ok() && match) matches.push_back(n);
+ }
+ }
+ for (Node* n : matches) {
+ TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ }
+ return Status::OK();
+}
+
+Status LowerWhileOpPass::RewriteNode(Node* n, Graph* g) {
+ const AttrValue* cond_attr = n->attrs().Find("cond");
+ if (cond_attr == nullptr) {
+ return errors::InvalidArgument("While cond function missing");
+ }
+ const AttrValue* body_attr = n->attrs().Find("body");
+ if (body_attr == nullptr) {
+ return errors::InvalidArgument("While body function missing");
+ }
+
+ TF_RETURN_IF_ERROR(LowerWhileHelper::Run(n, cond_attr->func().name(),
+ body_attr->func().name(), g));
+ g->RemoveNode(n);
+
+ return Status::OK();
+}
+
+REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 0,
+ LowerWhileOpPass);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/util/status_util.h b/tensorflow/core/common_runtime/lower_while_op.h
index ea92f61dce..eadafbeb91 100644
--- a/tensorflow/core/util/status_util.h
+++ b/tensorflow/core/common_runtime/lower_while_op.h
@@ -13,24 +13,25 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
-#define TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
-#include "tensorflow/core/graph/graph.h"
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-// Creates a tag to be used in an exception error message. This can be parsed by
-// the Python layer and replaced with information about the node.
-//
-// For example, error_format_tag(node, "${file}") returns
-// "^^node:NODE_NAME:${line}^^" which would be rewritten by the Python layer as
-// e.g. "file/where/node/was/created.py".
-inline string error_format_tag(const Node& node, const string& format) {
- return strings::StrCat("^^node:", node.name(), ":", format, "^^");
-}
+// Rewrite While ops to use lower level control flow primitives instead.
+class LowerWhileOpPass : public GraphOptimizationPass {
+ public:
+ Status Run(const GraphOptimizationPassOptions& options) override;
+
+ private:
+ // Rewrite the given While node `n` in graph `g` to use the lower level
+ // primitives Enter, Exit, Switch, Merge and NextIteration.
+ Status RewriteNode(Node* n, Graph* g);
+};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_UTIL_STATUS_UTIL_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_LOWER_WHILE_OP_H_
diff --git a/tensorflow/core/common_runtime/lower_while_op_test.cc b/tensorflow/core/common_runtime/lower_while_op_test.cc
new file mode 100644
index 0000000000..27cbada004
--- /dev/null
+++ b/tensorflow/core/common_runtime/lower_while_op_test.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/lower_while_op.h"
+#include "tensorflow/core/common_runtime/lower_if_op.h"
+
+#include "tensorflow/cc/client/client_session.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/ops/array_ops.h"
+#include "tensorflow/cc/ops/control_flow_ops_internal.h"
+#include "tensorflow/cc/ops/function_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/graph_runner.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/graph_def_builder_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+Status Rewrite(std::unique_ptr<Graph>* graph) {
+ FunctionDefLibrary flib;
+ FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
+
+ GraphOptimizationPassOptions opt_options;
+ opt_options.graph = graph;
+ opt_options.flib_def = &flib_def;
+ LowerWhileOpPass pass;
+ return pass.Run(opt_options);
+}
+
+TEST(LowerWhileOpTest, Simple) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for cond and body.
+ FunctionDefLibrary f_lib_proto;
+ *f_lib_proto.add_function() = test::function::XTimesTwo();
+ *f_lib_proto.add_function() = test::function::LessThanOrEqualToN(8);
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ Node* while_node;
+ std::vector<NodeBuilder::NodeOut> inputs({NodeBuilder::NodeOut(a.node())});
+ AttrValue cond_func;
+ cond_func.mutable_func()->set_name("LessThanOrEqualToN");
+ AttrValue body_func;
+ body_func.mutable_func()->set_name("XTimesTwo");
+ TF_ASSERT_OK(NodeBuilder("while", "While", &f_lib)
+ .Input(inputs)
+ .Attr("T", {DT_INT32})
+ .Attr("cond", cond_func)
+ .Attr("body", body_func)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Finalize(root.graph(), &while_node));
+ TF_ASSERT_OK(root.DoShapeInference(while_node));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no lower level control flow primitives.
+ int node_called_while_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsEnter());
+ ASSERT_FALSE(op->IsExit());
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ ASSERT_FALSE(op->IsNextIteration());
+ ASSERT_FALSE(op->IsLoopCond());
+ if (op->name() == "while") {
+ node_called_while_count++;
+ }
+ }
+ ASSERT_EQ(node_called_while_count, 1);
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ int enter_count = 0;
+ int exit_count = 0;
+ int switch_count = 0;
+ int merge_count = 0;
+ int next_iteration_count = 0;
+ node_called_while_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsEnter()) {
+ ++enter_count;
+ }
+ if (op->IsExit()) {
+ ++exit_count;
+ }
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ if (op->IsNextIteration()) {
+ ++next_iteration_count;
+ }
+ if (op->name() == "while") {
+ node_called_while_count++;
+ }
+ ASSERT_NE(op->type_string(), "While");
+ }
+ // One node per loop input.
+ ASSERT_EQ(enter_count, 1);
+ ASSERT_EQ(exit_count, 1);
+ ASSERT_EQ(switch_count, 1);
+ ASSERT_EQ(merge_count, 1);
+ ASSERT_EQ(next_iteration_count, 1);
+ ASSERT_EQ(node_called_while_count, 1);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(1));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 16);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(3));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(feeds, {Output(while_node)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 1);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 12);
+ }
+}
+
+TEST(LowerWhileOpTest, MultipleInputs) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+
+ // Add test functions for cond and body.
+ FunctionDefLibrary f_lib_proto;
+ *(f_lib_proto.add_function()) = test::function::XPlusOneXTimesY();
+ *(f_lib_proto.add_function()) = test::function::XYXLessThanOrEqualToN(4);
+ FunctionLibraryDefinition f_lib(OpRegistry::Global(), f_lib_proto);
+
+ Scope root = Scope::NewRootScope().ExitOnError();
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(f_lib_proto));
+ auto a = ops::_Arg(root.WithOpName("A"), DT_INT32, 0);
+ auto b = ops::_Arg(root.WithOpName("B"), DT_INT32, 1);
+ Node* while_node;
+ std::vector<NodeBuilder::NodeOut> inputs(
+ {NodeBuilder::NodeOut(a.node()), NodeBuilder::NodeOut(b.node())});
+ AttrValue cond_func;
+ cond_func.mutable_func()->set_name("XYXLessThanOrEqualToN");
+ AttrValue body_func;
+ body_func.mutable_func()->set_name("XPlusOneXTimesY");
+ TF_ASSERT_OK(NodeBuilder("while", "While", &f_lib)
+ .Input(inputs)
+ .Attr("T", {DT_INT32, DT_INT32})
+ .Attr("cond", cond_func)
+ .Attr("body", body_func)
+ .Attr(LowerIfOpPass::kLowerUsingSwitchMergeAttr, true)
+ .Finalize(root.graph(), &while_node));
+ TF_ASSERT_OK(root.DoShapeInference(while_node));
+ TF_ASSERT_OK(root.ToGraph(graph.get()));
+
+ // The input graph has no lower level control flow primitives.
+ for (const auto* op : graph->op_nodes()) {
+ ASSERT_FALSE(op->IsEnter());
+ ASSERT_FALSE(op->IsExit());
+ ASSERT_FALSE(op->IsSwitch());
+ ASSERT_FALSE(op->IsMerge());
+ ASSERT_FALSE(op->IsNextIteration());
+ ASSERT_FALSE(op->IsLoopCond());
+ }
+
+ TF_ASSERT_OK(Rewrite(&graph));
+
+ int enter_count = 0;
+ int exit_count = 0;
+ int switch_count = 0;
+ int merge_count = 0;
+ int next_iteration_count = 0;
+ for (const auto* op : graph->op_nodes()) {
+ if (op->IsEnter()) {
+ ++enter_count;
+ }
+ if (op->IsExit()) {
+ ++exit_count;
+ }
+ if (op->IsSwitch()) {
+ ++switch_count;
+ }
+ if (op->IsMerge()) {
+ ++merge_count;
+ }
+ if (op->IsNextIteration()) {
+ ++next_iteration_count;
+ }
+ ASSERT_NE(op->type_string(), "While");
+ }
+ // Two nodes per loop input.
+ ASSERT_EQ(enter_count, 2);
+ ASSERT_EQ(exit_count, 2);
+ ASSERT_EQ(switch_count, 2);
+ ASSERT_EQ(merge_count, 2);
+ ASSERT_EQ(next_iteration_count, 2);
+
+ // Verify execution.
+ ClientSession session(root);
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(1));
+ feeds.emplace(Output(b.node()), Input::Initializer(1));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(
+ feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 2);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
+ EXPECT_EQ(out_tensors[1].scalar<int>()(), 24);
+ }
+ {
+ ClientSession::FeedType feeds;
+ feeds.emplace(Output(a.node()), Input::Initializer(3));
+ feeds.emplace(Output(b.node()), Input::Initializer(5));
+ std::vector<Tensor> out_tensors;
+ TF_ASSERT_OK(session.Run(
+ feeds, {Output(while_node, 0), Output(while_node, 1)}, &out_tensors));
+ ASSERT_EQ(out_tensors.size(), 2);
+ EXPECT_EQ(out_tensors[0].scalar<int>()(), 5);
+ EXPECT_EQ(out_tensors[1].scalar<int>()(), 60);
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator.h b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
index 99bd43e090..429b19599b 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator.h
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator.h
@@ -23,10 +23,11 @@ limitations under the License.
#include <cstdlib>
#include "tensorflow/core/common_runtime/bfc_allocator.h"
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/common_runtime/pool_allocator.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mem.h"
+#include "tensorflow/core/platform/numa.h"
#ifndef INTEL_MKL_DNN_ONLY
#include "i_malloc.h"
@@ -38,19 +39,113 @@ typedef unsigned int uint;
namespace tensorflow {
-class MklSubAllocator : public SubAllocator {
+class MklSubAllocator : public BasicCPUAllocator {
public:
+ MklSubAllocator() : BasicCPUAllocator(port::kNUMANoAffinity, {}, {}) {}
~MklSubAllocator() override {}
+};
+
+// CPU allocator that handles small-size allocations by calling
+// suballocator directly. Mostly, it is just a wrapper around a suballocator
+// (that calls malloc and free directly) with support for bookkeeping.
+class MklSmallSizeAllocator : public Allocator {
+ public:
+ MklSmallSizeAllocator(SubAllocator* sub_allocator, size_t total_memory,
+ const string& name)
+ : sub_allocator_(sub_allocator), name_(name) {
+ stats_.bytes_limit = total_memory;
+ }
+ ~MklSmallSizeAllocator() override {}
+
+ TF_DISALLOW_COPY_AND_ASSIGN(MklSmallSizeAllocator);
- void* Alloc(size_t alignment, size_t num_bytes) override {
- return port::AlignedMalloc(num_bytes, alignment);
+ inline string Name() override { return name_; }
+
+ void* AllocateRaw(size_t alignment, size_t num_bytes) override {
+ void* ptr = sub_allocator_->Alloc(alignment, num_bytes);
+ if (ptr != nullptr) {
+ std::pair<void*, size_t> map_val(ptr, num_bytes);
+ mutex_lock l(mutex_);
+ // Check that insertion in the hash map was successful.
+ CHECK(map_.insert(map_val).second);
+ // Increment statistics for small-size allocations.
+ IncrementStats(num_bytes);
+ }
+ return ptr;
}
- void Free(void* ptr, size_t num_bytes) override { port::AlignedFree(ptr); }
+
+ void DeallocateRaw(void* ptr) override {
+ if (ptr == nullptr) {
+ LOG(ERROR) << "tried to deallocate nullptr";
+ return;
+ }
+
+ mutex_lock l(mutex_);
+ auto map_iter = map_.find(ptr);
+ if (map_iter != map_.end()) {
+ // Call free visitors.
+ size_t dealloc_bytes = map_iter->second;
+ sub_allocator_->Free(ptr, dealloc_bytes);
+ DecrementStats(dealloc_bytes);
+ map_.erase(map_iter);
+ } else {
+ LOG(ERROR) << "tried to deallocate invalid pointer";
+ return;
+ }
+ }
+
+ inline bool IsSmallSizeAllocation(const void* ptr) const {
+ mutex_lock l(mutex_);
+ return map_.find(ptr) != map_.end();
+ }
+
+ void GetStats(AllocatorStats* stats) override {
+ mutex_lock l(mutex_);
+ *stats = stats_;
+ }
+
+ void ClearStats() override {
+ mutex_lock l(mutex_);
+ stats_.Clear();
+ }
+
+ private:
+ // Increment statistics for the allocator handling small allocations.
+ inline void IncrementStats(size_t alloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ ++stats_.num_allocs;
+ stats_.bytes_in_use += alloc_size;
+ stats_.max_bytes_in_use =
+ std::max(stats_.max_bytes_in_use, stats_.bytes_in_use);
+ stats_.max_alloc_size =
+ std::max(alloc_size, static_cast<size_t>(stats_.max_alloc_size));
+ }
+
+ // Decrement statistics for the allocator handling small allocations.
+ inline void DecrementStats(size_t dealloc_size)
+ EXCLUSIVE_LOCKS_REQUIRED(mutex_) {
+ stats_.bytes_in_use -= dealloc_size;
+ }
+
+ SubAllocator* sub_allocator_; // Not owned by this class.
+
+ // Mutex for protecting updates to map of allocations.
+ mutable mutex mutex_;
+
+ // Allocator name
+ string name_;
+
+ // Hash map to keep track of "small" allocations
+ // We do not use BFC allocator for small allocations.
+ std::unordered_map<const void*, size_t> map_ GUARDED_BY(mutex_);
+
+ // Allocator stats for small allocs
+ AllocatorStats stats_ GUARDED_BY(mutex_);
};
/// CPU allocator for MKL that wraps BFC allocator and intercepts
/// and redirects memory allocation calls from MKL.
-class MklCPUAllocator : public VisitableAllocator {
+class MklCPUAllocator : public Allocator {
public:
// Constructor and other standard functions
@@ -62,7 +157,10 @@ class MklCPUAllocator : public VisitableAllocator {
MklCPUAllocator() { TF_CHECK_OK(Initialize()); }
- ~MklCPUAllocator() override { delete allocator_; }
+ ~MklCPUAllocator() override {
+ delete small_size_allocator_;
+ delete large_size_allocator_;
+ }
Status Initialize() {
VLOG(2) << "MklCPUAllocator: In MklCPUAllocator";
@@ -96,8 +194,15 @@ class MklCPUAllocator : public VisitableAllocator {
}
VLOG(1) << "MklCPUAllocator: Setting max_mem_bytes: " << max_mem_bytes;
- allocator_ = new BFCAllocator(new MklSubAllocator, max_mem_bytes,
- kAllowGrowth, kName);
+
+ sub_allocator_ = new MklSubAllocator();
+
+ // SubAllocator is owned by BFCAllocator, so we do not need to deallocate
+ // it in MklSmallSizeAllocator.
+ small_size_allocator_ =
+ new MklSmallSizeAllocator(sub_allocator_, max_mem_bytes, kName);
+ large_size_allocator_ =
+ new BFCAllocator(sub_allocator_, max_mem_bytes, kAllowGrowth, kName);
#ifndef INTEL_MKL_DNN_ONLY
// For redirecting all allocations from MKL to this allocator
// From: http://software.intel.com/en-us/node/528565
@@ -112,23 +217,46 @@ class MklCPUAllocator : public VisitableAllocator {
inline string Name() override { return kName; }
inline void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return allocator_->AllocateRaw(alignment, num_bytes);
+ // If the allocation size is less than threshold, call small allocator,
+ // otherwise call large-size allocator (BFC). We found that BFC allocator
+ // does not deliver good performance for small allocations when
+ // inter_op_parallelism_threads is high.
+ return (num_bytes < kSmallAllocationsThreshold)
+ ? small_size_allocator_->AllocateRaw(alignment, num_bytes)
+ : large_size_allocator_->AllocateRaw(alignment, num_bytes);
}
inline void DeallocateRaw(void* ptr) override {
- allocator_->DeallocateRaw(ptr);
+ // Check if ptr is for "small" allocation. If it is, then call Free
+ // directly. Otherwise, call BFC to handle free.
+ if (small_size_allocator_->IsSmallSizeAllocation(ptr)) {
+ small_size_allocator_->DeallocateRaw(ptr);
+ } else {
+ large_size_allocator_->DeallocateRaw(ptr);
+ }
}
- void GetStats(AllocatorStats* stats) override { allocator_->GetStats(stats); }
-
- void ClearStats() override { allocator_->ClearStats(); }
-
- void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
+ void GetStats(AllocatorStats* stats) override {
+ AllocatorStats l_stats, s_stats;
+ small_size_allocator_->GetStats(&s_stats);
+ large_size_allocator_->GetStats(&l_stats);
+
+ // Combine statistics from small-size and large-size allocator.
+ stats->num_allocs = l_stats.num_allocs + s_stats.num_allocs;
+ stats->bytes_in_use = l_stats.bytes_in_use + s_stats.bytes_in_use;
+ stats->max_bytes_in_use =
+ l_stats.max_bytes_in_use + s_stats.max_bytes_in_use;
+
+ // Since small-size allocations go to MklSmallSizeAllocator,
+ // max_alloc_size from large_size_allocator would be the maximum
+ // size allocated by MklCPUAllocator.
+ stats->max_alloc_size = l_stats.max_alloc_size;
+ stats->bytes_limit = std::max(s_stats.bytes_limit, l_stats.bytes_limit);
}
- void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
+ void ClearStats() override {
+ small_size_allocator_->ClearStats();
+ large_size_allocator_->ClearStats();
}
private:
@@ -148,24 +276,36 @@ class MklCPUAllocator : public VisitableAllocator {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
+ return nullptr; // return a value and make static code analyzers happy
}
static inline void* ReallocHook(void* ptr, size_t size) {
Status s = Status(error::Code::UNIMPLEMENTED,
"Unimplemented case for hooking MKL function.");
TF_CHECK_OK(s); // way to assert with an error message
+ return nullptr; // return a value and make static code analyzers happy
}
- /// Do we allow growth in BFC Allocator
+ // Do we allow growth in BFC Allocator
static const bool kAllowGrowth = true;
- /// Name
+ // Name
static constexpr const char* kName = "mklcpu";
- /// The alignment that we need for the allocations
+ // The alignment that we need for the allocations
static constexpr const size_t kAlignment = 64;
- VisitableAllocator* allocator_; // owned by this class
+ Allocator* large_size_allocator_; // owned by this class
+ MklSmallSizeAllocator* small_size_allocator_; // owned by this class.
+
+ SubAllocator* sub_allocator_; // not owned by this class
+
+ // Size in bytes that defines the upper-bound for "small" allocations.
+ // Any allocation below this threshold is "small" allocation.
+ static constexpr const size_t kSmallAllocationsThreshold = 4096;
+
+ // Prevent copying and assignment
+ TF_DISALLOW_COPY_AND_ASSIGN(MklCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
index a67411cd2e..e08ab57638 100644
--- a/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
+++ b/tensorflow/core/common_runtime/mkl_cpu_allocator_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/common_runtime/mkl_cpu_allocator.h"
@@ -50,4 +50,4 @@ TEST(MKLBFCAllocatorTest, TestMaxLimit) {
} // namespace tensorflow
-#endif // INTEL_MKL
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/common_runtime/optimization_registry.h b/tensorflow/core/common_runtime/optimization_registry.h
index f5d265aa24..6fcd2afd27 100644
--- a/tensorflow/core/common_runtime/optimization_registry.h
+++ b/tensorflow/core/common_runtime/optimization_registry.h
@@ -132,11 +132,12 @@ class OptimizationPassRegistration {
#define REGISTER_OPTIMIZATION_UNIQ_HELPER(ctr, grouping, phase, optimization) \
REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization)
-#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
- static optimization_registration::OptimizationPassRegistration \
- register_optimization_##ctr( \
- grouping, phase, \
- std::unique_ptr<GraphOptimizationPass>(new optimization()), \
+#define REGISTER_OPTIMIZATION_UNIQ(ctr, grouping, phase, optimization) \
+ static ::tensorflow::optimization_registration::OptimizationPassRegistration \
+ register_optimization_##ctr( \
+ grouping, phase, \
+ ::std::unique_ptr<::tensorflow::GraphOptimizationPass>( \
+ new optimization()), \
#optimization)
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
index f9f36443a8..6af4ca4d96 100644
--- a/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
+++ b/tensorflow/core/common_runtime/parallel_concat_optimizer.cc
@@ -50,8 +50,8 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
for (Node* n : matches) {
AttrSlice n_attrs = n->attrs();
- auto base_make_node = [n, g, &n_attrs](const string& op,
- const string& name) {
+ auto base_make_node = [n, &n_attrs](const string& op,
+ const string& name) {
NodeBuilder node_builder(name, op);
node_builder.Device(n->requested_device());
string colo;
@@ -60,7 +60,7 @@ class ParallelConcatRemovePass : public GraphOptimizationPass {
}
return node_builder;
};
- auto make_node = [n, g, &n_attrs, &base_make_node](string op) {
+ auto make_node = [n, g, &base_make_node](string op) {
return base_make_node(
op, g->NewName(strings::StrCat(n->name(), "/Internal")));
};
diff --git a/tensorflow/core/common_runtime/placer.cc b/tensorflow/core/common_runtime/placer.cc
index d581f45a90..3b59995433 100644
--- a/tensorflow/core/common_runtime/placer.cc
+++ b/tensorflow/core/common_runtime/placer.cc
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/util/status_util.h"
namespace tensorflow {
@@ -255,9 +254,11 @@ class ColocationGraph {
old_root_member.device_name,
allow_soft_placement_);
if (!s.ok()) {
- return errors::InvalidArgument("Cannot colocate nodes '", x.name(),
- "' and '", y.name(), ": ",
- s.error_message());
+ return errors::InvalidArgument(
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()), ": ",
+ s.error_message());
}
// Ensure that the common root has at least one supported device
@@ -268,8 +269,10 @@ class ColocationGraph {
old_root_member.supported_device_types);
if (new_root_member.supported_device_types.empty()) {
return errors::InvalidArgument(
- "Cannot colocate nodes '", x.name(), "' and '", y.name(),
- "' because no device type supports both of those nodes and the "
+ "Cannot colocate nodes ",
+ errors::FormatColocationNodeForError(x.name()), " and ",
+ errors::FormatColocationNodeForError(y.name()),
+ " because no device type supports both of those nodes and the "
"other nodes colocated with them.",
DebugInfo(x_root), DebugInfo(y_root));
}
@@ -377,8 +380,9 @@ class ColocationGraph {
// merged set device is different, so print both.
return errors::InvalidArgument(
"Could not satisfy explicit device specification '",
- node->requested_device(),
- "' because the node was colocated with a group of nodes that "
+ node->requested_device(), "' because the node ",
+ errors::FormatColocationNodeForError(node->name()),
+ " was colocated with a group of nodes that ",
"required incompatible device '",
DeviceNameUtils::ParsedNameToString(
members_[node_root].device_name),
@@ -810,10 +814,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
// Returns the first device in sorted devices list so we will always
@@ -857,10 +861,10 @@ Status Placer::Run() {
std::vector<Device*>* devices;
Status status = colocation_graph.GetDevicesForNode(node, &devices);
if (!status.ok()) {
- return AttachDef(errors::InvalidArgument(
- "Cannot assign a device for operation ",
- RichNodeName(node), ": ", status.error_message()),
- *node);
+ return AttachDef(
+ errors::InvalidArgument("Cannot assign a device for operation ",
+ node->name(), ": ", status.error_message()),
+ *node);
}
int assigned_device = -1;
@@ -926,22 +930,4 @@ void Placer::LogDeviceAssignment(const Node* node) const {
}
}
-bool Placer::ClientHandlesErrorFormatting() const {
- return options_ != nullptr &&
- options_->config.experimental().client_handles_error_formatting();
-}
-
-// Returns the node name in single quotes. If the client handles formatted
-// errors, appends a formatting tag which the client will reformat into, for
-// example, " (defined at filename:123)".
-string Placer::RichNodeName(const Node* node) const {
- string quoted_name = strings::StrCat("'", node->name(), "'");
- if (ClientHandlesErrorFormatting()) {
- string file_and_line = error_format_tag(*node, "${defined_at}");
- return strings::StrCat(quoted_name, file_and_line);
- } else {
- return quoted_name;
- }
-}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/placer.h b/tensorflow/core/common_runtime/placer.h
index fce87269c5..f97ffe7372 100644
--- a/tensorflow/core/common_runtime/placer.h
+++ b/tensorflow/core/common_runtime/placer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_PLACER_H_
-#define TENSORFLOW_COMMON_RUNTIME_PLACER_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_
#include <string>
#include <unordered_map>
@@ -87,8 +87,6 @@ class Placer {
// placement if the SessionOptions entry in 'options_' requests it.
void AssignAndLog(int assigned_device, Node* node) const;
void LogDeviceAssignment(const Node* node) const;
- bool ClientHandlesErrorFormatting() const;
- string RichNodeName(const Node* node) const;
Graph* const graph_; // Not owned.
const DeviceSet* const devices_; // Not owned.
@@ -100,4 +98,4 @@ class Placer {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_PLACER_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLACER_H_
diff --git a/tensorflow/core/common_runtime/placer_test.cc b/tensorflow/core/common_runtime/placer_test.cc
index 87f2f2ceb9..9b8a95e3b6 100644
--- a/tensorflow/core/common_runtime/placer_test.cc
+++ b/tensorflow/core/common_runtime/placer_test.cc
@@ -800,11 +800,11 @@ TEST_F(PlacerTest, TestInvalidMultipleColocationGroups) {
}
Status s = Place(&g);
- EXPECT_TRUE(
- str_util::StrContains(s.error_message(),
- "Cannot colocate nodes 'foo' and 'in' because no "
- "device type supports both of those nodes and the "
- "other nodes colocated with them"));
+ EXPECT_TRUE(str_util::StrContains(
+ s.error_message(),
+ "Cannot colocate nodes {{colocation_node foo}} and "
+ "{{colocation_node in}} because no device type supports both of those "
+ "nodes and the other nodes colocated with them"));
}
TEST_F(PlacerTest, TestColocationGroupWithReferenceConnections) {
@@ -867,9 +867,9 @@ TEST_F(PlacerTest, TestColocationGroupWithUnsatisfiableReferenceConnections) {
Status s = Place(&g);
EXPECT_TRUE(str_util::StrContains(
s.error_message(),
- "Cannot colocate nodes 'var3' and 'assign3' because no "
- "device type supports both of those nodes and the other "
- "nodes colocated with them."));
+ "Cannot colocate nodes {{colocation_node var3}} and {{colocation_node "
+ "assign3}} because no device type supports both of those nodes and the "
+ "other nodes colocated with them."));
}
TEST_F(PlacerTest, TestColocationAndReferenceConnections) {
@@ -1154,36 +1154,12 @@ TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementFormatTag) {
}
SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- true);
Status s = Place(&g, &options);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
LOG(WARNING) << s.error_message();
EXPECT_TRUE(str_util::StrContains(s.error_message(),
- "Cannot assign a device for operation 'in'"
- "^^node:in:${defined_at}^^"));
-}
-
-// Test that the "Cannot assign a device" error message does not contain a
-// format tag when not it shouldn't
-TEST_F(PlacerTest, TestNonexistentGpuNoAllowSoftPlacementNoFormatTag) {
- Graph g(OpRegistry::Global());
- { // Scope for temporary variables used to construct g.
- GraphDefBuilder b(GraphDefBuilder::kFailImmediately);
- ops::SourceOp("TestDevice",
- b.opts().WithName("in").WithDevice("/device:fakegpu:11"));
- TF_EXPECT_OK(BuildGraph(b, &g));
- }
-
- SessionOptions options;
- options.config.mutable_experimental()->set_client_handles_error_formatting(
- false);
- Status s = Place(&g, &options);
- EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot assign a device for operation 'in'"));
- EXPECT_FALSE(str_util::StrContains(
- s.error_message(), "'in' (defined at ^^node:in:${file}:${line}^^)"));
+ "Cannot assign a device for operation in"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(), "{{node in}}"));
}
// Test that placement fails when a node requests an explicit device that is not
@@ -1289,8 +1265,9 @@ TEST_F(PlacerTest, TestUnsatisfiableConstraintWithReferenceConnections) {
Status s = Place(&g);
EXPECT_EQ(error::INVALID_ARGUMENT, s.code());
- EXPECT_TRUE(str_util::StrContains(
- s.error_message(), "Cannot colocate nodes 'var' and 'assign'"));
+ EXPECT_TRUE(str_util::StrContains(s.error_message(),
+ "Cannot colocate nodes {{colocation_node "
+ "var}} and {{colocation_node assign}}"));
}
// Test that a generator node follows its consumers (where there are several
diff --git a/tensorflow/core/common_runtime/pool_allocator.cc b/tensorflow/core/common_runtime/pool_allocator.cc
index 10a24ed14c..66dc8f3322 100644
--- a/tensorflow/core/common_runtime/pool_allocator.cc
+++ b/tensorflow/core/common_runtime/pool_allocator.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -39,8 +40,7 @@ PoolAllocator::PoolAllocator(size_t pool_size_limit, bool auto_resize,
auto_resize_(auto_resize),
pool_size_limit_(pool_size_limit),
allocator_(allocator),
- size_rounder_(size_rounder),
- allocation_begun_(false) {
+ size_rounder_(size_rounder) {
if (auto_resize) {
CHECK_LT(size_t{0}, pool_size_limit)
<< "size limit must be > 0 if auto_resize is true.";
@@ -92,7 +92,6 @@ ChunkPrefix* FindPrefix(void* user_ptr) {
} // namespace
void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
- if (!allocation_begun_) allocation_begun_ = true;
if (num_bytes == 0) return nullptr;
// If alignment is larger than kPoolAlignment, increase num_bytes so that we
@@ -128,9 +127,6 @@ void* PoolAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
return PrepareChunk(r, alignment, num_bytes);
} else {
void* ptr = allocator_->Alloc(kPoolAlignment, num_bytes);
- for (const auto& v : alloc_visitors_) {
- v(ptr, num_bytes);
- }
return PrepareChunk(ptr, alignment, num_bytes);
}
}
@@ -140,9 +136,6 @@ void PoolAllocator::DeallocateRaw(void* ptr) {
ChunkPrefix* cp = FindPrefix(ptr);
CHECK_LE((void*)cp, (void*)ptr);
if (!has_size_limit_ && !auto_resize_) {
- for (const auto& v : free_visitors_) {
- v(cp, cp->num_bytes);
- }
allocator_->Free(cp, cp->num_bytes);
} else {
mutex_lock lock(mutex_);
@@ -163,9 +156,6 @@ void PoolAllocator::Clear() {
mutex_lock lock(mutex_);
for (auto iter : pool_) {
PtrRecord* pr = iter.second;
- for (const auto& v : free_visitors_) {
- v(pr->ptr, pr->num_bytes);
- }
allocator_->Free(pr->ptr, pr->num_bytes);
delete pr;
}
@@ -220,9 +210,6 @@ void PoolAllocator::EvictOne() {
DCHECK(iter != pool_.end());
}
pool_.erase(iter);
- for (const auto& v : free_visitors_) {
- v(prec->ptr, prec->num_bytes);
- }
allocator_->Free(prec->ptr, prec->num_bytes);
delete prec;
++evicted_count_;
@@ -268,28 +255,19 @@ void PoolAllocator::EvictOne() {
}
}
-void PoolAllocator::AddAllocVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddAllocVisitor may not be called after pool allocation "
- << "has begun.";
- alloc_visitors_.push_back(visitor);
-}
-
-void PoolAllocator::AddFreeVisitor(Visitor visitor) {
- mutex_lock lock(mutex_);
- CHECK(!allocation_begun_)
- << "AddFreeVisitor may not be called after pool allocation "
- << "has begun.";
- free_visitors_.push_back(visitor);
-}
-
void* BasicCPUAllocator::Alloc(size_t alignment, size_t num_bytes) {
- return port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ void* ptr = nullptr;
+ if (num_bytes > 0) {
+ ptr = port::AlignedMalloc(num_bytes, static_cast<int>(alignment));
+ VisitAlloc(ptr, numa_node_, num_bytes);
+ }
+ return ptr;
}
void BasicCPUAllocator::Free(void* ptr, size_t num_bytes) {
- port::AlignedFree(ptr);
+ if (num_bytes > 0) {
+ VisitFree(ptr, numa_node_, num_bytes);
+ port::AlignedFree(ptr);
+ }
}
-
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/pool_allocator.h b/tensorflow/core/common_runtime/pool_allocator.h
index 607734445b..5b4623ba10 100644
--- a/tensorflow/core/common_runtime/pool_allocator.h
+++ b/tensorflow/core/common_runtime/pool_allocator.h
@@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
#define TENSORFLOW_CORE_COMMON_RUNTIME_POOL_ALLOCATOR_H_
-// Simple LRU pool allocators for various flavors of CPU RAM that
-// implement the VisitableAllocator interface.
+// Simple LRU pool allocators for various flavors of CPU RAM.
#include <atomic>
#include <map>
#include <memory>
#include <vector>
-#include "tensorflow/core/common_runtime/visitable_allocator.h"
+#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -41,7 +40,7 @@ class RoundUpInterface {
// Size-limited pool of memory buffers obtained from a SubAllocator
// instance. Pool eviction policy is LRU.
-class PoolAllocator : public VisitableAllocator {
+class PoolAllocator : public Allocator {
public:
// "pool_size_limit" is the maximum number of returned, re-usable
// memory buffers to keep in the pool. If pool_size_limit == 0, the
@@ -64,14 +63,6 @@ class PoolAllocator : public VisitableAllocator {
void DeallocateRaw(void* ptr) override;
- // REQUIRES: The following functions may only be called prior
- // to the first Allocate*() call. Once allocation has begun, it is
- // illegal to register another visitor.
-
- void AddAllocVisitor(Visitor visitor) override;
-
- void AddFreeVisitor(Visitor visitor) override;
-
// Allocate an unused memory region of size "num_bytes". Fetch from
// the pool if available, otherwise call allocator_.
void* Get(size_t num_bytes);
@@ -141,12 +132,6 @@ class PoolAllocator : public VisitableAllocator {
int64 put_count_ GUARDED_BY(mutex_) = 0;
int64 allocated_count_ GUARDED_BY(mutex_) = 0;
int64 evicted_count_ GUARDED_BY(mutex_) = 0;
- // Write access to these is guarded by mutex_, but not read
- // access. They may only be modified prior to the first
- // allocation. Later attempts to modify will fail.
- std::vector<Visitor> alloc_visitors_;
- std::vector<Visitor> free_visitors_;
- std::atomic<bool> allocation_begun_;
};
// Do-nothing rounder. Passes through sizes unchanged.
@@ -166,7 +151,9 @@ class Pow2Rounder : public RoundUpInterface {
class BasicCPUAllocator : public SubAllocator {
public:
// Argument numa_node is currently ignored.
- explicit BasicCPUAllocator(int numa_node) : numa_node_(numa_node) {}
+ BasicCPUAllocator(int numa_node, const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : SubAllocator(alloc_visitors, free_visitors), numa_node_(numa_node) {}
~BasicCPUAllocator() override {}
@@ -176,6 +163,8 @@ class BasicCPUAllocator : public SubAllocator {
private:
int numa_node_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BasicCPUAllocator);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index 6dac4c3acf..c43a9d7dc2 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -113,7 +113,7 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
const string& key_prefix, int64 src_incarnation, int64 num_tensors,
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs, Rendezvous* rendezvous,
- std::vector<Tensor>* received_tensors, const StatusCallback& done) {
+ std::vector<Tensor>* received_tensors, StatusCallback done) {
std::vector<string> keys;
for (int64 i = 0; i < num_tensors; ++i) {
string name = strings::StrCat(key_prefix, i);
@@ -121,9 +121,8 @@ void ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
target_device, name, FrameAndIter(0, 0));
keys.push_back(key);
}
- RecvOutputsFromRendezvousAsync(
- rendezvous, device_context, alloc_attrs, keys, received_tensors,
- [done](const Status& status) { done(status); });
+ RecvOutputsFromRendezvousAsync(rendezvous, device_context, alloc_attrs, keys,
+ received_tensors, std::move(done));
}
Status ProcessFunctionLibraryRuntime::GetDeviceIncarnation(
@@ -192,7 +191,7 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
const string& function_key) const {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
}
@@ -204,11 +203,12 @@ bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
FunctionLibraryRuntime::LocalHandle
ProcessFunctionLibraryRuntime::GetHandleOnDevice(
const string& device_name, FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- if (function_data_.count(handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ if (iter == function_data_.end()) {
return kInvalidLocalHandle;
}
- FunctionData* function_data = function_data_[handle].get();
+ FunctionData* function_data = iter->second.get();
if (function_data->target_device() != device_name) {
return kInvalidLocalHandle;
}
@@ -217,9 +217,10 @@ ProcessFunctionLibraryRuntime::GetHandleOnDevice(
string ProcessFunctionLibraryRuntime::GetDeviceName(
FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- CHECK_EQ(1, function_data_.count(handle));
- FunctionData* function_data = function_data_[handle].get();
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ CHECK(iter != function_data_.end());
+ FunctionData* function_data = iter->second.get();
return function_data->target_device();
}
@@ -302,13 +303,15 @@ void ProcessFunctionLibraryRuntime::Run(
string target_device;
FunctionLibraryRuntime::LocalHandle local_handle;
{
- mutex_lock l(mu_);
- if (function_data_.count(handle) == 0) {
+ tf_shared_lock l(mu_);
+ auto iter = function_data_.find(handle);
+ if (iter == function_data_.end()) {
done(errors::NotFound("Handle: ", handle, " not found."));
return;
}
- target_device = function_data_[handle]->target_device();
- local_handle = function_data_[handle]->local_handle();
+ FunctionData* function_data = iter->second.get();
+ target_device = function_data->target_device();
+ local_handle = function_data->local_handle();
}
flr = GetFLR(target_device);
if (flr != nullptr) {
@@ -339,26 +342,29 @@ void ProcessFunctionLibraryRuntime::Run(
opts.rets_alloc_attrs;
std::vector<Tensor>* remote_rets = new std::vector<Tensor>;
flr->Run(opts, handle, args, remote_rets,
- [source_device, target_device, target_incarnation, rendezvous,
- device_context, rets_alloc_attrs, remote_rets, rets,
- done](const Status& status) {
- if (!status.ok()) {
- delete remote_rets;
- done(status);
- return;
- }
- int64 num_returns = remote_rets->size();
- delete remote_rets;
- // Now receive the return values from the target.
- ReceiveTensorsAsync(target_device, source_device, "ret_",
- target_incarnation, num_returns,
- device_context, rets_alloc_attrs, rendezvous,
- rets, done);
- });
+ std::bind(
+ [source_device, target_device, target_incarnation, rendezvous,
+ device_context, rets_alloc_attrs, remote_rets,
+ rets](const Status& status,
+ FunctionLibraryRuntime::DoneCallback& done) {
+ if (!status.ok()) {
+ delete remote_rets;
+ done(status);
+ return;
+ }
+ int64 num_returns = remote_rets->size();
+ delete remote_rets;
+ // Now receive the return values from the target.
+ ReceiveTensorsAsync(target_device, source_device, "ret_",
+ target_incarnation, num_returns,
+ device_context, rets_alloc_attrs,
+ rendezvous, rets, std::move(done));
+ },
+ std::placeholders::_1, std::move(done)));
return;
}
if (parent_ != nullptr) {
- parent_->Run(opts, local_handle, args, rets, done);
+ parent_->Run(opts, local_handle, args, rets, std::move(done));
return;
}
done(errors::Internal("Could not find device"));
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index 69381dd34d..53815715d8 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@@ -59,8 +60,6 @@ class ProcessFunctionLibraryRuntime {
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous);
- typedef std::function<void(const Status&)> StatusCallback;
-
// Receives `received_tensors` from `target_device` (originally sent from
// `source_device`) using `rendezvous`. Uses `key_prefix` to construct the
// keys to be retrieved. `device_context` should be for the device receiving
@@ -73,7 +72,7 @@ class ProcessFunctionLibraryRuntime {
DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
Rendezvous* rendezvous, std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ StatusCallback done);
static const char kDefaultFLRDevice[];
// Returns the FunctionLibraryRuntime for the corresponding device_name.
diff --git a/tensorflow/core/common_runtime/process_state.cc b/tensorflow/core/common_runtime/process_state.cc
index 447338e7bd..bcaa37fc8a 100644
--- a/tensorflow/core/common_runtime/process_state.cc
+++ b/tensorflow/core/common_runtime/process_state.cc
@@ -71,20 +71,28 @@ ProcessState::MemDesc ProcessState::PtrType(const void* ptr) {
return MemDesc();
}
-VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
+Allocator* ProcessState::GetCPUAllocator(int numa_node) {
CHECK_GE(numa_node, 0);
if (!numa_enabled_) numa_node = 0;
mutex_lock lock(mu_);
while (cpu_allocators_.size() <= static_cast<size_t>(numa_node)) {
+ // If visitors have been defined we need an Allocator built from
+ // a SubAllocator. Prefer BFCAllocator, but fall back to PoolAllocator
+ // depending on env var setting.
+ const bool alloc_visitors_defined =
+ (!cpu_alloc_visitors_.empty() || !cpu_free_visitors_.empty());
bool use_bfc_allocator = false;
- // TODO(reedwm): Switch default to BGFAllocator if it's at least as fast and
- // efficient.
- Status status = ReadBoolFromEnvVar("TF_CPU_ALLOCATOR_USE_BFC", false,
- &use_bfc_allocator);
+ Status status = ReadBoolFromEnvVar(
+ "TF_CPU_ALLOCATOR_USE_BFC", alloc_visitors_defined, &use_bfc_allocator);
if (!status.ok()) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
- VisitableAllocator* allocator;
+ Allocator* allocator = nullptr;
+ SubAllocator* sub_allocator =
+ (alloc_visitors_defined || use_bfc_allocator)
+ ? new BasicCPUAllocator(numa_enabled_ ? numa_node : -1,
+ cpu_alloc_visitors_, cpu_free_visitors_)
+ : nullptr;
if (use_bfc_allocator) {
// TODO(reedwm): evaluate whether 64GB by default is the best choice.
int64 cpu_mem_limit_in_mb = -1;
@@ -95,34 +103,63 @@ VisitableAllocator* ProcessState::GetCPUAllocator(int numa_node) {
LOG(ERROR) << "GetCPUAllocator: " << status.error_message();
}
int64 cpu_mem_limit = cpu_mem_limit_in_mb * (1LL << 20);
- allocator = new BFCAllocator(
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1), cpu_mem_limit,
- true /*allow_growth*/, "bfc_cpu_allocator_for_gpu" /*name*/);
+ DCHECK(sub_allocator);
+ allocator =
+ new BFCAllocator(sub_allocator, cpu_mem_limit, true /*allow_growth*/,
+ "bfc_cpu_allocator_for_gpu" /*name*/);
VLOG(2) << "Using BFCAllocator with memory limit of "
<< cpu_mem_limit_in_mb << " MB for ProcessState CPU allocator";
- } else {
- allocator = new PoolAllocator(
- 100 /*pool_size_limit*/, true /*auto_resize*/,
- new BasicCPUAllocator(numa_enabled_ ? numa_node : -1),
- new NoopRounder, "cpu_pool");
+ } else if (alloc_visitors_defined) {
+ DCHECK(sub_allocator);
+ allocator =
+ new PoolAllocator(100 /*pool_size_limit*/, true /*auto_resize*/,
+ sub_allocator, new NoopRounder, "cpu_pool");
VLOG(2) << "Using PoolAllocator for ProcessState CPU allocator "
<< "numa_enabled_=" << numa_enabled_
<< " numa_node=" << numa_node;
+ } else {
+ DCHECK(!sub_allocator);
+ allocator = cpu_allocator();
}
- if (LogMemory::IsEnabled()) {
+ if (LogMemory::IsEnabled() && !allocator->TracksAllocationSizes()) {
// Wrap the allocator to track allocation ids for better logging
// at the cost of performance.
- allocator = new TrackingVisitableAllocator(allocator, true);
+ allocator = new TrackingAllocator(allocator, true);
}
cpu_allocators_.push_back(allocator);
+ if (!sub_allocator) {
+ DCHECK(cpu_alloc_visitors_.empty() && cpu_free_visitors_.empty());
+ }
}
return cpu_allocators_[numa_node];
}
+void ProcessState::AddCPUAllocVisitor(SubAllocator::Visitor visitor) {
+ VLOG(1) << "AddCPUAllocVisitor";
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUAllocVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_alloc_visitors_.push_back(std::move(visitor));
+}
+
+void ProcessState::AddCPUFreeVisitor(SubAllocator::Visitor visitor) {
+ mutex_lock lock(mu_);
+ CHECK_EQ(0, cpu_allocators_.size()) // Crash OK
+ << "AddCPUFreeVisitor must be called prior to first call to "
+ "ProcessState::GetCPUAllocator";
+ cpu_free_visitors_.push_back(std::move(visitor));
+}
+
void ProcessState::TestOnlyReset() {
mutex_lock lock(mu_);
+ // Don't delete this value because it's static.
+ Allocator* default_cpu_allocator = cpu_allocator();
mem_desc_map_.clear();
- gtl::STLDeleteElements(&cpu_allocators_);
+ for (Allocator* a : cpu_allocators_) {
+ if (a != default_cpu_allocator) delete a;
+ }
+ cpu_allocators_.clear();
gtl::STLDeleteElements(&cpu_al_);
}
diff --git a/tensorflow/core/common_runtime/process_state.h b/tensorflow/core/common_runtime/process_state.h
index 2892677333..cac312d849 100644
--- a/tensorflow/core/common_runtime/process_state.h
+++ b/tensorflow/core/common_runtime/process_state.h
@@ -30,7 +30,6 @@ limitations under the License.
namespace tensorflow {
class Allocator;
-class VisitableAllocator;
class PoolAllocator;
// Singleton that manages per-process state, e.g. allocation of
@@ -65,7 +64,15 @@ class ProcessState {
// Returns the one CPUAllocator used for the given numa_node.
// TEMPORARY: ignores numa_node.
- VisitableAllocator* GetCPUAllocator(int numa_node);
+ Allocator* GetCPUAllocator(int numa_node);
+
+ // Registers alloc visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUAllocVisitor(SubAllocator::Visitor v);
+
+ // Registers free visitor for the CPU allocator(s).
+ // REQUIRES: must be called before GetCPUAllocator.
+ void AddCPUFreeVisitor(SubAllocator::Visitor v);
typedef std::unordered_map<const void*, MemDesc> MDMap;
@@ -87,7 +94,9 @@ class ProcessState {
mutex mu_;
- std::vector<VisitableAllocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<Allocator*> cpu_allocators_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_alloc_visitors_ GUARDED_BY(mu_);
+ std::vector<SubAllocator::Visitor> cpu_free_visitors_ GUARDED_BY(mu_);
virtual ~ProcessState();
diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc
index a5d31b75c7..e1dc08d645 100644
--- a/tensorflow/core/common_runtime/process_util.cc
+++ b/tensorflow/core/common_runtime/process_util.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/util.h"
namespace tensorflow {
@@ -56,24 +57,26 @@ int32 NumInterOpThreadsFromSessionOptions(const SessionOptions& options) {
const int32 inter_op = options.config.inter_op_parallelism_threads();
if (inter_op != 0) return inter_op;
#ifdef INTEL_MKL
- // MKL library executes ops in parallel using OMP threads
- // Set inter_op conservatively to avoid thread oversubscription that could
- // lead to severe perf degradations and OMP resource exhaustion
- int mkl_intra_op = 1;
+ if (!DisableMKL()) {
+ // MKL library executes ops in parallel using OMP threads
+ // Set inter_op conservatively to avoid thread oversubscription that could
+ // lead to severe perf degradations and OMP resource exhaustion
+ int mkl_intra_op = 1;
#ifdef _OPENMP
- mkl_intra_op = omp_get_max_threads();
+ mkl_intra_op = omp_get_max_threads();
#endif // _OPENMP
- CHECK_GE(mkl_intra_op, 1);
- const int32 mkl_inter_op = std::max(
- (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
- VLOG(0) << "Creating new thread pool with default inter op setting: "
- << mkl_inter_op
- << ". Tune using inter_op_parallelism_threads for best performance.";
- return mkl_inter_op;
-#else
+ DCHECK_GE(mkl_intra_op, 1);
+ const int32 mkl_inter_op = std::max(
+ (port::NumSchedulableCPUs() + mkl_intra_op - 1) / mkl_intra_op, 2);
+ VLOG(0)
+ << "Creating new thread pool with default inter op setting: "
+ << mkl_inter_op
+ << ". Tune using inter_op_parallelism_threads for best performance.";
+ return mkl_inter_op;
+ }
+#endif // INTEL_MKL
// Default to using the number of cores available in the process.
return port::NumSchedulableCPUs();
-#endif // INTEL_MKL
}
thread::ThreadPool* NewThreadPoolFromSessionOptions(
diff --git a/tensorflow/core/common_runtime/renamed_device.h b/tensorflow/core/common_runtime/renamed_device.h
index 103eee03b3..c00789a556 100644
--- a/tensorflow/core/common_runtime/renamed_device.h
+++ b/tensorflow/core/common_runtime/renamed_device.h
@@ -58,6 +58,15 @@ class RenamedDevice : public Device {
return underlying_->GetAllocator(attr);
}
+ Allocator* GetScopedAllocator(AllocatorAttributes attr,
+ int64 step_id) override {
+ return underlying_->GetScopedAllocator(attr, step_id);
+ }
+
+ ScopedAllocatorMgr* GetScopedAllocatorMgr() const override {
+ return underlying_->GetScopedAllocatorMgr();
+ }
+
const Eigen::ThreadPoolDevice* eigen_cpu_device() override {
return underlying_->eigen_cpu_device();
}
@@ -72,9 +81,10 @@ class RenamedDevice : public Device {
return underlying_->MakeGpuDevice();
}
- void ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
- DeviceContext* dc, Allocator* allocator) override {
- underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
+ Status ReinitializeGpuDevice(OpKernelContext* context, PerOpGpuDevice* device,
+ DeviceContext* dc,
+ Allocator* allocator) override {
+ return underlying_->ReinitializeGpuDevice(context, device, dc, allocator);
}
Status MakeTensorFromProto(const TensorProto& tensor_proto,
diff --git a/tensorflow/core/common_runtime/rendezvous_mgr.h b/tensorflow/core/common_runtime/rendezvous_mgr.h
index cb5848ede3..b4d8ab4eb2 100644
--- a/tensorflow/core/common_runtime/rendezvous_mgr.h
+++ b/tensorflow/core/common_runtime/rendezvous_mgr.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
-#define TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
#include <string>
#include <unordered_map>
@@ -87,4 +87,4 @@ class IntraProcessRendezvous : public Rendezvous {
} // end namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_RENDEZVOUS_MGR_H_
diff --git a/tensorflow/core/common_runtime/rendezvous_util.cc b/tensorflow/core/common_runtime/rendezvous_util.cc
index 92dc03812e..43ca3f1e3e 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.cc
+++ b/tensorflow/core/common_runtime/rendezvous_util.cc
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/rendezvous_util.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
@@ -54,7 +57,7 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
- const StatusCallback& done) {
+ StatusCallback done) {
if (keys.empty()) {
done(Status::OK());
return;
@@ -85,13 +88,7 @@ void RecvOutputsFromRendezvousAsync(
alloc_attr);
}
- typedef struct {
- mutex mu;
- int64 done_counter;
- Status shared_status = Status::OK();
- } CallState;
- CallState* call_state = new CallState;
- call_state->done_counter = keys.size();
+ auto status_cb = new ReffedStatusCallback(std::move(done));
for (auto& p : arguments) {
const string& key = std::get<0>(p);
Tensor* val = std::get<1>(p);
@@ -99,13 +96,13 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous::Args rendez_args;
rendez_args.device_context = device_context;
rendez_args.alloc_attrs = std::get<3>(p);
-
+ status_cb->Ref();
rendezvous->RecvAsync(
parsed, rendez_args,
- [val, done, key, call_state](const Status& s,
- const Rendezvous::Args& send_args,
- const Rendezvous::Args& recv_args,
- const Tensor& v, const bool is_dead) {
+ [val, key, status_cb](const Status& s,
+ const Rendezvous::Args& send_args,
+ const Rendezvous::Args& recv_args,
+ const Tensor& v, const bool is_dead) {
Status status = s;
if (status.ok()) {
*val = v;
@@ -114,20 +111,11 @@ void RecvOutputsFromRendezvousAsync(
" was not valid.");
}
}
- call_state->mu.lock();
- call_state->shared_status.Update(status);
- call_state->done_counter--;
- // If we are the last async call to return, call the done callback.
- if (call_state->done_counter == 0) {
- const Status& final_status = call_state->shared_status;
- call_state->mu.unlock();
- done(final_status);
- delete call_state;
- return;
- }
- call_state->mu.unlock();
+ status_cb->UpdateStatus(status);
+ status_cb->Unref();
});
}
+ status_cb->Unref();
}
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
diff --git a/tensorflow/core/common_runtime/rendezvous_util.h b/tensorflow/core/common_runtime/rendezvous_util.h
index aad910f6d8..deb9a7c822 100644
--- a/tensorflow/core/common_runtime/rendezvous_util.h
+++ b/tensorflow/core/common_runtime/rendezvous_util.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
@@ -42,7 +43,7 @@ void RecvOutputsFromRendezvousAsync(
Rendezvous* rendezvous, DeviceContext* device_context,
const std::vector<AllocatorAttributes>& alloc_attrs,
const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
- const StatusCallback& done);
+ StatusCallback done);
Status RecvOutputsFromRendezvous(Rendezvous* rendezvous, NamedTensors* out,
const Rendezvous::Args& args);
diff --git a/tensorflow/core/common_runtime/ring_reducer.cc b/tensorflow/core/common_runtime/ring_reducer.cc
index e26761703b..b1fe928ba7 100644
--- a/tensorflow/core/common_runtime/ring_reducer.cc
+++ b/tensorflow/core/common_runtime/ring_reducer.cc
@@ -14,16 +14,43 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/ring_reducer.h"
+#include <stdlib.h>
+#include <atomic>
+#include <functional>
+#include <utility>
+
#include "tensorflow/core/common_runtime/collective_rma_local.h"
+#include "tensorflow/core/common_runtime/collective_util.h"
#include "tensorflow/core/common_runtime/copy_tensor.h"
+#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/dma_helper.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/notification.h"
+#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
// Set true for greater intelligibility of debug mode log messages.
#define READABLE_KEYS false
+// RingReduce algorithm exchanges chunks of tensor between devices. The chunk
+// size depends on the number of subdivisions specified in the algorithm. If
+// the user does not specify the number of subdivisions, we infer the number
+// dynamically so that the resulting chunk size does not exceed
+// kMaxChunkSizeBytes, empirically set at 4 MiB.
+constexpr size_t kMaxChunkSizeBytes = (4 * 1024 * 1024);
+// kMaxSubdivsPerDev is used to give an upper bound on the number of
+// subdivisions dynamically generated. A reasonable value would be a small
+// multiple of the number of NICs adjacent to each device.
+constexpr int kMaxSubdivsPerDevice = 2;
namespace tensorflow {
namespace {
@@ -36,7 +63,8 @@ string RingReduceBufKey(const string& exec_key, int pass, int section,
return strings::StrCat("rred(", exec_key, "):pass(", pass, "):section(",
section, "):srcrank(", source_rank, ")");
} else {
- // TODO(tucker): Try out some kind of denser encoding, e.g. 128 bit hash.
+ // TODO(b/78352018): Try out some kind of denser encoding, e.g. 128 bit
+ // hash.
return strings::StrCat(exec_key, ":", pass, ":", section, ":", source_rank);
}
}
@@ -65,105 +93,203 @@ RingReducer::RingField* RingReducer::PCQueue::Dequeue() {
return rf;
}
-RingReducer::RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx,
- OpKernelContext::Params* op_params,
- const CollectiveParams& col_params,
- const string& exec_key, int64 step_id,
- const Tensor* input, Tensor* output)
- : col_exec_(col_exec),
- dev_mgr_(dev_mgr),
- ctx_(ctx),
- op_params_(op_params),
- col_params_(col_params),
- exec_key_(exec_key),
- input_(input),
- output_(output),
- rank_(col_params.subdiv_rank[0]),
- step_id_(step_id),
- group_size_(col_params.group.group_size),
- num_subdivs_(static_cast<int>(
- col_params.instance.impl_details.subdiv_permutations.size())),
+RingReducer::RingReducer()
+ : col_ctx_(nullptr),
+ col_params_(nullptr),
done_(nullptr),
- device_(nullptr),
- device_name_(
- col_params_.instance.device_names[col_params_.default_rank]) {
- CHECK_GT(group_size_, 0);
- CHECK_GT(num_subdivs_, 0);
-}
+ group_size_(-1),
+ num_subdivs_(-1) {}
RingReducer::~RingReducer() { group_size_tensor_ready_.WaitForNotification(); }
-string RingReducer::TensorDebugString(Tensor tensor) {
- const DeviceBase::GpuDeviceInfo* gpu_device_info =
- ctx_->device()->tensorflow_gpu_device_info();
- if (gpu_device_info) {
- Tensor cpu_tensor(tensor.dtype(), tensor.shape());
- Notification note;
- gpu_device_info->default_context->CopyDeviceTensorToCPU(
- &tensor, "" /*tensor_name*/, device_, &cpu_tensor,
- [&note](const Status& s) {
- CHECK(s.ok());
- note.Notify();
- });
- note.WaitForNotification();
- return cpu_tensor.SummarizeValue(64);
- } else {
- return tensor.SummarizeValue(64);
+Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) {
+ if (col_params->instance.shape.num_elements() == 0) {
+ return errors::Internal("shape in CollectiveParams should be non-empty");
+ }
+ const int kAvgDevPerTask =
+ col_params->group.group_size / col_params->group.num_tasks;
+ const int kMaxNumSubdivs = kMaxSubdivsPerDevice * kAvgDevPerTask;
+ if (kMaxNumSubdivs <= 0) {
+ return errors::Internal("Unexpected kMaxNumSubdivs ", kMaxNumSubdivs,
+ " in RingReducer");
+ }
+ // NOTE(ayushd): If no subdiv_offsets have been specified, dynamically add
+ // as many offsets as needed so that the size of tensor chunks <=
+ // kMaxChunkSizeBytes. Empirically, chunks that are too small or too large
+ // lead to worse performance.
+ int num_subdivs = 0;
+ const size_t tensor_size = col_params->instance.shape.num_elements() *
+ DataTypeSize(col_params->instance.data_type);
+ size_t chunk_size;
+ do {
+ ++num_subdivs;
+ int num_chunks = col_params->group.group_size * num_subdivs;
+ chunk_size = tensor_size / num_chunks;
+ VLOG(2) << "num_subdivs " << num_subdivs << " num_chunks " << num_chunks
+ << " chunk_size " << chunk_size;
+ } while (chunk_size > kMaxChunkSizeBytes && num_subdivs < kMaxNumSubdivs);
+ if (num_subdivs <= 0) {
+ return errors::Internal("Unexpected num_subdivs ", num_subdivs,
+ " in RingReducer");
+ }
+
+ int subdiv_stride = kAvgDevPerTask / num_subdivs;
+ if (subdiv_stride == 0) subdiv_stride = 1;
+ col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs);
+ for (int sdi = 0; sdi < num_subdivs; ++sdi) {
+ int subdiv_offset = subdiv_stride * sdi;
+ if (sdi % 2 == 1) subdiv_offset *= -1;
+ col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset);
+ }
+
+ if (VLOG_IS_ON(2)) {
+ string subdiv_buf;
+ for (const int subdiv_offset :
+ col_params->instance.impl_details.subdiv_offsets) {
+ strings::StrAppend(&subdiv_buf, " ", subdiv_offset);
+ }
+ VLOG(2) << "Dynamically generated " << num_subdivs
+ << " subdiv_offsets:" << subdiv_buf << " tensor_size "
+ << tensor_size << " chunk_size " << chunk_size;
+ }
+
+ return Status::OK();
+}
+
+Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) {
+ // TODO(b/113171733): change CHECKs to return errors.
+ CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE);
+ CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce");
+ const string& device_name =
+ col_params->instance.device_names[col_params->default_rank];
+ // Each subdiv permutation is a ring formed by rotating each
+ // single-task subsequence of devices by an offset. This makes most
+ // sense when each task has the same number of devices but we can't
+ // depend on that being the case so we'll compute something that
+ // works in any case.
+
+ // Start by counting the devices in each task.
+ // Precondition: device_names must be sorted so that all devices in
+ // the same task are adjacent.
+ VLOG(2) << "Sorted task names: "
+ << str_util::Join(col_params->instance.task_names, ", ");
+ std::vector<int> dev_per_task;
+ const string* prior_task_name = &col_params->instance.task_names[0];
+ int dev_count = 1;
+ for (int di = 1; di < col_params->group.group_size; ++di) {
+ if (col_params->instance.task_names[di] != *prior_task_name) {
+ dev_per_task.push_back(dev_count);
+ dev_count = 1;
+ prior_task_name = &col_params->instance.task_names[di];
+ } else {
+ ++dev_count;
+ }
+ }
+ dev_per_task.push_back(dev_count);
+ CHECK_EQ(col_params->group.num_tasks, dev_per_task.size());
+
+ if (col_params->instance.impl_details.subdiv_offsets.empty()) {
+ TF_RETURN_IF_ERROR(GenerateSubdivsInCollectiveParams(col_params));
+ }
+
+ // Generate a ring permutation for requested offset.
+ VLOG(2) << "Setting up perms for col_params " << col_params
+ << " subdiv_permutations "
+ << &col_params->instance.impl_details.subdiv_permutations;
+ col_params->instance.impl_details.subdiv_permutations.resize(
+ col_params->instance.impl_details.subdiv_offsets.size());
+ col_params->subdiv_rank.resize(
+ col_params->instance.impl_details.subdiv_offsets.size(), -1);
+ for (int sdi = 0;
+ sdi < col_params->instance.impl_details.subdiv_offsets.size(); ++sdi) {
+ std::vector<int>& perm =
+ col_params->instance.impl_details.subdiv_permutations[sdi];
+ CHECK_EQ(perm.size(), 0);
+ int offset = col_params->instance.impl_details.subdiv_offsets[sdi];
+ // A negative subdivision offset is interpreted as follows:
+ // 1. Reverse the local device ordering.
+ // 2. Begin the subdivision at abs(offset) in the reversed ordering.
+ bool reverse = false;
+ if (offset < 0) {
+ offset = abs(offset);
+ reverse = true;
+ }
+ int prior_dev_count = 0; // sum over prior worker device counts
+ for (int ti = 0; ti < col_params->group.num_tasks; ++ti) {
+ for (int di = 0; di < dev_per_task[ti]; ++di) {
+ int di_offset = (di + offset) % dev_per_task[ti];
+ int offset_di =
+ reverse ? (dev_per_task[ti] - (di_offset + 1)) : di_offset;
+ // Device index in global subdivision permutation.
+ int permuted_di = prior_dev_count + offset_di;
+ int rank = static_cast<int>(perm.size());
+ perm.push_back(permuted_di);
+ if (col_params->instance.device_names[permuted_di] == device_name) {
+ CHECK_EQ(permuted_di, col_params->default_rank);
+ col_params->subdiv_rank[sdi] = rank;
+ }
+ }
+ prior_dev_count += dev_per_task[ti];
+ }
+ CHECK_EQ(col_params->group.group_size, perm.size());
}
+
+ VLOG(2) << collective_util::SubdivPermDebugString(*col_params);
+ return Status::OK();
+}
+
+Status RingReducer::InitializeCollectiveContext(CollectiveContext* col_ctx) {
+ CHECK(col_ctx->dev_mgr);
+ col_ctx_ = col_ctx;
+ col_params_ = &col_ctx->col_params;
+ return collective_util::InitializeDeviceAndLocality(
+ col_ctx->dev_mgr, col_ctx->device_name, &col_ctx->device,
+ &col_ctx->device_locality);
}
void RingReducer::Run(StatusCallback done) {
+ CHECK(col_ctx_);
+ CHECK(col_params_);
done_ = std::move(done);
+ group_size_ = col_params_->group.group_size;
+ num_subdivs_ = static_cast<int>(
+ col_params_->instance.impl_details.subdiv_permutations.size());
+ CHECK_GT(num_subdivs_, 0);
- // Get local execution device.
if (VLOG_IS_ON(1)) {
string buf;
- for (int r = 0; r < col_params_.instance.device_names.size(); ++r) {
+ for (int r = 0; r < col_params_->instance.device_names.size(); ++r) {
strings::StrAppend(&buf, "dev ", r, " : ",
- col_params_.instance.device_names[r], "\n");
+ col_params_->instance.device_names[r], "\n");
}
for (int sd = 0;
- sd < col_params_.instance.impl_details.subdiv_permutations.size();
+ sd < col_params_->instance.impl_details.subdiv_permutations.size();
++sd) {
strings::StrAppend(&buf, "\nsubdiv ", sd, " perm: ");
- for (auto x : col_params_.instance.impl_details.subdiv_permutations[sd]) {
+ for (auto x :
+ col_params_->instance.impl_details.subdiv_permutations[sd]) {
strings::StrAppend(&buf, x, ", ");
}
}
- VLOG(1) << "RingReducer::Run for device " << device_name_
- << " default_rank " << col_params_.default_rank << "\n"
+ VLOG(1) << "RingReducer::Run for device " << col_ctx_->device_name
+ << " default_rank " << col_params_->default_rank << "\n"
<< buf;
}
- CHECK(dev_mgr_);
- Status status = dev_mgr_->LookupDevice(
- col_params_.instance.device_names[col_params_.default_rank], &device_);
- if (!status.ok()) {
- LOG(ERROR) << "Failed to find device "
- << col_params_.instance.device_names[col_params_.default_rank];
- for (auto d : dev_mgr_->ListDevices()) {
- LOG(ERROR) << "Available device " << d->name();
- }
- done_(status);
- return;
- }
- CHECK(device_);
- device_locality_ = device_->attributes().locality();
-
- VLOG(1) << this << " default_rank " << col_params_.default_rank << " cp "
- << &col_params_ << ": " << col_params_.ToString();
// Start by copying input to output if they're not already the same, i.e. if
// we're not computing in-place on the input tensor.
- if ((input_ != output_) &&
- (DMAHelper::base(input_) != DMAHelper::base(output_))) {
+ if ((col_ctx_->input != col_ctx_->output) &&
+ (DMAHelper::base(col_ctx_->input) != DMAHelper::base(col_ctx_->output))) {
// We are running in a blockable thread and the callback can't block so
// just wait here on the copy.
Notification note;
+ Status status;
CollectiveRemoteAccessLocal::MemCpyAsync(
- ctx_->input_device_context(0), ctx_->op_device_context(), device_,
- device_, ctx_->input_alloc_attr(0), ctx_->output_alloc_attr(0), input_,
- output_, 0 /*dev_to_dev_stream_index*/,
+ col_ctx_->op_ctx->input_device_context(0),
+ col_ctx_->op_ctx->op_device_context(), col_ctx_->device,
+ col_ctx_->device, col_ctx_->op_ctx->input_alloc_attr(0),
+ col_ctx_->op_ctx->output_alloc_attr(0), col_ctx_->input,
+ col_ctx_->output, 0 /*dev_to_dev_stream_index*/,
[this, &note, &status](const Status& s) {
status.Update(s);
note.Notify();
@@ -177,24 +303,43 @@ void RingReducer::Run(StatusCallback done) {
ContinueAfterInputCopy();
}
+string RingReducer::TensorDebugString(const Tensor& tensor) {
+ const DeviceBase::GpuDeviceInfo* gpu_device_info =
+ col_ctx_->op_ctx->device()->tensorflow_gpu_device_info();
+ if (gpu_device_info) {
+ Tensor cpu_tensor(tensor.dtype(), tensor.shape());
+ Notification note;
+ gpu_device_info->default_context->CopyDeviceTensorToCPU(
+ &tensor, "" /*tensor_name*/, col_ctx_->device, &cpu_tensor,
+ [&note](const Status& s) {
+ CHECK(s.ok());
+ note.Notify();
+ });
+ note.WaitForNotification();
+ return cpu_tensor.SummarizeValue(64);
+ } else {
+ return tensor.SummarizeValue(64);
+ }
+}
+
// Note that this function is blocking and must not run in any thread
// which cannot be blocked.
void RingReducer::ContinueAfterInputCopy() {
- AllocatorAttributes attr = ctx_->output_alloc_attr(0);
- ca_.reset(MakeCollectiveAdapter(output_, group_size_ * num_subdivs_,
- device_->GetAllocator(attr)));
+ AllocatorAttributes attr = col_ctx_->op_ctx->output_alloc_attr(0);
+ ca_.reset(MakeCollectiveAdapter(col_ctx_->output, group_size_ * num_subdivs_,
+ col_ctx_->device->GetAllocator(attr)));
- if (col_params_.final_op) {
+ if (col_params_->final_op) {
// Create an on-device scalar value from group_size_ that may be needed
// later.
// TODO(tucker): Cache and reuse across invocations? Or maybe the scalar
// can be provided to the kernel in host memory?
Tensor group_size_val = ca_->Scalar(group_size_);
- if (col_params_.group.device_type != "CPU") {
- group_size_tensor_ =
- ca_->Scalar(device_->GetAllocator(ctx_->input_alloc_attr(0)));
- DeviceContext* op_dev_ctx = ctx_->op_device_context();
- op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, device_,
+ if (col_params_->group.device_type != "CPU") {
+ group_size_tensor_ = ca_->Scalar(col_ctx_->device->GetAllocator(
+ col_ctx_->op_ctx->input_alloc_attr(0)));
+ DeviceContext* op_dev_ctx = col_ctx_->op_ctx->op_device_context();
+ op_dev_ctx->CopyCPUTensorToDevice(&group_size_val, col_ctx_->device,
&group_size_tensor_,
[this](const Status& s) {
if (!s.ok()) {
@@ -231,14 +376,14 @@ void RingReducer::StartAbort(const Status& s) {
// cancellation on all of the outstanding CollectiveRemoteAccess
// actions.
if (abort_started) {
- col_exec_->StartAbort(s);
+ col_ctx_->col_exec->StartAbort(s);
}
}
void RingReducer::Finish(bool ok) {
if (ok) {
// Recover the output from the adaptor.
- ca_->ConsumeFinalValue(output_);
+ ca_->ConsumeFinalValue(col_ctx_->output);
}
Status s;
{
@@ -275,7 +420,7 @@ Status RingReducer::ComputeBinOp(Device* device, OpKernel* op, Tensor* output,
// TODO(tucker): Is it possible to cache and reuse these objects? They're
// mostly identical inside one device execution.
std::unique_ptr<SubContext> sub_ctx(
- new SubContext(ctx_, op_params_, op, output, input));
+ new SubContext(col_ctx_->op_ctx, col_ctx_->op_params, op, output, input));
device->Compute(op, sub_ctx->sub_ctx_);
return sub_ctx->sub_ctx_->status();
}
@@ -295,18 +440,18 @@ void RingReducer::InitRingField(RingField* rf, int chunk_idx, int subdiv_idx,
rf->chunk_idx = chunk_idx;
rf->subdiv_idx = subdiv_idx;
rf->sc_idx = field_idx;
- rf->rank = col_params_.subdiv_rank[subdiv_idx];
+ rf->rank = col_params_->subdiv_rank[subdiv_idx];
rf->second_pass = false;
rf->action = RF_INIT;
// Recv from the device with preceding rank within the subdivision.
int recv_from_rank = (rf->rank + (group_size_ - 1)) % group_size_;
int send_to_rank = (rf->rank + 1) % group_size_;
- rf->recv_dev_idx = col_params_.instance.impl_details
+ rf->recv_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[subdiv_idx][recv_from_rank];
- int send_dev_idx = col_params_.instance.impl_details
+ int send_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[subdiv_idx][send_to_rank];
- rf->recv_is_remote = !col_params_.task.is_local[rf->recv_dev_idx];
- rf->send_is_remote = !col_params_.task.is_local[send_dev_idx];
+ rf->recv_is_remote = !col_params_->task.is_local[rf->recv_dev_idx];
+ rf->send_is_remote = !col_params_->task.is_local[send_dev_idx];
if (ca_->ChunkBytes(rf->sc_idx) > 0) {
// In pass 0 we skip Recv when rank = chunk_idx
rf->do_recv = (rf->chunk_idx != rf->rank);
@@ -360,45 +505,47 @@ string RingReducer::RingField::DebugString() const {
void RingReducer::DispatchSend(RingField* rf, const StatusCallback& done) {
CHECK(rf->do_send);
- string send_buf_key =
- RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx, rf->rank);
- VLOG(3) << "DispatchSend rank=" << col_params_.default_rank << " send key "
+ string send_buf_key = RingReduceBufKey(col_ctx_->exec_key, rf->second_pass,
+ rf->sc_idx, rf->rank);
+ VLOG(3) << "DispatchSend rank=" << col_params_->default_rank << " send key "
<< send_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " sc_idx "
<< rf->sc_idx;
int send_to_rank = (rf->rank + 1) % group_size_;
- int send_to_dev_idx = col_params_.instance.impl_details
+ int send_to_dev_idx = col_params_->instance.impl_details
.subdiv_permutations[rf->subdiv_idx][send_to_rank];
- col_exec_->PostToPeer(col_params_.instance.device_names[send_to_dev_idx],
- col_params_.instance.task_names[send_to_dev_idx],
- send_buf_key, device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), &rf->chunk,
- device_locality_, done);
+ col_ctx_->col_exec->PostToPeer(
+ col_params_->instance.device_names[send_to_dev_idx],
+ col_params_->instance.task_names[send_to_dev_idx], send_buf_key,
+ col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), &rf->chunk,
+ col_ctx_->device_locality, done);
}
void RingReducer::DispatchRecv(RingField* rf, const StatusCallback& done) {
CHECK(rf->do_recv);
string recv_buf_key =
- RingReduceBufKey(exec_key_, rf->second_pass, rf->sc_idx,
+ RingReduceBufKey(col_ctx_->exec_key, rf->second_pass, rf->sc_idx,
(rf->rank + (group_size_ - 1)) % group_size_);
- VLOG(3) << "DispatchRecv rank=" << col_params_.default_rank << " recv key "
+ VLOG(3) << "DispatchRecv rank=" << col_params_->default_rank << " recv key "
<< recv_buf_key << " chunk " << ca_->TBounds(rf->chunk) << " into "
- << ((col_params_.merge_op != nullptr) ? "tmp_chunk" : "chunk");
- Tensor* dst_tensor = (!rf->second_pass && (col_params_.merge_op != nullptr))
+ << ((col_params_->merge_op != nullptr) ? "tmp_chunk" : "chunk");
+ Tensor* dst_tensor = (!rf->second_pass && (col_params_->merge_op != nullptr))
? &rf->tmp_chunk
: &rf->chunk;
- col_exec_->RecvFromPeer(col_params_.instance.device_names[rf->recv_dev_idx],
- col_params_.instance.task_names[rf->recv_dev_idx],
- col_params_.task.is_local[rf->recv_dev_idx],
- recv_buf_key, device_, ctx_->op_device_context(),
- ctx_->output_alloc_attr(0), dst_tensor,
- device_locality_, rf->subdiv_idx, done);
+ col_ctx_->col_exec->RecvFromPeer(
+ col_params_->instance.device_names[rf->recv_dev_idx],
+ col_params_->instance.task_names[rf->recv_dev_idx],
+ col_params_->task.is_local[rf->recv_dev_idx], recv_buf_key,
+ col_ctx_->device, col_ctx_->op_ctx->op_device_context(),
+ col_ctx_->op_ctx->output_alloc_attr(0), dst_tensor,
+ col_ctx_->device_locality, rf->subdiv_idx, done);
}
string RingReducer::FieldState() {
- string s = strings::StrCat("RingReducer ",
- strings::Hex(reinterpret_cast<uint64>(this)),
- " exec ", exec_key_, " step_id=", step_id_,
- " state of all ", rfv_.size(), " fields:");
+ string s = strings::StrCat(
+ "RingReducer ", strings::Hex(reinterpret_cast<uint64>(this)), " exec ",
+ col_ctx_->exec_key, " step_id=", col_ctx_->step_id, " state of all ",
+ rfv_.size(), " fields:");
for (int i = 0; i < rfv_.size(); ++i) {
s.append("\n");
s.append(rfv_[i].DebugString());
@@ -415,13 +562,6 @@ bool RingReducer::RunAsyncParts() {
rfv_.clear();
rfv_.resize(group_size_ * num_subdivs_);
PCQueue ready_queue;
- int field_done_count = 0;
- int send_pending_count = 0;
- int recv_pending_count = 0;
- std::atomic<bool> aborted(false);
- field_done_count = 0;
- send_pending_count = 0;
- recv_pending_count = 0;
for (int chunk_idx = 0; chunk_idx < group_size_; ++chunk_idx) {
for (int subdiv_idx = 0; subdiv_idx < num_subdivs_; ++subdiv_idx) {
int rf_index = (chunk_idx * num_subdivs_) + subdiv_idx;
@@ -429,6 +569,30 @@ bool RingReducer::RunAsyncParts() {
ready_queue.Enqueue(&rfv_[rf_index]);
}
}
+ const DeviceBase::GpuDeviceInfo* gpu_info =
+ col_ctx_->device->tensorflow_gpu_device_info();
+ if (gpu_info) {
+ // Wait for all currently queued events on the CPU compute stream to
+ // complete before proceeding. The previous InitRingField calls allocated
+ // temp memory buffers that are not guaranteed to be valid (e.g. for RDMA
+ // write) unless we do.
+ Notification note;
+ Status s = gpu_info->default_context->ThenExecute(
+ col_ctx_->device, gpu_info->stream, [&note]() { note.Notify(); });
+ if (s.ok()) {
+ note.WaitForNotification();
+ } else {
+ mutex_lock l(status_mu_);
+ status_ =
+ errors::Internal("Failed to dispatch ThenExecute in RingReducer");
+ return false;
+ }
+ }
+
+ int field_done_count = 0;
+ int send_pending_count = 0;
+ int recv_pending_count = 0;
+ std::atomic<bool> aborted(false);
// Loop until all RingFields have advanced to completion.
while (field_done_count < rfv_.size()) {
@@ -468,8 +632,9 @@ bool RingReducer::RunAsyncParts() {
--recv_pending_count;
if (!rf->second_pass) {
rf->action = RF_REDUCE;
- Status s = ComputeBinOp(device_, col_params_.merge_op.get(),
- &rf->chunk, &rf->tmp_chunk);
+ Status s =
+ ComputeBinOp(col_ctx_->device, col_params_->merge_op.get(),
+ &rf->chunk, &rf->tmp_chunk);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@@ -479,11 +644,12 @@ bool RingReducer::RunAsyncParts() {
}
break;
case RF_REDUCE:
- if (!rf->second_pass && col_params_.final_op.get() && rf->is_final) {
+ if (!rf->second_pass && col_params_->final_op.get() && rf->is_final) {
rf->action = RF_FINALIZE;
group_size_tensor_ready_.WaitForNotification();
- Status s = ComputeBinOp(device_, col_params_.final_op.get(),
- &rf->chunk, &group_size_tensor_);
+ Status s =
+ ComputeBinOp(col_ctx_->device, col_params_->final_op.get(),
+ &rf->chunk, &group_size_tensor_);
if (!s.ok()) {
aborted = true;
StartAbort(s);
@@ -544,7 +710,8 @@ bool RingReducer::RunAsyncParts() {
case RF_SEND:
--send_pending_count;
break;
- default: {} // Ignore any other actions
+ default: {
+ } // Ignore any other actions
}
}
}
@@ -552,9 +719,11 @@ bool RingReducer::RunAsyncParts() {
CHECK_EQ(send_pending_count, 0);
CHECK_EQ(recv_pending_count, 0);
- VLOG(2) << this << " rank=" << rank_ << " finish;"
+ VLOG(2) << this << " device=" << col_ctx_->device_name << " finish;"
<< " final value " << TensorDebugString(ca_->Value());
return !aborted;
}
+REGISTER_COLLECTIVE(RingReduce, RingReducer);
+
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/ring_reducer.h b/tensorflow/core/common_runtime/ring_reducer.h
index 3e1988e787..0848e37b52 100644
--- a/tensorflow/core/common_runtime/ring_reducer.h
+++ b/tensorflow/core/common_runtime/ring_reducer.h
@@ -16,25 +16,35 @@ limitations under the License.
#define TENSORFLOW_CORE_COMMON_RUNTIME_RING_REDUCER_H_
#include <deque>
+#include <memory>
+#include <string>
+#include <vector>
#include "tensorflow/core/common_runtime/base_collective_executor.h"
#include "tensorflow/core/framework/collective.h"
-#include "tensorflow/core/framework/device_attributes.pb.h"
namespace tensorflow {
-class DeviceMgr;
+class Device;
// Ring-algorithm implementation of collective all-reduce.
-class RingReducer {
+class RingReducer : public CollectiveImplementationInterface {
public:
- RingReducer(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
- OpKernelContext* ctx, OpKernelContext::Params* op_params,
- const CollectiveParams& col_params, const string& exec_key,
- int64 step_id, const Tensor* input, Tensor* output);
+ RingReducer();
+ ~RingReducer() override;
- virtual ~RingReducer();
+ // Establishes the requested number of subdivision permutations based on the
+ // ring order implicit in the device order.
+ Status InitializeCollectiveParams(CollectiveParams* col_params) override;
- void Run(StatusCallback done);
+ // Initializes members of CollectiveContext not yet initialized, i.e. device
+ // and device_locality. Also saves the CollectiveContext in this object.
+ Status InitializeCollectiveContext(CollectiveContext* col_ctx) override;
+
+ // Begins async execution of the ring reduce algorithm.
+ // Must be called in a blockable thread.
+ // TODO(b/80529858): remove the previous warning when we have a dedicated
+ // collective threadpool.
+ void Run(StatusCallback done) override;
private:
// Called when a bad status is received that implies we should terminate
@@ -101,7 +111,7 @@ class RingReducer {
// For constructing log messages for debugging.
string FieldState();
- string TensorDebugString(Tensor tensor);
+ string TensorDebugString(const Tensor& tensor);
// Producer/Consumer Queue of RingField structs.
class PCQueue {
@@ -116,30 +126,19 @@ class RingReducer {
std::deque<RingField*> deque_ GUARDED_BY(pcq_mu_);
};
- CollectiveExecutor* col_exec_; // Not owned
- const DeviceMgr* dev_mgr_; // Not owned
- OpKernelContext* ctx_; // Not owned
- OpKernelContext::Params* op_params_; // Not owned
- const CollectiveParams& col_params_;
- const string exec_key_;
- const Tensor* input_; // Not owned
- Tensor* output_; // Not owned
- const int rank_;
- const int64 step_id_;
- const int group_size_;
- const int num_subdivs_;
+ CollectiveContext* col_ctx_; // Not owned
+ const CollectiveParams* col_params_; // Not owned
+ StatusCallback done_;
+ int group_size_;
+ int num_subdivs_;
Tensor group_size_tensor_;
Notification group_size_tensor_ready_;
std::unique_ptr<CollectiveAdapter> ca_;
- StatusCallback done_;
- Device* device_; // The device for which this instance labors
- const string device_name_;
- DeviceLocality device_locality_;
-
mutex status_mu_;
Status status_ GUARDED_BY(status_mu_);
-
std::vector<RingField> rfv_;
+
+ friend class RingReducerTest;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/ring_reducer_test.cc b/tensorflow/core/common_runtime/ring_reducer_test.cc
index fcdf9deff8..75aba43572 100644
--- a/tensorflow/core/common_runtime/ring_reducer_test.cc
+++ b/tensorflow/core/common_runtime/ring_reducer_test.cc
@@ -37,7 +37,6 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-namespace {
// Wraps CollectiveRemoteAccessLocal with the ability to return an
// error status to the N'th action.
@@ -135,27 +134,28 @@ class RingReducerTest : public ::testing::Test {
protected:
RingReducerTest() : device_type_(DEVICE_CPU) {}
- void SetUp() override {
-#if GOOGLE_CUDA
+#ifdef GOOGLE_CUDA
+ void InitGPUDevices() {
auto device_factory = DeviceFactory::GetFactory("GPU");
CHECK(device_factory);
SessionOptions options;
Status s = device_factory->CreateDevices(
options, "/job:worker/replica:0/task:0", &gpu_devices_);
CHECK(s.ok());
-#endif
}
+#endif
~RingReducerTest() override {
stop_ = true;
- for (auto i : instances_) {
- delete i;
- }
+ for (auto i : instances_) delete i;
if (col_exec_) col_exec_->Unref();
}
void Init(int num_workers, int num_devices, DataType dtype,
const DeviceType& device_type, int num_subdivs, int fail_after) {
+#ifdef GOOGLE_CUDA
+ InitGPUDevices();
+#endif
device_type_ = device_type;
std::vector<Device*> local_devices;
SessionOptions sess_opts;
@@ -201,6 +201,7 @@ class RingReducerTest : public ::testing::Test {
col_params_.instance.instance_key = kInstanceKey;
col_params_.instance.impl_details.subdiv_offsets.clear();
col_params_.instance.type = REDUCTION_COLLECTIVE;
+ col_params_.instance.impl_details.collective_name = "RingReduce";
col_params_.instance.data_type = dtype;
col_params_.instance.impl_details.subdiv_permutations.resize(num_subdivs);
col_params_.subdiv_rank.resize(num_subdivs);
@@ -259,13 +260,17 @@ class RingReducerTest : public ::testing::Test {
}
}
- void Reduce() {
+ void Reduce(int fail_after) {
std::atomic<int> done(0);
for (auto di : instances_) {
SchedClosure([di, &done] {
di->DoReduce();
++done;
});
+ if (fail_after > 0) {
+ // Stagger the op execution starts.
+ Env::Default()->SleepForMicroseconds(100);
+ }
}
while (done < static_cast<int>(instances_.size())) {
if (stop_) break;
@@ -295,7 +300,7 @@ class RingReducerTest : public ::testing::Test {
}
});
}
- Reduce();
+ Reduce(fail_after);
if (fail_after > 0) {
// Confirm that every device terminated with the expected error status.
for (int di = 0; di < static_cast<int>(instances_.size()); ++di) {
@@ -373,6 +378,22 @@ class RingReducerTest : public ::testing::Test {
return GetKernel(node_def, device_type, device);
}
+ void RunSubdivPermsTest(
+ CollectiveParams* cp,
+ const std::vector<std::vector<int>>& expected_subdiv_perms,
+ const std::vector<int>& expected_subdiv_rank) {
+ col_exec_ = nullptr;
+ cp->instance.impl_details.subdiv_permutations.clear();
+ cp->subdiv_rank.clear();
+ // Create a stub ring reducer only for testing param initialization.
+ RingReducer reducer;
+ TF_CHECK_OK(reducer.InitializeCollectiveParams(cp));
+ EXPECT_EQ(expected_subdiv_perms,
+ cp->instance.impl_details.subdiv_permutations);
+ EXPECT_EQ(expected_subdiv_rank, cp->subdiv_rank);
+ reducer.group_size_tensor_ready_.Notify(); // To unblock destructor.
+ }
+
class DeviceInstance {
public:
DeviceInstance(int rank, const string& dev_name,
@@ -475,8 +496,8 @@ class RingReducerTest : public ::testing::Test {
op_params.op_kernel = op.get();
OpKernelContext ctx(&op_params, 1);
- // We never actually execute the kernel, so we need to do the
- // output allocation that it would do, ourselves.
+ // We never actually execute the kernel, so we need to do the output
+ // allocation it would do, ourselves.
Tensor* output_tensor_ptr = nullptr;
TF_CHECK_OK(ctx.forward_input_or_allocate_output({0}, 0, tensor_.shape(),
&output_tensor_ptr));
@@ -485,20 +506,17 @@ class RingReducerTest : public ::testing::Test {
// Prepare a RingReducer instance.
string exec_key =
strings::StrCat(col_params_.instance.instance_key, ":0:0");
- RingReducer rr(parent_->col_exec_, parent_->dev_mgr_.get(), &ctx,
- &op_params, col_params_, exec_key, kStepId, &tensor_,
- &tensor_);
-
- // Start execution in a threadpool then wait for completion.
- Notification notification;
- SchedClosure([this, &notification, &rr]() {
- rr.Run([this, &notification](Status s) {
- status_ = s;
- notification.Notify();
- });
- });
- notification.WaitForNotification();
- CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+ RingReducer reducer;
+ CollectiveContext col_ctx(parent_->col_exec_, parent_->dev_mgr_.get(),
+ &ctx, &op_params, col_params_, exec_key,
+ kStepId, &tensor_, &tensor_);
+ TF_CHECK_OK(reducer.InitializeCollectiveContext(&col_ctx));
+
+ // Run the all-reduce.
+ reducer.Run([this](Status s) { status_ = s; });
+ if (status_.ok()) {
+ CHECK(tensor_.CopyFrom(*ctx.mutable_output(0), tensor_.shape()));
+ }
dev_ctx->Unref();
}
@@ -531,6 +549,108 @@ class RingReducerTest : public ::testing::Test {
int32 reduce_counter_ GUARDED_BY(mu_) = 0;
};
+CollectiveParams SetUpCollectiveParams(const int num_devs_per_task,
+ const int num_tasks) {
+ CollectiveParams cp;
+ const int kNumDevs = num_devs_per_task * num_tasks;
+ cp.group.group_key = 1;
+ cp.group.group_size = kNumDevs;
+ cp.group.device_type = DeviceType("GPU");
+ cp.group.num_tasks = num_tasks;
+ cp.instance.instance_key = 3;
+ cp.instance.type = REDUCTION_COLLECTIVE;
+ cp.instance.data_type = DataType(DT_FLOAT);
+ cp.instance.shape = TensorShape({kNumDevs});
+ cp.instance.impl_details.collective_name = "RingReduce";
+ cp.instance.impl_details.subdiv_offsets.push_back(0);
+ cp.is_source = false;
+ for (int i = 0; i < kNumDevs; ++i) {
+ int task_id = i / num_devs_per_task;
+ int dev_id = i % num_devs_per_task;
+ string task_name = strings::StrCat("/job:worker/replica:0/task:", task_id);
+ string device_name = strings::StrCat(task_name, "/device:GPU:", dev_id);
+ cp.instance.task_names.push_back(task_name);
+ cp.instance.device_names.push_back(device_name);
+ }
+ return cp;
+}
+
+TEST_F(RingReducerTest, InitializeParams) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets = {0, 4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {4, 5, 6, 7, 0, 1, 2, 3, 12, 13, 14, 15,
+ 8, 9, 10, 11, 20, 21, 22, 23, 16, 17, 18, 19}},
+ {0, 4});
+
+ cp.instance.impl_details.subdiv_offsets = {0, -4};
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+
+ cp.default_rank = 3;
+ cp.instance.impl_details.subdiv_offsets = {3, -3};
+ RunSubdivPermsTest(&cp,
+ {{3, 4, 5, 6, 7, 0, 1, 2, 11, 12, 13, 14,
+ 15, 8, 9, 10, 19, 20, 21, 22, 23, 16, 17, 18},
+ {4, 3, 2, 1, 0, 7, 6, 5, 12, 11, 10, 9,
+ 8, 15, 14, 13, 20, 19, 18, 17, 16, 23, 22, 21}},
+ {0, 1});
+}
+
+TEST_F(RingReducerTest, AutomaticSubdivs) {
+ const int kNumDevsPerTask = 8;
+ const int kNumTasks = 3;
+ const int kNumDevs = kNumDevsPerTask * kNumTasks;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ // Test automatic generation of subdiv offsets.
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}},
+ {0});
+
+ // Set shape so that with 2 subdivs chunk_size is 3 MiB. This should cause 2
+ // offsets, {0, -4}, to be generated.
+ {
+ int num_subdivs = 2;
+ int num_chunks = kNumDevs * num_subdivs;
+ size_t chunk_size = 3 * 1048576; // 3 MB
+ size_t tensor_size = chunk_size * num_chunks;
+ cp.instance.shape =
+ TensorShape({static_cast<int64>(tensor_size / DataTypeSize(DT_FLOAT))});
+ }
+ cp.instance.impl_details.subdiv_offsets.clear();
+ RunSubdivPermsTest(&cp,
+ {{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23},
+ {3, 2, 1, 0, 7, 6, 5, 4, 11, 10, 9, 8,
+ 15, 14, 13, 12, 19, 18, 17, 16, 23, 22, 21, 20}},
+ {0, 3});
+}
+
+TEST_F(RingReducerTest, AutomaticSubdivUpperBound) {
+ const int kNumDevsPerTask = 1;
+ const int kNumTasks = 4;
+ CollectiveParams cp = SetUpCollectiveParams(kNumDevsPerTask, kNumTasks);
+
+ cp.default_rank = 0;
+ cp.instance.impl_details.subdiv_offsets.clear();
+ cp.instance.shape = TensorShape({104857600 / DataTypeSize(DT_FLOAT)});
+ RunSubdivPermsTest(&cp, {{0, 1, 2, 3}, {0, 1, 2, 3}}, {0, 0});
+}
+
+// TODO(b/113171733): change to use TEST_P.
#define DEF_TEST(B, T, W, D, S, L, A) \
TEST_F(RingReducerTest, \
DaTy##B##_DevTy##T##_Wkr##W##_Dev##D##_Sdiv##S##_Len##L##_Abrt##A) { \
@@ -575,6 +695,7 @@ DEF_TEST(INT64, CPU, 1, 2, 1, 1001, 0)
DEF_TEST(INT64, CPU, 2, 8, 3, 4095, 0)
// Failure tests
+DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 1)
DEF_TEST(FLOAT, CPU, 2, 8, 1, 9408, 7)
DEF_TEST(FLOAT, CPU, 2, 8, 2, 9408, 11)
#endif
@@ -604,5 +725,4 @@ DEF_TEST(FLOAT, GPU, 1, 8, 1, 9408, 2)
DEF_TEST(FLOAT, GPU, 1, 8, 2, 9408, 5)
#endif
-} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_factory.h b/tensorflow/core/common_runtime/session_factory.h
index 81c172c6ae..8565088afc 100644
--- a/tensorflow/core/common_runtime/session_factory.h
+++ b/tensorflow/core/common_runtime/session_factory.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
-#define TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_
#include <string>
@@ -73,4 +73,4 @@ class SessionFactory {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SESSION_FACTORY_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_FACTORY_H_
diff --git a/tensorflow/core/common_runtime/session_ref.cc b/tensorflow/core/common_runtime/session_ref.cc
deleted file mode 100644
index b931ef4229..0000000000
--- a/tensorflow/core/common_runtime/session_ref.cc
+++ /dev/null
@@ -1,170 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/common_runtime/session_ref.h"
-
-#include <utility>
-
-namespace tensorflow {
-
-namespace {
-
-// Scope helper to track active calls and manage session lifetime.
-struct RunCounter {
- std::shared_ptr<Session> session;
- uint64* value;
- mutex* m;
- condition_variable* cv;
-
- explicit RunCounter(std::shared_ptr<Session> s, uint64* v, mutex* m,
- condition_variable* cv)
- : session(std::move(s)), value(v), m(m), cv(cv) {
- mutex_lock l(*m);
- ++*value;
- }
-
- ~RunCounter() {
- mutex_lock l(*m);
- if (--*value == 0) {
- cv->notify_all();
- }
- }
-};
-
-} // namespace
-
-Status SessionRef::CheckNotClosed() {
- mutex_lock l(run_lock_);
- if (session_ == nullptr) return errors::Cancelled("Session has been closed.");
- return ::tensorflow::Status::OK();
-}
-
-Status SessionRef::Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(run_options, inputs, output_tensor_names,
- target_node_names, outputs, run_metadata);
-}
-
-Status SessionRef::Create(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(graph);
-}
-
-Status SessionRef::Create(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Create(run_options, graph);
-}
-
-Status SessionRef::Extend(const RunOptions& run_options,
- const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(run_options, graph);
-}
-
-Status SessionRef::Extend(const GraphDef& graph) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Extend(graph);
-}
-
-Status SessionRef::Close(const RunOptions& run_options) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close(run_options);
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Close() {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- mutex_lock l(run_lock_);
- Status status = session_->Close();
- session_.reset();
- while (run_count_ > 0) {
- run_finished_.wait(l);
- }
- return status;
-}
-
-Status SessionRef::Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->Run(inputs, output_tensor_names, target_node_names,
- outputs);
-}
-
-Status SessionRef::ListDevices(std::vector<DeviceAttributes>* response) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ListDevices(response);
-}
-
-Status SessionRef::PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRunSetup(input_names, output_names, target_nodes, handle);
-}
-
-Status SessionRef::PRun(const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->PRun(handle, inputs, output_names, outputs);
-}
-
-Status SessionRef::MakeCallable(const CallableOptions& callable_options,
- CallableHandle* out_handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->MakeCallable(callable_options, out_handle);
-}
-
-Status SessionRef::RunCallable(CallableHandle handle,
- const std::vector<Tensor>& feed_tensors,
- std::vector<Tensor>* fetch_tensors,
- RunMetadata* run_metadata) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->RunCallable(handle, feed_tensors, fetch_tensors,
- run_metadata);
-}
-
-Status SessionRef::ReleaseCallable(CallableHandle handle) {
- TF_RETURN_IF_ERROR(CheckNotClosed());
- RunCounter rc(session_, &run_count_, &run_lock_, &run_finished_);
- return rc.session->ReleaseCallable(handle);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/session_ref.h b/tensorflow/core/common_runtime/session_ref.h
deleted file mode 100644
index 9459e7edbe..0000000000
--- a/tensorflow/core/common_runtime/session_ref.h
+++ /dev/null
@@ -1,86 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-#define TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
-
-#include <memory>
-
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/public/session.h"
-
-namespace tensorflow {
-
-// A `SessionRef` manages the lifetime of a wrapped `Session` pointer.
-//
-// SessionRef blocks the return of Close() until all pending operations have
-// been completed or cancelled and underlying session has been freed. Any
-// subsequent operations on the SessionRef object will return errors::Cancelled.
-class SessionRef : public Session {
- public:
- SessionRef(Session* session) : session_(session) {}
- virtual ~SessionRef() {}
-
- Status Create(const GraphDef& graph) override;
- Status Extend(const GraphDef& graph) override;
- Status Create(const RunOptions& run_options, const GraphDef& graph) override;
- Status Extend(const RunOptions& run_options, const GraphDef& graph) override;
- Status Run(const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs) override;
-
- Status ListDevices(std::vector<DeviceAttributes>* response) override;
-
- Status Close() override;
- Status Close(const RunOptions& run_options) override;
-
- Status Run(const RunOptions& run_options,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_tensor_names,
- const std::vector<string>& target_node_names,
- std::vector<Tensor>* outputs, RunMetadata* run_metadata) override;
-
- Status PRunSetup(const std::vector<string>& input_names,
- const std::vector<string>& output_names,
- const std::vector<string>& target_nodes,
- string* handle) override;
-
- Status PRun(const string& handle,
- const std::vector<std::pair<string, Tensor> >& inputs,
- const std::vector<string>& output_names,
- std::vector<Tensor>* outputs) override;
-
- Status MakeCallable(const CallableOptions& callable_options,
- CallableHandle* out_handle) override;
-
- Status RunCallable(CallableHandle handle,
- const std::vector<Tensor>& feed_tensors,
- std::vector<Tensor>* fetch_tensors,
- RunMetadata* run_metadata) override;
-
- Status ReleaseCallable(CallableHandle handle) override;
-
- private:
- mutex run_lock_;
- condition_variable run_finished_;
- uint64 run_count_ GUARDED_BY(run_lock_) = {0};
- std::shared_ptr<Session> session_;
-
- Status CheckNotClosed();
-};
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SESSION_REF_H_
diff --git a/tensorflow/core/common_runtime/session_state.cc b/tensorflow/core/common_runtime/session_state.cc
index 65ff356e73..5b1915755d 100644
--- a/tensorflow/core/common_runtime/session_state.cc
+++ b/tensorflow/core/common_runtime/session_state.cc
@@ -70,7 +70,7 @@ Status TensorStore::SaveTensors(const std::vector<string>& output_names,
// Save only the tensors in output_names in the session.
for (const string& name : output_names) {
TensorId id(ParseTensorName(name));
- const string& op_name = std::string(id.first);
+ const string op_name(id.first);
auto it = tensors_.find(op_name);
if (it != tensors_.end()) {
// Save the tensor to the session state.
diff --git a/tensorflow/core/common_runtime/single_threaded_cpu_device.h b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
index 04d5af9087..22650b0d83 100644
--- a/tensorflow/core/common_runtime/single_threaded_cpu_device.h
+++ b/tensorflow/core/common_runtime/single_threaded_cpu_device.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
namespace tensorflow {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.cc b/tensorflow/core/common_runtime/step_stats_collector.cc
index af6880c6b3..a70ab93d4a 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.cc
+++ b/tensorflow/core/common_runtime/step_stats_collector.cc
@@ -16,13 +16,18 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/common_runtime/costmodel_manager.h"
#include "tensorflow/core/framework/allocation_description.pb.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tracking_allocator.h"
#include "tensorflow/core/graph/costmodel.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
@@ -36,10 +41,131 @@ struct AllocStats {
};
} // namespace
-NodeExecStatsWrapper::NodeExecStatsWrapper()
- : NodeExecStatsWrapper(new NodeExecStats) {}
-NodeExecStatsWrapper::NodeExecStatsWrapper(NodeExecStats* stats)
- : stats_(stats) {}
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ const Node* node, StepStatsCollector* step_stats_collector)
+ : NodeExecStatsWrapper(MakeUnique<NodeExecStats>(), node,
+ step_stats_collector) {
+ stats_->set_node_name(node->name());
+}
+
+NodeExecStatsWrapper::NodeExecStatsWrapper(
+ std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector)
+ : stats_(std::move(stats)),
+ node_(node),
+ step_stats_collector_(step_stats_collector) {}
+
+void NodeExecStatsWrapper::Done(const string& device) {
+ // TODO(tucker): merge with the DetailText function in session.cc in a common
+ // location.
+ DCHECK(node_);
+ string memory;
+ for (auto& all : stats_->memory()) {
+ int64 tot = all.total_bytes();
+ if (tot >= 0.1 * 1048576.0) {
+ int64 peak = all.peak_bytes();
+ if (peak > 0) {
+ memory =
+ strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB %.1fMB] ", tot / 1048576.0,
+ peak / 1048576.0));
+ } else {
+ memory = strings::StrCat(memory, "[", all.allocator_name(),
+ strings::Printf(" %.1fMB] ", tot / 1048576.0));
+ }
+ }
+ }
+ const AttrSlice attrs = node_->attrs();
+ string text;
+ if (IsSend(node_)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
+ string recv_device;
+ TF_CHECK_OK(GetNodeAttr(attrs, "recv_device", &recv_device));
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
+ "(", tensor_name, " @", recv_device);
+ } else if (IsRecv(node_)) {
+ string tensor_name;
+ TF_CHECK_OK(GetNodeAttr(attrs, "tensor_name", &tensor_name));
+ string send_device;
+ TF_CHECK_OK(GetNodeAttr(attrs, "send_device", &send_device));
+ text = strings::StrCat(memory, node_->name(), " = ", node_->type_string(),
+ "(", tensor_name, " @", send_device);
+ } else {
+ text =
+ strings::StrCat(memory, node_->name(), " = ", node_->type_string(), "(",
+ str_util::Join(node_->requested_inputs(), ", "), ")");
+ }
+ stats_->set_timeline_label(text);
+ step_stats_collector_->Save(device, this);
+}
+
+void NodeExecStatsWrapper::RecordExecutorStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ stats_->set_all_start_micros(now_nanos / EnvTime::kMicrosToNanos);
+ stats_->set_all_start_nanos(now_nanos);
+}
+
+void NodeExecStatsWrapper::RecordComputeStarted() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_start_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_start_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordComputeEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_op_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_op_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::RecordExecutorEnded() {
+ int64 now_nanos = Env::Default()->NowNanos();
+ DCHECK_NE(stats_->all_start_micros(), 0);
+ DCHECK_NE(stats_->all_start_nanos(), 0);
+ stats_->set_all_end_rel_micros(now_nanos / EnvTime::kMicrosToNanos -
+ stats_->all_start_micros());
+ stats_->set_all_end_rel_nanos(now_nanos - stats_->all_start_nanos());
+}
+
+void NodeExecStatsWrapper::SetScheduled(int64 nanos) {
+ stats_->set_scheduled_micros(nanos / EnvTime::kMicrosToNanos);
+ stats_->set_scheduled_nanos(nanos);
+}
+
+void NodeExecStatsWrapper::SetMemory(OpKernelContext* ctx) {
+ for (const auto& allocator_pair : ctx->wrapped_allocators()) {
+ AddAllocation(allocator_pair.first, allocator_pair.second);
+ }
+ auto* ms = stats_->mutable_memory_stats();
+ ms->set_temp_memory_size(ctx->temp_memory_allocated());
+ for (const auto& alloc_id : ctx->persistent_alloc_ids()) {
+ ms->mutable_persistent_tensor_alloc_ids()->Add(alloc_id);
+ }
+ ms->set_persistent_memory_size(ctx->persistent_memory_allocated());
+}
+
+void NodeExecStatsWrapper::SetOutput(int slot, const Tensor* tensor) {
+ DCHECK(tensor);
+ NodeOutput* node_output = stats_->add_output();
+ node_output->set_slot(slot);
+ tensor->FillDescription(node_output->mutable_tensor_description());
+}
+
+void NodeExecStatsWrapper::SetReferencedTensors(
+ const TensorReferenceVector& tensors) {
+ // be careful not to increment the reference count on any tensor
+ // while recording the information
+ for (size_t i = 0; i < tensors.size(); ++i) {
+ AllocationDescription* description = stats_->add_referenced_tensor();
+ tensors.at(i).FillDescription(description);
+ }
+}
void NodeExecStatsWrapper::AddAllocation(
Allocator* allocator, TrackingAllocator* tracking_allocator) {
@@ -68,8 +194,8 @@ void NodeExecStatsWrapper::Finalize() {
allocations_.clear();
}
-StepStatsCollector::StepStatsCollector(StepStats* ss)
- : finalized_(false), step_stats_(ss) {}
+StepStatsCollector::StepStatsCollector(StepStats* step_stats)
+ : finalized_(false), step_stats_(step_stats) {}
static int ExtractGpuWithStreamAll(string device_name) {
// Check if the device name matches the ".*gpu:(\\d+)/stream:all$" regexp,
@@ -94,7 +220,7 @@ static int ExtractGpuWithStreamAll(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -123,7 +249,7 @@ static int ExtractGpuWithoutStream(string device_name) {
} else {
// Convert the captured string into an integer. But first we need to put
// the digits back in order
- string ordered_capture = std::string(capture);
+ string ordered_capture(capture);
std::reverse(ordered_capture.begin(), ordered_capture.end());
int gpu_id;
CHECK(strings::safe_strto32(ordered_capture, &gpu_id));
@@ -170,7 +296,7 @@ void StepStatsCollector::BuildCostModel(
for (auto& itr : per_device_stats) {
const StringPiece device_name = itr.first;
- const int gpu_id = ExtractGpuWithoutStream(std::string(device_name));
+ const int gpu_id = ExtractGpuWithoutStream(string(device_name));
if (gpu_id >= 0) {
// Reference the gpu hardware stats in addition to the regular stats
// for this gpu device if they're available.
@@ -256,28 +382,40 @@ void StepStatsCollector::BuildCostModel(
}
}
-void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
- Save(device, new NodeExecStatsWrapper(nt));
+void StepStatsCollector::Save(const string& device,
+ NodeExecStats* node_stats_pb) {
+ Save(device,
+ new NodeExecStatsWrapper(std::unique_ptr<NodeExecStats>(node_stats_pb),
+ nullptr, this));
}
void StepStatsCollector::Save(const string& device,
- NodeExecStatsWrapper* stats) {
- if (!stats) return;
- VLOG(1) << "Save dev " << device << " nt " << stats->stats();
+ NodeExecStatsWrapper* node_stats) {
+ if (!node_stats) return;
+ VLOG(1) << "Save dev " << device << " node stats " << node_stats->stats();
{
mutex_lock l(mu_);
if (finalized_) {
LOG(WARNING) << "stats saved after finalize will not be collected.";
}
- if (!step_stats_ || collectedNodes >= kMaxCollectedNodes) {
+ if (!step_stats_ || collected_nodes_ >= kMaxCollectedNodes) {
VLOG(1) << "step_stats_ nullptr or already collected too many nodes.";
- delete stats;
+ delete node_stats;
return;
}
- auto& dss = dev_stats_[device];
- dss.push_back(std::unique_ptr<NodeExecStatsWrapper>(stats));
- collectedNodes++;
+ auto& device_stats = dev_stats_[device];
+ device_stats.push_back(std::unique_ptr<NodeExecStatsWrapper>(node_stats));
+ collected_nodes_++;
+ }
+}
+
+NodeExecStatsInterface* StepStatsCollector::CreateNodeExecStats(
+ const Node* node) {
+ // Only collect statistics for non-transfer nodes.
+ if (IsSend(node) || IsRecv(node)) {
+ return nullptr;
}
+ return new NodeExecStatsWrapper(node, this);
}
string StepStatsCollector::ReportAllocsOnResourceExhausted(const string& err) {
@@ -364,12 +502,12 @@ void StepStatsCollector::Finalize() {
FinalizeInternal();
}
-void StepStatsCollector::FinalizeAndSwap(StepStats* ss) {
+void StepStatsCollector::FinalizeAndSwap(StepStats* step_stats) {
mutex_lock l(mu_);
CHECK(step_stats_);
FinalizeInternal();
- ss->Swap(step_stats_);
- collectedNodes = 0;
+ step_stats->Swap(step_stats_);
+ collected_nodes_ = 0;
}
void StepStatsCollector::FinalizeInternal() {
diff --git a/tensorflow/core/common_runtime/step_stats_collector.h b/tensorflow/core/common_runtime/step_stats_collector.h
index 996dbb59bc..4365b11b19 100644
--- a/tensorflow/core/common_runtime/step_stats_collector.h
+++ b/tensorflow/core/common_runtime/step_stats_collector.h
@@ -12,14 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor_reference.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
@@ -30,45 +32,130 @@ class Allocator;
class AllocatorMemoryUsed;
class CostModelManager;
class Graph;
+class Node;
class NodeExecStats;
+class OpKernelContext;
class StepStats;
+class StepStatsCollector;
+class Tensor;
class TrackingAllocator;
+// Statistics collection interface for individual node execution.
+//
+// See `NodeExecStatsWrapper` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class NodeExecStatsInterface {
+ public:
+ virtual ~NodeExecStatsInterface() {}
+
+ // Called when the statistics collection for the node has finished. Once this
+ // method is called, the caller should not make assumptions about the validity
+ // of this object.
+ virtual void Done(const string& device) = 0;
+
+ // Called immediately after this node starts being processed by the executor.
+ virtual void RecordExecutorStarted() = 0;
+
+ // Called immediately before this node's `Compute()` or `ComputeAsync()`
+ // method is called.
+ virtual void RecordComputeStarted() = 0;
+
+ // Called immediately after this node's `Compute()` method returned (or, for
+ // asynchronous operations, the callback passed to its `ComputeAsync()` method
+ // was called).
+ virtual void RecordComputeEnded() = 0;
+
+ // Called immediately after this executor finishes processing this node.
+ virtual void RecordExecutorEnded() = 0;
+
+ // Records information about the memory allocated during the execution of this
+ // node.
+ virtual void SetMemory(OpKernelContext* ctx) = 0;
+
+ // Records information about the tensor produced by this node at the given
+ // output slot.
+ virtual void SetOutput(int slot, const Tensor* tensor) = 0;
+
+ // Records information about the tensors that were accessed during the
+ // execution of this node.
+ virtual void SetReferencedTensors(const TensorReferenceVector& tensors) = 0;
+
+ // Records the absolute time in nanoseconds at which this node became
+ // runnable (i.e. was scheduled for execution).
+ virtual void SetScheduled(int64 nanos) = 0;
+};
+
// Wraps NodeExecStats and adds allocation to it.
-class NodeExecStatsWrapper {
+class NodeExecStatsWrapper : public NodeExecStatsInterface {
public:
- NodeExecStatsWrapper();
- // Owns 'stats'.
- NodeExecStatsWrapper(NodeExecStats* stats);
+ // Does not take ownership of `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(const Node* node,
+ StepStatsCollector* step_stats_collector);
+
+ // Takes ownership of 'stats' but not `node` or `step_stats_collector`.
+ NodeExecStatsWrapper(std::unique_ptr<NodeExecStats> stats, const Node* node,
+ StepStatsCollector* step_stats_collector);
// Destructor calls Finalize() to release the TrackingAllocators.
~NodeExecStatsWrapper() { Finalize(); }
- NodeExecStats* stats() { return stats_.get(); }
-
- // "Does not take ownership of the 'allocator'.
- // Transfers ownership of the 'tracking_allocator' to *this."
- void AddAllocation(Allocator* allocator,
- TrackingAllocator* tracking_allocator);
+ void Done(const string& device) override;
+ void RecordExecutorStarted() override;
+ void RecordComputeStarted() override;
+ void RecordComputeEnded() override;
+ void RecordExecutorEnded() override;
+ void SetMemory(OpKernelContext* ctx) override;
+ void SetOutput(int slot, const Tensor* tensor) override;
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override;
+ void SetScheduled(int64 nanos) override;
private:
friend class StepStatsCollector;
+ NodeExecStats* stats() { return stats_.get(); }
+
// Populates stats_ and releases TrackingAllocator.
void Finalize();
+ // Does not take ownership of the `allocator`.
+ // Takes ownership of `tracking_allocator`.
+ void AddAllocation(Allocator* allocator,
+ TrackingAllocator* tracking_allocator);
+
gtl::InlinedVector<std::pair<AllocatorMemoryUsed*, TrackingAllocator*>, 2>
allocations_;
std::unique_ptr<NodeExecStats> stats_;
+ const Node* const node_; // Not owned.
+ StepStatsCollector* const step_stats_collector_; // Not owned.
+};
+
+// Statistics collection interface for step execution.
+//
+// See `StepStatsCollector` for a concrete implementation of this interface
+// that interfaces with the `Session` layer.
+class StepStatsCollectorInterface {
+ public:
+ virtual ~StepStatsCollectorInterface() {}
+
+ // Creates an instance of `NodeExecStatsInterface` that should be used for
+ // collecting statistics about individual node execution.
+ virtual NodeExecStatsInterface* CreateNodeExecStats(const Node* node) = 0;
+
+ // Generates a string reporting the currently used memory based
+ // on ResourceExhausted OOM `err` message.
+ // `err` message needs to contain device name and allocator name, e.g.:
+ // "ResourceExhaustedError: OOM when allocating tensor ...
+ // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
+ virtual string ReportAllocsOnResourceExhausted(const string& err) = 0;
};
// StepStatsCollector manages the collection of a StepStats object.
// The StepStats object holds multiple DeviceStats.
// Each DeviceStats object holds multiple NodeExecStats.
-class StepStatsCollector {
+class StepStatsCollector : public StepStatsCollectorInterface {
public:
- // Does not take ownership of `ss`.
- explicit StepStatsCollector(StepStats* ss);
+ // Does not take ownership of `step_stats`.
+ explicit StepStatsCollector(StepStats* step_stats);
// BuildCostModel builds or updates a CostModel managed by cost_model_manager,
// using the currently collected DeviceStats associated with the devices in
@@ -77,39 +164,37 @@ class StepStatsCollector {
CostModelManager* cost_model_manager,
const std::unordered_map<string, const Graph*>& device_map);
- // Save saves nt to the DeviceStats object associated with device.
+ // Saves node statistics to the DeviceStats object associated with device.
// Should be called before Finalize.
- void Save(const string& device, NodeExecStats* nt);
- void Save(const string& device, NodeExecStatsWrapper* stats);
+ void Save(const string& device, NodeExecStats* node_stats_pb);
+ void Save(const string& device, NodeExecStatsWrapper* node_stats);
- // Generates a string reporting the currently used memory based
- // on ResourceExhausted OOM `err` message.
- // `err` message needs to contain device name and allocator name, E.g.:
- // "ResourceExhaustedError: OOM when allocating tensor ...
- // on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc"
- string ReportAllocsOnResourceExhausted(const string& err);
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override;
+ string ReportAllocsOnResourceExhausted(const string& err) override;
// The following 2 Finalize methods populate the StepStats passed
// from the constructor. Calling it more than once won't have any effect.
// User shouldn't call Save() methods after Finalize.
void Finalize();
// swaps the content of StepStats* from constructor with 'ss'.
- void FinalizeAndSwap(StepStats* ss);
+ void FinalizeAndSwap(StepStats* step_stats);
private:
+ // TODO(suharshs): Make this configurable if its not possible to find a value
+ // that works for all cases.
+ static const uint64 kMaxCollectedNodes = 1 << 20;
+
+ typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeStatsVector;
+
void FinalizeInternal() EXCLUSIVE_LOCKS_REQUIRED(mu_);
- typedef std::vector<std::unique_ptr<NodeExecStatsWrapper>> NodeExecStatsVec;
- // TODO(suharshs): Make this configurable if its not possible to find a value
- // that works for all cases.
- const uint64 kMaxCollectedNodes = 1 << 20;
mutex mu_;
bool finalized_ GUARDED_BY(mu_);
- std::unordered_map<string, NodeExecStatsVec> dev_stats_ GUARDED_BY(mu_);
+ std::unordered_map<string, NodeStatsVector> dev_stats_ GUARDED_BY(mu_);
StepStats* step_stats_ GUARDED_BY(mu_);
- uint64 collectedNodes GUARDED_BY(mu_) = 0;
+ uint64 collected_nodes_ GUARDED_BY(mu_) = 0;
};
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_STEP_STATS_COLLECTOR_H_
diff --git a/tensorflow/core/common_runtime/sycl/sycl_allocator.h b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
index 550f193332..cc5909de17 100644
--- a/tensorflow/core/common_runtime/sycl/sycl_allocator.h
+++ b/tensorflow/core/common_runtime/sycl/sycl_allocator.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building TensorFlow with SYCL support
#endif
-#ifndef TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#define TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/allocator.h"
@@ -72,4 +72,4 @@ class SYCLAllocator : public Allocator {
} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_COMMON_RUNTIME_SYCL_SYCL_ALLOCATOR_H_
diff --git a/tensorflow/core/common_runtime/threadpool_device.cc b/tensorflow/core/common_runtime/threadpool_device.cc
index 7406ecf4f8..6404d8bc6a 100644
--- a/tensorflow/core/common_runtime/threadpool_device.cc
+++ b/tensorflow/core/common_runtime/threadpool_device.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/platform/tracing.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/util.h"
#ifdef INTEL_MKL
#ifdef _OPENMP
@@ -49,6 +50,8 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
allocator_(allocator),
scoped_allocator_mgr_(new ScopedAllocatorMgr(name)) {
#ifdef INTEL_MKL
+ // Early return when MKL is disabled
+ if (DisableMKL()) return;
#ifdef _OPENMP
const char* user_omp_threads = getenv("OMP_NUM_THREADS");
if (user_omp_threads == nullptr) {
@@ -70,17 +73,6 @@ ThreadPoolDevice::ThreadPoolDevice(const SessionOptions& options,
ThreadPoolDevice::~ThreadPoolDevice() {}
-void ThreadPoolDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) {
- // When Xprof/ThreadScape profiling is off (which is the default), the
- // following code is simple enough that its overhead is negligible.
- tracing::ScopedActivity activity(op_kernel->name(), op_kernel->type_string(),
- op_kernel->IsExpensive());
- tracing::ScopedRegion region(tracing::EventCategory::kCompute,
- op_kernel->name());
-
- op_kernel->Compute(context);
-}
-
Allocator* ThreadPoolDevice::GetAllocator(AllocatorAttributes attr) {
return allocator_;
}
@@ -124,8 +116,12 @@ class MklCPUAllocatorFactory : public AllocatorFactory {
}
};
-REGISTER_MEM_ALLOCATOR("MklCPUAllocator", 200, MklCPUAllocatorFactory);
+#ifdef ENABLE_MKL
+REGISTER_MEM_ALLOCATOR("MklCPUAllocator", (DisableMKL() ? 50 : 200),
+ MklCPUAllocatorFactory);
+#endif // ENABLE_MKL
+
} // namespace
-#endif
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/threadpool_device.h b/tensorflow/core/common_runtime/threadpool_device.h
index afc5d15ebc..51bd038a1c 100644
--- a/tensorflow/core/common_runtime/threadpool_device.h
+++ b/tensorflow/core/common_runtime/threadpool_device.h
@@ -29,7 +29,6 @@ class ThreadPoolDevice : public LocalDevice {
Allocator* allocator);
~ThreadPoolDevice() override;
- void Compute(OpKernel* op_kernel, OpKernelContext* context) override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
Allocator* GetScopedAllocator(AllocatorAttributes attr,
int64 step_id) override;
diff --git a/tensorflow/core/common_runtime/visitable_allocator.h b/tensorflow/core/common_runtime/visitable_allocator.h
deleted file mode 100644
index 8edf922d11..0000000000
--- a/tensorflow/core/common_runtime/visitable_allocator.h
+++ /dev/null
@@ -1,79 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-#define TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
-
-#include <functional>
-#include "tensorflow/core/framework/allocator.h"
-#include "tensorflow/core/framework/tracking_allocator.h"
-
-namespace tensorflow {
-
-// Subclass VisitableAllocator instead of Allocator when a memory
-// allocator needs to enable some kind of registration/deregistration
-// of memory areas.
-class VisitableAllocator : public Allocator {
- public:
- // Visitor gets called with a pointer to a memory area and its
- // size in bytes.
- typedef std::function<void(void*, size_t)> Visitor;
-
- // Register a visitor guaranteed to be called exactly once on each
- // chunk of memory newly allocated from the underlying device.
- // Typically, chunks will be reused and possibly sub-divided by a
- // pool manager, so the calls will happen only once per process
- // execution, not once per tensor (re)allocation.
- virtual void AddAllocVisitor(Visitor visitor) = 0;
-
- // Register a visitor guaranteed to be called on each chunk of
- // memory returned to the underlying device.
- virtual void AddFreeVisitor(Visitor visitor) = 0;
-};
-
-// Needed for cases when a VisitableAllocator gets wrapped for tracking.
-// Multiple-inheritance is considered acceptable in this case because
-// VisitableAllocator is a pure virtual interface and only TrackingAllocator
-// has default implementation.
-class TrackingVisitableAllocator : public TrackingAllocator,
- public VisitableAllocator {
- public:
- TrackingVisitableAllocator(VisitableAllocator* allocator, bool track_ids)
- : TrackingAllocator(allocator, track_ids), allocator_(allocator) {}
- ~TrackingVisitableAllocator() override {}
-
- string Name() override { return TrackingAllocator::Name(); }
-
- void* AllocateRaw(size_t alignment, size_t num_bytes) override {
- return TrackingAllocator::AllocateRaw(alignment, num_bytes);
- }
-
- void DeallocateRaw(void* ptr) override {
- TrackingAllocator::DeallocateRaw(ptr);
- }
-
- void AddAllocVisitor(Visitor visitor) override {
- allocator_->AddAllocVisitor(visitor);
- }
-
- void AddFreeVisitor(Visitor visitor) override {
- allocator_->AddFreeVisitor(visitor);
- }
-
- protected:
- VisitableAllocator* allocator_;
-};
-} // namespace tensorflow
-#endif // TENSORFLOW_COMMON_RUNTIME_VISITABLE_ALLOCATOR_H_
diff --git a/tensorflow/core/debug/debug_callback_registry.h b/tensorflow/core/debug/debug_callback_registry.h
index 8f08c656c2..bcd4ddc50c 100644
--- a/tensorflow/core/debug/debug_callback_registry.h
+++ b/tensorflow/core/debug/debug_callback_registry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
-#define TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_
#include <functional>
#include <map>
@@ -68,4 +68,4 @@ class DebugCallbackRegistry {
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_CALLBACK_REGISTRY_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_CALLBACK_REGISTRY_H_
diff --git a/tensorflow/core/debug/debug_graph_utils.cc b/tensorflow/core/debug/debug_graph_utils.cc
index 7641edea52..5fc95a8f20 100644
--- a/tensorflow/core/debug/debug_graph_utils.cc
+++ b/tensorflow/core/debug/debug_graph_utils.cc
@@ -356,8 +356,8 @@ Status DebugNodeInserter::ParseDebugOpName(
"Malformed attributes in debug op name \"", debug_op_name, "\"");
}
- const string key = std::string(seg.substr(0, eq_index));
- const string value = std::string(
+ const string key(seg.substr(0, eq_index));
+ const string value(
seg.substr(eq_index + 1, attribute_seg.size() - eq_index - 1));
if (key.empty() || value.empty()) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/debug/debug_graph_utils.h b/tensorflow/core/debug/debug_graph_utils.h
index 64deff1f00..86dc90a134 100644
--- a/tensorflow/core/debug/debug_graph_utils.h
+++ b/tensorflow/core/debug/debug_graph_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_NODE_INSERTER_H_
-#define TENSORFLOW_DEBUG_NODE_INSERTER_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_
#include <unordered_map>
#include <vector>
@@ -123,4 +123,4 @@ class DebugNodeInserter {
};
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_NODE_INSERTER_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_GRAPH_UTILS_H_
diff --git a/tensorflow/core/debug/debug_grpc_testlib.h b/tensorflow/core/debug/debug_grpc_testlib.h
index 8d3c9ff575..93376613b6 100644
--- a/tensorflow/core/debug/debug_grpc_testlib.h
+++ b/tensorflow/core/debug/debug_grpc_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
-#define TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_
#include <atomic>
#include <unordered_set>
@@ -84,4 +84,4 @@ bool PollTillFirstRequestSucceeds(const string& server_url,
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_GRPC_TESTLIB_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_GRPC_TESTLIB_H_
diff --git a/tensorflow/core/debug/debug_io_utils.cc b/tensorflow/core/debug/debug_io_utils.cc
index 9e8002d490..6994dec3b5 100644
--- a/tensorflow/core/debug/debug_io_utils.cc
+++ b/tensorflow/core/debug/debug_io_utils.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <stddef.h>
#include <string.h>
#include <cmath>
+#include <cstdlib>
+#include <cstring>
#include <limits>
#include <utility>
#include <vector>
@@ -399,8 +401,8 @@ Status DebugIO::PublishDebugMetadata(
strings::Printf("%.14lld", session_run_index))),
Env::Default()->NowMicros());
status.Update(DebugFileIO::DumpEventProtoToFile(
- event, std::string(io::Dirname(core_metadata_path)),
- std::string(io::Basename(core_metadata_path))));
+ event, string(io::Dirname(core_metadata_path)),
+ string(io::Basename(core_metadata_path))));
}
}
@@ -418,6 +420,19 @@ Status DebugIO::PublishDebugTensor(const DebugNodeKey& debug_node_key,
if (str_util::Lowercase(url).find(kFileURLScheme) == 0) {
const string dump_root_dir = url.substr(strlen(kFileURLScheme));
+ const int64 tensorBytes =
+ tensor.IsInitialized() ? tensor.TotalBytes() : 0;
+ if (!DebugFileIO::requestDiskByteUsage(tensorBytes)) {
+ return errors::ResourceExhausted(
+ "TensorFlow Debugger has exhausted file-system byte-size "
+ "allowance (",
+ DebugFileIO::globalDiskBytesLimit, "), therefore it cannot ",
+ "dump an additional ", tensorBytes, " byte(s) of tensor data ",
+ "for the debug tensor ", debug_node_key.node_name, ":",
+ debug_node_key.output_slot, ". You may use the environment ",
+ "variable TFDBG_DISK_BYTES_LIMIT to set a higher limit.");
+ }
+
Status s = DebugFileIO::DumpTensorToDir(
debug_node_key, tensor, wall_time_us, dump_root_dir, nullptr);
if (!s.ok()) {
@@ -632,8 +647,8 @@ Status DebugFileIO::DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
std::vector<Event> events;
TF_RETURN_IF_ERROR(
WrapTensorAsEvents(debug_node_key, tensor, wall_time_us, 0, &events));
- return DumpEventProtoToFile(events[0], std::string(io::Dirname(file_path)),
- std::string(io::Basename(file_path)));
+ return DumpEventProtoToFile(events[0], string(io::Dirname(file_path)),
+ string(io::Basename(file_path)));
}
Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
@@ -642,7 +657,7 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
return Status::OK();
}
- string parent_dir = std::string(io::Dirname(dir));
+ string parent_dir(io::Dirname(dir));
if (!env->FileExists(parent_dir).ok()) {
// The parent path does not exist yet, create it first.
Status s = RecursiveCreateDir(env, parent_dir); // Recursive call
@@ -670,6 +685,42 @@ Status DebugFileIO::RecursiveCreateDir(Env* env, const string& dir) {
}
}
+// Default total disk usage limit: 100 GBytes
+const uint64 DebugFileIO::defaultGlobalDiskBytesLimit = 107374182400L;
+uint64 DebugFileIO::globalDiskBytesLimit = 0;
+uint64 DebugFileIO::diskBytesUsed = 0;
+
+mutex DebugFileIO::bytes_mu(LINKER_INITIALIZED);
+
+bool DebugFileIO::requestDiskByteUsage(uint64 bytes) {
+ mutex_lock l(bytes_mu);
+ if (globalDiskBytesLimit == 0) {
+ const char* env_tfdbg_disk_bytes_limit = getenv("TFDBG_DISK_BYTES_LIMIT");
+ if (env_tfdbg_disk_bytes_limit == nullptr ||
+ strlen(env_tfdbg_disk_bytes_limit) == 0) {
+ globalDiskBytesLimit = defaultGlobalDiskBytesLimit;
+ } else {
+ strings::safe_strtou64(string(env_tfdbg_disk_bytes_limit),
+ &globalDiskBytesLimit);
+ }
+ }
+
+ if (bytes == 0) {
+ return true;
+ }
+ if (diskBytesUsed + bytes < globalDiskBytesLimit) {
+ diskBytesUsed += bytes;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void DebugFileIO::resetDiskByteUsage() {
+ mutex_lock l(bytes_mu);
+ diskBytesUsed = 0;
+}
+
#ifndef PLATFORM_WINDOWS
DebugGrpcChannel::DebugGrpcChannel(const string& server_stream_addr)
: server_stream_addr_(server_stream_addr),
diff --git a/tensorflow/core/debug/debug_io_utils.h b/tensorflow/core/debug/debug_io_utils.h
index c974a47051..5390ce408a 100644
--- a/tensorflow/core/debug/debug_io_utils.h
+++ b/tensorflow/core/debug/debug_io_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_IO_UTILS_H_
-#define TENSORFLOW_DEBUG_IO_UTILS_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
#include <cstddef>
#include <functional>
@@ -193,6 +193,26 @@ class DebugFileIO {
const string& dir_name,
const string& file_name);
+ // Request additional bytes to be dumped to the file system.
+ //
+ // Does not actually dump the bytes, but instead just performs the
+ // bookkeeping necessary to prevent the total dumped amount of data from
+ // exceeding the limit (default 100 GBytes or set customly through the
+ // environment variable TFDBG_DISK_BYTES_LIMIT).
+ //
+ // Args:
+ // bytes: Number of bytes to request.
+ //
+ // Returns:
+ // Whether the request is approved given the total dumping
+ // limit.
+ static bool requestDiskByteUsage(uint64 bytes);
+
+ // Reset the disk byte usage to zero.
+ static void resetDiskByteUsage();
+
+ static uint64 globalDiskBytesLimit;
+
private:
// Encapsulates the Tensor in an Event protobuf and write it to file.
static Status DumpTensorToEventFile(const DebugNodeKey& debug_node_key,
@@ -204,6 +224,15 @@ class DebugFileIO {
// TODO(cais): Replace with shared implementation once http://b/30497715 is
// fixed.
static Status RecursiveCreateDir(Env* env, const string& dir);
+
+ // Tracks how much disk has been used so far.
+ static uint64 diskBytesUsed;
+ // Mutex for thread-safe access to diskBytesUsed.
+ static mutex bytes_mu;
+ // Default limit for the disk space.
+ static const uint64 defaultGlobalDiskBytesLimit;
+
+ friend class DiskUsageLimitTest;
};
} // namespace tensorflow
@@ -398,4 +427,4 @@ class DebugGrpcIO {
} // namespace tensorflow
#endif // #ifndef(PLATFORM_WINDOWS)
-#endif // TENSORFLOW_DEBUG_IO_UTILS_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_IO_UTILS_H_
diff --git a/tensorflow/core/debug/debug_io_utils_test.cc b/tensorflow/core/debug/debug_io_utils_test.cc
index 0807a85b8b..82e0ae5edb 100644
--- a/tensorflow/core/debug/debug_io_utils_test.cc
+++ b/tensorflow/core/debug/debug_io_utils_test.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstdlib>
#include <unordered_set>
#include "tensorflow/core/debug/debug_io_utils.h"
@@ -454,5 +455,50 @@ TEST_F(DebugIOUtilsTest, PublishTensorConcurrentlyToPartiallyOverlappingPaths) {
}
}
+class DiskUsageLimitTest : public ::testing::Test {
+ public:
+ void Initialize() {
+ setenv("TFDBG_DISK_BYTES_LIMIT", "", 1);
+ DebugFileIO::resetDiskByteUsage();
+ DebugFileIO::globalDiskBytesLimit = 0;
+ }
+};
+
+TEST_F(DiskUsageLimitTest, RequestWithZeroByteIsOkay) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(0L));
+}
+
+TEST_F(DiskUsageLimitTest, ExceedingLimitAfterOneCall) {
+ Initialize();
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(100L * 1024L * 1024L * 1024L));
+}
+
+TEST_F(DiskUsageLimitTest, ExceedingLimitAfterTwoCalls) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1024L));
+}
+
+TEST_F(DiskUsageLimitTest, ResetDiskByteUsageWorks) {
+ Initialize();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+ DebugFileIO::resetDiskByteUsage();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(50L * 1024L * 1024L * 1024L));
+}
+
+TEST_F(DiskUsageLimitTest, CustomEnvVarIsObeyed) {
+ Initialize();
+ setenv("TFDBG_DISK_BYTES_LIMIT", "1024", 1);
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(1024L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1000L));
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(23L));
+ ASSERT_FALSE(DebugFileIO::requestDiskByteUsage(1L));
+ DebugFileIO::resetDiskByteUsage();
+ ASSERT_TRUE(DebugFileIO::requestDiskByteUsage(1023L));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/debug/debug_node_key.h b/tensorflow/core/debug/debug_node_key.h
index b46054c013..eaeb369790 100644
--- a/tensorflow/core/debug/debug_node_key.h
+++ b/tensorflow/core/debug/debug_node_key.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUG_NODE_KEY_H_
-#define TENSORFLOW_DEBUG_NODE_KEY_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_
#include <string>
@@ -48,4 +48,4 @@ struct DebugNodeKey {
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUG_NODE_KEY_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUG_NODE_KEY_H_
diff --git a/tensorflow/core/debug/debugger_state_impl.cc b/tensorflow/core/debug/debugger_state_impl.cc
index 2f5aaf93fa..79798f9392 100644
--- a/tensorflow/core/debug/debugger_state_impl.cc
+++ b/tensorflow/core/debug/debugger_state_impl.cc
@@ -27,6 +27,9 @@ DebuggerState::DebuggerState(const DebugOptions& debug_options) {
debug_urls_.insert(url);
}
}
+ if (debug_options.reset_disk_byte_usage()) {
+ DebugFileIO::resetDiskByteUsage();
+ }
}
DebuggerState::~DebuggerState() {
diff --git a/tensorflow/core/debug/debugger_state_impl.h b/tensorflow/core/debug/debugger_state_impl.h
index 52e2663d08..8f6e53fafe 100644
--- a/tensorflow/core/debug/debugger_state_impl.h
+++ b/tensorflow/core/debug/debugger_state_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DEBUGGER_STATE_IMPL_H_
-#define TENSORFLOW_DEBUGGER_STATE_IMPL_H_
+#ifndef TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_
+#define TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_
#include "tensorflow/core/common_runtime/debugger_state_interface.h"
@@ -58,4 +58,4 @@ class DebugGraphDecorator : public DebugGraphDecoratorInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_DEBUGGER_STATE_IMPL_H_
+#endif // TENSORFLOW_CORE_DEBUG_DEBUGGER_STATE_IMPL_H_
diff --git a/tensorflow/core/distributed_runtime/BUILD b/tensorflow/core/distributed_runtime/BUILD
index b2192c5a80..37029f3f1a 100644
--- a/tensorflow/core/distributed_runtime/BUILD
+++ b/tensorflow/core/distributed_runtime/BUILD
@@ -562,6 +562,7 @@ cc_library(
deps = [
":worker_cache",
":worker_interface",
+ "//tensorflow/core:framework",
],
)
diff --git a/tensorflow/core/distributed_runtime/graph_mgr.cc b/tensorflow/core/distributed_runtime/graph_mgr.cc
index 6c146036ae..3361819e43 100644
--- a/tensorflow/core/distributed_runtime/graph_mgr.cc
+++ b/tensorflow/core/distributed_runtime/graph_mgr.cc
@@ -233,14 +233,11 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
params.function_library = lib;
params.create_kernel = [session, lib, opseg](const NodeDef& ndef,
OpKernel** kernel) {
- // We do not share the kernel via the OpSegment if the node is
- // stateless, or a function.
// NOTE(mrry): We must not share function kernels (implemented
// using `CallOp`) between subgraphs, because `CallOp::handle_`
// is tied to a particular subgraph. Even if the function itself
// is stateful, the `CallOp` that invokes it is not.
- if (!lib->IsStateful(ndef.op()) ||
- lib->GetFunctionLibraryDefinition()->Find(ndef.op()) != nullptr) {
+ if (!OpSegment::ShouldOwnKernel(lib, ndef.op())) {
return lib->CreateKernel(ndef, kernel);
}
auto create_fn = [lib, &ndef](OpKernel** kernel) {
@@ -252,8 +249,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
return opseg->FindOrCreate(session, ndef.name(), kernel, create_fn);
};
params.delete_kernel = [lib](OpKernel* kernel) {
- // If the node is stateful, opseg owns it. Otherwise, delete it.
- if (kernel && !lib->IsStateful(kernel->type_string())) {
+ if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
delete kernel;
}
};
@@ -479,10 +475,7 @@ void GraphMgr::StartParallelExecutors(const string& handle, int64 step_id,
delete step_container;
});
Executor::Args args;
- {
- mutex_lock l(mu_);
- args.step_id = ++next_id_;
- }
+ args.step_id = step_id;
args.rendezvous = rendezvous;
args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
args.cancellation_manager = cancellation_manager;
diff --git a/tensorflow/core/distributed_runtime/master.cc b/tensorflow/core/distributed_runtime/master.cc
index a48f734d3e..269f620e42 100644
--- a/tensorflow/core/distributed_runtime/master.cc
+++ b/tensorflow/core/distributed_runtime/master.cc
@@ -53,6 +53,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/master.pb.h"
#include "tensorflow/core/protobuf/worker.pb.h"
#include "tensorflow/core/public/session_options.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -167,13 +168,55 @@ class DeviceFinder {
}
// Enumerates all known workers' target. A target name is a
// prefix of a device name. E.g., /job:mnist/replica:0/task:10.
- CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
- const string& local_device_name = env_->local_devices[0]->name();
- std::vector<string> workers;
- worker_cache->ListWorkers(&workers);
if (filters_.empty()) {
+ // If no filters were specified, we list all known workers in
+ // `worker_cache`.
+ std::vector<string> workers;
+ worker_cache->ListWorkers(&workers);
std::swap(workers, targets_);
} else {
+ // When applying filters, we must include the local worker, even if it
+ // does not match any of the filters.
+ CHECK_GT(env_->local_devices.size(), 0) << "No local devices provided.";
+ const string& local_device_name = env_->local_devices[0]->name();
+ DeviceNameUtils::ParsedName local_parsed_name;
+ CHECK(DeviceNameUtils::ParseFullName(local_device_name,
+ &local_parsed_name));
+ bool all_filters_have_job = true;
+ std::unordered_set<string> filter_job_names({local_parsed_name.job});
+ for (const DeviceNameUtils::ParsedName& filter : filters_) {
+ all_filters_have_job = all_filters_have_job && filter.has_job;
+ if (filter.has_job) {
+ filter_job_names.insert(filter.job);
+ }
+ }
+
+ std::vector<string> workers;
+ if (all_filters_have_job) {
+ // If all of the device filters have a job specified, then we only need
+ // to list the workers in the jobs named in the filter, because a worker
+ // in any other job would not match any filter.
+ for (const string& job_name : filter_job_names) {
+ VLOG(2) << "Selectively listing workers in job: " << job_name;
+ std::vector<string> workers_in_job;
+ worker_cache->ListWorkersInJob(job_name, &workers_in_job);
+ workers.insert(workers.end(), workers_in_job.begin(),
+ workers_in_job.end());
+ }
+ } else {
+ // If any of the device filters does not have a job specified, then we
+ // must list the workers from all jobs.
+ VLOG(2) << "Listing workers in all jobs because some device "
+ << "filter has no job specified. Filters were:";
+ if (device_filters.empty()) {
+ VLOG(2) << "- <NO FILTERS>";
+ } else {
+ for (const string& filter : device_filters) {
+ VLOG(2) << "- " << filter;
+ }
+ }
+ worker_cache->ListWorkers(&workers);
+ }
for (const string& name : workers) {
if (MatchFilters(name) ||
DeviceNameUtils::IsSameAddressSpace(name, local_device_name)) {
diff --git a/tensorflow/core/distributed_runtime/master_env.h b/tensorflow/core/distributed_runtime/master_env.h
index da26c42aca..837ccd1dd4 100644
--- a/tensorflow/core/distributed_runtime/master_env.h
+++ b/tensorflow/core/distributed_runtime/master_env.h
@@ -99,4 +99,4 @@ struct MasterEnv {
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MASTER_ENV_H_
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index d34ca53f73..8e9eec1ed9 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -449,7 +449,7 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
- c->req.set_collective_graph_key(bg_opts_.collective_graph_key);
+ c->req.set_collective_graph_key(client_graph()->collective_graph_key);
VLOG(2) << "Register " << c->req.graph_def().DebugString();
auto cb = [c, &done](const Status& s) {
c->status = s;
@@ -615,7 +615,7 @@ Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
// inadvertently slowing down the normal run path.
if (is_partial_) {
for (const auto& name_index : feeds) {
- const auto iter = part.feed_key.find(std::string(name_index.first));
+ const auto iter = part.feed_key.find(string(name_index.first));
if (iter == part.feed_key.end()) {
// The provided feed must be for a different partition.
continue;
@@ -959,7 +959,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
// Skip if already fed.
if (input.second) continue;
TensorId id(ParseTensorName(input.first));
- const Node* n = execution_state->get_node_by_name(std::string(id.first));
+ const Node* n = execution_state->get_node_by_name(string(id.first));
if (n == nullptr) {
return errors::NotFound("Feed ", input.first, ": not found");
}
@@ -975,7 +975,7 @@ Status MasterSession::ReffedClientGraph::CheckFetches(
for (size_t i = 0; i < req.num_fetches(); ++i) {
const string& fetch = req.fetch_name(i);
const TensorId id(ParseTensorName(fetch));
- const Node* n = execution_state->get_node_by_name(std::string(id.first));
+ const Node* n = execution_state->get_node_by_name(string(id.first));
if (n == nullptr) {
return errors::NotFound("Fetch ", fetch, ": not found");
}
@@ -1111,10 +1111,6 @@ uint64 HashBuildGraphOptions(const BuildGraphOptions& opts) {
h = Hash64(watch_summary.c_str(), watch_summary.size(), h);
}
- if (opts.collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey) {
- h = Hash64Combine(opts.collective_graph_key, h);
- }
-
return h;
}
@@ -1788,10 +1784,10 @@ Status MasterSession::PostRunCleanup(MasterSession::ReffedClientGraph* rcg,
Status s = run_status;
if (s.ok()) {
pss->end_micros = Env::Default()->NowMicros();
- if (rcg->build_graph_options().collective_graph_key !=
+ if (rcg->client_graph()->collective_graph_key !=
BuildGraphOptions::kNoCollectiveGraphKey) {
env_->collective_executor_mgr->RetireStepId(
- rcg->build_graph_options().collective_graph_key, step_id);
+ rcg->client_graph()->collective_graph_key, step_id);
}
// Schedule post-processing and cleanup to be done asynchronously.
rcg->ProcessStats(step_id, pss, ph.get(), run_options, out_run_metadata);
@@ -1850,7 +1846,7 @@ Status MasterSession::DoRunWithLocalExecution(
// Keeps the highest 8 bits 0x01: we reserve some bits of the
// step_id for future use.
- uint64 step_id = NewStepId(bgopts.collective_graph_key);
+ uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
std::unique_ptr<ProfileHandler> ph;
@@ -1914,8 +1910,7 @@ Status MasterSession::DoRunCallable(CallOptions* opts, ReffedClientGraph* rcg,
// Prepare.
int64 count = rcg->get_and_increment_execution_count();
- const uint64 step_id =
- NewStepId(rcg->build_graph_options().collective_graph_key);
+ const uint64 step_id = NewStepId(rcg->client_graph()->collective_graph_key);
TRACEPRINTF("stepid %llu", step_id);
const RunOptions& run_options = rcg->callable_options().run_options();
diff --git a/tensorflow/core/distributed_runtime/message_wrappers.h b/tensorflow/core/distributed_runtime/message_wrappers.h
index 72a0c7edd8..474ac0e186 100644
--- a/tensorflow/core/distributed_runtime/message_wrappers.h
+++ b/tensorflow/core/distributed_runtime/message_wrappers.h
@@ -721,4 +721,4 @@ class NonOwnedProtoRunStepResponse : public MutableRunStepResponseWrapper {
} // namespace tensorflow
-#endif // TENSORFLOW
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_MESSAGE_WRAPPERS_H_
diff --git a/tensorflow/core/distributed_runtime/remote_device.cc b/tensorflow/core/distributed_runtime/remote_device.cc
index 15e5919c54..a043c5dee6 100644
--- a/tensorflow/core/distributed_runtime/remote_device.cc
+++ b/tensorflow/core/distributed_runtime/remote_device.cc
@@ -37,7 +37,7 @@ string GetLocalDeviceName(StringPiece fullname) {
auto pos = fullname.rfind('/');
CHECK_NE(pos, StringPiece::npos);
fullname.remove_prefix(pos + 1);
- return std::string(fullname);
+ return string(fullname);
}
class RemoteDevice : public Device {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
index b7eb3c9015..456c30ecf4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.cc
@@ -163,6 +163,13 @@ class MultiGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ for (GrpcChannelCache* cache : caches_) {
+ cache->ListWorkersInJob(job_name, workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
mutex_lock l(mu_); // could use reader lock
GrpcChannelCache* cache = gtl::FindPtrOrNull(target_caches_, target);
@@ -223,6 +230,13 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) override {
+ if (job_name == job_id_) {
+ ListWorkers(workers);
+ }
+ }
+
string TranslateTask(const string& target) override {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
index 4861cdb691..6fa99d7b14 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel.h
@@ -66,6 +66,8 @@ class GrpcChannelCache {
// /job:<job identifier>/task:<task id>
// e.g. /job:mnist/task:2
virtual void ListWorkers(std::vector<string>* workers) = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) = 0;
// If found, returns a gRPC channel that is connected to the remote
// worker named by 'target'. 'target' is of the following
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
index f07a5a0974..a814ef85e2 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_channel_test.cc
@@ -89,13 +89,33 @@ TEST(GrpcChannelTest, HostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- EXPECT_EQ(std::vector<string>(
- {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
- "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(
+ std::vector<string>(
+ {"/job:mnist/replica:0/task:0", "/job:mnist/replica:0/task:1",
+ "/job:mnist/replica:0/task:2", "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4", "/job:mnist/replica:0/task:5"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, SparseHostPorts) {
@@ -135,13 +155,30 @@ TEST(GrpcChannelTest, SparseHostPorts) {
EXPECT_NE(d_4_1.get(), e_5_2.get());
}
- std::vector<string> workers;
- cc->ListWorkers(&workers);
- std::sort(workers.begin(), workers.end());
- EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
- "/job:mnist/replica:0/task:3",
- "/job:mnist/replica:0/task:4"}),
- workers);
+ {
+ std::vector<string> workers;
+ cc->ListWorkers(&workers);
+ std::sort(workers.begin(), workers.end());
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("mnist", &workers);
+ EXPECT_EQ(std::vector<string>({"/job:mnist/replica:0/task:0",
+ "/job:mnist/replica:0/task:3",
+ "/job:mnist/replica:0/task:4"}),
+ workers);
+ }
+
+ {
+ std::vector<string> workers;
+ cc->ListWorkersInJob("other", &workers);
+ EXPECT_TRUE(workers.empty());
+ }
}
TEST(GrpcChannelTest, NewHostPortGrpcChannelValidation) {
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
index 709c3833e7..b85c1dc5b4 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
-#define TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
#include <memory>
@@ -35,4 +35,4 @@ WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
} // namespace tensorflow
-#endif // TENSORFLOW_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_REMOTE_WORKER_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
index bcd46a4c06..c4f2247145 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc
@@ -190,6 +190,8 @@ Status GrpcServer::Init(
builder.SetMaxMessageSize(std::numeric_limits<int32>::max());
builder.SetOption(
std::unique_ptr<::grpc::ServerBuilderOption>(new NoReusePortOption));
+ // Allow subclasses to specify more args to pass to the gRPC server.
+ MaybeMutateBuilder(&builder);
master_impl_ = CreateMaster(&master_env_);
master_service_ = NewGrpcMasterService(master_impl_.get(), config, &builder);
worker_impl_ =
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
index 3366246afb..7979e96d3e 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_server_lib.h
@@ -59,6 +59,9 @@ typedef std::function<std::unique_ptr<GrpcWorker>(WorkerEnv*)>
class GrpcServer : public ServerInterface {
protected:
GrpcServer(const ServerDef& server_def, Env* env);
+ // Allow children classes to override this and provide custom args to the
+ // server before it is constructed. Default behavior is to do nothing.
+ virtual void MaybeMutateBuilder(::grpc::ServerBuilder* builder) {}
public:
static Status Create(const ServerDef& server_def, Env* env,
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
index d5baaae353..98164e750b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_testlib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
-#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_
+#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_
#include <memory>
#include <string>
@@ -71,4 +71,4 @@ class TestCluster {
} // end namespace test
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_TESTLIB_H_
+#endif // TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_RPC_GRPC_TESTLIB_H_
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
index b9f21ea211..e1541db69b 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc
@@ -54,6 +54,11 @@ class GrpcWorkerCache : public WorkerCachePartial {
channel_cache_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ channel_cache_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
diff --git a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
index 25ff6512a0..b070dd13dd 100644
--- a/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
+++ b/tensorflow/core/distributed_runtime/rpc/rpc_rendezvous_mgr_test.cc
@@ -50,6 +50,8 @@ namespace {
// Fake cache implementation for WorkerEnv.
class DummyWorkerCache : public WorkerCacheInterface {
void ListWorkers(std::vector<string>* workers) const override {}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {}
WorkerInterface* CreateWorker(const string& target) override {
return nullptr;
}
diff --git a/tensorflow/core/distributed_runtime/tensor_coding.h b/tensorflow/core/distributed_runtime/tensor_coding.h
index bae4ec794c..4c34297990 100644
--- a/tensorflow/core/distributed_runtime/tensor_coding.h
+++ b/tensorflow/core/distributed_runtime/tensor_coding.h
@@ -87,6 +87,9 @@ class TensorResponse {
// modified.
const RecvTensorResponse& metadata() const { return meta_; }
+ // Return pointer to the device hosting the tensor.
+ DeviceBase* device() const { return device_; }
+
private:
bool ParseTensorSubmessage(protobuf::io::CodedInputStream* input,
TensorProto* tensor_meta);
diff --git a/tensorflow/core/distributed_runtime/test_utils.h b/tensorflow/core/distributed_runtime/test_utils.h
index 48d83845dd..88a97da34d 100644
--- a/tensorflow/core/distributed_runtime/test_utils.h
+++ b/tensorflow/core/distributed_runtime/test_utils.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <unordered_map>
#include "tensorflow/core/distributed_runtime/worker_cache.h"
#include "tensorflow/core/distributed_runtime/worker_interface.h"
+#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
@@ -138,6 +139,19 @@ class TestWorkerCache : public WorkerCacheInterface {
}
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ workers->clear();
+ for (auto it : workers_) {
+ DeviceNameUtils::ParsedName device_name;
+ CHECK(DeviceNameUtils::ParseFullName(it.first, &device_name));
+ CHECK(device_name.has_job);
+ if (job_name == device_name.job) {
+ workers->push_back(it.first);
+ }
+ }
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
auto it = workers_.find(target);
if (it != workers_.end()) {
diff --git a/tensorflow/core/distributed_runtime/worker_cache.h b/tensorflow/core/distributed_runtime/worker_cache.h
index 8521f8956b..0c8575b4d5 100644
--- a/tensorflow/core/distributed_runtime/worker_cache.h
+++ b/tensorflow/core/distributed_runtime/worker_cache.h
@@ -36,6 +36,8 @@ class WorkerCacheInterface {
// Updates *workers with strings naming the remote worker tasks to
// which open channels have been established.
virtual void ListWorkers(std::vector<string>* workers) const = 0;
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const = 0;
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
index 43c3b6285b..1f309b4361 100644
--- a/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
+++ b/tensorflow/core/distributed_runtime/worker_cache_wrapper.h
@@ -32,6 +32,10 @@ class WorkerCacheWrapper : public WorkerCacheInterface {
virtual void ListWorkers(std::vector<string>* workers) const {
return wrapped_->ListWorkers(workers);
}
+ virtual void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const {
+ return wrapped_->ListWorkersInJob(job_name, workers);
+ }
// If "target" names a remote task for which an RPC channel exists
// or can be constructed, returns a pointer to a WorkerInterface object
diff --git a/tensorflow/core/distributed_runtime/worker_session.cc b/tensorflow/core/distributed_runtime/worker_session.cc
index ca6dc1b1de..c7d0c6b7f3 100644
--- a/tensorflow/core/distributed_runtime/worker_session.cc
+++ b/tensorflow/core/distributed_runtime/worker_session.cc
@@ -35,6 +35,11 @@ class WorkerFreeListCache : public WorkerCacheInterface {
wrapped_->ListWorkers(workers);
}
+ void ListWorkersInJob(const string& job_name,
+ std::vector<string>* workers) const override {
+ wrapped_->ListWorkersInJob(job_name, workers);
+ }
+
WorkerInterface* CreateWorker(const string& target) override {
mutex_lock l(mu_);
auto p = workers_.find(target);
diff --git a/tensorflow/core/example/example.proto b/tensorflow/core/example/example.proto
index e7142a4ef9..e36e51d8d5 100644
--- a/tensorflow/core/example/example.proto
+++ b/tensorflow/core/example/example.proto
@@ -199,7 +199,13 @@ message Example {
// to determine if all features within the FeatureList must
// have the same size. The same holds for this FeatureList across multiple
// examples.
-//
+// - For sequence modeling, e.g.:
+// http://colah.github.io/posts/2015-08-Understanding-LSTMs/
+// https://github.com/tensorflow/nmt
+// the feature lists represent a sequence of frames.
+// In this scenario, all FeatureLists in a SequenceExample have the same
+// number of Feature messages, so that the ith element in each FeatureList
+// is part of the ith frame (or time step).
// Examples of conformant and non-conformant examples' FeatureLists:
//
// Conformant FeatureLists:
diff --git a/tensorflow/core/example/example_parser_configuration.h b/tensorflow/core/example/example_parser_configuration.h
index 3d06bd55e2..8bbed28471 100644
--- a/tensorflow/core/example/example_parser_configuration.h
+++ b/tensorflow/core/example/example_parser_configuration.h
@@ -53,4 +53,4 @@ Status ExampleParserConfigurationProtoToFeatureVectors(
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSE_CONFIGURATION_H_
+#endif // TENSORFLOW_CORE_EXAMPLE_EXAMPLE_PARSER_CONFIGURATION_H_
diff --git a/tensorflow/core/example/feature_util.h b/tensorflow/core/example/feature_util.h
index 2265498b5e..016d1a92c1 100644
--- a/tensorflow/core/example/feature_util.h
+++ b/tensorflow/core/example/feature_util.h
@@ -97,12 +97,13 @@ limitations under the License.
// GetFeatureValues<FeatureType>(feature) -> RepeatedField<FeatureType>
// Returns values of the feature for the FeatureType.
-#ifndef TENSORFLOW_EXAMPLE_FEATURE_H_
-#define TENSORFLOW_EXAMPLE_FEATURE_H_
+#ifndef TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_
+#define TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_
#include <iterator>
#include <type_traits>
+#include "absl/base/macros.h"
#include "tensorflow/core/example/example.pb.h"
#include "tensorflow/core/example/feature.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -113,10 +114,10 @@ namespace tensorflow {
namespace internal {
-// DEPRECATED: Use GetFeature instead.
// TODO(gorban): Update all clients in a followup CL.
// Returns a reference to a feature corresponding to the name.
// Note: it will create a new Feature if it is missing in the example.
+ABSL_DEPRECATED("Use GetFeature instead.")
Feature& ExampleFeature(const string& name, Example* example);
// Specializations of RepeatedFieldTrait define a type of RepeatedField
@@ -314,12 +315,12 @@ bool HasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, GetFeatures(example));
}
-// DEPRECATED: use HasFeature instead.
// TODO(gorban): update all clients in a followup CL.
template <typename... FeatureType>
+ABSL_DEPRECATED("Use HasFeature instead.")
bool ExampleHasFeature(const string& key, const Example& example) {
return HasFeature<FeatureType...>(key, example);
}
} // namespace tensorflow
-#endif // TENSORFLOW_EXAMPLE_FEATURE_H_
+#endif // TENSORFLOW_CORE_EXAMPLE_FEATURE_UTIL_H_
diff --git a/tensorflow/core/framework/allocator.cc b/tensorflow/core/framework/allocator.cc
index 888ed0c57b..84cee5569c 100644
--- a/tensorflow/core/framework/allocator.cc
+++ b/tensorflow/core/framework/allocator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/allocator_registry.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/tracking_allocator.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
@@ -56,6 +57,14 @@ void RunResourceDtor(ResourceHandle* p, size_t n) {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
+void Allocator::RunVariantCtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
+}
+
+void Allocator::RunVariantDtor(Variant* p, size_t n) {
+ for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
+}
+
// If true, cpu allocator collects more stats.
static bool cpu_allocator_collect_stats = false;
// If true, cpu allocator collects full stats.
@@ -187,7 +196,7 @@ class CPUAllocatorFactory : public AllocatorFactory {
class CPUSubAllocator : public SubAllocator {
public:
explicit CPUSubAllocator(CPUAllocator* cpu_allocator)
- : cpu_allocator_(cpu_allocator) {}
+ : SubAllocator({}, {}), cpu_allocator_(cpu_allocator) {}
void* Alloc(size_t alignment, size_t num_bytes) override {
return cpu_allocator_->AllocateRaw(alignment, num_bytes);
@@ -213,4 +222,22 @@ Allocator* cpu_allocator() {
}
return cpu_alloc;
}
+
+SubAllocator::SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors)
+ : alloc_visitors_(alloc_visitors), free_visitors_(free_visitors) {}
+
+void SubAllocator::VisitAlloc(void* ptr, int index, size_t num_bytes) {
+ for (const auto& v : alloc_visitors_) {
+ v(ptr, index, num_bytes);
+ }
+}
+
+void SubAllocator::VisitFree(void* ptr, int index, size_t num_bytes) {
+ // Although we don't guarantee any order of visitor application, strive
+ // to apply free visitors in reverse order of alloc visitors.
+ for (int i = free_visitors_.size() - 1; i >= 0; --i) {
+ free_visitors_[i](ptr, index, num_bytes);
+ }
+}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator.h b/tensorflow/core/framework/allocator.h
index 774b1fe137..8c23604625 100644
--- a/tensorflow/core/framework/allocator.h
+++ b/tensorflow/core/framework/allocator.h
@@ -23,12 +23,14 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/type_traits.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+class Variant;
+
// Attributes for a single allocation call. Different calls to the same
// allocator could potentially have different allocation attributes.
struct AllocationAttributes {
@@ -228,13 +230,9 @@ class Allocator {
for (size_t i = 0; i < n; ++p, ++i) p->~ResourceHandle();
}
- virtual void RunVariantCtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) new (p) Variant();
- }
+ virtual void RunVariantCtor(Variant* p, size_t n);
- virtual void RunVariantDtor(Variant* p, size_t n) {
- for (size_t i = 0; i < n; ++p, ++i) p->~Variant();
- }
+ virtual void RunVariantDtor(Variant* p, size_t n);
// TODO(jeff): Maybe provide some interface to give info about
// current allocation state (total number of bytes available for
@@ -390,13 +388,36 @@ void EnableCPUAllocatorStats(bool enable);
// full statistics. By default, it's disabled.
void EnableCPUAllocatorFullStats(bool enable);
-// Abstract interface of an object that does the underlying suballoc/free of
-// memory for a higher-level allocator.
+// An object that does the underlying suballoc/free of memory for a higher-level
+// allocator. The expectation is that the higher-level allocator is doing some
+// kind of cache or pool management so that it will call SubAllocator::Alloc and
+// Free relatively infrequently, compared to the number of times its own
+// AllocateRaw and Free methods are called.
class SubAllocator {
public:
+ // Visitor gets called with a pointer to a memory area and its
+ // size in bytes. The index value will be numa_node for a CPU
+ // allocator and GPU id for a GPU allocator.
+ typedef std::function<void(void*, int index, size_t)> Visitor;
+
+ SubAllocator(const std::vector<Visitor>& alloc_visitors,
+ const std::vector<Visitor>& free_visitors);
+
virtual ~SubAllocator() {}
virtual void* Alloc(size_t alignment, size_t num_bytes) = 0;
virtual void Free(void* ptr, size_t num_bytes) = 0;
+
+ protected:
+ // Implementation of Alloc() method must call this on newly allocated
+ // value.
+ void VisitAlloc(void* ptr, int index, size_t num_bytes);
+
+ // Implementation of Free() method must call this on value to be
+ // freed immediately before deallocation.
+ void VisitFree(void* ptr, int index, size_t num_bytes);
+
+ const std::vector<Visitor> alloc_visitors_;
+ const std::vector<Visitor> free_visitors_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/framework/allocator_registry.h b/tensorflow/core/framework/allocator_registry.h
index 24f282ce84..e907c52ba9 100644
--- a/tensorflow/core/framework/allocator_registry.h
+++ b/tensorflow/core/framework/allocator_registry.h
@@ -21,6 +21,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/numa.h"
namespace tensorflow {
diff --git a/tensorflow/core/framework/attr_value_util.h b/tensorflow/core/framework/attr_value_util.h
index 0da9b1081b..9fce488793 100644
--- a/tensorflow/core/framework/attr_value_util.h
+++ b/tensorflow/core/framework/attr_value_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
#include <functional>
#include <string>
@@ -126,4 +126,4 @@ bool SubstitutePlaceholders(const SubstituteFunc& substitute, AttrValue* value);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_ATTR_VALUE_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_ATTR_VALUE_UTIL_H_
diff --git a/tensorflow/core/framework/attr_value_util_test.cc b/tensorflow/core/framework/attr_value_util_test.cc
index 1a3994736c..4ffd732f8e 100644
--- a/tensorflow/core/framework/attr_value_util_test.cc
+++ b/tensorflow/core/framework/attr_value_util_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <numeric>
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h
index 2f79d0fa70..e9e94024f5 100644
--- a/tensorflow/core/framework/bfloat16.h
+++ b/tensorflow/core/framework/bfloat16.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_BFLOAT16_H_
-#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
+#define TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/platform/byte_order.h"
@@ -60,4 +60,4 @@ void BFloat16ToFloat(const bfloat16* src, float* dst, int64 size);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_BFLOAT16_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_BFLOAT16_H_
diff --git a/tensorflow/core/framework/cancellation.cc b/tensorflow/core/framework/cancellation.cc
index 1258e40c93..af59500aee 100644
--- a/tensorflow/core/framework/cancellation.cc
+++ b/tensorflow/core/framework/cancellation.cc
@@ -89,6 +89,16 @@ bool CancellationManager::DeregisterCallback(CancellationToken token) {
}
}
+bool CancellationManager::TryDeregisterCallback(CancellationToken token) {
+ mutex_lock lock(mu_);
+ if (is_cancelled_ || is_cancelling_) {
+ return false;
+ } else {
+ callbacks_.erase(token);
+ return true;
+ }
+}
+
CancellationManager::~CancellationManager() {
if (!callbacks_.empty()) {
StartCancel();
diff --git a/tensorflow/core/framework/cancellation.h b/tensorflow/core/framework/cancellation.h
index 90074c87b2..7a5d942486 100644
--- a/tensorflow/core/framework/cancellation.h
+++ b/tensorflow/core/framework/cancellation.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_CANCELLATION_H_
-#define TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
+#define TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
#include <atomic>
#include <functional>
@@ -122,6 +122,15 @@ class CancellationManager {
// cancellation manager.
bool DeregisterCallback(CancellationToken token);
+ // Deregister the callback that, when registered, was associated
+ // with the given cancellation token. Returns true iff the callback
+ // was deregistered and will not be invoked; otherwise returns false
+ // immediately, with no guarantee that the callback has completed.
+ //
+ // This method is guaranteed to return true if StartCancel has not been
+ // called.
+ bool TryDeregisterCallback(CancellationToken token);
+
private:
bool is_cancelling_;
std::atomic_bool is_cancelled_;
@@ -134,4 +143,4 @@ class CancellationManager {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_CANCELLATION_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_CANCELLATION_H_
diff --git a/tensorflow/core/framework/cancellation_test.cc b/tensorflow/core/framework/cancellation_test.cc
index e3f18240b5..bf7593bc5f 100644
--- a/tensorflow/core/framework/cancellation_test.cc
+++ b/tensorflow/core/framework/cancellation_test.cc
@@ -115,4 +115,56 @@ TEST(Cancellation, IsCancelled) {
delete cm;
}
+TEST(Cancellation, TryDeregisterWithoutCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_TRUE(deregistered);
+ delete manager;
+ EXPECT_FALSE(is_cancelled);
+}
+
+TEST(Cancellation, TryDeregisterAfterCancel) {
+ bool is_cancelled = false;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(
+ token, [&is_cancelled]() { is_cancelled = true; });
+ EXPECT_TRUE(registered);
+ manager->StartCancel();
+ EXPECT_TRUE(is_cancelled);
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+ delete manager;
+}
+
+TEST(Cancellation, TryDeregisterDuringCancel) {
+ Notification cancel_started, finish_callback, cancel_complete;
+ CancellationManager* manager = new CancellationManager();
+ auto token = manager->get_cancellation_token();
+ bool registered = manager->RegisterCallback(token, [&]() {
+ cancel_started.Notify();
+ finish_callback.WaitForNotification();
+ });
+ EXPECT_TRUE(registered);
+
+ thread::ThreadPool w(Env::Default(), "test", 1);
+ w.Schedule([&]() {
+ manager->StartCancel();
+ cancel_complete.Notify();
+ });
+ cancel_started.WaitForNotification();
+
+ bool deregistered = manager->TryDeregisterCallback(token);
+ EXPECT_FALSE(deregistered);
+
+ finish_callback.Notify();
+ cancel_complete.WaitForNotification();
+ delete manager;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.cc b/tensorflow/core/framework/collective.cc
index d4ac50cbbe..4cb277d5a8 100644
--- a/tensorflow/core/framework/collective.cc
+++ b/tensorflow/core/framework/collective.cc
@@ -21,6 +21,31 @@ limitations under the License.
namespace tensorflow {
+namespace {
+// A RegistrationInfo object stores a collective implementation registration
+// details. `factory` is used to create instances of the collective
+// implementation.
+struct RegistrationInfo {
+ // This constructor also creates, and stores in `param_resolver_instance`,
+ // what is effectively a static instance of the collective implementation.
+ // During param resolution of collective ops we return this static instance.
+ // The actual op execution gets a fresh instance using `factory`.
+ RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
+ : name(n),
+ factory(std::move(f)),
+ param_resolver_instance(this->factory()) {}
+ string name;
+ CollectiveRegistry::Factory factory;
+ CollectiveImplementationInterface* param_resolver_instance;
+};
+
+std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
+ static std::vector<RegistrationInfo>* registry =
+ new std::vector<RegistrationInfo>;
+ return registry;
+}
+} // namespace
+
string CollGroupParams::ToString() const {
return strings::StrCat("CollGroupParams {group_key=", group_key,
" group_size=", group_size,
@@ -102,7 +127,8 @@ string CollectiveParams::ToString() const {
strings::StrAppend(&v, " ", instance.ToString());
strings::StrAppend(&v, " ", task.ToString());
strings::StrAppend(&v, " default_rank=", default_rank,
- " is_source=", is_source, " subdiv_rank={");
+ " is_source=", is_source, " source_rank=", source_rank,
+ " subdiv_rank={");
for (const auto& r : subdiv_rank) {
strings::StrAppend(&v, r, ",");
}
@@ -115,7 +141,81 @@ string CollectiveParams::ToString() const {
return ctx->params_;
}
+CollectiveContext::CollectiveContext(CollectiveExecutor* col_exec,
+ const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx,
+ OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params,
+ const string& exec_key, int64 step_id,
+ const Tensor* input, Tensor* output)
+ : col_exec(col_exec),
+ dev_mgr(dev_mgr),
+ op_ctx(ctx),
+ op_params(op_params),
+ col_params(col_params),
+ exec_key(exec_key),
+ step_id(step_id),
+ input(input),
+ output(output),
+ device(nullptr),
+ device_name(col_params.instance.device_names[col_params.default_rank]) {}
+
/*static*/
int64 CollectiveExecutor::kInvalidId = -1;
+/*static*/
+Status CollectiveRegistry::Lookup(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, false);
+}
+
+/*static*/
+Status CollectiveRegistry::LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation) {
+ return LookupHelper(collective_name, implementation, true);
+}
+
+/*static*/
+void CollectiveRegistry::GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry)
+ implementations->emplace_back(reg_info.factory());
+}
+
+/*static*/
+Status CollectiveRegistry::Register(const string& collective_name,
+ Factory factory) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name)
+ return errors::Internal("Already registered collective ",
+ collective_name);
+ }
+ registry->emplace_back(collective_name, std::move(factory));
+ return Status::OK();
+}
+
+/*static*/
+Status CollectiveRegistry::LookupHelper(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation, bool param_resolver) {
+ std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
+ for (const RegistrationInfo& reg_info : *registry) {
+ if (reg_info.name == collective_name) {
+ if (param_resolver) {
+ *implementation = reg_info.param_resolver_instance;
+ } else {
+ *implementation = reg_info.factory();
+ }
+ return Status::OK();
+ }
+ }
+ return errors::Internal(
+ "CollectiveRegistry::Lookup did not find collective implementation ",
+ collective_name);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index c3e6388e28..e35edb09d0 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -12,12 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
-#define TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
#include <string>
#include <vector>
+#include "tensorflow/core/framework/device_attributes.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -30,7 +31,8 @@ class CompleteGroupRequest;
class CompleteGroupResponse;
class CompleteInstanceRequest;
class CompleteInstanceResponse;
-class DeviceLocality;
+class Device;
+class DeviceMgr;
class GetStepSequenceRequest;
class GetStepSequenceResponse;
class Op;
@@ -64,10 +66,10 @@ struct CollGroupParams {
// interpretation. On first execution the runtime will update this
// structure with decisions that will guide all subsequent executions.
struct CollImplDetails {
+ string collective_name;
std::vector<std::vector<int>> subdiv_permutations;
std::vector<int> subdiv_offsets;
- // broadcast only: rank of source in each subdiv
- std::vector<int> subdiv_source_rank;
+ std::vector<int> subdiv_source_rank; // rank of source in each subdiv
};
// Data common to all members of a collective instance.
@@ -104,6 +106,7 @@ struct CollectiveParams {
string name = ""; // node name used only for log or error messages
int default_rank = -1; // index of this op within device_names
bool is_source = false; // broadcast only
+ int source_rank = -1; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::unique_ptr<OpKernel> merge_op; // reduction only
@@ -306,6 +309,110 @@ class PerStepCollectiveRemoteAccess : public CollectiveRemoteAccess {
virtual void StartAbort(const Status& s) = 0;
};
+class CollectiveContext {
+ public:
+ CollectiveContext(CollectiveExecutor* col_exec, const DeviceMgr* dev_mgr,
+ OpKernelContext* ctx, OpKernelContext::Params* op_params,
+ const CollectiveParams& col_params, const string& exec_key,
+ int64 step_id, const Tensor* input, Tensor* output);
+
+ virtual ~CollectiveContext() = default;
+
+ CollectiveExecutor* col_exec; // Not owned
+ const DeviceMgr* dev_mgr; // Not owned
+ OpKernelContext* op_ctx; // Not owned
+ OpKernelContext::Params* op_params; // Not owned
+ const CollectiveParams& col_params;
+ const string exec_key;
+ const int64 step_id;
+ const Tensor* input; // Not owned
+ Tensor* output; // Not owned
+ Device* device; // The device for which this instance labors
+ const string device_name;
+ DeviceLocality device_locality;
+};
+
+// Interface of a Collective Op implementation. Each specific CollectiveOp will
+// implement this interface and register the implementation via the
+// CollectiveRegistry detailed below. See common_runtime/ring_reducer and
+// common_runtime/hierarchical_tree_broadcaster for examples.
+class CollectiveImplementationInterface {
+ public:
+ virtual ~CollectiveImplementationInterface() = default;
+
+ // Initializes the portions of `col_params` specific to this
+ // implementation. Called exactly once for every Collective instance during
+ // the CollectiveParams resolution process when the graph is first executed.
+ // NOTE(ayushd): This is effectively a static function because it modifies the
+ // `col_params` passed in and should not manipulate any data members. However
+ // because it is virtual and needs to be implemented by every derived class we
+ // do not mark it as static.
+ virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
+
+ // Prepares the CollectiveContext for executing this CollectiveImplementation.
+ // Called from CollectiveExecutor right before calling Run(). The
+ // CollectiveContext passed in must outlive the CollectiveImplementation
+ // object.
+ virtual Status InitializeCollectiveContext(CollectiveContext* col_ctx) = 0;
+
+ // Processes and moves data according to the logic of this Collective
+ // implementation. Relies on appropriate initialization of op-specific
+ // CollectiveParams in InitializeCollectiveParams(), as well as appropriate
+ // context initialization in InitializeCollectiveContext().
+ virtual void Run(StatusCallback done) = 0;
+};
+
+// Static-methods only class for registering and looking up collective
+// implementations.
+class CollectiveRegistry {
+ public:
+ using Factory = std::function<CollectiveImplementationInterface*()>;
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, creates an instance of the implementation and
+ // assign to `implementation`.
+ static Status Lookup(const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Looks up a previously registered CollectiveImplementation under
+ // `collective_name`. If found, returns the static instance of this
+ // implementation via `implementation`. This instance should only be used to
+ // call InitializateCollectiveParams.
+ static Status LookupParamResolverInstance(
+ const string& collective_name,
+ CollectiveImplementationInterface** implementation);
+
+ // Returns all registered collective implementations.
+ static void GetAll(
+ std::vector<CollectiveImplementationInterface*>* implementations);
+
+ private:
+ friend class CollectiveRegistration;
+ // Registers a CollectiveImplementation with name `collective_name` and
+ // factory `factory`. The latter is a function used to create instances of
+ // the CollectiveImplementation. Also creates a static instance of the
+ // implementation - this instance is used during param resolution and should
+ // only be used to call InitializeCollectiveParams.
+ static Status Register(const string& collective_name, Factory factory);
+
+ static Status LookupHelper(const string& collective_name,
+ CollectiveImplementationInterface** implementation,
+ bool param_resolver);
+};
+
+// Class used to call CollectiveRegistry::Register. This should only be used to
+// create a global static object.
+class CollectiveRegistration {
+ public:
+ CollectiveRegistration(const string& collective_name,
+ CollectiveRegistry::Factory factory) {
+ TF_CHECK_OK(CollectiveRegistry::Register(collective_name, factory));
+ }
+};
+
+#define REGISTER_COLLECTIVE(name, implementation) \
+ static CollectiveRegistration register_##name##_collective( \
+ #name, []() { return new implementation; });
+
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_COLLECTIVE_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_COLLECTIVE_H_
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 21c6940b62..50403b4004 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -432,9 +432,9 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
DimensionHandle batch_size_dim;
DimensionHandle input_depth_dim;
gtl::InlinedVector<DimensionHandle, 2> input_spatial_dims(2);
- TF_RETURN_IF_ERROR(DimensionsFromShape(conv_input_shape, data_format,
- &batch_size_dim, &input_spatial_dims,
- &input_depth_dim, c));
+ TF_RETURN_IF_ERROR(DimensionsFromShape(
+ conv_input_shape, data_format, &batch_size_dim,
+ absl::MakeSpan(input_spatial_dims), &input_depth_dim, c));
DimensionHandle output_depth_dim = c->Dim(
filter_shape, GetFilterDimIndex<num_spatial_dims>(filter_format, 'O'));
@@ -1306,6 +1306,113 @@ Status RandomShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+namespace {
+
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
+template <typename T>
+Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
+ const Tensor* sizes_value,
+ std::vector<DimensionHandle>* dims) {
+ auto sizes_vec = sizes_value->vec<T>();
+ for (int i = 0; i < sizes_value->NumElements(); ++i) {
+ DimensionHandle dim = c->Dim(c->input(0), i);
+ if (sizes_vec(i) != -1) {
+ auto dim_val = c->Value(dim);
+ if (sizes_vec(i) < 0) {
+ return errors::InvalidArgument(
+ "Out of bounds slicing on dimension ", i, " of length ", dim_val,
+ ": sizes vector cannot be < -1, but was ", sizes_vec(i));
+ }
+
+ dims->emplace_back(c->MakeDim(sizes_vec(i)));
+ } else {
+ DimensionHandle result;
+ TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
+ dims->emplace_back(result);
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace
+
+Status SliceShape(InferenceContext* c) {
+ ShapeHandle input = c->input(0);
+ ShapeHandle begin_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
+ ShapeHandle sizes_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
+
+ // Merge to check compatibility of begin and sizes tensors.
+ TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
+
+ DimensionHandle ndims = c->Dim(begin_shape, 0);
+ if (c->ValueKnown(ndims)) {
+ TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
+ }
+
+ // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
+ // values, even though the `begin` value does not represent a shape.
+ ShapeHandle begin_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
+
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes` might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown").
+ const Tensor* sizes_value = c->input_tensor(2);
+
+ if (sizes_value != nullptr) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
+ std::vector<DimensionHandle> dims;
+ // If the begin and sizes tensors are available, then
+ // we can be precise about the shape of the output.
+ if (sizes_value->dtype() == DT_INT64) {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int64>(c, begin_value, sizes_value, &dims));
+ } else {
+ TF_RETURN_IF_ERROR(
+ SliceHelper<int32>(c, begin_value, sizes_value, &dims));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ } else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ dims.reserve(c->Rank(sizes_value));
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+ // We might know the rank of the input.
+ if (c->RankKnown(input)) {
+ c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
+ return Status::OK();
+ } else {
+ return shape_inference::UnknownShape(c);
+ }
+ }
+
+ return Status::OK();
+}
+
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
ShapeHandle values_shape, ShapeHandle shape_shape) {
// Validate ranks.
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index 2bedce1d6a..3a496e06ae 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
-#define TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
#include <array>
@@ -293,6 +293,9 @@ inline Status BroadcastBinaryOpShapeFn(InferenceContext* c) {
// Shape function for random operations.
Status RandomShape(shape_inference::InferenceContext* c);
+// Shape function for Slice opertaions.
+Status SliceShape(shape_inference::InferenceContext* c);
+
// Validates the 3 component tensors of a sparse tensor have the proper
// shapes. This mimics SparseTensor.__init__ in python/framework/ops.py.
Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
@@ -311,4 +314,4 @@ Status ExplicitShapes(InferenceContext* c);
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_OPS_COMMON_SHAPE_FNS_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_COMMON_SHAPE_FNS_H_
diff --git a/tensorflow/core/framework/control_flow.h b/tensorflow/core/framework/control_flow.h
index 4dad0b4fef..4839e02e22 100644
--- a/tensorflow/core/framework/control_flow.h
+++ b/tensorflow/core/framework/control_flow.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
-#define TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_
+#define TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
@@ -55,4 +55,4 @@ struct FrameAndIterHash {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_CONTROL_FLOW_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_CONTROL_FLOW_H_
diff --git a/tensorflow/core/framework/dataset.cc b/tensorflow/core/framework/dataset.cc
index e886ef7b8e..284dafb886 100644
--- a/tensorflow/core/framework/dataset.cc
+++ b/tensorflow/core/framework/dataset.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {
-
+namespace data {
namespace {
// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
@@ -74,18 +74,18 @@ class DatasetVariantWrapper {
} // namespace
Status GraphDefBuilderWrapper::AddDataset(
- const GraphDatasetBase* dataset,
+ const DatasetBase* dataset,
const std::vector<std::pair<size_t, Node*>>& inputs,
const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
- const string& op_type_name = dataset->op_name();
+ const string& name = dataset->name();
std::unique_ptr<const GraphDefBuilder::Options> opts(
new GraphDefBuilder::Options(b_->opts()));
// TODO(srbs|mrry): Not all datasets have output_types and output_shapes
// attributes defined. It will be nice to have a consistent pattern.
- bool has_output_types_attr = HasAttr(op_type_name, "output_types");
- bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
+ bool has_output_types_attr = HasAttr(name, "output_types");
+ bool has_output_shapes_attr = HasAttr(name, "output_shapes");
if (has_output_shapes_attr) {
opts.reset(new GraphDefBuilder::Options(
opts->WithAttr("output_shapes", dataset->output_shapes())));
@@ -102,8 +102,7 @@ Status GraphDefBuilderWrapper::AddDataset(
return errors::Internal("AddDataset: Failed to build Options with error ",
opts->StatusToString());
}
- NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
- opts->op_registry());
+ NodeBuilder node_builder(opts->GetNameForOp(name), name, opts->op_registry());
{
size_t total_size = inputs.size() + list_inputs.size();
auto inputs_iter = inputs.begin();
@@ -128,28 +127,31 @@ Status GraphDefBuilderWrapper::AddDataset(
}
*output = opts->FinalizeBuilder(&node_builder);
if (*output == nullptr) {
- return errors::Internal("AddDataset: Failed to build ", op_type_name,
+ return errors::Internal("AddDataset: Failed to build ", name,
" op with error ", opts->StatusToString());
}
return Status::OK();
}
-Status GraphDefBuilderWrapper::AddFunction(
- const FunctionLibraryDefinition& flib_def, const string& function_name) {
+Status GraphDefBuilderWrapper::AddFunction(SerializationContext* ctx,
+ const string& function_name) {
if (b_->HasFunction(function_name)) {
VLOG(1) << "Function with name " << function_name << "already exists in"
<< " the graph. It will not be added again.";
return Status::OK();
}
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(flib_def, function_name));
- const FunctionDef* f_def = flib_def.Find(function_name);
+ if (!ctx->allow_stateful_functions()) {
+ TF_RETURN_IF_ERROR(
+ EnsureFunctionIsStateless(ctx->flib_def(), function_name));
+ }
+ const FunctionDef* f_def = ctx->flib_def().Find(function_name);
if (f_def == nullptr) {
return errors::InvalidArgument("Unable to find FunctionDef for ",
function_name, " in the registry.");
}
FunctionDefLibrary def;
*def.add_function() = *f_def;
- const string gradient_func = flib_def.FindGradient(function_name);
+ const string gradient_func = ctx->flib_def().FindGradient(function_name);
if (!gradient_func.empty()) {
GradientDef* g_def = def.add_gradient();
g_def->set_function_name(function_name);
@@ -160,23 +162,30 @@ Status GraphDefBuilderWrapper::AddFunction(
// Recursively add functions in inputs of function_name.
for (const NodeDef& node_def : f_def->node_def()) {
const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def.LookUp(node_def.op(), &op_reg_data));
+ TF_RETURN_IF_ERROR(ctx->flib_def().LookUp(node_def.op(), &op_reg_data));
if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, op_reg_data->op_def.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
}
// Recursively add functions in attrs of this NodeDef.
for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, pair.second));
}
}
// Recursively add functions in attrs of function_name.
for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, flib_def));
+ TF_RETURN_IF_ERROR(AddAttrFunctions(ctx, iter->second));
}
return Status::OK();
}
+void GraphDefBuilderWrapper::AddPlaceholderInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Placeholder",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("shape", val.shape()));
+}
+
void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
Node** output) {
*output = ops::SourceOp(
@@ -184,27 +193,32 @@ void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
}
-bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
+bool GraphDefBuilderWrapper::HasAttr(const string& name,
const string& attr_name) const {
const OpDef* op_def = nullptr;
- Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
+ Status s = b_->opts().op_registry()->LookUpOpDef(name, &op_def);
if (!s.ok() || op_def == nullptr) {
return false;
}
return HasAttr(op_def, attr_name);
}
-Status GraphDatasetBase::Serialize(SerializationContext* ctx,
- string* serialized_graph_def,
- string* output_node) const {
+Status DatasetBase::Save(SerializationContext* ctx,
+ IteratorStateWriter* writer) const {
+ string serialized_graph_def;
+ string output_node;
GraphDefBuilder b;
DatasetGraphDefBuilder db(&b);
Node* node = nullptr;
TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
- *output_node = node->name();
+ output_node = node->name();
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
- graph_def.SerializeToString(serialized_graph_def);
+ graph_def.SerializeToString(&serialized_graph_def);
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
return Status::OK();
}
@@ -264,8 +278,8 @@ void BinaryDatasetOpKernel::MakeDataset(OpKernelContext* ctx,
MakeDataset(ctx, input, another_input, output);
}
-const char GraphDatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
-const char GraphDatasetBase::kDatasetGraphOutputNodeKey[] =
+const char DatasetBase::kDatasetGraphKey[] = "_DATASET_GRAPH";
+const char DatasetBase::kDatasetGraphOutputNodeKey[] =
"_DATASET_GRAPH_OUTPUT_NODE";
BackgroundWorker::BackgroundWorker(Env* env, const string& name) {
@@ -315,22 +329,5 @@ void BackgroundWorker::WorkerLoop() {
}
}
-namespace dataset {
-
-IteratorContext MakeIteratorContext(OpKernelContext* ctx) {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.lib = ctx->function_library();
- // Note: must use reinterpret_cast because function.h forward-declares Device.
- DeviceBase* device =
- reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- return IteratorContext(params);
-}
-
-} // namespace dataset
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 66e836f9a6..964a7d5f8c 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
@@ -40,6 +41,18 @@ limitations under the License.
namespace tensorflow {
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class Node;
+
+namespace data {
+// A constant that can be used to enable auto-tuning.
+constexpr int kAutoTune = -1;
+
+class DatasetBase;
+class SerializationContext;
+
// Interface for reading values from a key-value store.
// Used for restoring iterator state.
class IteratorStateReader {
@@ -63,12 +76,6 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
-// Forward declarations to avoid introducing a dependency on headers in
-// "tensorflow/core/graph/...".
-class GraphDefBuilder;
-class GraphDatasetBase;
-class Node;
-
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@@ -108,10 +115,11 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- // Adds a Const node with Tensor value to the Graph.
+ // Adds a `Const` node for the given tensor value to the graph.
+ //
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
- // non-null if the method returns with an OK status.
- // The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
Status AddTensor(const Tensor& val, Node** output) {
AddTensorInternal(val, output);
if (*output == nullptr) {
@@ -120,7 +128,21 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- Status AddDataset(const GraphDatasetBase* dataset,
+ // Adds a `Placeholder` node for the given tensor value to the graph.
+ //
+ // `*output` contains a pointer to the output `Node`. It is guaranteed to be
+ // non-null if the method returns with an OK status. The returned `Node`
+ // pointer is owned by the backing graph of `GraphDefBuilder`.
+ Status AddPlaceholder(const Tensor& val, Node** output) {
+ AddPlaceholderInternal(val, output);
+ if (*output == nullptr) {
+ return errors::Internal(
+ "AddPlaceholder: Failed to build Placeholder op.");
+ }
+ return Status::OK();
+ }
+
+ Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
@@ -133,7 +155,7 @@ class GraphDefBuilderWrapper {
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
- Status AddDataset(const GraphDatasetBase* dataset,
+ Status AddDataset(const DatasetBase* dataset,
const std::vector<Node*>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
@@ -145,7 +167,7 @@ class GraphDefBuilderWrapper {
}
Status AddDataset(
- const GraphDatasetBase* dataset,
+ const DatasetBase* dataset,
const std::vector<std::pair<size_t, Node*>>& inputs,
const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
@@ -154,11 +176,11 @@ class GraphDefBuilderWrapper {
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
// name has already been added, returns with OK status. If a user-defined with
- // name `function_name` is not found in the FunctionLibraryDefinition, returns
- // an InvalidArgumentError. If the function with name `function_name` or any
- // of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(const FunctionLibraryDefinition& flib_def,
- const string& function_name);
+ // name `function_name` is not found in the context's function library,
+ // returns an InvalidArgumentError. If the function with name `function_name`
+ // or any of its dependent functions are stateful, and the context does not
+ // explicitly permit stateful functions, returns an InvalidArgument error.
+ Status AddFunction(SerializationContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -166,6 +188,7 @@ class GraphDefBuilderWrapper {
}
private:
+ void AddPlaceholderInternal(const Tensor& val, Node** output);
void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(const FunctionLibraryDefinition& flib_def,
@@ -204,8 +227,7 @@ class GraphDefBuilderWrapper {
return (str_util::EndsWith(op_def->name(), "Dataset") &&
op_def->output_arg_size() == 1 &&
op_def->output_arg(0).type() == DT_VARIANT) ||
- dataset::WhitelistedStatefulOpRegistry::Global()->Contains(
- op_def->name());
+ WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name());
}
bool HasAttr(const string& op_type_name, const string& attr_name) const;
@@ -219,13 +241,13 @@ class GraphDefBuilderWrapper {
return false;
}
- Status AddAttrFunctions(const AttrValue& attr_value,
- const FunctionLibraryDefinition& flib_def) {
+ Status AddAttrFunctions(SerializationContext* ctx,
+ const AttrValue& attr_value) {
if (attr_value.has_func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, attr_value.func().name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, attr_value.func().name()));
} else if (attr_value.has_list()) {
for (const NameAttrList& name_attr_list : attr_value.list().func()) {
- TF_RETURN_IF_ERROR(AddFunction(flib_def, name_attr_list.name()));
+ TF_RETURN_IF_ERROR(AddFunction(ctx, name_attr_list.name()));
}
}
return Status::OK();
@@ -256,15 +278,8 @@ class IteratorContext {
// Function call support.
std::function<void(std::function<void()>)> runner = nullptr;
- // A function that returns the current `StatsAggregator` instance to be
- // used when recording statistics about the iterator.
- //
- // NOTE(mrry): This is somewhat awkward, because (i) the `StatsAggregator`
- // is a property of the `IteratorResource` (which this class does not know
- // about), and (ii) it can change after the `IteratorContext` has been
- // created. Better suggestions are welcome!
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter =
- nullptr;
+ // The `StatsAggregator` object to record statistics about the iterator.
+ std::shared_ptr<StatsAggregator> stats_aggregator = nullptr;
// The FunctionLibraryRuntime object to be used to make function calls.
FunctionLibraryRuntime* lib = nullptr;
@@ -272,23 +287,32 @@ class IteratorContext {
// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;
+
+ // If non-null, identifies the object used for performance modeling.
+ std::shared_ptr<model::Model> model = nullptr;
};
explicit IteratorContext(Params params) : params_(std::move(params)) {}
+ explicit IteratorContext(OpKernelContext* ctx) {
+ params_.env = ctx->env();
+ params_.runner = *(ctx->runner());
+ params_.lib = ctx->function_library();
+ // NOTE: must use reinterpret_cast because function.h forward-declares
+ // Device.
+ DeviceBase* device =
+ reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
+ params_.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ }
+
Env* env() const { return params_.env; }
std::function<void(std::function<void()>)>* runner() {
return &params_.runner;
}
- std::shared_ptr<StatsAggregator> stats_aggregator() {
- if (params_.stats_aggregator_getter) {
- return params_.stats_aggregator_getter();
- } else {
- return nullptr;
- }
- }
std::shared_ptr<const FunctionLibraryDefinition> function_library() {
return params_.function_library;
@@ -306,10 +330,14 @@ class IteratorContext {
return params_.allocator_getter;
}
- std::function<std::shared_ptr<StatsAggregator>()> stats_aggregator_getter() {
- return params_.stats_aggregator_getter;
+ std::shared_ptr<StatsAggregator> stats_aggregator() {
+ return params_.stats_aggregator;
}
+ std::shared_ptr<model::Model> model() { return params_.model; }
+
+ Params params() { return params_; }
+
private:
Params params_;
};
@@ -318,13 +346,21 @@ class IteratorContext {
class SerializationContext {
public:
struct Params {
- const FunctionLibraryDefinition* flib_def; // Not owned.
+ bool allow_stateful_functions = false;
+ const FunctionLibraryDefinition* flib_def = nullptr; // Not owned.
+ std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
};
explicit SerializationContext(Params params) : params_(std::move(params)) {}
+ bool allow_stateful_functions() { return params_.allow_stateful_functions; }
+
const FunctionLibraryDefinition& flib_def() { return *params_.flib_def; }
+ std::vector<std::pair<string, Tensor>>* input_list() {
+ return params_.input_list;
+ }
+
private:
Params params_;
@@ -336,7 +372,11 @@ class SerializationContext {
// defined below.
class IteratorBase {
public:
- virtual ~IteratorBase() {}
+ virtual ~IteratorBase() {
+ for (auto rit = cleanup_fns_.rbegin(); rit != cleanup_fns_.rend(); ++rit) {
+ (*rit)();
+ }
+ }
// Gets the next output from the range that this iterator is traversing.
//
@@ -355,6 +395,11 @@ class IteratorBase {
virtual Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) = 0;
+ Status GetNext(IteratorContext&& ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ return GetNext(&ctx, out_tensors, end_of_sequence);
+ }
+
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// iterator.
@@ -365,6 +410,10 @@ class IteratorBase {
// in the outputs of this iterator.
virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+ // Returns a string that identifies the sequence of iterators leading up to
+ // this iterator.
+ virtual const string& prefix() const = 0;
+
// Performs initialization that needs to happen outside of a constructor to
// properly propagate errors.
virtual Status Initialize(IteratorContext* ctx) { return Status::OK(); }
@@ -404,12 +453,54 @@ class IteratorBase {
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
+
+ private:
+ friend class DatasetBase; // for access to `AddCleanupFunction`
+
+ // Registers a cleanup function to be called upon object destruction.
+ //
+ // Registered functions are invoked in the reserve order of registration.
+ void AddCleanupFunction(std::function<void()>&& cleanup_fn) {
+ cleanup_fns_.push_back(std::move(cleanup_fn));
+ }
+
+ std::vector<std::function<void()>> cleanup_fns_;
+};
+
+// Represents runtime information needed to construct a dataset.
+class DatasetContext {
+ public:
+ struct Params {
+ string name;
+ };
+
+ explicit DatasetContext(Params params) : params_(std::move(params)) {}
+
+ explicit DatasetContext(OpKernelContext* ctx) {
+ params_.name = ctx->op_kernel().type_string();
+ }
+
+ const string& name() const { return params_.name; }
+
+ private:
+ Params params_;
};
// Represents a (potentially infinite) range of outputs, where each
// output is a tuple of tensors.
class DatasetBase : public core::RefCounted {
public:
+ // Key for storing the Dataset graph in the serialized format.
+ TF_EXPORT static const char kDatasetGraphKey[];
+
+ // Key for storing the output node of the Dataset graph in the serialized
+ // format.
+ TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
+
+ explicit DatasetBase(DatasetContext&& ctx) : name_(ctx.name()) {}
+
+ const string& name() const { return name_; }
+
// Returns a new iterator for iterating over the range of elements in
// this dataset.
//
@@ -423,9 +514,21 @@ class DatasetBase : public core::RefCounted {
Status MakeIterator(IteratorContext* ctx, const string& prefix,
std::unique_ptr<IteratorBase>* iterator) const {
*iterator = MakeIteratorInternal(prefix);
+ if (ctx->model()) {
+ ctx->model()->AddNode((*iterator)->prefix(), prefix);
+ std::shared_ptr<model::Model> model = ctx->model();
+ const string& prefix = (*iterator)->prefix();
+ (*iterator)->AddCleanupFunction(
+ [model, prefix]() { model->RemoveNode(prefix); });
+ }
return (*iterator)->Initialize(ctx);
}
+ Status MakeIterator(IteratorContext&& ctx, const string& prefix,
+ std::unique_ptr<IteratorBase>* iterator) const {
+ return MakeIterator(&ctx, prefix, iterator);
+ }
+
// Returns a vector of DataType values, representing the respective
// element types of each tuple component in the outputs of this
// dataset.
@@ -441,16 +544,11 @@ class DatasetBase : public core::RefCounted {
// Serializes the dataset and writes it to the `writer`.
virtual Status Save(SerializationContext* ctx,
- IteratorStateWriter* writer) const {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
+ IteratorStateWriter* writer) const;
protected:
- // TODO(srbs): Ideally all graph related logic should reside in
- // GraphDatasetBase. However, that would require Datasets defined in all ops
- // to derive from GraphDatasetBase. Once that is done we can move
- // DatasetGraphDefBuilder and AsGraphDefInternal to GraphDatasetBase.
+ friend class DatasetToGraphOp; // For access to graph related members.
+
class DatasetGraphDefBuilder : public GraphDefBuilderWrapper {
public:
DatasetGraphDefBuilder(GraphDefBuilder* b) : GraphDefBuilderWrapper(b) {}
@@ -463,54 +561,13 @@ class DatasetBase : public core::RefCounted {
// TODO(jsimsa): Consolidate overloading into a single method.
virtual Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
- Node** node) const {
- return AsGraphDefInternal(b, node);
- }
-
- virtual Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
- Node** node) const {
- return errors::Unimplemented("%s does not support serialization",
- DebugString());
- }
+ Node** node) const = 0;
virtual std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const = 0;
- friend class DatasetToGraphOp; // For access to graph related members.
-};
-
-// Base-class for datasets that are built by ops.
-class GraphDatasetBase : public DatasetBase {
- public:
- GraphDatasetBase(OpKernelContext* ctx)
- : op_name_(ctx->op_kernel().type_string()) {}
-
- const string op_name() const { return op_name_; }
-
- Status Save(SerializationContext* ctx,
- IteratorStateWriter* writer) const override {
- string serialized_graph_def;
- string output_node;
- TF_RETURN_IF_ERROR(Serialize(ctx, &serialized_graph_def, &output_node));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(kDatasetGraphKey, serialized_graph_def));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(kDatasetGraphOutputNodeKey, output_node));
- return Status::OK();
- }
-
- // Key for storing the Dataset graph in the serialized format.
- TF_EXPORT static const char kDatasetGraphKey[];
-
- // Key for storing the output node of the Dataset graph in the serialized
- // format.
- TF_EXPORT static const char kDatasetGraphOutputNodeKey[];
-
private:
- Status Serialize(SerializationContext* ctx, string* serialized_graph_def,
- string* output_node) const;
-
- const string op_name_;
+ const string name_;
};
// Represents an iterator that is associated with a particular dataset.
@@ -531,7 +588,7 @@ class DatasetBaseIterator : public IteratorBase {
~DatasetBaseIterator() override { params_.dataset->Unref(); }
// The sequence of iterators leading up to this iterator.
- const string& prefix() const { return params_.prefix; }
+ const string& prefix() const override { return params_.prefix; }
const DataTypeVector& output_dtypes() const override {
return params_.dataset->output_dtypes();
@@ -544,7 +601,10 @@ class DatasetBaseIterator : public IteratorBase {
Status GetNext(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) final {
tracing::ScopedActivity activity(params_.prefix);
+ RecordStart(ctx, true /* stop_output */);
Status s = GetNextInternal(ctx, out_tensors, end_of_sequence);
+ if (s.ok() && !*end_of_sequence) RecordElement(ctx);
+ RecordStop(ctx, true /* start_output */);
if (TF_PREDICT_FALSE(errors::IsOutOfRange(s) && !*end_of_sequence)) {
s = errors::Internal(
"Iterator \"", params_.prefix,
@@ -571,6 +631,54 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
+ // When performance modeling is enabled, this method adds a constant parameter
+ // to the model node corresponding to this iterator.
+ void AddConstantParameter(IteratorContext* ctx, const string& name,
+ int64 value) {
+ if (ctx->model()) {
+ ctx->model()->AddConstantParameter(prefix(), name, value);
+ }
+ }
+
+ // When performance modeling is enabled, this method adds a tunable parameter
+ // to the model node corresponding to this iterator.
+ //
+ // The performance modeling logic may use `state` to set the value of the
+ // tunable parameter at any point during the lifetime of this iterator. When
+ // it does, it acquires `state->mu` and notifies `state->cond_var`.
+ void AddTunableParameter(IteratorContext* ctx, const string& name,
+ std::shared_ptr<model::SharedState> state, int64 min,
+ int64 max) {
+ if (ctx->model()) {
+ ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min,
+ max);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // this iterator has produced an element.
+ void RecordElement(IteratorContext* ctx) {
+ if (ctx->model()) {
+ ctx->model()->RecordElement(prefix());
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has started work.
+ void RecordStart(IteratorContext* ctx, bool stop_output = false) {
+ if (ctx->model()) {
+ ctx->model()->RecordStart(prefix(), stop_output);
+ }
+ }
+
+ // When performance modeling is enabled, this method records the fact that
+ // a thread of this iterator has stopped work.
+ void RecordStop(IteratorContext* ctx, bool start_output = false) {
+ if (ctx->model()) {
+ ctx->model()->RecordStop(prefix(), start_output);
+ }
+ }
+
private:
BaseParams params_;
};
@@ -718,11 +826,20 @@ class BackgroundWorker {
std::deque<std::function<void()>> work_queue_ GUARDED_BY(mu_);
};
-namespace dataset {
-
-IteratorContext MakeIteratorContext(OpKernelContext* ctx);
-
-} // namespace dataset
+} // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::DatasetBase;
+using data::DatasetContext;
+using data::DatasetIterator;
+using data::DatasetOpKernel;
+using data::IteratorBase;
+using data::IteratorContext;
+using data::IteratorStateReader;
+using data::IteratorStateWriter;
+using data::SerializationContext;
+using data::UnaryDatasetOpKernel;
} // namespace tensorflow
diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
index 3b48999edb..74bd39cb61 100644
--- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h
+++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h
@@ -16,38 +16,38 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_
+#include <unordered_set>
#include "tensorflow/core/lib/core/status.h"
namespace tensorflow {
-namespace dataset {
+namespace data {
// Registry for stateful ops that need to be used in dataset functions.
// See below macro for usage details.
class WhitelistedStatefulOpRegistry {
public:
- Status Add(StringPiece op_name) {
- op_names_.insert(op_name);
+ Status Add(string op_name) {
+ op_names_.insert(std::move(op_name));
return Status::OK();
}
- bool Contains(StringPiece op_name) {
- return op_names_.find(op_name) != op_names_.end();
- }
+ bool Contains(const string& op_name) { return op_names_.count(op_name); }
static WhitelistedStatefulOpRegistry* Global() {
- static WhitelistedStatefulOpRegistry* reg =
- new WhitelistedStatefulOpRegistry;
+ static auto* reg = new WhitelistedStatefulOpRegistry;
return reg;
}
private:
- WhitelistedStatefulOpRegistry() {}
- WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy);
+ WhitelistedStatefulOpRegistry() = default;
+ WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) =
+ delete;
WhitelistedStatefulOpRegistry operator=(
- WhitelistedStatefulOpRegistry const& copy);
- std::set<StringPiece> op_names_;
+ WhitelistedStatefulOpRegistry const& copy) = delete;
+
+ std::unordered_set<string> op_names_;
};
-} // namespace dataset
+} // namespace data
// Use this macro to whitelist an op that is marked stateful but needs to be
// used inside a map_fn in an input pipeline. This is only needed if you wish
@@ -67,10 +67,9 @@ class WhitelistedStatefulOpRegistry {
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(__COUNTER__, name)
#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name)
-#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
- static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
- ::tensorflow::dataset::WhitelistedStatefulOpRegistry::Global()->Add( \
- name)
+#define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \
+ static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \
+ ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name)
} // namespace tensorflow
diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h
index b184fd91e1..446c31b17f 100644
--- a/tensorflow/core/framework/device_base.h
+++ b/tensorflow/core/framework/device_base.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h"
@@ -89,6 +90,15 @@ class DeviceContext : public core::RefCounted {
Tensor* cpu_tensor, StatusCallback done) {
done(errors::Internal("Unrecognized device type in device-to-CPU Copy"));
}
+
+ // If possible, wait for all events on *stream to complete then execute func.
+ // A non-OK Status is returned otherwise. The stream argument should be the
+ // one provided by GpuDeviceInfo. This function is not applicable to devices
+ // that don't provide such a value.
+ virtual Status ThenExecute(Device* device, stream_executor::Stream* stream,
+ std::function<void()> func) {
+ return errors::Internal("ThenExecute not supported by device");
+ }
};
// map[i] is the DeviceContext* for the node with id i, if i < map.size().
@@ -167,9 +177,9 @@ class DeviceBase {
return nullptr;
}
- // DEPRECATED: Use `this->GetAllocator()` or `this->GetScopedAllocator()`.
// This method is provided for backwards compatibility, and will be removed
// in a future release.
+ ABSL_DEPRECATED("Use `this->GetAllocator()` or `this->GetScopedAllocator()`.")
Allocator* GetStepAllocator(AllocatorAttributes attr, ResourceMgr*) {
return GetAllocator(attr);
}
@@ -205,10 +215,12 @@ class DeviceBase {
// This is overridden by GPU devices to reinitialize the derived
// type returned by MakeGpuDevice.
- virtual void ReinitializeGpuDevice(OpKernelContext* /*context*/,
- PerOpGpuDevice* /*device*/,
- DeviceContext* /*dc*/,
- Allocator* /*allocator*/) {}
+ virtual Status ReinitializeGpuDevice(OpKernelContext* /*context*/,
+ PerOpGpuDevice* /*device*/,
+ DeviceContext* /*dc*/,
+ Allocator* /*allocator*/) {
+ return Status::OK();
+ }
// Unimplemented by default
virtual const DeviceAttributes& attributes() const;
diff --git a/tensorflow/core/framework/fake_input.h b/tensorflow/core/framework/fake_input.h
index 103db47a99..c3062762ff 100644
--- a/tensorflow/core/framework/fake_input.h
+++ b/tensorflow/core/framework/fake_input.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
-#define TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_
+#define TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/types.h"
@@ -37,4 +37,4 @@ inline FakeInputFunctor FakeInput(std::initializer_list<DataType> dts) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_FAKE_INPUT_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_FAKE_INPUT_H_
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index 6b92e10d76..20f957190b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -504,7 +504,7 @@ string Print(const NodeDef& n) {
std::vector<string> dep;
for (StringPiece s : n.input()) {
if (str_util::ConsumePrefix(&s, "^")) {
- dep.push_back(std::string(s));
+ dep.emplace_back(s);
} else {
dat.push_back(s);
}
@@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
+ TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
@@ -1154,6 +1162,17 @@ Status FunctionLibraryDefinition::LookUp(
return default_registry_->LookUp(op, op_reg_data);
}
+string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const {
+ tf_shared_lock l(mu_);
+ int index = 0;
+ string name = strings::StrCat(prefix, index);
+ while (function_defs_.find(name) != function_defs_.end()) {
+ ++index;
+ name = strings::StrCat(prefix, index);
+ }
+ return name;
+}
+
const FunctionDef* FunctionLibraryDefinition::GetAttrImpl(
const NodeDef& ndef) const {
if (ndef.op() != kGradientOp) {
@@ -1283,6 +1302,18 @@ FunctionDef FunctionDefHelper::Create(
for (const auto& r : ret_def) {
fdef.mutable_ret()->insert({r.first, r.second});
}
+
+ auto* op_def_registry = OpRegistry::Global();
+ // Check if any op is stateful.
+ for (const auto& n : node_def) {
+ const OpDef* op_def = nullptr;
+ auto status = op_def_registry->LookUpOpDef(n.op, &op_def);
+ // Lookup can fail if e.g. we are calling a function that was not yet
+ // defined. If it happens, conservatively assume the op is stateful.
+ if (!status.ok() || op_def->is_stateful()) {
+ fdef.mutable_signature()->set_is_stateful(true);
+ }
+ }
return fdef;
}
@@ -1344,6 +1375,7 @@ FunctionDef FunctionDefHelper::Define(const string& name,
strings::StrCat(src.ret[0], ":", o.first, ":", i - o.second.first);
}
}
+ if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true);
}
// Returns
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index c81f4a4450..4d6d68e214 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_FUNCTION_H_
-#define TENSORFLOW_FRAMEWORK_FUNCTION_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
+#define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
#include <vector>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -41,7 +41,7 @@ class ProcessFunctionLibraryRuntime;
class ResourceMgr;
class Rendezvous;
class ScopedStepContainer;
-class StepStatsCollector;
+class StepStatsCollectorInterface;
class Node;
// FunctionDefHelper::Create is a convenient helper to construct a
@@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// a non-OK status if "func" was not found in the library, OK otherwise.
Status ReplaceFunction(const string& func, const FunctionDef& fdef);
+ // Replaces the gradient corresponding to `grad.function_name()`. Returns
+ // a non-OK status if "grad.function_name()" was not found in the library, OK
+ // otherwise.
+ Status ReplaceGradient(const GradientDef& grad);
+
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
@@ -358,6 +363,10 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
const OpRegistrationData** op_reg_data) const override
LOCKS_EXCLUDED(mu_);
+ // Generates new function name with the specified prefix that is unique
+ // across this library.
+ string UniqueFunctionName(StringPiece prefix) const LOCKS_EXCLUDED(mu_);
+
// Ops created for function arguments bear the name given by `kArgOp`; those
// created for return values bear the name given by `kRetOp`.
static constexpr const char* const kArgOp = "_Arg";
@@ -490,6 +499,11 @@ class FunctionLibraryRuntime {
// Instantiates the function using an executor of the given type. If empty,
// the default TensorFlow executor will be used.
string executor_type;
+
+ // If true, the runtime will attempt to create kernels for the function at
+ // instantiation time, rather than on the first run. This can be used to
+ // surface errors earlier.
+ bool create_kernels_eagerly = false;
};
typedef uint64 Handle;
virtual Status Instantiate(const string& function_name, AttrSlice attrs,
@@ -527,7 +541,7 @@ class FunctionLibraryRuntime {
CancellationManager* cancellation_manager = nullptr;
CollectiveExecutor* collective_executor = nullptr;
ScopedStepContainer* step_container = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
@@ -705,9 +719,10 @@ Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def,
#define REGISTER_OP_GRADIENT_UNIQ_HELPER(ctr, name, fn) \
REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn)
-#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
- static bool unused_grad_##ctr = SHOULD_REGISTER_OP_GRADIENT && \
- ::tensorflow::gradient::RegisterOp(name, fn)
+#define REGISTER_OP_GRADIENT_UNIQ(ctr, name, fn) \
+ static bool unused_grad_##ctr TF_ATTRIBUTE_UNUSED = \
+ SHOULD_REGISTER_OP_GRADIENT && \
+ ::tensorflow::gradient::RegisterOp(name, fn)
namespace gradient {
// Register a gradient creator for the "op".
@@ -731,4 +746,4 @@ GET_ATTR(bool)
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_FUNCTION_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_H_
diff --git a/tensorflow/core/framework/function_testlib.cc b/tensorflow/core/framework/function_testlib.cc
index 41270b8e5e..0445c242e9 100644
--- a/tensorflow/core/framework/function_testlib.cc
+++ b/tensorflow/core/framework/function_testlib.cc
@@ -49,8 +49,8 @@ NodeDef NDef(StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs,
gtl::ArraySlice<std::pair<string, FDH::AttrValueWrapper>> attrs,
const string& device) {
NodeDef n;
- n.set_name(name.ToString());
- n.set_op(op.ToString());
+ n.set_name(string(name));
+ n.set_op(string(op));
for (const auto& in : inputs) n.add_input(in);
n.set_device(device);
for (auto na : attrs) n.mutable_attr()->insert({na.first, na.second.proto});
@@ -91,6 +91,31 @@ FunctionDef IsZero() {
});
}
+FunctionDef RandomUniform() {
+ const Tensor kZero = test::AsScalar<int64>(0);
+
+ return FDH::Define(
+ // Name
+ "RandomUniform",
+ // Args
+ {"x: T"},
+ // Return values
+ {"random_uniform: int64"},
+ // Attr def
+ {"T:{float, double, int32, int64, string}"},
+ {{{"random_uniform/shape"},
+ "Const",
+ {},
+ {{"value", kZero}, {"dtype", DT_INT64}}},
+ {{"random_uniform"},
+ "RandomUniform",
+ {"random_uniform/shape"},
+ {{"T", DT_INT32},
+ {"Tout", DT_FLOAT},
+ {"seed", 87654321},
+ {"seed2", 42}}}});
+}
+
FunctionDef XTimesTwo() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
@@ -110,6 +135,22 @@ FunctionDef XTimesTwo() {
});
}
+FunctionDef XAddX() {
+ return FDH::Define(
+ // Name
+ "XAddX",
+ // Args
+ {"x: T"},
+ // Return values
+ {"y: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"y"}, "Add", {"x", "x"}, {{"T", "$T"}}},
+ });
+}
+
FunctionDef XTimesTwoInt32() {
const Tensor kTwo = test::AsScalar<int64>(2);
return FDH::Define(
@@ -219,6 +260,62 @@ FunctionDef InvalidControlFlow() {
{{"o", "add:z"}});
}
+FunctionDef LessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "LessThanOrEqualToN",
+ // Args
+ {"x: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"y"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "y"}, {{"T", "$T"}}},
+ });
+}
+
+FunctionDef XPlusOneXTimesY() {
+ const Tensor kOne = test::AsScalar<int64>(1);
+ return FDH::Define(
+ // Name
+ "XPlusOneXTimesY",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"s: T", "t: T"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {{{"one"}, "Const", {}, {{"value", kOne}, {"dtype", DT_INT64}}},
+ {{"increment"}, "Cast", {"one"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"s"}, "Add", {"x", "increment"}, {{"T", "$T"}}},
+ {{"t"}, "Mul", {"x", "y"}, {{"T", "$T"}}}});
+}
+
+FunctionDef XYXLessThanOrEqualToN(int64 N) {
+ const Tensor kN = test::AsScalar<int64>(N);
+ return FDH::Define(
+ // Name
+ "XYXLessThanOrEqualToN",
+ // Args
+ {"x: T", "y: T"},
+ // Return values
+ {"z: bool"},
+ // Attr def
+ {"T: {float, double, int32, int64}"},
+ // Nodes
+ {
+ {{"N"}, "Const", {}, {{"value", kN}, {"dtype", DT_INT64}}},
+ {{"N1"}, "Cast", {"N"}, {{"SrcT", DT_INT64}, {"DstT", "$T"}}},
+ {{"z"}, "LessEqual", {"x", "N1"}, {{"T", "$T"}}},
+ });
+}
+
void FunctionTestSchedClosure(std::function<void()> fn) {
static thread::ThreadPool* w =
new thread::ThreadPool(Env::Default(), "Test", 8);
diff --git a/tensorflow/core/framework/function_testlib.h b/tensorflow/core/framework/function_testlib.h
index af08d296b2..a01743423b 100644
--- a/tensorflow/core/framework/function_testlib.h
+++ b/tensorflow/core/framework/function_testlib.h
@@ -63,6 +63,9 @@ GraphDef GDef(gtl::ArraySlice<NodeDef> nodes,
// x:T -> x * 2.
FunctionDef XTimesTwo();
+// x:T -> x + x.
+FunctionDef XAddX();
+
// x:T -> x * 2, where x is int32.
FunctionDef XTimesTwoInt32();
@@ -81,12 +84,24 @@ FunctionDef NonZero();
// x: T -> bool.
FunctionDef IsZero();
+// x: T -> int64
+FunctionDef RandomUniform();
+
// x:T, y:T -> y:T, x:T
FunctionDef Swap();
// Contains malformed control flow which can't be run by the executor.
FunctionDef InvalidControlFlow();
+// x:T -> x <= N.
+FunctionDef LessThanOrEqualToN(int64 N);
+
+// x:T, y:T -> x+1, x*y
+FunctionDef XPlusOneXTimesY();
+
+// x:T, y:T -> x <= N
+FunctionDef XYXLessThanOrEqualToN(int64 N);
+
void FunctionTestSchedClosure(std::function<void()> fn);
} // end namespace function
diff --git a/tensorflow/core/framework/graph_def_util.h b/tensorflow/core/framework/graph_def_util.h
index 525e84a989..2f8d5e8f51 100644
--- a/tensorflow/core/framework/graph_def_util.h
+++ b/tensorflow/core/framework/graph_def_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
#include <set>
#include "tensorflow/core/framework/op.h"
@@ -118,4 +118,4 @@ Status StrippedOpListForGraph(const GraphDef& graph_def,
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_GRAPH_DEF_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/kernel_def_builder.h b/tensorflow/core/framework/kernel_def_builder.h
index 2966aa58de..32dd21f94e 100644
--- a/tensorflow/core/framework/kernel_def_builder.h
+++ b/tensorflow/core/framework/kernel_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
-#define TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -84,4 +84,4 @@ KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_KERNEL_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/log_memory.h b/tensorflow/core/framework/log_memory.h
index faef7b8e98..1b926ddaa3 100644
--- a/tensorflow/core/framework/log_memory.h
+++ b/tensorflow/core/framework/log_memory.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
-#define TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_
+#define TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -108,4 +108,4 @@ class LogMemory {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_LOG_MEMORY_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_LOG_MEMORY_H_
diff --git a/tensorflow/core/framework/lookup_interface.h b/tensorflow/core/framework/lookup_interface.h
index 1381dd66a5..0622dd06cb 100644
--- a/tensorflow/core/framework/lookup_interface.h
+++ b/tensorflow/core/framework/lookup_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
-#define TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor.h"
@@ -142,4 +142,4 @@ class LookupInterface : public ResourceBase {
} // namespace lookup
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_LOOKUP_INTERFACE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_LOOKUP_INTERFACE_H_
diff --git a/tensorflow/core/framework/memory_types.h b/tensorflow/core/framework/memory_types.h
index d3918513d3..f719131bcb 100644
--- a/tensorflow/core/framework/memory_types.h
+++ b/tensorflow/core/framework/memory_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
@@ -35,4 +35,4 @@ Status MemoryTypesForNode(const OpRegistryInterface* op_registry,
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_MEMORY_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_MEMORY_TYPES_H_
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
new file mode 100644
index 0000000000..bfdb3a6658
--- /dev/null
+++ b/tensorflow/core/framework/model.cc
@@ -0,0 +1,420 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/model.h"
+
+#include <memory>
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
+void Model::Node::CollectTunables(
+ std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
+ tf_shared_lock l(mu_);
+ for (auto input : inputs_) {
+ input->CollectTunables(tunables);
+ }
+ switch (type_) {
+ case Type::MAP_AND_BATCH:
+ case Type::PARALLEL_INTERLEAVE_V2:
+ case Type::PARALLEL_MAP: {
+ if (auto* tunable_param =
+ gtl::FindOrNull(tunable_params_, "parallelism")) {
+ tunables->push_back(*tunable_param);
+ }
+ return;
+ }
+ default:
+ return;
+ }
+}
+
+int64 Model::Node::GetParameterValue(const string& name) {
+ if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
+ return (*tunable_param)->value;
+ }
+ return constant_params_[name];
+}
+
+int64 Model::Node::ProcessingTimeLocked() {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::MAP_AND_BATCH:
+ case Type::PADDED_BATCH: {
+ int64 batch_size = GetParameterValue("batch_size");
+ return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ return NanosPerElementLocked() +
+ static_cast<int64>(ratio *
+ static_cast<double>(ProcessingTimeForInputs()));
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE:
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 processing_time =
+ ProcessingTimeForInputs() - inputs_.front()->ProcessingTime();
+ return NanosPerElementLocked() +
+ static_cast<double>(processing_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::PARALLEL_MAP:
+ case Type::PREFETCH:
+ // TODO(jsimsa): use processing time history as a prior for future inputs
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ return NanosPerElementLocked() + ProcessingTimeForInputs();
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+int64 Model::Node::OutputTimeLocked(std::vector<int64>* input_times) {
+ switch (type_) {
+ case Type::BATCH:
+ case Type::PADDED_BATCH: {
+ double batch_size = GetParameterValue("batch_size");
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) /
+ batch_size);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ batch_size * OutputTimeForInputs(input_times);
+ }
+ case Type::FILTER: {
+ std::shared_ptr<Node> input = inputs_.front();
+ int64 old_value = (*input_times)[input_times->size() - 1];
+ double ratio = static_cast<double>(input->num_elements()) /
+ static_cast<double>(num_elements_);
+ (*input_times)[input_times->size() - 1] = static_cast<int64>(
+ static_cast<double>(old_value + NanosPerElementLocked()) / ratio);
+ auto cleanup = gtl::MakeCleanup([input_times, old_value]() {
+ (*input_times)[input_times->size() - 1] = old_value;
+ });
+ return NanosPerElementLocked() +
+ static_cast<int64>(
+ static_cast<double>(OutputTimeForInputs(input_times)) * ratio);
+ }
+ case Type::FLAT_MAP:
+ case Type::INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ // TODO(jsimsa): use cycle length metadata instead of `inputs_.size() - 1`
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1));
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ int64 output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ return NanosPerElementLocked() +
+ static_cast<double>(output_time) /
+ static_cast<double>(inputs_.size() - 1);
+ }
+ case Type::MAP_AND_BATCH: {
+ double batch_size = GetParameterValue("batch_size");
+ double parallelism = GetParameterValue("parallelism");
+ int64 delta =
+ static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
+ (batch_size * parallelism));
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ batch_size * OutputTimeForInputs(input_times));
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism = GetParameterValue("parallelism");
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_INTERLEAVE_V2: {
+ // TODO(jsimsa): model the first input
+ if (inputs_.size() <= 1) {
+ return NanosPerElementLocked();
+ }
+ int64 delta = static_cast<double>(NanosPerElementLocked()) *
+ static_cast<double>(inputs_.size() - 1);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 inputs_output_time = OutputTimeForInputs(input_times) -
+ inputs_.front()->OutputTime(input_times);
+ double parallelism =
+ std::min(static_cast<int>(GetParameterValue("cycle_length")),
+ static_cast<int>(GetParameterValue("parallelism")));
+ int64 output_time =
+ NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
+ static_cast<double>(inputs_.size() - 1)) /
+ parallelism);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PARALLEL_MAP: {
+ double parallelism =
+ std::min(port::NumSchedulableCPUs(),
+ static_cast<int>(GetParameterValue("parallelism")));
+ int64 delta = static_cast<int64>(
+ static_cast<double>(NanosPerElementLocked()) / parallelism);
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ int64 output_time =
+ static_cast<double>(NanosPerElementLocked()) / parallelism +
+ OutputTimeForInputs(input_times);
+ return std::max(0LL,
+ output_time - input_times->at(input_times->size() - 2));
+ }
+ case Type::PREFETCH: {
+ int64 delta = NanosPerElementLocked();
+ input_times->push_back(delta);
+ auto cleanup =
+ gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
+ return std::max(0LL, NanosPerElementLocked() +
+ OutputTimeForInputs(input_times) -
+ input_times->at(input_times->size() - 2));
+ }
+ case Type::CACHE:
+ case Type::CONCATENATE:
+ case Type::MAP:
+ case Type::REPEAT:
+ case Type::SHUFFLE:
+ case Type::SKIP:
+ case Type::TAKE:
+ case Type::ZIP: {
+ int64 delta = NanosPerElementLocked();
+ (*input_times)[input_times->size() - 1] += delta;
+ auto cleanup = gtl::MakeCleanup([input_times, delta]() {
+ (*input_times)[input_times->size() - 1] -= delta;
+ });
+ return NanosPerElementLocked() + OutputTimeForInputs(input_times);
+ }
+ default:
+ return NanosPerElementLocked();
+ }
+}
+
+void Model::AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, node_name);
+ if (node) {
+ (*node)->add_constant_param(parameter_name, value);
+ }
+}
+
+void Model::AddNode(const string& name, const string& output_name) {
+ // The name captures the sequence of iterators joined by `::`. We use the full
+ // sequence as the key in the lookup table, but only the last element of the
+ // sequence as the name node.
+ std::vector<string> tokens =
+ str_util::Split(name, ':', str_util::SkipEmpty());
+ // The output name might contain an index. We need to strip it to make it
+ // possible for the model to successfully identify the output node.
+ string sanitized_output_name = output_name;
+ if (str_util::EndsWith(output_name, "]")) {
+ sanitized_output_name = output_name.substr(0, output_name.rfind('['));
+ }
+ std::shared_ptr<Node> output;
+ mutex_lock l(mu_);
+ auto it = lookup_table_.find(sanitized_output_name);
+ if (it != lookup_table_.end()) {
+ output = it->second;
+ }
+ std::shared_ptr<Node> node(new Node(id_counter_++, tokens.back(), output));
+ if (!output_) {
+ output_ = node;
+ }
+ if (output) {
+ output->add_input(node);
+ }
+ lookup_table_.insert(std::make_pair(name, node));
+}
+
+void Model::AddProcessingTime(const string& name, int64 delta) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->add_processing_time(delta);
+ }
+}
+
+void Model::AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::shared_ptr<SharedState> state, int64 min,
+ int64 max) {
+ tf_shared_lock l(mu_);
+ auto node = *gtl::FindOrNull(lookup_table_, node_name);
+ DCHECK(node);
+ node->add_tunable_param(parameter_name, std::move(state), min, max);
+}
+
+// The optimization algorithm starts by setting all tunable parallelism
+// parameters to 1. It then repeatedly identifies the parameter whose increase
+// in parallelism decreases the output time the most. This process is repeated
+// until all parameters reach their maximum values or the projected output time
+// is less than or equal to the processing time needed to produce an element
+// divided by CPU budget.
+void Model::Optimize(int64 cpu_budget) {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ {
+ tf_shared_lock lock(mu_);
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
+ break;
+ }
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
+ }
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
+ }
+ best_tunable->value++;
+ }
+ }
+ VLOG(2) << "Number of knobs: " << tunables.size();
+ for (auto& tunable : tunables) {
+ VLOG(2) << "Setting tunable parameter: " << tunable->value;
+ mutex_lock l(*tunable->state->mu);
+ tunable->state->value = tunable->value;
+ tunable->state->cond_var->notify_all();
+ }
+}
+
+void Model::RecordElement(const string& name) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_element();
+ }
+}
+
+void Model::RecordStart(const string& name, bool stop_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ if (stop_output && (*node)->output()) {
+ (*node)->output()->record_stop();
+ }
+ (*node)->record_start();
+ }
+}
+
+void Model::RecordStop(const string& name, bool start_output) {
+ tf_shared_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node) {
+ (*node)->record_stop();
+ if (start_output && (*node)->output()) {
+ (*node)->output()->record_start();
+ }
+ }
+}
+
+void Model::RemoveNode(const string& name) {
+ mutex_lock l(mu_);
+ auto node = gtl::FindOrNull(lookup_table_, name);
+ if (node && (*node)->output()) {
+ (*node)->output()->remove_input(*node);
+ }
+ lookup_table_.erase(name);
+}
+
+std::vector<std::shared_ptr<Model::Node::Tunable>> Model::CollectTunables() {
+ std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
+ output_->CollectTunables(&tunables);
+ return tunables;
+}
+
+int64 Model::OutputTime() {
+ std::vector<int64> input_times(1, 0);
+ return output_->OutputTime(&input_times);
+}
+
+int64 Model::ProcessingTime() { return output_->ProcessingTime(); }
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
new file mode 100644
index 0000000000..eae0fa70e8
--- /dev/null
+++ b/tensorflow/core/framework/model.h
@@ -0,0 +1,408 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
+
+#include <list>
+#include <memory>
+#include <string>
+#include <thread> // (b/114492873): move this include into core/platform
+#include <utility>
+#include <vector>
+
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace data {
+namespace model {
+
+// Represents thread-safe state that can be shared between an input pipeline and
+// the performance model.
+struct SharedState {
+ public:
+ explicit SharedState(int64 value, std::shared_ptr<mutex> mu,
+ std::shared_ptr<condition_variable> cond_var)
+ : value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {}
+
+ std::shared_ptr<mutex> mu;
+ std::shared_ptr<condition_variable> cond_var;
+ int64 value;
+};
+
+// Abstract representation of a TensorFlow input pipeline that can be used
+// for collecting runtime information and optimizing performance. It collects
+// runtime information about execution of the input pipeline that is used to
+// create a performance model, which is in turn used to identify optimal values
+// of tunable parameters.
+//
+// Developers of tf.data transformations are not expected to interact with this
+// class directly. Boiler plate code for creating the abstract representation of
+// the input pipeline and collecting runtime information has been added to the
+// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
+class Model {
+ public:
+ Model() = default;
+
+ // Adds a constant parameter for the given node.
+ void AddConstantParameter(const string& node_name,
+ const string& parameter_name, int64 value)
+ LOCKS_EXCLUDED(mu_);
+
+ // Adds a node with the given name and given output (identified by name).
+ void AddNode(const string& name, const string& output_name)
+ LOCKS_EXCLUDED(mu_);
+
+ // Increments the processing time for the given node..
+ void AddProcessingTime(const string& name, int64 delta) LOCKS_EXCLUDED(mu_);
+
+ // Adds a tunable parameter for the given node.
+ void AddTunableParameter(const string& node_name,
+ const string& parameter_name,
+ std::shared_ptr<SharedState> value, int64 min,
+ int64 max) LOCKS_EXCLUDED(mu_);
+
+ // Runs optimization.
+ void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
+
+ // Records that a node has produced an element.
+ void RecordElement(const string& name) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has started work. If `stop_output` is set, it
+ // also records that the output of the given node has stopped work.
+ void RecordStart(const string& name, bool stop_output) LOCKS_EXCLUDED(mu_);
+
+ // Records that the given node has stopped work. If `stop_output` is set, it
+ // also records that the output of the given node has started work.
+ void RecordStop(const string& name, bool start_output) LOCKS_EXCLUDED(mu_);
+
+ // Removes the given node.
+ void RemoveNode(const string& name) LOCKS_EXCLUDED(mu_);
+
+ private:
+ // Abstract representation of a TensorFlow input pipeline node. It collects
+ // information about inputs to this node, processing time spent executing the
+ // node logic, number of elements produced by the node, various other
+ // information (e.g. batch size or execution parallelism).
+ //
+ // Developers of tf.data transformations are not expected to interact with
+ // this class directly. Boiler plate code for creating the abstract
+ // representation of the input pipeline and collecting common information has
+ // been added to the implementation of `DatasetBase` and `DatasetBaseIterator`
+ // respectively.
+ //
+ // In addition, `DatasetBaseIterator` provides wrappers that can be used for
+ // transformation-specific information collection. The `SetMetadata` wrapper
+ // can be used to pass arbitrary metadata to the modeling framework, while the
+ // `StartWork` and `StopWork` wrappers should be used to correctly account for
+ // processing time of multi-threaded transformation that yield the CPU; such
+ // transformations should invoke `StartWork()` when a transformation thread
+ // starts executing (e.g. when created or woken up) and `StopWork()` when a
+ // transformation thread stops executing (e.g. when returning or waiting).
+ //
+ // TODO(jsimsa): Create an API to capture the abstract semantics of each
+ // tf.data transformation and replace switch-case blocks with inheritance.
+ class Node {
+ public:
+ // Represents a tunable parameter.
+ struct Tunable {
+ Tunable(std::shared_ptr<SharedState> state, int64 min, int64 max)
+ : value(state->value), min(min), max(max), state(std::move(state)) {}
+
+ // Identifies the model value of the parameter. This can be different from
+ // the actual value (e.g. during optimization search).
+ int64 value;
+
+ // Identifies the minimum value of the parameter.
+ int64 min;
+
+ // Identifies the maximum value of the parameter.
+ int64 max;
+
+ // Shared state of the parameter.
+ std::shared_ptr<SharedState> state;
+ };
+
+ Node(int64 id, const string& name, std::shared_ptr<Node> output)
+ : id_(id), name_(name), type_(TypeFromName(name)), output_(output) {}
+
+ // Adds a constant parameter.
+ void add_constant_param(const string& name, int64 value)
+ LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ constant_params_[name] = value;
+ }
+
+ // Adds an input.
+ void add_input(std::shared_ptr<Node> node) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.push_back(node);
+ }
+
+ // Increments the aggregate processing time by the given delta.
+ void add_processing_time(int64 delta) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ // Adds a tunable parameter.
+ void add_tunable_param(const string& name,
+ std::shared_ptr<SharedState> state, int64 min,
+ int64 max) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ tunable_params_[name] =
+ std::make_shared<Tunable>(std::move(state), min, max);
+ }
+
+ // Returns the unique node ID.
+ int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
+
+ // Returns the node inputs.
+ std::list<std::shared_ptr<Node>> inputs() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return inputs_;
+ }
+
+ // Returns the node name.
+ const string& name() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return name_;
+ }
+
+ // Returns the number of elements produced by the node.
+ int64 num_elements() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return num_elements_;
+ }
+
+ // Returns the node output.
+ std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return output_;
+ }
+
+ // Records that the node produced an element.
+ void record_element() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ num_elements_++;
+ }
+
+ // Records that a node thread has started executing.
+ void record_start() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ work_start_[std::this_thread::get_id()] = Env::Default()->NowNanos();
+ }
+
+ // Records that a node thread has stopped executing.
+ void record_stop() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ std::thread::id tid = std::this_thread::get_id();
+ auto start_time = gtl::FindOrNull(work_start_, tid);
+ DCHECK(start_time)
+ << "Encountered a stop event that was not preceded by a start event.";
+ if (start_time) {
+ processing_time_ += Env::Default()->NowNanos() - *start_time;
+ work_start_.erase(tid);
+ }
+ }
+
+ // Removes an input.
+ void remove_input(std::shared_ptr<Node> input) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ inputs_.remove(input);
+ }
+
+ // Set the node output.
+ void set_output(std::shared_ptr<Node> output) LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ output_ = output;
+ }
+
+ // Collects tunable parameters in the subtree rooted in this node.
+ void CollectTunables(std::vector<std::shared_ptr<Tunable>>* tunables)
+ LOCKS_EXCLUDED(mu_);
+
+ // Returns the per-element output time for this node.
+ int64 OutputTime(std::vector<int64>* input_times) LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return OutputTimeLocked(input_times);
+ }
+
+ // Returns the per-element processing time spent in the subtree rooted in
+ // this node.
+ int64 ProcessingTime() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return ProcessingTimeLocked();
+ }
+
+ private:
+ enum class Type {
+ BATCH = 0,
+ CACHE,
+ CONCATENATE,
+ FILTER,
+ FLAT_MAP,
+ INTERLEAVE,
+ MAP,
+ MAP_AND_BATCH,
+ PADDED_BATCH,
+ PARALLEL_INTERLEAVE,
+ PARALLEL_INTERLEAVE_V2,
+ PARALLEL_MAP,
+ PREFETCH,
+ REPEAT,
+ SHUFFLE,
+ SKIP,
+ TAKE,
+ ZIP,
+ UNKNOWN,
+ };
+
+ // Gets a value of the given parameter (tunable or constant).
+ int64 GetParameterValue(const string& name) SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in this node.
+ int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
+ tf_shared_lock l(mu_);
+ return NanosPerElementLocked();
+ }
+
+ int64 NanosPerElementLocked() SHARED_LOCKS_REQUIRED(mu_) {
+ if (num_elements_ == 0) {
+ return 0;
+ }
+ return (int64)((double)processing_time_ / (double)num_elements_);
+ }
+
+ int64 OutputTimeLocked(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTimeForInputs(std::vector<int64>* input_times)
+ SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->OutputTime(input_times);
+ }
+ return sum;
+ }
+
+ int64 ProcessingTimeLocked() SHARED_LOCKS_REQUIRED(mu_);
+
+ // Returns the per-element processing time spent in the inputs of this node.
+ int64 ProcessingTimeForInputs() SHARED_LOCKS_REQUIRED(mu_) {
+ int64 sum = 0;
+ for (auto input : inputs_) {
+ sum += input->ProcessingTime();
+ }
+ return sum;
+ }
+
+ Type TypeFromName(const string& name) SHARED_LOCKS_REQUIRED(mu_) {
+ if (name_ == "Batch") {
+ return Type::BATCH;
+ }
+ if (str_util::EndsWith(name_, "Cache")) {
+ return Type::CACHE;
+ }
+ if (name_ == "Concatenate") {
+ return Type::CONCATENATE;
+ }
+ if (name_ == "Filter") {
+ return Type::FILTER;
+ }
+ if (name_ == "FlatMap") {
+ return Type::FLAT_MAP;
+ }
+ if (name_ == "Interleave") {
+ return Type::INTERLEAVE;
+ }
+ if (name_ == "Map") {
+ return Type::MAP;
+ }
+ if (name_ == "MapAndBatch") {
+ return Type::MAP_AND_BATCH;
+ }
+ if (name_ == "PaddedBatch") {
+ return Type::PADDED_BATCH;
+ }
+ if (name_ == "ParallelInterleave") {
+ return Type::PARALLEL_INTERLEAVE;
+ }
+ if (name_ == "ParallelInterleaveV2") {
+ return Type::PARALLEL_INTERLEAVE_V2;
+ }
+ if (name_ == "ParallelMap") {
+ return Type::PARALLEL_MAP;
+ }
+ if (name_ == "Prefetch") {
+ return Type::PREFETCH;
+ }
+ if (str_util::EndsWith(name_, "Repeat")) {
+ return Type::REPEAT;
+ }
+ if (name_ == "Shuffle") {
+ return Type::SHUFFLE;
+ }
+ if (str_util::EndsWith(name_, "Skip")) {
+ return Type::SKIP;
+ }
+ if (str_util::EndsWith(name_, "Take")) {
+ return Type::TAKE;
+ }
+ if (name_ == "Zip") {
+ return Type::ZIP;
+ }
+ return Type::UNKNOWN;
+ }
+
+ mutex mu_;
+ const int64 id_;
+ const string name_;
+ const Type type_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+ int64 num_elements_ GUARDED_BY(mu_) = 0;
+ std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
+ std::map<string, int64> constant_params_ GUARDED_BY(mu_);
+ // Tunables are shared with the model during optimization.
+ std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
+ std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ };
+
+ std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
+ SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 OutputTime() SHARED_LOCKS_REQUIRED(mu_);
+
+ int64 ProcessingTime() SHARED_LOCKS_REQUIRED(mu_);
+
+ // Used for coordination between different input pipeline threads. Exclusive
+ // access is required only when adding or removing nodes. Concurrent access to
+ // existing nodes is protected by a node mutex.
+ mutex mu_;
+ int64 id_counter_ GUARDED_BY(mu_) = 1;
+ std::shared_ptr<Node> output_ GUARDED_BY(mu_);
+ std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
+};
+
+} // namespace model
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_MODEL_H_
diff --git a/tensorflow/core/framework/node_def_builder.cc b/tensorflow/core/framework/node_def_builder.cc
index 8e00bfe4f8..348a825af9 100644
--- a/tensorflow/core/framework/node_def_builder.cc
+++ b/tensorflow/core/framework/node_def_builder.cc
@@ -24,23 +24,22 @@ limitations under the License.
namespace tensorflow {
NodeDefBuilder::NodeOut::NodeOut(StringPiece n, int i, DataType dt)
- : node(std::string(n)), index(i), data_type(dt) {}
+ : node(n), index(i), data_type(dt) {}
NodeDefBuilder::NodeOut::NodeOut() {
// uninitialized, call Reset() before use.
}
void NodeDefBuilder::NodeOut::Reset(StringPiece n, int i, DataType dt) {
- node = std::string(n);
+ node = string(n);
index = i;
data_type = dt;
}
NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
const OpRegistryInterface* op_registry) {
- node_def_.set_name(std::string(name));
- const Status status =
- op_registry->LookUpOpDef(std::string(op_name), &op_def_);
+ node_def_.set_name(string(name));
+ const Status status = op_registry->LookUpOpDef(string(op_name), &op_def_);
if (status.ok()) {
Initialize();
} else {
@@ -51,7 +50,7 @@ NodeDefBuilder::NodeDefBuilder(StringPiece name, StringPiece op_name,
NodeDefBuilder::NodeDefBuilder(StringPiece name, const OpDef* op_def)
: op_def_(op_def) {
- node_def_.set_name(std::string(name));
+ node_def_.set_name(string(name));
Initialize();
}
@@ -171,7 +170,7 @@ void NodeDefBuilder::AddInput(StringPiece src_node, int src_index) {
} else if (src_index > 0) {
node_def_.add_input(strings::StrCat(src_node, ":", src_index));
} else {
- node_def_.add_input(std::string(src_node));
+ node_def_.add_input(string(src_node));
}
}
@@ -194,12 +193,12 @@ void NodeDefBuilder::VerifyInputRef(const OpDef::ArgDef* input_arg,
}
NodeDefBuilder& NodeDefBuilder::ControlInput(StringPiece src_node) {
- control_inputs_.push_back(std::string(src_node));
+ control_inputs_.emplace_back(src_node);
return *this;
}
NodeDefBuilder& NodeDefBuilder::Device(StringPiece device_spec) {
- node_def_.set_device(std::string(device_spec));
+ node_def_.set_device(string(device_spec));
return *this;
}
diff --git a/tensorflow/core/framework/node_def_builder.h b/tensorflow/core/framework/node_def_builder.h
index c138332beb..ad07ec5480 100644
--- a/tensorflow/core/framework/node_def_builder.h
+++ b/tensorflow/core/framework/node_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
-#define TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
#include <functional>
#include <vector>
@@ -175,4 +175,4 @@ class NodeDefBuilder {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/node_def_util.cc b/tensorflow/core/framework/node_def_util.cc
index 0bd79366eb..43ac1d0ada 100644
--- a/tensorflow/core/framework/node_def_util.cc
+++ b/tensorflow/core/framework/node_def_util.cc
@@ -254,7 +254,7 @@ DEFINE_GET_ATTR(NameAttrList, func, "func", emplace_back, v, ;);
#undef DEFINE_GET_ATTR
bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name) {
- return node_def.attr().find(std::string(attr_name)) != node_def.attr().end();
+ return node_def.attr().find(string(attr_name)) != node_def.attr().end();
}
static const string& kEmptyString = *new string();
@@ -372,6 +372,14 @@ Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
node_def.name());
}
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs) {
+ for (const auto& arg : op_def.input_arg()) {
+ TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
+ }
+ return Status::OK();
+}
+
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int output_port, DataType* output_type) {
DataTypeVector output_types;
@@ -397,12 +405,18 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs) {
- for (const auto& arg : op_def.input_arg()) {
- TF_RETURN_IF_ERROR(AddArgToSig(node_def, arg, inputs));
- }
+ TF_RETURN_IF_ERROR(InputTypesForNode(node_def, op_def, inputs));
return OutputTypesForNode(node_def, op_def, outputs);
}
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs) {
+ DataTypeVector outputs;
+ TF_RETURN_IF_ERROR(OutputTypesForNode(node_def, op_def, &outputs));
+ *num_outputs = outputs.size();
+ return Status::OK();
+}
+
Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def) {
if (node_def.op() != op_def.name()) {
return errors::InvalidArgument("NodeDef op '", node_def.op(),
@@ -653,7 +667,7 @@ Status AttachDef(const Status& status, const Node& node) {
void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def) {
node_def->mutable_attr()->insert(
- AttrValueMap::value_type(std::string(name), value));
+ AttrValueMap::value_type(string(name), value));
}
#define ADD_NODE_ATTR(T) \
@@ -691,7 +705,7 @@ ADD_NODE_ATTR(gtl::ArraySlice<NameAttrList>)
#undef ADD_NODE_ATTR
void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map) {
- map->insert(AttrValueMap::value_type(std::string(name), value));
+ map->insert(AttrValueMap::value_type(string(name), value));
}
#define ADD_ATTR(T) \
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index c012b7c3d3..0ff67554eb 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -13,11 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
-#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
@@ -249,6 +248,10 @@ const string& GetNodeAttrString(const AttrSlice& attrs, StringPiece attr_name);
// REQUIRES: ValidateOpDef(op_def).ok()
Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
int input_port, DataType* input_type);
+// Computes the input types for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
+ DataTypeVector* inputs);
// Computes the output type for a specific node output.
// REQUIRES: ValidateOpDef(op_def).ok()
Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
@@ -261,6 +264,10 @@ Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
// REQUIRES: ValidateOpDef(op_def).ok()
Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
DataTypeVector* inputs, DataTypeVector* outputs);
+// Computes the number of outputs for a specific node.
+// REQUIRES: ValidateOpDef(op_def).ok()
+Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
+ int* num_outputs);
// Validates that the NodeDef:
// * Defines all expected attrs from the OpDef.
@@ -312,4 +319,4 @@ Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
NodeDef* node_def);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_NODE_DEF_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/node_def_util_test.cc b/tensorflow/core/framework/node_def_util_test.cc
index 74cc594863..d9d437024a 100644
--- a/tensorflow/core/framework/node_def_util_test.cc
+++ b/tensorflow/core/framework/node_def_util_test.cc
@@ -370,6 +370,48 @@ TEST(NodeDefUtilTest, ValidSyntax) {
"Illegal op input name 'a:00");
}
+TEST(InputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(InputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_FLOAT);
+ EXPECT_EQ(types[1], DT_INT32);
+
+ DataType type;
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_FLOAT);
+ EXPECT_TRUE(InputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_INT32);
+ EXPECT_FALSE(InputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
+TEST(OutputTypesForNode, Simple) {
+ const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
+ .Input("a: float")
+ .Input("b: int32")
+ .Output("c: string")
+ .Output("d: bool"));
+ const NodeDef node_def = ToNodeDef(
+ NodeDefBuilder("simple", &op_def).Input(FakeInput()).Input(FakeInput()));
+ DataTypeVector types;
+ EXPECT_TRUE(OutputTypesForNode(node_def, op_def, &types).ok());
+ EXPECT_EQ(types[0], DT_STRING);
+ EXPECT_EQ(types[1], DT_BOOL);
+
+ DataType type;
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 0, &type).ok());
+ EXPECT_EQ(type, DT_STRING);
+ EXPECT_TRUE(OutputTypeForNode(node_def, op_def, 1, &type).ok());
+ EXPECT_EQ(type, DT_BOOL);
+ EXPECT_FALSE(OutputTypeForNode(node_def, op_def, 2, &type).ok());
+}
+
TEST(NameRangesForNodeTest, Simple) {
const OpDef op_def = ToOpDef(OpDefBuilder("Simple")
.Input("a: float")
diff --git a/tensorflow/core/framework/numeric_op.h b/tensorflow/core/framework/numeric_op.h
index 4538ff053c..0167e21f11 100644
--- a/tensorflow/core/framework/numeric_op.h
+++ b/tensorflow/core/framework/numeric_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
-#define TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -110,4 +110,4 @@ class BinaryElementWiseOp : public BinaryOp<T> {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_NUMERIC_OP_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_OP_H_
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index b1d0127809..3236d1897c 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
#include <complex>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -122,4 +122,4 @@ struct hash<Eigen::half> {
} // namespace std
#endif // _MSC_VER
-#endif // TENSORFLOW_FRAMEWORK_NUMERIC_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_NUMERIC_TYPES_H_
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 3ccca4090d..81ed5f95f0 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_H_
-#define TENSORFLOW_FRAMEWORK_OP_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_H_
#include <functional>
#include <unordered_map>
@@ -209,16 +209,16 @@ template <>
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
- OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
- builder_.Attr(spec);
+ OpDefBuilderWrapper<true>& Attr(string spec) {
+ builder_.Attr(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Input(StringPiece spec) {
- builder_.Input(spec);
+ OpDefBuilderWrapper<true>& Input(string spec) {
+ builder_.Input(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Output(StringPiece spec) {
- builder_.Output(spec);
+ OpDefBuilderWrapper<true>& Output(string spec) {
+ builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
@@ -237,12 +237,12 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
- OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
- builder_.Deprecated(version, explanation);
+ OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
+ builder_.Deprecated(version, std::move(explanation));
return *this;
}
- OpDefBuilderWrapper<true>& Doc(StringPiece text) {
- builder_.Doc(text);
+ OpDefBuilderWrapper<true>& Doc(string text) {
+ builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
@@ -309,4 +309,4 @@ struct OpDefBuilderReceiver {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_H_
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 91eb6c0672..8a9bb63182 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -526,32 +526,32 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
-OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(std::string(op_name)); // NOLINT
+OpDefBuilder::OpDefBuilder(string op_name) {
+ op_def()->set_name(std::move(op_name));
}
-OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
- attrs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Attr(string spec) {
+ attrs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
- inputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Input(string spec) {
+ inputs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
- outputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Output(string spec) {
+ outputs_.push_back(std::move(spec));
return *this;
}
#ifndef TF_LEAN_BINARY
-OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+OpDefBuilder& OpDefBuilder::Doc(string text) {
if (!doc_.empty()) {
errors_.push_back(
strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
} else {
- doc_.assign(text.data(), text.size());
+ doc_ = std::move(text);
}
return *this;
}
@@ -577,14 +577,14 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
-OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
+OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
if (op_def()->has_deprecation()) {
errors_.push_back(
strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(std::string(explanation));
+ deprecation->set_explanation(std::move(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index fbfb4018aa..8077b20598 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -16,8 +16,8 @@ limitations under the License.
// Class and associated machinery for specifying an Op's OpDef and shape
// inference function for Op registration.
-#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
-#define TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_
#include <string>
#include <vector>
@@ -51,7 +51,7 @@ struct OpRegistrationData {
class OpDefBuilder {
public:
// Constructs an OpDef with just the name field set.
- explicit OpDefBuilder(StringPiece op_name);
+ explicit OpDefBuilder(string op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
@@ -84,7 +84,7 @@ class OpDefBuilder {
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
- OpDefBuilder& Attr(StringPiece spec);
+ OpDefBuilder& Attr(string spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
@@ -101,8 +101,8 @@ class OpDefBuilder {
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
- OpDefBuilder& Input(StringPiece spec);
- OpDefBuilder& Output(StringPiece spec);
+ OpDefBuilder& Input(string spec);
+ OpDefBuilder& Output(string spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
@@ -112,7 +112,7 @@ class OpDefBuilder {
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
- OpDefBuilder& Deprecated(int version, StringPiece explanation);
+ OpDefBuilder& Deprecated(int version, string explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
@@ -128,9 +128,9 @@ class OpDefBuilder {
// to suppress the automatically-generated type documentation in
// generated output.
#ifndef TF_LEAN_BINARY
- OpDefBuilder& Doc(StringPiece text);
+ OpDefBuilder& Doc(string text);
#else
- OpDefBuilder& Doc(StringPiece text) { return *this; }
+ OpDefBuilder& Doc(string text) { return *this; }
#endif
// Sets the shape function to be used for shape inference.
@@ -162,4 +162,4 @@ class OpDefBuilder {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_BUILDER_H_
diff --git a/tensorflow/core/framework/op_def_util.cc b/tensorflow/core/framework/op_def_util.cc
index 9be0dc69d2..3597f43d51 100644
--- a/tensorflow/core/framework/op_def_util.cc
+++ b/tensorflow/core/framework/op_def_util.cc
@@ -172,6 +172,15 @@ const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def) {
return nullptr;
}
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def) {
+ for (int i = 0; i < api_def.in_arg_size(); ++i) {
+ if (api_def.in_arg(i).name() == name) {
+ return &api_def.in_arg(i);
+ }
+ }
+ return nullptr;
+}
+
#define VALIDATE(EXPR, ...) \
do { \
if (!(EXPR)) { \
diff --git a/tensorflow/core/framework/op_def_util.h b/tensorflow/core/framework/op_def_util.h
index 0ba1325a03..85afe2bdea 100644
--- a/tensorflow/core/framework/op_def_util.h
+++ b/tensorflow/core/framework/op_def_util.h
@@ -16,10 +16,11 @@ limitations under the License.
// TODO(josh11b): Probably not needed for OpKernel authors, so doesn't
// need to be as publicly accessible as other files in framework/.
-#ifndef TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_
#include <string>
+#include "tensorflow/core/framework/api_def.pb.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -47,6 +48,10 @@ OpDef::AttrDef* FindAttrMutable(StringPiece name, OpDef* op_def);
// Returns nullptr if no such attr is found.
const OpDef::ArgDef* FindInputArg(StringPiece name, const OpDef& op_def);
+// Searches api_def for input argument with the indicated name.
+// Returns nullptr if no such attr is found.
+const ApiDef::Arg* FindInputArg(StringPiece name, const ApiDef& api_def);
+
// Produce a human-readable version of an op_def that is more concise
// than a text-format proto. Excludes descriptions.
string SummarizeOpDef(const OpDef& op_def);
@@ -98,4 +103,4 @@ uint64 OpDefHash(const OpDef& o);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_DEF_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_DEF_UTIL_H_
diff --git a/tensorflow/core/framework/op_gen_lib.cc b/tensorflow/core/framework/op_gen_lib.cc
index 4b56d807df..505ab54775 100644
--- a/tensorflow/core/framework/op_gen_lib.cc
+++ b/tensorflow/core/framework/op_gen_lib.cc
@@ -186,7 +186,7 @@ static bool FindMultiline(StringPiece line, size_t colon, string* end) {
while (str_util::ConsumePrefix(&line, " ")) {
}
if (str_util::ConsumePrefix(&line, "<<")) {
- *end = std::string(line);
+ *end = string(line);
return true;
}
return false;
diff --git a/tensorflow/core/framework/op_gen_lib.h b/tensorflow/core/framework/op_gen_lib.h
index 533dd64805..c269e2df04 100644
--- a/tensorflow/core/framework/op_gen_lib.h
+++ b/tensorflow/core/framework/op_gen_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
-#define TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_
+#define TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_
#include <string>
#include <unordered_map>
@@ -97,4 +97,4 @@ class ApiDefMap {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_OP_GEN_LIB_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_OP_GEN_LIB_H_
diff --git a/tensorflow/core/framework/op_kernel.cc b/tensorflow/core/framework/op_kernel.cc
index b285accce7..3e34bf0418 100644
--- a/tensorflow/core/framework/op_kernel.cc
+++ b/tensorflow/core/framework/op_kernel.cc
@@ -41,6 +41,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -80,10 +81,8 @@ Status MatchSignatureHelper(const DataTypeSlice expected_inputs,
// OpKernel ------------------------------------------------------------------
-// TODO(mrry): Convert to std::make_unique when available.
OpKernel::OpKernel(OpKernelConstruction* context)
- : OpKernel(context,
- std::unique_ptr<const NodeDef>(new NodeDef(context->def()))) {}
+ : OpKernel(context, MakeUnique<const NodeDef>(context->def())) {}
OpKernel::OpKernel(OpKernelConstruction* context,
std::unique_ptr<const NodeDef> node_def)
@@ -266,9 +265,12 @@ OpKernelContext::OpKernelContext(Params* params, int num_outputs)
params_->ensure_eigen_gpu_device();
if (params_->eigen_gpu_device != nullptr) {
Allocator* eigen_gpu_allocator = get_allocator(AllocatorAttributes());
- params_->device->ReinitializeGpuDevice(this, params_->eigen_gpu_device,
- params_->op_device_context,
- eigen_gpu_allocator);
+ Status s = params_->device->ReinitializeGpuDevice(
+ this, params_->eigen_gpu_device, params_->op_device_context,
+ eigen_gpu_allocator);
+ if (!s.ok()) {
+ SetStatus(s);
+ }
}
if (params_->record_tensor_accesses) {
referenced_tensors_.Init();
@@ -525,10 +527,8 @@ std::unique_ptr<Tensor> OpKernelContext::forward_input(
return nullptr;
}
}
- // TODO(rmlarsen): Use MakeUnique here. There is already a copy in
- // tensorflow/compiler/xla/ptr_util.h. Perhaps this should be part of
- // general cleanup of ownership in this code.
- std::unique_ptr<Tensor> output_tensor(new Tensor());
+
+ auto output_tensor = MakeUnique<Tensor>();
CHECK(output_tensor->CopyFrom(*input.tensor, output_shape));
return output_tensor;
}
@@ -913,7 +913,7 @@ void OpKernelContext::clear_recorded_memory() {
struct KernelRegistration {
KernelRegistration(const KernelDef& d, StringPiece c,
kernel_factory::OpKernelRegistrar::Factory f)
- : def(d), kernel_class_name(std::string(c)), factory(f) {}
+ : def(d), kernel_class_name(c), factory(f) {}
const KernelDef def;
const string kernel_class_name;
const kernel_factory::OpKernelRegistrar::Factory factory;
diff --git a/tensorflow/core/framework/op_kernel.h b/tensorflow/core/framework/op_kernel.h
index aab95b785b..4bbd6c3d7d 100644
--- a/tensorflow/core/framework/op_kernel.h
+++ b/tensorflow/core/framework/op_kernel.h
@@ -70,7 +70,7 @@ class OpRegistryInterface;
class ResourceMgr;
class ScopedStepContainer;
class CollectiveExecutor;
-class StepStatsCollector;
+class StepStatsCollectorInterface;
class OpKernel {
public:
@@ -372,18 +372,37 @@ class OpKernelConstruction {
template <typename ListType, typename ElementType>
class OpArgIterator {
public:
- typedef OpArgIterator<ListType, ElementType> ME;
+ using iterator_category = std::forward_iterator_tag;
+ using value_type = ElementType;
+ using pointer = ElementType*;
+ using reference = ElementType&;
+ using difference_type = ptrdiff_t;
+
OpArgIterator(const ListType* list, int i) : list_(list), i_(i) {}
- bool operator==(const ME& rhs) {
+
+ bool operator==(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ == rhs.i_;
}
- bool operator!=(const ME& rhs) {
+
+ bool operator!=(const OpArgIterator& rhs) {
DCHECK(list_ == rhs.list_);
return i_ != rhs.i_;
}
- void operator++() { ++i_; }
- ElementType& operator*() { return (*list_)[i_]; }
+
+ OpArgIterator operator++() { // prefix ++it
+ ++i_;
+ return *this;
+ }
+
+ OpArgIterator operator++(int) { // postfix it++
+ OpArgIterator old_value = *this;
+ ++i_;
+ return old_value;
+ }
+
+ reference operator*() { return (*list_)[i_]; }
+ pointer operator->() { return &(*list_)[i_]; }
private:
const ListType* const list_;
@@ -394,7 +413,7 @@ class OpArgIterator {
// that are passed to the op as a single named argument.
class OpInputList {
public:
- typedef OpArgIterator<OpInputList, const Tensor&> Iterator;
+ typedef OpArgIterator<OpInputList, const Tensor> Iterator;
OpInputList() : ctx_(nullptr), start_(0), stop_(0) {}
OpInputList(OpKernelContext* ctx, int start, int stop)
: ctx_(ctx), start_(start), stop_(stop) {}
@@ -569,7 +588,7 @@ class OpKernelContext {
CallFrameInterface* call_frame = nullptr;
FunctionLibraryRuntime* function_library = nullptr;
std::function<void(std::function<void()>)>* runner = nullptr;
- StepStatsCollector* stats_collector = nullptr;
+ StepStatsCollectorInterface* stats_collector = nullptr;
// TensorSliceReaderCache support.
checkpoint::TensorSliceReaderCacheWrapper* slice_reader_cache = nullptr;
@@ -984,7 +1003,7 @@ class OpKernelContext {
std::function<void(std::function<void()>)>* runner() const {
return params_->runner;
}
- StepStatsCollector* stats_collector() const {
+ StepStatsCollectorInterface* stats_collector() const {
return params_->stats_collector;
}
diff --git a/tensorflow/core/framework/op_segment.cc b/tensorflow/core/framework/op_segment.cc
index dfc5aa7747..75ed4a4eaf 100644
--- a/tensorflow/core/framework/op_segment.cc
+++ b/tensorflow/core/framework/op_segment.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_segment.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -99,4 +100,11 @@ void OpSegment::RemoveHold(const string& session_handle) {
delete item;
}
+bool OpSegment::ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op) {
+ // OpSegment should not own kernel if the node is stateless, or a function.
+ return lib->IsStateful(node_op) &&
+ lib->GetFunctionLibraryDefinition()->Find(node_op) == nullptr;
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_segment.h b/tensorflow/core/framework/op_segment.h
index 4433a2554f..37d939ea2b 100644
--- a/tensorflow/core/framework/op_segment.h
+++ b/tensorflow/core/framework/op_segment.h
@@ -60,6 +60,10 @@ class OpSegment {
Status FindOrCreate(const string& session_handle, const string& node_name,
OpKernel** kernel, CreateKernelFn create_fn);
+ // Returns true if OpSegment should own the kernel.
+ static bool ShouldOwnKernel(FunctionLibraryRuntime* lib,
+ const string& node_op);
+
private:
// op name -> OpKernel
typedef std::unordered_map<string, OpKernel*> KernelMap;
diff --git a/tensorflow/core/framework/queue_interface.h b/tensorflow/core/framework/queue_interface.h
index 4aeaab3d9b..4ca4416c5a 100644
--- a/tensorflow/core/framework/queue_interface.h
+++ b/tensorflow/core/framework/queue_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
-#define TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_
#include <string>
#include <vector>
@@ -99,4 +99,4 @@ class QueueInterface : public ResourceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_QUEUE_INTERFACE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_QUEUE_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_base.h b/tensorflow/core/framework/reader_base.h
index cb44be4dee..5b82e9181f 100644
--- a/tensorflow/core/framework/reader_base.h
+++ b/tensorflow/core/framework/reader_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_READER_BASE_H_
-#define TENSORFLOW_FRAMEWORK_READER_BASE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
#include <memory>
#include <string>
@@ -135,4 +135,4 @@ class ReaderBase : public ReaderInterface {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_READER_BASE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_READER_BASE_H_
diff --git a/tensorflow/core/framework/reader_interface.h b/tensorflow/core/framework/reader_interface.h
index dac6056b5a..f894acbe1d 100644
--- a/tensorflow/core/framework/reader_interface.h
+++ b/tensorflow/core/framework/reader_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
-#define TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_
#include <memory>
#include <string>
@@ -84,4 +84,4 @@ class ReaderInterface : public ResourceBase {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_READER_INTERFACE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_READER_INTERFACE_H_
diff --git a/tensorflow/core/framework/reader_op_kernel.h b/tensorflow/core/framework/reader_op_kernel.h
index ffd6a1a184..e65a8695be 100644
--- a/tensorflow/core/framework/reader_op_kernel.h
+++ b/tensorflow/core/framework/reader_op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_
#include <functional>
#include <string>
@@ -85,4 +85,4 @@ class ReaderOpKernel : public ResourceOpKernel<ReaderInterface> {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_READER_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_READER_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/register_types.h b/tensorflow/core/framework/register_types.h
index f1cd37ecda..ddb5b10c18 100644
--- a/tensorflow/core/framework/register_types.h
+++ b/tensorflow/core/framework/register_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
// This file is used by cuda code and must remain compilable by nvcc.
#include "tensorflow/core/framework/numeric_types.h"
@@ -161,9 +161,12 @@ limitations under the License.
TF_CALL_int64(m) TF_CALL_int32(m) TF_CALL_uint16(m) TF_CALL_int16(m) \
TF_CALL_uint8(m) TF_CALL_int8(m)
+#define TF_CALL_FLOAT_TYPES(m) \
+ TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
+
#define TF_CALL_REAL_NUMBER_TYPES(m) \
TF_CALL_INTEGRAL_TYPES(m) \
- TF_CALL_half(m) TF_CALL_bfloat16(m) TF_CALL_float(m) TF_CALL_double(m)
+ TF_CALL_FLOAT_TYPES(m)
#define TF_CALL_REAL_NUMBER_TYPES_NO_BFLOAT16(m) \
TF_CALL_INTEGRAL_TYPES(m) TF_CALL_half(m) TF_CALL_float(m) TF_CALL_double(m)
@@ -225,4 +228,4 @@ limitations under the License.
#define TF_CALL_SYCL_NUMBER_TYPES(m) TF_CALL_float(m) TF_CALL_SYCL_double(m)
#endif // __ANDROID_TYPES_SLIM__
-#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_H_
diff --git a/tensorflow/core/framework/register_types_traits.h b/tensorflow/core/framework/register_types_traits.h
index ab35c2f095..d475a1972d 100644
--- a/tensorflow/core/framework/register_types_traits.h
+++ b/tensorflow/core/framework/register_types_traits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
-#define TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
// This file is used by cuda code and must remain compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -102,4 +102,4 @@ struct proxy_type {
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_REGISTER_TYPES_TRAITS_H_
diff --git a/tensorflow/core/framework/resource_mgr.cc b/tensorflow/core/framework/resource_mgr.cc
index 0a19861efd..508a8d3149 100644
--- a/tensorflow/core/framework/resource_mgr.cc
+++ b/tensorflow/core/framework/resource_mgr.cc
@@ -271,7 +271,7 @@ string ContainerInfo::DebugString() const {
"]");
}
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input) {
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input) {
return ctx->input(input).flat<ResourceHandle>()(0);
}
@@ -288,4 +288,13 @@ Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
return ctx->resource_manager()->Delete(p);
}
+Status ResourceHandlesShape(shape_inference::InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ }
+ return Status::OK();
+}
+
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index 33d4cb77ff..4a531648d9 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -13,9 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
-#define TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
+#include <memory>
#include <string>
#include <typeindex>
#include <typeinfo>
@@ -61,8 +62,8 @@ namespace tensorflow {
//
// // Create a var.
// MyVar* my_var = new MyVar;
-// my_var.val = Tensor(DT_FLOAT, my_shape);
-// my_var.val.flat<float>().setZeros(); // 0 initialized.
+// my_var->val = Tensor(DT_FLOAT, my_shape);
+// my_var->val.flat<float>().setZeros(); // 0 initialized.
// ctx->SetStatus(rm.Create("my_container", "my_name", my_var));
//
// // += a variable.
@@ -79,7 +80,7 @@ class ResourceBase : public core::RefCounted {
virtual string DebugString() = 0;
// Returns memory used by this resource.
- virtual int64 MemoryUsed() const { return 0; };
+ virtual int64 MemoryUsed() const { return 0; }
};
// Container used for per-step resources.
@@ -127,6 +128,14 @@ class ResourceMgr {
Status Lookup(const string& container, const string& name,
T** resource) const TF_MUST_USE_RESULT;
+ // Similar to Lookup, but looks up multiple resources at once, with only a
+ // single lock acquisition.
+ template <typename T>
+ Status LookupMany(absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>*
+ resource) const TF_MUST_USE_RESULT;
+
// If "container" has a resource "name", returns it in
// "*resource". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*resource".
@@ -234,19 +243,36 @@ ResourceHandle MakePerStepResourceHandle(OpKernelContext* ctx,
const string& name);
// Returns a resource handle from a numbered op input.
-ResourceHandle HandleFromInput(OpKernelContext* ctx, int input);
+const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
+//
+// If successful, the caller transfers the ownership of one ref on `resource` to
+// `ctx->resource_mgr()`.
template <typename T>
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
// Looks up a resource pointed by a given resource handle.
+//
+// If the lookup is successful, the caller takes the ownership of one ref on
+// `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
+// Looks up multiple resources pointed by a sequence of resource handles.
+template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
+
// Looks up or creates a resource.
+//
+// If successful, the caller takes the ownership of one ref on `*value`, and
+// must call its `Unref()` method when it has finished using it. If the
+// `creator` is invoked, its reference on the created resource is transferred
+// to `ctx->resource_mgr()`.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator);
@@ -348,6 +374,8 @@ class ResourceHandleOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
private:
string container_;
string name_;
@@ -356,6 +384,26 @@ class ResourceHandleOp : public OpKernel {
std::atomic<bool> initialized_{false};
};
+// Utility op kernel to produce a handle to a resource of type T.
+template <typename T>
+class ResourceHandlesOp : public OpKernel {
+ public:
+ explicit ResourceHandlesOp(OpKernelConstruction* context);
+
+ void Compute(OpKernelContext* ctx) override;
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ std::vector<string> containers_;
+ std::vector<string> names_;
+ mutex mutex_;
+ std::vector<Tensor> resources_;
+ std::atomic<bool> initialized_{false};
+};
+
+Status ResourceHandlesShape(shape_inference::InferenceContext* c);
+
// Registers a kernel for an op which produces a handle to a resource of the
// specified type.
#define REGISTER_RESOURCE_HANDLE_KERNEL(Type) \
@@ -388,6 +436,24 @@ Status ResourceMgr::Lookup(const string& container, const string& name,
}
template <typename T>
+Status ResourceMgr::LookupMany(
+ absl::Span<std::pair<const string*, const string*> const>
+ containers_and_names,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* resources) const {
+ CheckDeriveFromResourceBase<T>();
+ tf_shared_lock l(mu_);
+ resources->resize(containers_and_names.size());
+ for (size_t i = 0; i < containers_and_names.size(); ++i) {
+ T* resource;
+ TF_RETURN_IF_ERROR(LookupInternal(*containers_and_names[i].first,
+ *containers_and_names[i].second,
+ &resource));
+ (*resources)[i].reset(resource);
+ }
+ return Status::OK();
+}
+
+template <typename T>
Status ResourceMgr::LookupInternal(const string& container, const string& name,
T** resource) const {
ResourceBase* found = nullptr;
@@ -497,6 +563,19 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
}
template <typename T>
+Status LookupResources(
+ OpKernelContext* ctx, absl::Span<ResourceHandle const* const> p,
+ std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values) {
+ std::vector<std::pair<const string*, const string*>> containers_and_names(
+ p.size());
+ for (size_t i = 0; i < p.size(); ++i) {
+ TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, *p[i]));
+ containers_and_names[i] = {&p[i]->container(), &p[i]->name()};
+ }
+ return ctx->resource_manager()->LookupMany(containers_and_names, values);
+}
+
+template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
@@ -553,6 +632,46 @@ void ResourceHandleOp<T>::Compute(OpKernelContext* ctx) {
ctx->set_output(0, resource_);
}
+template <typename T>
+ResourceHandlesOp<T>::ResourceHandlesOp(OpKernelConstruction* context)
+ : OpKernel(context) {
+ int n;
+ OP_REQUIRES_OK(context, context->GetAttr("N", &n));
+ OP_REQUIRES_OK(context, context->GetAttr("containers", &containers_));
+ OP_REQUIRES_OK(context, context->GetAttr("shared_names", &names_));
+ OP_REQUIRES(
+ context, containers_.size() == n,
+ errors::InvalidArgument("Number of containers (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ OP_REQUIRES(context, names_.size() == n,
+ errors::InvalidArgument("Number of names (", containers_.size(),
+ ") must be equal to N (", n, ")"));
+ resources_.resize(n);
+}
+
+template <typename T>
+void ResourceHandlesOp<T>::Compute(OpKernelContext* ctx) {
+ if (!initialized_.load()) {
+ mutex_lock ml(mutex_);
+ // Checking again to see if another thread has initialized the resource.
+ if (!initialized_.load()) {
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}),
+ &resources_[i], attr));
+ ResourceHandle h =
+ MakeResourceHandle<T>(ctx, containers_[i], names_[i]);
+ resources_[i].template scalar<ResourceHandle>()() = h;
+ }
+ initialized_.store(true);
+ }
+ }
+ for (size_t i = 0; i < resources_.size(); ++i) {
+ ctx->set_output(i, resources_[i]);
+ }
+}
+
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_RESOURCE_MGR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_MGR_H_
diff --git a/tensorflow/core/framework/resource_op_kernel.h b/tensorflow/core/framework/resource_op_kernel.h
index 0a8da8b3bf..fbcd439dea 100644
--- a/tensorflow/core/framework/resource_op_kernel.h
+++ b/tensorflow/core/framework/resource_op_kernel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
-#define TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
#include <string>
@@ -136,4 +136,4 @@ class ResourceOpKernel : public OpKernel {
};
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_RESOURCE_OP_KERNEL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_OP_KERNEL_H_
diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc
new file mode 100644
index 0000000000..0c4007eafc
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.cc
@@ -0,0 +1,249 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/run_handler.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/run_handler_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+// Contains the concrete implementation of the RunHandler.
+// Externally visible RunHandler class simply forwards the work to this one.
+class RunHandler::Impl {
+ public:
+ explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
+ Reset();
+ }
+
+ ~Impl() {}
+
+ void set_inter_op_scheduling_range(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ inter_op_scheduling_range_.store(EncodePartition(start, limit),
+ std::memory_order_release);
+ }
+
+ std::uint_fast32_t inter_op_scheduling_range() const {
+ return inter_op_scheduling_range_.load(std::memory_order_acquire);
+ }
+
+ // Stores now time (in microseconds) since unix epoch when the handler is
+ // requested via RunHandlerPool::Get().
+ uint64 start_time_us() const { return start_time_us_; }
+
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ void Reset();
+
+ RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
+
+ private:
+ // Encoding/decoding logic for storing [start, limit) into a single
+ // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
+ const int kMaxPartitionBits = 16;
+ const int kMaxThreads = 1 << kMaxPartitionBits;
+
+ std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ return (start << kMaxPartitionBits) | limit;
+ }
+
+ void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
+ std::uint_fast32_t* limit) {
+ *limit = val & (kMaxThreads - 1);
+ val >>= kMaxPartitionBits;
+ *start = val;
+ }
+
+ std::atomic_uint_fast32_t inter_op_scheduling_range_;
+ RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
+ uint64 start_time_us_;
+};
+
+// Contains shared state across all run handlers present in the pool. Also
+// responsible for pool management decisions.
+// This class is thread safe.
+class RunHandlerPool::Impl {
+ public:
+ explicit Impl(int num_inter_op_threads)
+ : max_handlers_(128),
+ inter_op_thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
+ iterations_(0) {
+ VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
+ for (int i = 0; i < max_handlers_; ++i) {
+ handlers_.emplace_back(new RunHandler::Impl(this));
+ free_handlers_.push_back(handlers_.back().get());
+ }
+ }
+
+ ~Impl() {
+ // Sanity check that all handlers have been returned back to the pool before
+ // destruction.
+ DCHECK_EQ(handlers_.size(), max_handlers_);
+ DCHECK_EQ(free_handlers_.size(), handlers_.size());
+ DCHECK_EQ(sorted_active_handlers_.size(), 0);
+ }
+
+ thread::ThreadPool* inter_op_thread_pool() const {
+ return inter_op_thread_pool_.get();
+ }
+
+ std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ while (free_handlers_.empty()) {
+ one_handler_free_.wait(l);
+ }
+ // Remove the last entry from free_handlers_ and add to the end of
+ // sorted_active_handlers_.
+ auto* handler_impl = free_handlers_.back();
+ handler_impl->Reset();
+ // Sortedness isn't violated if we simply add at the end of the list, since
+ // handlers are expected to be obtained in increasing order of time.
+ sorted_active_handlers_.push_back(handler_impl);
+ DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
+ free_handlers_.pop_back();
+
+ RecomputePoolStatsLocked();
+ return WrapUnique<RunHandler>(new RunHandler(handler_impl));
+ }
+
+ void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ DCHECK_GT(sorted_active_handlers_.size(), 0);
+
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ double elapsed = (now - handler->start_time_us()) / 1000.0;
+ time_hist_.Add(elapsed);
+
+ // Erase from and update sorted_active_handlers_. Add it to the end of
+ // free_handlers_.
+ auto iter = std::find(sorted_active_handlers_.begin(),
+ sorted_active_handlers_.end(), handler);
+ DCHECK(iter != sorted_active_handlers_.end())
+ << "Unexpected handler: " << handler
+ << " is being requested for release";
+
+ // Remove this handler from this list and add it to the list of free
+ // handlers.
+ sorted_active_handlers_.erase(iter);
+ free_handlers_.push_back(handler);
+ DCHECK_LE(free_handlers_.size(), max_handlers_);
+
+ RecomputePoolStatsLocked();
+ }
+ one_handler_free_.notify_one();
+ }
+
+ private:
+ void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Maximum number of handlers pre-created during pool construction time. The
+ // number has been chosen expecting each handler might at least want 1
+ // inter-op thread for execution (during compute intensive workloads like
+ // inference).
+ const int max_handlers_;
+
+ // Thread safe part.
+ const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
+
+ // Thread compatible part used only by lock under RunHandlerPool.
+ // Handlers are sorted by start time.
+ std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
+ std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
+ // Histogram of elapsed runtime of every handler (in ms).
+ histogram::Histogram time_hist_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
+ int64 iterations_ GUARDED_BY(mu_);
+ condition_variable one_handler_free_;
+ mutex mu_;
+};
+
+void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
+ int num_active_requests = sorted_active_handlers_.size();
+ if (num_active_requests == 0) return;
+
+ int num_threads = inter_op_thread_pool_->NumThreads();
+
+ inter_op_start_.resize(num_active_requests);
+ inter_op_limit_.resize(num_active_requests);
+
+ const int kMinThreadsPerRequest = 3;
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ kMinThreadsPerRequest, &inter_op_start_,
+ &inter_op_limit_);
+
+ for (int i = 0; i < num_active_requests; ++i) {
+ sorted_active_handlers_[i]->set_inter_op_scheduling_range(
+ inter_op_start_[i], inter_op_limit_[i]);
+ }
+
+ if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
+ VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
+ VLOG(1) << "Active session runs: " << num_active_requests;
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ string ranges_str = "";
+ string times_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) {
+ times_str += " ";
+ ranges_str += " ";
+ }
+
+ times_str += strings::StrCat(
+ (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
+ ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
+ inter_op_limit_[i], ")");
+ }
+ VLOG(1) << "Elapsed times are: " << times_str;
+ VLOG(1) << "Ranges are: " << ranges_str;
+ }
+}
+
+void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
+ std::uint_fast32_t start = 0, limit = 0;
+ DecodePartition(inter_op_scheduling_range(), &start, &limit);
+ pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn));
+}
+
+void RunHandler::Impl::Reset() {
+ set_inter_op_scheduling_range(
+ 0, pool_impl_->inter_op_thread_pool()->NumThreads());
+ start_time_us_ = tensorflow::Env::Default()->NowMicros();
+}
+
+RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
+ : impl_(new Impl(num_inter_op_threads)) {}
+
+RunHandlerPool::~RunHandlerPool() {}
+
+std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
+
+RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
+
+void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
+ impl_->ScheduleInterOpClosure(std::move(fn));
+}
+
+RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h
new file mode 100644
index 0000000000..72fa6301b4
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.h
@@ -0,0 +1,95 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+class RunHandler;
+
+// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
+// that can be used for tracking inter-op work for a given Session::Run().
+// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
+// 'active' when its unique_ptr is returned by Get() and is being used by a
+// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
+//
+// Expected usage:
+//
+// * Create a single RunHandlerPool (say run_handler_pool_).
+//
+// * When a Session::Run() is invoked, obtain a handler by:
+// auto handler = run_handler_pool_->Get();
+//
+// * Use handler for scheduling all inter-op work by:
+// handler->ScheduleInterOpClosure(closure);
+//
+// This class is thread safe.
+class RunHandlerPool {
+ public:
+ explicit RunHandlerPool(int num_inter_op_threads);
+ ~RunHandlerPool();
+
+ // Returns an inactive RunHandler from the pool.
+ //
+ // RunHandlers in RunHandlerPool are initially 'inactive'.
+ // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
+ // and is being used by a client. It becomes 'inactive' once more when the
+ // unique_ptr is destroyed.
+ //
+ // Will block unless there is an inactive handler.
+ std::unique_ptr<RunHandler> Get();
+
+ private:
+ class Impl;
+ friend class RunHandler;
+
+ std::unique_ptr<Impl> impl_;
+};
+
+// RunHandler can be used to schedule inter-op closures to run on a global pool
+// shared across all Session::Run(s).
+//
+// It can only be created via RunHandlerPool::Get().
+//
+// This class can be used instead of directly scheduling closures on a global
+// pool since it maintains a global view across all sessions and optimizes pool
+// scheduling to improve (median and tail) latency.
+//
+// This class is thread safe.
+class RunHandler {
+ public:
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ ~RunHandler();
+
+ private:
+ class Impl;
+ friend class RunHandlerPool::Impl;
+
+ explicit RunHandler(Impl* impl);
+
+ Impl* impl_; // NOT OWNED.
+};
+
+} // end namespace tensorflow.
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc
new file mode 100644
index 0000000000..3087998c69
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec) {
+ // Each request is expected to have weight W[i] = num_active_requests - i.
+ // Therefore, total_weight = sum of all request weights.
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float demand_factor = static_cast<float>(num_threads) / total_weight;
+ float last_cumulative_weight = 0.0;
+ min_threads_per_request = std::max(1, min_threads_per_request);
+ for (int i = 0; i != num_active_requests; i++) {
+ float cumulative_weight =
+ static_cast<float>(i + 1) *
+ (num_active_requests - static_cast<float>(i) * 0.5f);
+ float weight = cumulative_weight - last_cumulative_weight;
+ // Quantize thread_demand by rounding up, and also satisfying
+ // `min_threads_per_request` constraint.
+ // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from
+ // rounding weights like 4.0 to 5.
+ int demand =
+ std::max(min_threads_per_request,
+ static_cast<int>(ceil(weight * demand_factor - 0.00001f)));
+ // For the quantized range [start, end); compute the floor of real start,
+ // and expand downwards from there with length `demand` and adjust for
+ // boundary conditions.
+ int start = last_cumulative_weight * demand_factor;
+ int end = std::min(num_threads, start + demand);
+ start = std::max(0, std::min(start, end - demand));
+ start_vec->at(i) = start;
+ end_vec->at(i) = end;
+ last_cumulative_weight = cumulative_weight;
+ }
+}
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h
new file mode 100644
index 0000000000..c0c36aeccb
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.h
@@ -0,0 +1,43 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+
+#include <cstdint>
+#include <vector>
+
+namespace tensorflow {
+
+// Assign thread ranges to requests.
+// Requests are numbered 0...num_active_requests-1, and
+// threads are numbered 0...num_threads-1.
+// On return, the range start_vec->at(i)...end_vec->at(i)-1
+// indicates the subrange of the threads available to request i.
+// The ranges given to different requests may overlap.
+// Lower numbered requests will tend to be assigned more threads.
+// Thus, a client might associate older requests with lower
+// array indices so they receive access to more threads.
+// However, the routine ensures that each request is given access
+// to at least min(min_threads_per_request, num_threads) threads.
+// Every thread will be assigned to at least one request range,
+// assuming there is at least one request.
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec);
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc
new file mode 100644
index 0000000000..a1928c132b
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <vector>
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+namespace tensorflow {
+namespace {
+
+void VerifyFunction(int num_active_requests, int num_threads,
+ int min_threads_per_request, bool print_stats = false) {
+ if (print_stats) {
+ LOG(INFO) << "Test case# num_active_requests: " << num_active_requests
+ << " num_threads: " << num_threads
+ << " min_threads: " << min_threads_per_request;
+ }
+ std::vector<std::uint_fast32_t> start(num_active_requests);
+ std::vector<std::uint_fast32_t> end(num_active_requests);
+
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ min_threads_per_request, &start, &end);
+ string range_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) range_str += " ";
+ range_str += strings::StrCat("[", start[i], ", ", end[i], ")");
+
+ ASSERT_GE(start[i], 0) << range_str;
+ ASSERT_LE(end[i], num_threads) << range_str;
+ if (i > 0) {
+ // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i)
+ ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str;
+ // No missing threads.
+ ASSERT_GE(end[i - 1], start[i]) << range_str;
+ }
+ // Each interval is at least of size 'min_threads_per_request'.
+ ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str;
+ // Verify that assigned (quantized) threads is not overly estimated
+ // from real demand, when the demand is high (>=
+ // min_threads_per_request).
+ float entry_weight = num_active_requests - i;
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float thread_demand = (entry_weight * num_threads) / total_weight;
+ if (thread_demand > min_threads_per_request) {
+ // We expect some over-estimation of threads due to quantization,
+ // but we hope it's not more than 1 extra thread.
+ ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0)
+ << "Ranges: " << range_str << " thread_demand: " << thread_demand
+ << " i: " << i;
+ }
+ }
+ ASSERT_EQ(end[num_active_requests - 1], num_threads);
+ ASSERT_EQ(start[0], 0);
+ if (print_stats) {
+ LOG(INFO) << "Assigned ranges: " << range_str;
+ }
+}
+
+TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) {
+ const int kMinThreadsPerRequestBound = 12;
+ const int kMaxActiveRequests = 128;
+ const int kMaxThreads = 128;
+
+ for (int min_threads_per_request = 1;
+ min_threads_per_request <= kMinThreadsPerRequestBound;
+ ++min_threads_per_request) {
+ for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests;
+ ++num_active_requests) {
+ for (int num_threads = min_threads_per_request;
+ num_threads <= kMaxThreads; ++num_threads) {
+ VerifyFunction(num_active_requests, num_threads,
+ min_threads_per_request);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/selective_registration.h b/tensorflow/core/framework/selective_registration.h
index 503947969d..4b281a04bf 100644
--- a/tensorflow/core/framework/selective_registration.h
+++ b/tensorflow/core/framework/selective_registration.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
-#define TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#define TENSORFLOW_CORE_FRAMEWORK_SELECTIVE_REGISTRATION_H_
#include <string.h>
@@ -55,4 +55,4 @@ static_assert(false, "ops_to_register.h must define SHOULD_REGISTER macros");
#define SHOULD_REGISTER_OP_KERNEL(clz) true
#endif
-#endif // TENSORFLOW_FRAMEWORK_SELECTIVE_REGISTRATION_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_SELECTIVE_REGISTRATION_H_
diff --git a/tensorflow/core/framework/session_state.h b/tensorflow/core/framework/session_state.h
index 653a661dd2..63568685f2 100644
--- a/tensorflow/core/framework/session_state.h
+++ b/tensorflow/core/framework/session_state.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
-#define TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
#include <string>
#include <unordered_map>
@@ -90,4 +90,4 @@ class TensorStore {
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_SESSION_STATE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_SESSION_STATE_H_
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index 8d597e198d..3e77028a5f 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -950,8 +950,7 @@ Status InferenceContext::GetScalarFromTensor(const Tensor* t, int64* val) {
*val = t->scalar<int64>()();
return Status::OK();
} else {
- return errors::InvalidArgument(
- "Scalar input for dim size must be int32 or int64");
+ return errors::InvalidArgument("Scalar input must be int32 or int64.");
}
}
diff --git a/tensorflow/core/framework/shape_inference_testutil.h b/tensorflow/core/framework/shape_inference_testutil.h
index f6656b3b45..bb4dc25da4 100644
--- a/tensorflow/core/framework/shape_inference_testutil.h
+++ b/tensorflow/core/framework/shape_inference_testutil.h
@@ -32,7 +32,7 @@ class Tensor;
struct ShapeInferenceTestOp {
typedef std::pair<string, DataType> ShapeAndType;
- explicit ShapeInferenceTestOp(StringPiece name) : name(std::string(name)) {}
+ explicit ShapeInferenceTestOp(StringPiece name) : name(string(name)) {}
string name;
NodeDef node_def;
std::vector<const Tensor*> input_tensors;
diff --git a/tensorflow/core/framework/stats_aggregator.h b/tensorflow/core/framework/stats_aggregator.h
index 4a18efc940..af53ed0a3c 100644
--- a/tensorflow/core/framework/stats_aggregator.h
+++ b/tensorflow/core/framework/stats_aggregator.h
@@ -25,6 +25,8 @@ namespace tensorflow {
class Summary;
+namespace data {
+
// A `StatsAggregator` accumulates statistics incrementally. A
// `StatsAggregator` can accumulate multiple different statistics, distinguished
// by a string name.
@@ -87,6 +89,7 @@ class StatsAggregatorResource : public ResourceBase {
const std::shared_ptr<StatsAggregator> stats_aggregator_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_FRAMEWORK_STATS_AGGREGATOR_H_
diff --git a/tensorflow/core/framework/tensor.cc b/tensorflow/core/framework/tensor.cc
index a82beb7e8f..1dea6da911 100644
--- a/tensorflow/core/framework/tensor.cc
+++ b/tensorflow/core/framework/tensor.cc
@@ -617,13 +617,13 @@ bool Tensor::IsInitialized() const {
}
void Tensor::CheckType(DataType expected_dtype) const {
- CHECK_EQ(dtype(), expected_dtype)
+ CHECK_EQ(dtype(), expected_dtype) << " "
<< DataTypeString(expected_dtype) << " expected, got "
<< DataTypeString(dtype());
}
void Tensor::CheckTypeAndIsAligned(DataType expected_dtype) const {
- CHECK_EQ(dtype(), expected_dtype)
+ CHECK_EQ(dtype(), expected_dtype) << " "
<< DataTypeString(expected_dtype) << " expected, got "
<< DataTypeString(dtype());
CHECK(IsAligned()) << "ptr = " << base<void>();
@@ -812,6 +812,28 @@ Tensor Tensor::Slice(int64 start, int64 limit) const {
return ret;
}
+Tensor Tensor::SubSlice(int64 index) const {
+ CHECK_GE(dims(), 1); // Crash ok.
+ CHECK_LE(0, index); // Crash ok.
+ int64 dim0_size = shape_.dim_size(0);
+ CHECK_LE(index, dim0_size); // Crash ok.
+ Tensor ret;
+ ret.shape_ = shape_;
+ ret.shape_.RemoveDim(0);
+ ret.set_dtype(dtype());
+ ret.buf_ = nullptr;
+ if (dim0_size > 0) {
+ const int64 elems_per_dim0 = NumElements() / dim0_size;
+ const int64 delta = index * elems_per_dim0;
+ const int64 num_elems = elems_per_dim0;
+ if (buf_) {
+ DataType dt = dtype();
+ CASES(dt, ret.buf_ = new SubBuffer<T>(buf_, delta, num_elems));
+ }
+ }
+ return ret;
+}
+
bool Tensor::FromProto(const TensorProto& proto) {
return FromProto(cpu_allocator(), proto);
}
@@ -948,9 +970,69 @@ void PrintOneDim(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
}
}
+// Appends the spacing between elements for a given dim onto a result string
+void PrintDimSpacing(int dim_index, int num_dims, string* result) {
+ if (dim_index == num_dims - 1) {
+ strings::StrAppend(result, " ");
+ return;
+ }
+ for (int j = 0; j < num_dims - dim_index - 1; j++) {
+ strings::StrAppend(result, "\n");
+ }
+ for (int j = 0; j <= dim_index; j++) {
+ strings::StrAppend(result, " ");
+ }
+}
+
+// Print from left dim to right dim recursively.
+template <typename T>
+void PrintOneDimV2(int dim_index, const gtl::InlinedVector<int64, 4>& shape,
+ int64 num_elts_at_ends, int num_dims, const T* data,
+ int64 data_index, string* result) {
+ // We have recursed beyond all the dimensions into a single element
+ // of the tensor.
+ if (dim_index == num_dims) {
+ strings::StrAppend(result, PrintOneElement(data[data_index]));
+ return;
+ }
+
+ strings::StrAppend(result, "[");
+ int64 element_count = shape[dim_index];
+ int64 start_of_end =
+ std::max(num_elts_at_ends, element_count - num_elts_at_ends);
+
+ // Loop every element of one dim.
+ int64 elements_per_iter = 1;
+ for (int i = dim_index + 1; i < num_dims; i++) {
+ elements_per_iter *= shape[i];
+ }
+ for (int64 i = 0; (i < num_elts_at_ends) && (i < element_count); i++) {
+ if (i > 0) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ }
+
+ // As for each element, print the sub-dim.
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+ if (element_count > 2 * num_elts_at_ends) {
+ PrintDimSpacing(dim_index, num_dims, result);
+ strings::StrAppend(result, "...");
+ }
+ for (int64 i = start_of_end; i < element_count; i++) {
+ // As for each element, print the sub-dim.
+ PrintDimSpacing(dim_index, num_dims, result);
+ PrintOneDimV2(dim_index + 1, shape, num_elts_at_ends, num_dims, data,
+ data_index + elements_per_iter * i, result);
+ }
+
+ strings::StrAppend(result, "]");
+}
+
template <typename T>
string SummarizeArray(int64 limit, int64 num_elts,
- const TensorShape& tensor_shape, const char* data) {
+ const TensorShape& tensor_shape, const char* data,
+ const bool print_v2) {
string ret;
const T* array = reinterpret_cast<const T*>(data);
@@ -963,17 +1045,26 @@ string SummarizeArray(int64 limit, int64 num_elts,
if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
- int64 data_index = 0;
- const int shape_size = tensor_shape.dims();
- PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+ if (print_v2) {
+ const int num_dims = tensor_shape.dims();
+ PrintOneDimV2(0, shape, limit, num_dims, array, 0, &ret);
+ } else {
+ int64 data_index = 0;
+ const int shape_size = tensor_shape.dims();
+ PrintOneDim(0, shape, limit, shape_size, array, &data_index, &ret);
+
+ if (num_elts > limit) strings::StrAppend(&ret, "...");
+ }
- if (num_elts > limit) strings::StrAppend(&ret, "...");
return ret;
}
} // namespace
-string Tensor::SummarizeValue(int64 max_entries) const {
+string Tensor::SummarizeValue(int64 max_entries, bool print_v2) const {
const int64 num_elts = NumElements();
+ if (max_entries < 0) {
+ max_entries = num_elts;
+ }
size_t limit = std::min(max_entries, num_elts);
if ((limit > 0) && (buf_ == nullptr)) {
return strings::StrCat("uninitialized Tensor of ", num_elts,
@@ -982,50 +1073,54 @@ string Tensor::SummarizeValue(int64 max_entries) const {
const char* data = limit > 0 ? tensor_data().data() : nullptr;
switch (dtype()) {
case DT_HALF:
- return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data);
+ return SummarizeArray<Eigen::half>(limit, num_elts, shape_, data,
+ print_v2);
break;
case DT_FLOAT:
- return SummarizeArray<float>(limit, num_elts, shape_, data);
+ return SummarizeArray<float>(limit, num_elts, shape_, data, print_v2);
break;
case DT_DOUBLE:
- return SummarizeArray<double>(limit, num_elts, shape_, data);
+ return SummarizeArray<double>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT32:
- return SummarizeArray<uint32>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT32:
- return SummarizeArray<int32>(limit, num_elts, shape_, data);
+ return SummarizeArray<int32>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT8:
case DT_QUINT8:
- return SummarizeArray<uint8>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT16:
case DT_QUINT16:
- return SummarizeArray<uint16>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT16:
case DT_QINT16:
- return SummarizeArray<int16>(limit, num_elts, shape_, data);
+ return SummarizeArray<int16>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT8:
case DT_QINT8:
- return SummarizeArray<int8>(limit, num_elts, shape_, data);
+ return SummarizeArray<int8>(limit, num_elts, shape_, data, print_v2);
break;
case DT_UINT64:
- return SummarizeArray<uint64>(limit, num_elts, shape_, data);
+ return SummarizeArray<uint64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_INT64:
- return SummarizeArray<int64>(limit, num_elts, shape_, data);
+ return SummarizeArray<int64>(limit, num_elts, shape_, data, print_v2);
break;
case DT_BOOL:
// TODO(tucker): Is it better to emit "True False..."? This
// will emit "1 0..." which is more compact.
- return SummarizeArray<bool>(limit, num_elts, shape_, data);
+ return SummarizeArray<bool>(limit, num_elts, shape_, data, print_v2);
break;
default: {
// All irregular cases
string ret;
+ if (print_v2) {
+ strings::StrAppend(&ret, "[");
+ }
// TODO(irving): Don't call flat every time around this
// loop.
for (size_t i = 0; i < limit; ++i) {
@@ -1045,6 +1140,9 @@ string Tensor::SummarizeValue(int64 max_entries) const {
}
}
if (max_entries < num_elts) strings::StrAppend(&ret, "...");
+ if (print_v2) {
+ strings::StrAppend(&ret, "]");
+ }
return ret;
}
}
diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h
index 1b19ab5da3..d0f9eb56e2 100644
--- a/tensorflow/core/framework/tensor.h
+++ b/tensorflow/core/framework/tensor.h
@@ -37,11 +37,12 @@ namespace tensorflow {
class AllocationDescription;
class Allocator;
class OpKernelContext;
+class Tensor;
class TensorBuffer;
class TensorCApi;
class TensorDescription;
class TensorProto;
-class VariantTensorData;
+
namespace batch_util {
Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index);
Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
@@ -153,7 +154,7 @@ class Tensor {
/// Returns the estimated memory usage of this tensor.
size_t TotalBytes() const;
- // Returns the size of sallocated memory for this tensor.
+ // Returns the size of allocated memory for this tensor.
size_t AllocatedBytes() const;
/// Returns true iff this tensor is aligned.
@@ -199,10 +200,29 @@ class Tensor {
/// must check the returned tensor's alignment before calling certain
/// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
///
+ /// NOTE: When fed with an N-dimensional tensor, this method returns a tensor
+ /// also with N dimensions. If you want to select a sub tensor, see SubSlice.
+ ///
/// REQUIRES: `dims()` >= 1
/// REQUIRES: `0 <= dim0_start <= dim0_limit <= dim_size(0)`
Tensor Slice(int64 dim0_start, int64 dim0_limit) const;
+ /// \brief Select a subslice from this tensor along the 1st dimension.
+ ///
+ /// When fed with an N-dimensional tensor, this method returns a tensor with
+ /// N-1 dimensions, where the returned tensor is a subslice of the input
+ /// tensor along the first dimension. The N-1 dimensions of the returned
+ /// tensor are the last N-1 dimensions of the input tensor.
+ ///
+ /// NOTE: The returned tensor may not satisfy the same alignment
+ /// requirement as this tensor depending on the shape. The caller
+ /// must check the returned tensor's alignment before calling certain
+ /// methods that have alignment requirement (e.g., `flat()`, `tensor()`).
+ ///
+ /// REQUIRES: `dims()` >= 1
+ /// REQUIRES: `0 <= dim0_start < dim_size(0)`
+ Tensor SubSlice(int64 index) const;
+
/// \brief Parse `other` and construct the tensor.
/// Returns `true` iff the parsing succeeds. If the parsing fails,
@@ -429,7 +449,7 @@ class Tensor {
int64 begin) const;
/// Render the first `max_entries` values in `*this` into a string.
- string SummarizeValue(int64 max_entries) const;
+ string SummarizeValue(int64 max_entries, bool print_v2 = false) const;
/// A human-readable summary of the tensor suitable for debugging.
string DebugString() const;
diff --git a/tensorflow/core/framework/tensor_slice.h b/tensorflow/core/framework/tensor_slice.h
index 6019737342..82f21fb17e 100644
--- a/tensorflow/core/framework/tensor_slice.h
+++ b/tensorflow/core/framework/tensor_slice.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
-#define TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -221,4 +221,4 @@ void TensorSlice::FillIndicesAndSizes(
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TENSOR_SLICE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SLICE_H_
diff --git a/tensorflow/core/framework/tensor_test.cc b/tensorflow/core/framework/tensor_test.cc
index 84a373c196..c596604143 100644
--- a/tensorflow/core/framework/tensor_test.cc
+++ b/tensorflow/core/framework/tensor_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/math/math_util.h"
@@ -1227,6 +1228,45 @@ TEST(Tensor, Slice_Basic) {
}
}
+TEST(Tensor, SubSlice_Basic) {
+ { // General
+ Tensor x(DT_FLOAT, TensorShape({10, 4, 36}));
+ // Fills in known values.
+ for (int i = 0; i < 10; ++i) {
+ x.SubSlice(i).flat<float>().setConstant(i * 1.f);
+ }
+ // A simple sub-slice along dim0.
+ Tensor y = x.SubSlice(5);
+ EXPECT_TRUE(y.shape().IsSameSize(TensorShape({4, 36})));
+ auto tx = x.tensor<float, 3>();
+ auto ty = y.tensor<float, 2>();
+ for (int j = 0; j < 4; ++j) {
+ for (int k = 0; k < 36; ++k) {
+ EXPECT_EQ(ty(j, k), 5.0);
+ EXPECT_EQ(&tx(5, j, k), &ty(j, k));
+ }
+ }
+ Tensor z = y.SubSlice(3).SubSlice(31);
+ auto tz = z.unaligned_flat<float>();
+ EXPECT_EQ(*tz.data(), 5.0);
+ }
+ {
+ // Test unaligned access via a SubSlice.
+ Tensor x(DT_FLOAT, TensorShape({30, 5}));
+ x.flat<float>().setConstant(0.0);
+
+ // Take an unaligned subslice.
+ Tensor y = x.SubSlice(1);
+#if EIGEN_MAX_ALIGN_BYTES > 0
+ EXPECT_FALSE(y.IsAligned());
+#endif
+ y.unaligned_flat<float>().setConstant(1.0);
+ for (int64 i = 0; i < y.NumElements(); ++i) {
+ EXPECT_EQ(1.0, y.unaligned_flat<float>()(i));
+ }
+ }
+}
+
template <typename T>
Tensor MkTensor(DataType dt, const TensorShape& shape,
std::vector<T> init_values) {
@@ -1294,6 +1334,63 @@ TEST(SummarizeValue, STRING) {
EXPECT_EQ("one two three four five one...", x.SummarizeValue(6));
}
+TEST(SummarizeValue, INT32_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<int>(DT_INT32, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, INT32Dims_PRINT_V2) {
+ Tensor x = MkTensor<int>(DT_INT32, TensorShape({3, 4}),
+ {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12});
+ EXPECT_EQ("[[1 ... 4]\n ...\n [9 ... 12]]", x.SummarizeValue(1, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(10, true));
+ EXPECT_EQ("[[1 2 3 4]\n [5 6 7 8]\n [9 10 11 12]]",
+ x.SummarizeValue(-1, true));
+}
+
+TEST(SummarizeValue, FLOAT_PRINT_V2) {
+ Tensor x = MkTensor<float>(DT_FLOAT, TensorShape({5}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[1 2 3 4 0]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[1 2 ... 4 0]", x.SummarizeValue(2, true));
+ EXPECT_EQ("[1 ... 0]", x.SummarizeValue(1, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[1 2]\n [3 4]]", x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({2, 2, 1, 1}), {1, 2, 3, 4, 0});
+ EXPECT_EQ("[[[[1]]\n\n [[2]]]\n\n\n [[[3]]\n\n [[4]]]]",
+ x.SummarizeValue(16, true));
+ x = MkTensor<float>(DT_FLOAT, TensorShape({0}), {});
+ EXPECT_EQ("[]", x.SummarizeValue(16, true));
+}
+
+TEST(SummarizeValue, BOOL_PRINT_V2) {
+ Tensor x = MkTensor<bool>(DT_BOOL, TensorShape({5}), {false, true, true});
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[0 1 1 0 1]", x.SummarizeValue(-1, true));
+ EXPECT_EQ("[0 1 ... 0 1]", x.SummarizeValue(2, true));
+}
+
+TEST(SummarizeValue, STRING_PRINT_V2) {
+ Tensor x = MkTensor<string>(DT_STRING, TensorShape({5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(16, true));
+ EXPECT_EQ("[one two three four five]", x.SummarizeValue(-1, true));
+ x = MkTensor<string>(DT_STRING, TensorShape({5, 1, 5}),
+ {"one", "two", "three", "four", "five"});
+ EXPECT_EQ("[one two three four five one...]", x.SummarizeValue(6, true));
+}
+
void BM_CreateAndDestroy(int iters) {
TensorShape shape({10, 20});
while (--iters) {
diff --git a/tensorflow/core/framework/tensor_types.h b/tensorflow/core/framework/tensor_types.h
index a5c1a56bfc..6f981db189 100644
--- a/tensorflow/core/framework/tensor_types.h
+++ b/tensorflow/core/framework/tensor_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -123,4 +123,4 @@ To32Bit(TensorType in) {
}
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TENSOR_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_TYPES_H_
diff --git a/tensorflow/core/framework/tensor_util.h b/tensorflow/core/framework/tensor_util.h
index 43d2d95311..a7cf600bab 100644
--- a/tensorflow/core/framework/tensor_util.h
+++ b/tensorflow/core/framework/tensor_util.h
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
-#define TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include <vector>
@@ -160,4 +161,4 @@ CreateTensorProto(const std::vector<Type>& values,
} // namespace tensor
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TENSOR_UTIL_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TENSOR_UTIL_H_
diff --git a/tensorflow/core/framework/tracking_allocator.h b/tensorflow/core/framework/tracking_allocator.h
index 661c28969e..5eafce662e 100644
--- a/tensorflow/core/framework/tracking_allocator.h
+++ b/tensorflow/core/framework/tracking_allocator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
-#define TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_
#include <unordered_map>
#include "tensorflow/core/framework/allocator.h"
@@ -130,4 +130,4 @@ class TrackingAllocator : public Allocator {
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TRACKING_ALLOCATOR_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TRACKING_ALLOCATOR_H_
diff --git a/tensorflow/core/framework/type_index.h b/tensorflow/core/framework/type_index.h
index b978d90fa8..989fc42e26 100644
--- a/tensorflow/core/framework/type_index.h
+++ b/tensorflow/core/framework/type_index.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_
-#define TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
#include <string>
#if defined(__GXX_RTTI) || defined(_CPPRTTI)
@@ -84,4 +84,4 @@ inline TypeIndex MakeTypeIndex() {
#endif // __GXX_RTTI
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TYPE_INDEX_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
diff --git a/tensorflow/core/framework/type_traits.h b/tensorflow/core/framework/type_traits.h
index e8351e494f..96fbf92938 100644
--- a/tensorflow/core/framework/type_traits.h
+++ b/tensorflow/core/framework/type_traits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
-#define TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_
#include <limits>
#include <utility>
@@ -106,4 +106,4 @@ struct is_signed<tensorflow::qint32> : public is_signed<tensorflow::int32> {};
} // namespace std
-#endif // TENSORFLOW_FRAMEWORK_TYPE_TRAITS_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_TRAITS_H_
diff --git a/tensorflow/core/framework/types.h b/tensorflow/core/framework/types.h
index ff7c9855d6..2e96b05787 100644
--- a/tensorflow/core/framework/types.h
+++ b/tensorflow/core/framework/types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_TYPES_H_
-#define TENSORFLOW_FRAMEWORK_TYPES_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
+#define TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
#include <map>
#include <set>
@@ -30,7 +30,6 @@ limitations under the License.
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/resource_handle.h"
#include "tensorflow/core/framework/types.pb.h"
-#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -39,6 +38,8 @@ limitations under the License.
namespace tensorflow {
+class Variant;
+
// MemoryType is used to describe whether input or output Tensors of
// an OpKernel should reside in "Host memory" (e.g., CPU memory) or
// "Device" Memory (CPU memory for CPU devices, GPU memory for GPU
@@ -481,4 +482,4 @@ bool DataTypeAlwaysOnHost(DataType dt);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_TYPES_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_TYPES_H_
diff --git a/tensorflow/core/framework/variant.cc b/tensorflow/core/framework/variant.cc
index 5a507804b0..d43e3c72ec 100644
--- a/tensorflow/core/framework/variant.cc
+++ b/tensorflow/core/framework/variant.cc
@@ -23,11 +23,11 @@ limitations under the License.
namespace tensorflow {
-bool Variant::TryDecode(Variant* out) const {
- const VariantTensorDataProto* p = get<VariantTensorDataProto>();
- if (p == nullptr) return false;
- VariantTensorData data(*p);
- return out->Decode(data);
+bool Variant::Decode(VariantTensorData data) {
+ if (!is_empty()) {
+ return value_->Decode(std::move(data));
+ }
+ return true;
}
template <>
@@ -54,13 +54,12 @@ string TypeNameVariant(const VariantTensorDataProto& value) {
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data) {
- data->FromProto(value);
+ data->FromConstProto(value);
}
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value) {
- data.ToProto(value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
+ data->ToProto(value);
return true;
}
@@ -70,8 +69,8 @@ void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
}
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value) {
- return value->ParseFromString(buf);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
+ return value->ParseFromString(*buf);
}
void EncodeVariantList(const Variant* variant_array, int64 n,
@@ -93,8 +92,10 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
if (variant_array[i].is_empty()) {
variant_array[i] = VariantTensorDataProto();
}
+ // TODO(ebrevdo): Replace with StringPiece? Any way to make this a
+ // zero-copy operation that keeps a reference to the data in d?
string str(d->Data(sizes[i]), sizes[i]);
- if (!variant_array[i].Decode(str)) return false;
+ if (!variant_array[i].Decode(std::move(str))) return false;
if (!DecodeUnaryVariant(&variant_array[i])) {
LOG(ERROR) << "Could not decode variant with type_name: \""
<< variant_array[i].TypeName()
diff --git a/tensorflow/core/framework/variant.h b/tensorflow/core/framework/variant.h
index c02391dae3..10eabbc85f 100644
--- a/tensorflow/core/framework/variant.h
+++ b/tensorflow/core/framework/variant.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_H_
-#define TENSORFLOW_FRAMEWORK_VARIANT_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
#include <functional>
#include <iostream>
@@ -23,7 +23,6 @@ limitations under the License.
#include <unordered_map>
#include <utility>
-#include "tensorflow/core/framework/tensor.pb.h" // TODO(b/62899350): Remove
#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/core/status.h"
@@ -38,17 +37,19 @@ string TypeNameVariant(const T& value);
template <typename T>
string DebugStringVariant(const T& value);
+// Allows for specializations of Variant Decoding. `data` may be modified in
+// the process of decoding to `value`.
template <typename T>
-void EncodeVariant(const T& value, VariantTensorData* data);
+bool DecodeVariant(VariantTensorData* data, T* value);
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value);
+bool DecodeVariant(string* buf, T* value);
template <typename T>
-void EncodeVariant(const T& value, string* buf);
+void EncodeVariant(const T& value, VariantTensorData* data);
template <typename T>
-bool DecodeVariant(const string& buf, T* value);
+void EncodeVariant(const T& value, string* buf);
// This is an implementation of a type-erased container that can store an
// object of any type. The implementation is very similar to std::any, but has
@@ -67,7 +68,7 @@ bool DecodeVariant(const string& buf, T* value);
//
// string TypeName() const;
// void Encode(VariantTensorData* data) const;
-// void Decode(const VariantTensorData& data);
+// void Decode(VariantTensorData data);
//
// Simple POD types can elide the Encode/Decode functions, they are provided by
// helper methods.
@@ -121,7 +122,7 @@ bool DecodeVariant(const string& buf, T* value);
// x.Encode(&serialized_f);
//
// Variant y = Foo(); // default constructed Foo.
-// y.Decode(&serialized_f);
+// y.Decode(std::move(serialized_f));
// EXPECT_EQ(*x.get<Foo>(), *y.get<Foo>());
//
//
@@ -145,10 +146,6 @@ bool DecodeVariant(const string& buf, T* value);
// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
// EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
// y_type_unknown.TypeId());
-// // Decode and get y_type_unknown; compare to value in x.
-// Foo f_decoded;
-// EXPECT_TRUE(x.MaybeDecodeAndCopy(&f_decoded));
-// EXPECT_EQ(f_decoded, f);
//
class Variant {
public:
@@ -241,12 +238,7 @@ class Variant {
}
// Deserialize `data` and update the stored object.
- bool Decode(const VariantTensorData& data) {
- if (!is_empty()) {
- return value_->Decode(data);
- }
- return true;
- }
+ bool Decode(VariantTensorData data);
// Helper methods to directly serialize/deserialize from strings.
void Encode(string* buf) const {
@@ -254,31 +246,13 @@ class Variant {
value_->Encode(buf);
}
}
- bool Decode(const string& buf) {
+ bool Decode(string buf) {
if (!is_empty()) {
- return value_->Decode(buf);
+ return value_->Decode(std::move(buf));
}
return true;
}
- template <typename T>
- bool MaybeDecodeAndCopy(T* out) const {
- const T* ret = get<T>();
- if (ret != nullptr) {
- *out = std::move(*ret);
- return true;
- };
- Variant decoded = T();
- if (!TryDecode(&decoded)) return false;
- T* decoded_ret = decoded.get<T>();
- CHECK_NOTNULL(decoded_ret);
- *out = std::move(*decoded_ret);
- return true;
- }
-
- private:
- bool TryDecode(Variant* out) const;
-
private:
struct in_place_t {};
static constexpr in_place_t in_place{};
@@ -292,9 +266,9 @@ class Variant {
virtual string TypeName() const = 0;
virtual string DebugString() const = 0;
virtual void Encode(VariantTensorData* data) const = 0;
- virtual bool Decode(const VariantTensorData& data) = 0;
+ virtual bool Decode(VariantTensorData data) = 0;
virtual void Encode(string* buf) const = 0;
- virtual bool Decode(const string& data) = 0;
+ virtual bool Decode(string data) = 0;
};
template <typename T>
@@ -325,15 +299,13 @@ class Variant {
EncodeVariant(value, data);
}
- bool Decode(const VariantTensorData& data) override {
- return DecodeVariant(data, &value);
+ bool Decode(VariantTensorData data) override {
+ return DecodeVariant(&data, &value);
}
void Encode(string* buf) const override { EncodeVariant(value, buf); }
- bool Decode(const string& buf) override {
- return DecodeVariant(buf, &value);
- }
+ bool Decode(string buf) override { return DecodeVariant(&buf, &value); }
T value;
};
@@ -351,4 +323,4 @@ const void* Variant::get() const;
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_H_
diff --git a/tensorflow/core/framework/variant_encode_decode.h b/tensorflow/core/framework/variant_encode_decode.h
index ded04b2a30..5e08e5a7a6 100644
--- a/tensorflow/core/framework/variant_encode_decode.h
+++ b/tensorflow/core/framework/variant_encode_decode.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
-#define TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
#include <iostream>
#include <type_traits>
@@ -22,6 +22,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/abi.h"
@@ -81,7 +82,7 @@ void EncodeVariantImpl(const T& value,
// Specialization for POD type
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, true /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -90,7 +91,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for tensorflow::Tensor
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, true /* Tensor */,
false /* protobuf */>,
T* value) {
@@ -100,7 +101,7 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for protobuf
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
true /* protobuf */>,
T* value) {
@@ -111,11 +112,11 @@ bool DecodeVariantImpl(const VariantTensorData& data,
// Specialization for other types
template <typename T>
-bool DecodeVariantImpl(const VariantTensorData& data,
+bool DecodeVariantImpl(VariantTensorData data,
TypeResolver<T, false /* is_pod */, false /* Tensor */,
false /* protobuf */>,
T* value) {
- return value->Decode(data);
+ return value->Decode(std::move(data));
}
template <typename C, typename = void>
@@ -224,8 +225,8 @@ void EncodeVariant(const T& value, VariantTensorData* data) {
}
template <typename T>
-bool DecodeVariant(const VariantTensorData& data, T* value) {
- return DecodeVariantImpl(data, TypeResolver<T>(), value);
+bool DecodeVariant(VariantTensorData* data, T* value) {
+ return DecodeVariantImpl(std::move(*data), TypeResolver<T>(), value);
}
template <typename T>
@@ -238,26 +239,31 @@ void EncodeVariant(const T& value, string* buf) {
}
template <typename T>
-bool DecodeVariant(const string& buf, T* value) {
+bool DecodeVariant(string* buf, T* value) {
VariantTensorData data;
- if (!data.ParseFromString(buf)) return false;
- if (!DecodeVariantImpl(data, TypeResolver<T>(), value)) return false;
+ if (!data.ParseFromString(*buf)) return false;
+ if (!DecodeVariantImpl(std::move(data), TypeResolver<T>(), value)) {
+ return false;
+ }
return true;
}
// Specializations for VariantTensorDataProto
template <>
string TypeNameVariant(const VariantTensorDataProto& value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value,
VariantTensorData* data);
+
template <>
-bool DecodeVariant(const VariantTensorData& data,
- VariantTensorDataProto* value);
+bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value);
+
template <>
void EncodeVariant(const VariantTensorDataProto& value, string* buf);
+
template <>
-bool DecodeVariant(const string& buf, VariantTensorDataProto* value);
+bool DecodeVariant(string* buf, VariantTensorDataProto* value);
// Encodes an array of Variant objects in to the given StringListEncoder.
// `variant_array` is assumed to point to an array of `n` Variant objects.
@@ -271,4 +277,4 @@ bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_ENCODE_DECODE_H_
diff --git a/tensorflow/core/framework/variant_op_copy_test.cc b/tensorflow/core/framework/variant_op_copy_test.cc
index 60fa7bd559..daa744e877 100644
--- a/tensorflow/core/framework/variant_op_copy_test.cc
+++ b/tensorflow/core/framework/variant_op_copy_test.cc
@@ -90,15 +90,15 @@ REGISTER_UNARY_VARIANT_DECODE_FUNCTION(StoredTensorValue, "StoredTensorValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyCPUToGPU);
+ StoredTensorValue::CopyCPUToGPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_HOST,
- "StoredTensorValue", StoredTensorValue::CopyGPUToCPU);
+ StoredTensorValue::CopyGPUToCPU);
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
StoredTensorValue, VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
- "StoredTensorValue", StoredTensorValue::CopyGPUToGPU);
+ StoredTensorValue::CopyGPUToGPU);
REGISTER_OP("CreateTestVariant")
.Input("input: T")
diff --git a/tensorflow/core/framework/variant_op_registry.cc b/tensorflow/core/framework/variant_op_registry.cc
index ee07db1aee..ef5b240aea 100644
--- a/tensorflow/core/framework/variant_op_registry.cc
+++ b/tensorflow/core/framework/variant_op_registry.cc
@@ -38,21 +38,19 @@ UnaryVariantOpRegistry* UnaryVariantOpRegistry::Global() {
}
UnaryVariantOpRegistry::VariantShapeFn* UnaryVariantOpRegistry::GetShapeFn(
- StringPiece type_name) {
- auto found = shape_fns.find(type_name);
+ const TypeIndex& type_index) {
+ auto found = shape_fns.find(type_index);
if (found == shape_fns.end()) return nullptr;
return &found->second;
}
-void UnaryVariantOpRegistry::RegisterShapeFn(const string& type_name,
+void UnaryVariantOpRegistry::RegisterShapeFn(const TypeIndex& type_index,
const VariantShapeFn& shape_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantShape";
- VariantShapeFn* existing = GetShapeFn(type_name);
+ VariantShapeFn* existing = GetShapeFn(type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantShapeFn for type_name: " << type_name
- << " already registered";
- shape_fns.insert(std::pair<StringPiece, VariantShapeFn>(
- GetPersistentStringPiece(type_name), shape_fn));
+ << "Unary VariantShapeFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name()) << " already registered";
+ shape_fns.insert(std::pair<TypeIndex, VariantShapeFn>(type_index, shape_fn));
}
Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
@@ -60,11 +58,11 @@ Status GetUnaryVariantShape(const Tensor& variant_tensor, TensorShape* shape) {
CHECK_EQ(variant_tensor.dims(), 0);
const Variant& v = variant_tensor.scalar<Variant>()();
UnaryVariantOpRegistry::VariantShapeFn* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetShapeFn(v.TypeId());
if (shape_fn == nullptr) {
return errors::Internal(
- "No unary variant shape function found for Variant type_name: ",
- v.TypeName());
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(v.TypeId().name()));
}
return (*shape_fn)(v, shape);
}
@@ -79,7 +77,7 @@ Status ScalarShape(const T&, TensorShape* shape) {
} // namespace
#define REGISTER_VARIANT_SHAPE_TYPE(T) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, TF_STR(T), ScalarShape<T>);
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, ScalarShape<T>);
// No encode/shape registered for std::complex<> and Eigen::half
// objects yet.
@@ -143,25 +141,24 @@ REGISTER_VARIANT_DECODE_TYPE(double);
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn*
UnaryVariantOpRegistry::GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name) {
- auto found = device_copy_fns.find(std::make_pair(direction, type_name));
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
+ auto found = device_copy_fns.find(std::make_pair(direction, type_index));
if (found == device_copy_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterDeviceCopyFn(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDeviceCopy";
- AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_name);
+ AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
CHECK_EQ(existing, nullptr)
<< "UnaryVariantDeviceCopy for direction: " << direction
- << " and type_name: " << type_name << " already registered";
+ << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
+ << " already registered";
device_copy_fns.insert(
- std::pair<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn>(
- std::make_pair(direction, GetPersistentStringPiece(type_name)),
- device_copy_fn));
+ std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn>(std::make_pair(direction, type_index),
+ device_copy_fn));
}
Status VariantDeviceCopy(
@@ -170,35 +167,34 @@ Status VariantDeviceCopy(
const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
- from.TypeName());
+ from.TypeId());
if (device_copy_fn == nullptr) {
return errors::Internal(
"No unary variant device copy function found for direction: ",
- direction, " and Variant type_name: ", from.TypeName());
+ direction, " and Variant type_index: ",
+ port::MaybeAbiDemangle(from.TypeId().name()));
}
return (*device_copy_fn)(from, to, copy_fn);
}
// Special casing UnaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantUnaryOpFn* UnaryVariantOpRegistry::GetUnaryOpFn(
- VariantUnaryOp op, StringPiece device, StringPiece type_name) {
- auto found = unary_op_fns.find({op, device, type_name});
+ VariantUnaryOp op, StringPiece device, const TypeIndex& type_index) {
+ auto found = unary_op_fns.find({op, device, type_index});
if (found == unary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterUnaryOpFn(
- VariantUnaryOp op, const string& device, const string& type_name,
+ VariantUnaryOp op, const string& device, const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantUnaryOp";
- VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_name);
+ VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantUnaryOpFn for type_name: " << type_name
+ << "Unary VariantUnaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- unary_op_fn));
+ {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
}
namespace {
@@ -212,7 +208,7 @@ Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
- DEVICE_CPU, T, TF_STR(T), \
+ DEVICE_CPU, T, \
ZerosLikeVariantPrimitiveType<T>);
// No zeros_like registered for std::complex<> or Eigen::half objects yet.
@@ -226,24 +222,22 @@ REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
// Special casing BinaryOpFn per op and per device.
UnaryVariantOpRegistry::VariantBinaryOpFn*
UnaryVariantOpRegistry::GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name) {
- auto found = binary_op_fns.find({op, device, type_name});
+ const TypeIndex& type_index) {
+ auto found = binary_op_fns.find({op, device, type_index});
if (found == binary_op_fns.end()) return nullptr;
return &found->second;
}
void UnaryVariantOpRegistry::RegisterBinaryOpFn(
- VariantBinaryOp op, const string& device, const string& type_name,
+ VariantBinaryOp op, const string& device, const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn) {
- CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantBinaryOp";
- VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_name);
+ VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
CHECK_EQ(existing, nullptr)
- << "Unary VariantBinaryOpFn for type_name: " << type_name
+ << "Unary VariantBinaryOpFn for type_index: "
+ << port::MaybeAbiDemangle(type_index.name())
<< " already registered for device type: " << device;
binary_op_fns.insert(std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
- {op, GetPersistentStringPiece(device),
- GetPersistentStringPiece(type_name)},
- add_fn));
+ {op, GetPersistentStringPiece(device), type_index}, add_fn));
}
namespace {
@@ -257,8 +251,7 @@ Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
#define REGISTER_VARIANT_ADD_TYPE(T) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
- T, TF_STR(T), \
- AddVariantPrimitiveType<T>);
+ T, AddVariantPrimitiveType<T>);
// No add registered for std::complex<> or Eigen::half objects yet.
REGISTER_VARIANT_ADD_TYPE(int);
diff --git a/tensorflow/core/framework/variant_op_registry.h b/tensorflow/core/framework/variant_op_registry.h
index c9e8dd2217..7eb37e859f 100644
--- a/tensorflow/core/framework/variant_op_registry.h
+++ b/tensorflow/core/framework/variant_op_registry.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
-#define TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
#include <string>
#include <unordered_set>
@@ -22,10 +22,14 @@ limitations under the License.
#define EIGEN_USE_THREADS
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/type_index.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/hash/hash.h"
+#include "tensorflow/core/platform/abi.h"
namespace tensorflow {
@@ -90,10 +94,11 @@ class UnaryVariantOpRegistry {
AsyncVariantDeviceCopyFn;
// Add a shape lookup function to the registry.
- void RegisterShapeFn(const string& type_name, const VariantShapeFn& shape_fn);
+ void RegisterShapeFn(const TypeIndex& type_index,
+ const VariantShapeFn& shape_fn);
- // Returns nullptr if no shape function was found for the given TypeName.
- VariantShapeFn* GetShapeFn(StringPiece type_name);
+ // Returns nullptr if no shape function was found for the given TypeIndex.
+ VariantShapeFn* GetShapeFn(const TypeIndex& type_index);
// Add a decode function to the registry.
void RegisterDecodeFn(const string& type_name,
@@ -104,33 +109,33 @@ class UnaryVariantOpRegistry {
// Add a copy-to-GPU function to the registry.
void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
- const string& type_name,
+ const TypeIndex& type_index,
const AsyncVariantDeviceCopyFn& device_copy_fn);
// Returns nullptr if no copy function was found for the given
// TypeName and direction.
AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
- const VariantDeviceCopyDirection direction, StringPiece type_name);
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index);
// Add a unary op function to the registry.
void RegisterUnaryOpFn(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantUnaryOpFn& unary_op_fn);
// Returns nullptr if no unary op function was found for the given
// op, device, and TypeName.
VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Add a binary op function to the registry.
void RegisterBinaryOpFn(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const VariantBinaryOpFn& add_fn);
// Returns nullptr if no binary op function was found for the given
// op, device and TypeName.
VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
- StringPiece type_name);
+ const TypeIndex& type_index);
// Get a pointer to a global UnaryVariantOpRegistry object
static UnaryVariantOpRegistry* Global();
@@ -145,24 +150,26 @@ class UnaryVariantOpRegistry {
static std::unordered_set<string>* PersistentStringStorage();
private:
- std::unordered_map<StringPiece, VariantShapeFn, StringPieceHasher> shape_fns;
- std::unordered_map<StringPiece, VariantDecodeFn, StringPieceHasher>
- decode_fns;
+ struct TypeIndexHash {
+ std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
+ };
+
+ gtl::FlatMap<TypeIndex, VariantShapeFn, TypeIndexHash> shape_fns;
+ gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
// Map std::pair<Direction, type_name> to function.
struct PairHash {
template <typename Direction>
- std::size_t operator()(const std::pair<Direction, StringPiece>& x) const {
+ std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
- ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
+ ret = Hash64Combine(ret, std::get<1>(x).hash_code());
return ret;
}
- StringPieceHasher sp_hasher_;
};
- std::unordered_map<std::pair<VariantDeviceCopyDirection, StringPiece>,
- AsyncVariantDeviceCopyFn, PairHash>
+ gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
+ AsyncVariantDeviceCopyFn, PairHash>
device_copy_fns;
// Map std::tuple<Op, device, type_name> to function.
@@ -172,10 +179,11 @@ class UnaryVariantOpRegistry {
// and references therein
template <typename Op>
struct FuncTuple {
- FuncTuple(const Op& op, const StringPiece& dev, const StringPiece& tname)
- : op_type_(op), device_(dev), typename_(tname){};
+ FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
+ : op_type_(op), device_(dev), type_index_(type_index) {}
Op op_type_;
- StringPiece device_, typename_;
+ StringPiece device_;
+ TypeIndex type_index_;
};
// friend declaration for operator==
// needed for clang
@@ -184,11 +192,11 @@ class UnaryVariantOpRegistry {
struct TupleHash {
template <typename Op>
std::size_t operator()(
- const std::tuple<Op, StringPiece, StringPiece>& x) const {
+ const std::tuple<Op, StringPiece, TypeIndex>& x) const {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
- ret = Hash64Combine(ret, sp_hasher_(std::get<2>(x)));
+ ret = Hash64Combine(ret, std::get<2>(x).hash_code());
return ret;
}
@@ -197,14 +205,14 @@ class UnaryVariantOpRegistry {
// The hash of an enum is just its value as a std::size_t.
std::size_t ret = static_cast<std::size_t>(x.op_type_);
ret = Hash64Combine(ret, sp_hasher_(x.device_));
- ret = Hash64Combine(ret, sp_hasher_(x.typename_));
+ ret = Hash64Combine(ret, x.type_index_.hash_code());
return ret;
}
StringPieceHasher sp_hasher_;
};
- std::unordered_map<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
unary_op_fns;
- std::unordered_map<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
+ gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
binary_op_fns;
// Find or insert a string into a persistent string storage
@@ -225,7 +233,7 @@ template <typename Op>
inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
- (lhs.typename_ == rhs.typename_);
+ (lhs.type_index_ == rhs.type_index_);
}
// Gets a TensorShape from a Tensor containing a scalar Variant.
// Returns an Internal error if the Variant does not have a registered shape
@@ -276,7 +284,7 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
Variant* v_out) {
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
- UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeName());
+ UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
if (unary_op_fn == nullptr) {
return errors::Internal(
"No unary variant unary_op function found for unary variant op enum: ",
@@ -297,15 +305,15 @@ Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
template <typename Device>
Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
const Variant& a, const Variant& b, Variant* out) {
- if (a.TypeName() != b.TypeName()) {
+ if (a.TypeId() != b.TypeId()) {
return errors::Internal(
"BianryOpVariants: Variants a and b have different "
- "type names: '",
+ "type ids. Type names: '",
a.TypeName(), "' vs. '", b.TypeName(), "'");
}
const string& device = DeviceName<Device>::value;
UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
- UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeName());
+ UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
if (binary_op_fn == nullptr) {
return errors::Internal(
"No unary variant binary_op function found for binary variant op "
@@ -323,16 +331,18 @@ class UnaryVariantShapeRegistration {
public:
typedef std::function<Status(const T& t, TensorShape*)> LocalVariantShapeFn;
- UnaryVariantShapeRegistration(const string& type_name,
+ UnaryVariantShapeRegistration(const TypeIndex& type_index,
const LocalVariantShapeFn& shape_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterShapeFn(
- type_name,
- [type_name, shape_fn](const Variant& v, TensorShape* s) -> Status {
+ type_index,
+ [type_index_name, shape_fn](const Variant& v,
+ TensorShape* s) -> Status {
const T* t = v.get<T>();
if (t == nullptr) {
return errors::Internal(
- "VariantShapeFn: Could not access object, type_name: ",
- type_name);
+ "VariantShapeFn: Could not access object, type_index: ",
+ type_index_name);
}
return shape_fn(*t, s);
});
@@ -355,11 +365,11 @@ class UnaryVariantDecodeRegistration {
return false;
}
Variant decoded = T();
- VariantTensorData data(*t);
- if (!decoded.Decode(data)) {
+ VariantTensorData data(std::move(*t));
+ if (!decoded.Decode(std::move(data))) {
return false;
}
- *v = std::move(decoded);
+ std::swap(decoded, *v);
return true;
});
}
@@ -372,11 +382,12 @@ class UnaryVariantDeviceCopyRegistration {
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
LocalVariantDeviceCopyFn;
UnaryVariantDeviceCopyRegistration(
- const VariantDeviceCopyDirection direction, const string& type_name,
+ const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
const LocalVariantDeviceCopyFn& device_copy_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
- direction, type_name,
- [type_name, device_copy_fn](
+ direction, type_index,
+ [type_index_name, device_copy_fn](
const Variant& from, Variant* to,
UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
device_copy_tensor_fn) -> Status {
@@ -384,8 +395,8 @@ class UnaryVariantDeviceCopyRegistration {
*to = T();
if (from.get<T>() == nullptr) {
return errors::Internal(
- "VariantCopyToGPUFn: Could not access object, type_name: ",
- type_name);
+ "VariantCopyToGPUFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *from.get<T>();
T* t_out = to->get<T>();
@@ -401,18 +412,19 @@ class UnaryVariantUnaryOpRegistration {
public:
UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantUnaryOpFn& unary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
- op, device, type_name,
- [type_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
- Variant* v_out) -> Status {
+ op, device, type_index,
+ [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
+ Variant* v_out) -> Status {
DCHECK_NE(v_out, nullptr);
*v_out = T();
if (v.get<T>() == nullptr) {
return errors::Internal(
- "VariantUnaryOpFn: Could not access object, type_name: ",
- type_name);
+ "VariantUnaryOpFn: Could not access object, type_index: ",
+ type_index_name);
}
const T& t = *v.get<T>();
T* t_out = v_out->get<T>();
@@ -429,23 +441,25 @@ class UnaryVariantBinaryOpRegistration {
public:
UnaryVariantBinaryOpRegistration(VariantBinaryOp op, const string& device,
- const string& type_name,
+ const TypeIndex& type_index,
const LocalVariantBinaryOpFn& binary_op_fn) {
+ const string type_index_name = port::MaybeAbiDemangle(type_index.name());
UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
- op, device, type_name,
- [type_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
- const Variant& b, Variant* out) -> Status {
+ op, device, type_index,
+ [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
+ const Variant& b,
+ Variant* out) -> Status {
DCHECK_NE(out, nullptr);
*out = T();
if (a.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'a', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'a', type_index: ",
+ type_index_name);
}
if (b.get<T>() == nullptr) {
return errors::Internal(
- "VariantBinaryOpFn: Could not access object 'b', type_name: ",
- type_name);
+ "VariantBinaryOpFn: Could not access object 'b', type_index: ",
+ type_index_name);
}
const T& t_a = *a.get<T>();
const T& t_b = *b.get<T>();
@@ -459,19 +473,19 @@ class UnaryVariantBinaryOpRegistration {
// Register a unary shape variant function with the signature:
// Status ShapeFn(const T& t, TensorShape* s);
-// to Variants having TypeName type_name.
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, type_name, shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name, \
- shape_function)
+// to Variants having TypeIndex type_index.
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(T, shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, MakeTypeIndex<T>(), shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_name, \
- shape_function) \
- REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, shape_function)
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ_HELPER(ctr, T, type_index, \
+ shape_function) \
+ REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, shape_function)
-#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_name, \
+#define REGISTER_UNARY_VARIANT_SHAPE_FUNCTION_UNIQ(ctr, T, type_index, \
shape_function) \
static variant_op_registry_fn_registration::UnaryVariantShapeRegistration<T> \
- register_unary_variant_op_shape_registration_fn_##ctr(type_name, \
+ register_unary_variant_op_shape_registration_fn_##ctr(type_index, \
shape_function)
// Register a unary decode variant function for the given type.
@@ -519,65 +533,65 @@ class UnaryVariantBinaryOpRegistration {
// ****** NOTE ******
// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
// ****** NOTE ******
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- T, direction, type_name, device_copy_fn) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, T, direction, type_name, device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
+ device_copy_fn) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
- ctr, T, direction, type_name, device_copy_fn) \
+ ctr, T, direction, type_index, device_copy_fn) \
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn)
+ ctr, T, direction, type_index, device_copy_fn)
-#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
- ctr, T, direction, type_name, device_copy_fn) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantDeviceCopyRegistration<T> \
- register_unary_variant_op_device_copy_fn_##ctr(direction, type_name, \
- device_copy_fn)
+#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
+ ctr, T, direction, type_index, device_copy_fn) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantDeviceCopyRegistration<T> \
+ register_unary_variant_op_device_copy_fn_##ctr( \
+ direction, type_index, device_copy_fn)
// Register a unary unary_op variant function with the signature:
// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for UnaryVariantOp enum op.
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, type_name, \
- unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
+ unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
-#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, unary_op_function) \
- REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, type_name, \
- unary_op_function)
+#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
+ ctr, op, device, T, type_index, unary_op_function) \
+ REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
+ type_index, unary_op_function)
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, unary_op_function) \
+ ctr, op, device, T, type_index, unary_op_function) \
static variant_op_registry_fn_registration::UnaryVariantUnaryOpRegistration< \
T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
unary_op_function)
// Register a binary_op variant function with the signature:
// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
-// to Variants having TypeName type_name, for device string device,
+// to Variants having TypeIndex type_index, for device string device,
// for BinaryVariantOp enum OP.
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, type_name, \
- binary_op_function) \
- REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- __COUNTER__, op, device, T, type_name, binary_op_function)
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
+ binary_op_function) \
+ REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
+ __COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
- ctr, op, device, T, type_name, binary_op_function) \
+ ctr, op, device, T, type_index, binary_op_function) \
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function)
+ ctr, op, device, T, type_index, binary_op_function)
-#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
- ctr, op, device, T, type_name, binary_op_function) \
- static variant_op_registry_fn_registration:: \
- UnaryVariantBinaryOpRegistration<T> \
- register_unary_variant_op_decoder_fn_##ctr(op, device, type_name, \
+#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
+ ctr, op, device, T, type_index, binary_op_function) \
+ static variant_op_registry_fn_registration:: \
+ UnaryVariantBinaryOpRegistration<T> \
+ register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
binary_op_function)
} // end namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_OP_REGISTRY_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
diff --git a/tensorflow/core/framework/variant_op_registry_test.cc b/tensorflow/core/framework/variant_op_registry_test.cc
index 7055e62c0e..b2443e8676 100644
--- a/tensorflow/core/framework/variant_op_registry_test.cc
+++ b/tensorflow/core/framework/variant_op_registry_test.cc
@@ -89,41 +89,37 @@ struct VariantValue {
int value;
};
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, "TEST VariantValue",
- VariantValue::ShapeFn);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(VariantValue, VariantValue::ShapeFn);
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(VariantValue, "TEST VariantValue");
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(
VariantValue, VariantDeviceCopyDirection::HOST_TO_DEVICE,
- "TEST VariantValue", VariantValue::CPUToGPUCopyFn);
+ VariantValue::CPUToGPUCopyFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, VariantValue,
- "TEST VariantValue",
VariantValue::CPUZerosLikeFn);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, VariantValue,
- "TEST VariantValue",
VariantValue::GPUZerosLikeFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- VariantValue, "TEST VariantValue",
- VariantValue::CPUAddFn);
+ VariantValue, VariantValue::CPUAddFn);
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- VariantValue, "TEST VariantValue",
- VariantValue::GPUAddFn);
+ VariantValue, VariantValue::GPUAddFn);
} // namespace
TEST(VariantOpShapeRegistryTest, TestBasic) {
- EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn("YOU SHALL NOT PASS"),
+ class Blah {};
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetShapeFn(MakeTypeIndex<Blah>()),
nullptr);
- auto* shape_fn =
- UnaryVariantOpRegistry::Global()->GetShapeFn("TEST VariantValue");
+ auto* shape_fn = UnaryVariantOpRegistry::Global()->GetShapeFn(
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(shape_fn, nullptr);
TensorShape shape;
@@ -142,10 +138,11 @@ TEST(VariantOpShapeRegistryTest, TestBasic) {
TEST(VariantOpShapeRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantShapeFn f;
- string kTypeName = "fjfjfj";
- registry.RegisterShapeFn(kTypeName, f);
- EXPECT_DEATH(registry.RegisterShapeFn(kTypeName, f),
- "fjfjfj already registered");
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
+ registry.RegisterShapeFn(kTypeIndex, f);
+ EXPECT_DEATH(registry.RegisterShapeFn(kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpDecodeRegistryTest, TestBasic) {
@@ -180,13 +177,14 @@ TEST(VariantOpDecodeRegistryTest, TestDuplicate) {
TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
// No registered copy fn for GPU<->GPU.
- EXPECT_EQ(
- UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::DEVICE_TO_DEVICE, "TEST VariantValue"),
- nullptr);
+ EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
+ VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
+ MakeTypeIndex<VariantValue>()),
+ nullptr);
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, "TEST VariantValue");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE,
+ MakeTypeIndex<VariantValue>());
EXPECT_NE(copy_to_gpu_fn, nullptr);
VariantValue vv{true /* early_exit */};
@@ -208,17 +206,19 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
- kTypeName, f);
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
- VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeName, f),
- "fjfjfj already registered");
+ VariantDeviceCopyDirection::HOST_TO_DEVICE, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -242,8 +242,9 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
- ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
@@ -269,25 +270,26 @@ TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantUnaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_CPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_CPU, kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, kTypeName,
- f);
+ registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU,
+ kTypeIndex, f);
EXPECT_DEATH(registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP,
- DEVICE_GPU, kTypeName, f),
- "fjfjfj already registered");
+ DEVICE_GPU, kTypeIndex, f),
+ "FjFjFj already registered");
}
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
- return;
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_CPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -312,8 +314,9 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
#if GOOGLE_CUDA
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
+ class Blah {};
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
- ADD_VARIANT_BINARY_OP, DEVICE_GPU, "YOU SHALL NOT PASS"),
+ ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
nullptr);
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
@@ -340,17 +343,18 @@ TEST(VariantOpAddRegistryTest, TestBasicGPU) {
TEST(VariantOpAddRegistryTest, TestDuplicate) {
UnaryVariantOpRegistry registry;
UnaryVariantOpRegistry::VariantBinaryOpFn f;
- string kTypeName = "fjfjfj";
+ class FjFjFj {};
+ const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
- registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeName, f);
+ registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU, kTypeIndex, f);
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- kTypeName, f),
- "fjfjfj already registered");
+ kTypeIndex, f),
+ "FjFjFj already registered");
}
} // namespace tensorflow
diff --git a/tensorflow/core/framework/variant_tensor_data.cc b/tensorflow/core/framework/variant_tensor_data.cc
index 99712dc114..3e67e4a864 100644
--- a/tensorflow/core/framework/variant_tensor_data.cc
+++ b/tensorflow/core/framework/variant_tensor_data.cc
@@ -22,8 +22,8 @@ namespace tensorflow {
VariantTensorData::VariantTensorData() {}
-VariantTensorData::VariantTensorData(const VariantTensorDataProto& proto) {
- FromProto(proto);
+VariantTensorData::VariantTensorData(VariantTensorDataProto proto) {
+ FromProto(std::move(proto));
}
VariantTensorData::~VariantTensorData() {}
@@ -52,7 +52,19 @@ void VariantTensorData::ToProto(VariantTensorDataProto* proto) const {
}
}
-bool VariantTensorData::FromProto(const VariantTensorDataProto& proto) {
+bool VariantTensorData::FromProto(VariantTensorDataProto proto) {
+ // TODO(ebrevdo): Do this lazily.
+ set_type_name(proto.type_name());
+ set_metadata(proto.metadata());
+ for (const auto& tensor : proto.tensors()) {
+ Tensor tmp;
+ if (!tmp.FromProto(tensor)) return false;
+ tensors_.push_back(tmp);
+ }
+ return true;
+}
+
+bool VariantTensorData::FromConstProto(const VariantTensorDataProto& proto) {
set_type_name(proto.type_name());
set_metadata(proto.metadata());
for (const auto& tensor : proto.tensors()) {
@@ -75,10 +87,10 @@ bool VariantTensorData::SerializeToString(string* buf) {
return proto.SerializeToString(buf);
}
-bool VariantTensorData::ParseFromString(const string& s) {
+bool VariantTensorData::ParseFromString(string s) {
VariantTensorDataProto proto;
const bool status = proto.ParseFromString(s);
- if (status) FromProto(proto);
+ if (status) FromProto(std::move(proto));
return status;
}
diff --git a/tensorflow/core/framework/variant_tensor_data.h b/tensorflow/core/framework/variant_tensor_data.h
index 1d87bc341a..8a240ee1e3 100644
--- a/tensorflow/core/framework/variant_tensor_data.h
+++ b/tensorflow/core/framework/variant_tensor_data.h
@@ -13,19 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H
-#define TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H
+#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_
+#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_
#include <algorithm>
#include <vector>
+#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class VariantTensorDataProto;
-class Tensor;
// The serialization format for Variant objects. Objects with references to
// other Tensors can simply store those tensors in the `tensors` field, and
@@ -38,7 +38,7 @@ class Tensor;
class VariantTensorData {
public:
VariantTensorData();
- VariantTensorData(const VariantTensorDataProto& proto);
+ VariantTensorData(VariantTensorDataProto proto);
~VariantTensorData();
// Name of the type of objects being serialized.
@@ -68,12 +68,14 @@ class VariantTensorData {
// Conversion to and from VariantTensorDataProto
void ToProto(VariantTensorDataProto* proto) const;
- bool FromProto(const VariantTensorDataProto& proto);
+ // This allows optimizations via std::move.
+ bool FromProto(VariantTensorDataProto proto);
+ bool FromConstProto(const VariantTensorDataProto& proto);
// Serialization via VariantTensorDataProto
string SerializeAsString() const;
bool SerializeToString(string* buf);
- bool ParseFromString(const string& s);
+ bool ParseFromString(string s);
string DebugString() const;
@@ -112,4 +114,4 @@ string ProtoDebugString(const VariantTensorData& object);
} // namespace tensorflow
-#endif // TENSORFLOW_FRAMEWORK_VARIANT_TENSOR_DATA_H
+#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_TENSOR_DATA_H_
diff --git a/tensorflow/core/framework/variant_test.cc b/tensorflow/core/framework/variant_test.cc
index eef5c47d15..08d09de7b8 100644
--- a/tensorflow/core/framework/variant_test.cc
+++ b/tensorflow/core/framework/variant_test.cc
@@ -144,8 +144,8 @@ TEST(VariantTest, TypeMismatch) {
struct TensorList {
void Encode(VariantTensorData* data) const { data->tensors_ = vec; }
- bool Decode(const VariantTensorData& data) {
- vec = data.tensors_;
+ bool Decode(VariantTensorData data) {
+ vec = std::move(data.tensors_);
return true;
}
@@ -186,7 +186,7 @@ TEST(VariantTest, TensorListTest) {
x.Encode(&serialized);
Variant y = TensorList();
- y.Decode(serialized);
+ y.Decode(std::move(serialized));
const TensorList& decoded_vec = *y.get<TensorList>();
for (int i = 0; i < 4; ++i) {
@@ -204,15 +204,6 @@ TEST(VariantTest, TensorListTest) {
EXPECT_EQ(y_unknown.DebugString(),
strings::StrCat(
"Variant<type: TensorList value: ", data.DebugString(), ">"));
-
- TensorList unknown_decoded_vec;
- EXPECT_TRUE(y_unknown.MaybeDecodeAndCopy(&unknown_decoded_vec));
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i].flat<int>()(0), i);
- }
- for (int i = 0; i < 4; ++i) {
- EXPECT_EQ(unknown_decoded_vec.vec[i + 4].flat<float>()(0), 2 * i);
- }
}
TEST(VariantTest, VariantArray) {
diff --git a/tensorflow/core/graph/algorithm.h b/tensorflow/core/graph/algorithm.h
index 5bbbc6f6dc..45f8a29a92 100644
--- a/tensorflow/core/graph/algorithm.h
+++ b/tensorflow/core/graph/algorithm.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_ALGORITHM_H_
-#define TENSORFLOW_GRAPH_ALGORITHM_H_
+#ifndef TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
+#define TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
#include <functional>
#include <unordered_set>
@@ -117,4 +117,4 @@ bool FixupSourceAndSinkEdges(Graph* g);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_ALGORITHM_H_
+#endif // TENSORFLOW_CORE_GRAPH_ALGORITHM_H_
diff --git a/tensorflow/core/graph/colors.h b/tensorflow/core/graph/colors.h
index c1e1940cac..43d2225571 100644
--- a/tensorflow/core/graph/colors.h
+++ b/tensorflow/core/graph/colors.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_COLORS_H_
-#define TENSORFLOW_GRAPH_COLORS_H_
+#ifndef TENSORFLOW_CORE_GRAPH_COLORS_H_
+#define TENSORFLOW_CORE_GRAPH_COLORS_H_
namespace tensorflow {
@@ -26,4 +26,4 @@ const char* ColorFor(int dindex);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_COLORS_H_
+#endif // TENSORFLOW_CORE_GRAPH_COLORS_H_
diff --git a/tensorflow/core/graph/control_flow.h b/tensorflow/core/graph/control_flow.h
index 548820720b..5abe77f5a1 100644
--- a/tensorflow/core/graph/control_flow.h
+++ b/tensorflow/core/graph/control_flow.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_CONTROL_FLOW_H_
-#define TENSORFLOW_GRAPH_CONTROL_FLOW_H_
+#ifndef TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_
+#define TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_
#include <vector>
@@ -48,4 +48,4 @@ Status BuildControlFlowInfo(const Graph* g, std::vector<ControlFlowInfo>* info,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_CONTROL_FLOW_H_
+#endif // TENSORFLOW_CORE_GRAPH_CONTROL_FLOW_H_
diff --git a/tensorflow/core/graph/costmodel.h b/tensorflow/core/graph/costmodel.h
index 9b703e4693..2d94dd5cdc 100644
--- a/tensorflow/core/graph/costmodel.h
+++ b/tensorflow/core/graph/costmodel.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_COSTMODEL_H_
-#define TENSORFLOW_GRAPH_COSTMODEL_H_
+#ifndef TENSORFLOW_CORE_GRAPH_COSTMODEL_H_
+#define TENSORFLOW_CORE_GRAPH_COSTMODEL_H_
#include <unordered_map>
#include <vector>
@@ -229,4 +229,4 @@ class CostModel {
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_COSTMODEL_H_
+#endif // TENSORFLOW_CORE_GRAPH_COSTMODEL_H_
diff --git a/tensorflow/core/graph/default_device.h b/tensorflow/core/graph/default_device.h
index 68d7c8e553..f0f53c91f4 100644
--- a/tensorflow/core/graph/default_device.h
+++ b/tensorflow/core/graph/default_device.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
-#define TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+#ifndef TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_
+#define TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_
#include <string>
@@ -38,4 +38,4 @@ inline void SetDefaultDevice(const string& device, GraphDef* graph_def) {
} // namespace graph
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_DEFAULT_DEVICE_H_
+#endif // TENSORFLOW_CORE_GRAPH_DEFAULT_DEVICE_H_
diff --git a/tensorflow/core/graph/gradients.cc b/tensorflow/core/graph/gradients.cc
index c1a8a63784..bec41712b1 100644
--- a/tensorflow/core/graph/gradients.cc
+++ b/tensorflow/core/graph/gradients.cc
@@ -65,16 +65,37 @@ struct NodeOutEq {
static Node* AddZerosLike(Graph* g, NodeOut input) {
DCHECK_LT(0, input.dtype());
DCHECK_LT(input.dtype(), DT_FLOAT_REF);
- NodeDef ndef;
- ndef.set_name(g->NewName(kNodeLabel));
- ndef.set_op("ZerosLike");
- ndef.add_input(input.name());
- AddNodeAttr("T", input.dtype(), &ndef);
- Status s;
- Node* ret = g->AddNode(ndef, &s);
- TF_CHECK_OK(s);
- g->AddEdge(input.node, input.index, ret, 0);
- return ret;
+ if (input.dtype() == DT_RESOURCE) {
+ NodeDef read_def;
+ read_def.set_name(g->NewName("Read"));
+ read_def.set_op("ReadVariableOp");
+ read_def.add_input(input.name());
+ AddNodeAttr("dtype", DT_FLOAT, &read_def);
+ Status s;
+ Node* read = g->AddNode(read_def, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, read, 0);
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(read_def.name());
+ AddNodeAttr("T", DT_FLOAT, &ndef);
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(read, 0, ret, 0);
+ return ret;
+ } else {
+ NodeDef ndef;
+ ndef.set_name(g->NewName(kNodeLabel));
+ ndef.set_op("ZerosLike");
+ ndef.add_input(input.name());
+ AddNodeAttr("T", input.dtype(), &ndef);
+ Status s;
+ Node* ret = g->AddNode(ndef, &s);
+ TF_CHECK_OK(s);
+ g->AddEdge(input.node, input.index, ret, 0);
+ return ret;
+ }
}
static Node* AddSymGrad(Graph* g, Node* n, gtl::ArraySlice<NodeOut> grads) {
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 568f0870c0..7a4a0096fa 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -192,6 +192,11 @@ void Node::ClearAttr(const string& name) {
(*props_->node_def.mutable_attr()).erase(name);
}
+void Node::set_name(string name) {
+ MaybeCopyOnWrite();
+ props_->node_def.set_name(std::move(name));
+}
+
void Node::set_requested_device(const string& device) {
MaybeCopyOnWrite();
props_->node_def.set_device(device);
@@ -483,7 +488,7 @@ const Edge* Graph::AddControlEdge(Node* source, Node* dest,
void Graph::RemoveControlEdge(const Edge* e) {
if (!e->src_->IsSource() && !e->dst_->IsSink()) {
e->dst_->MaybeCopyOnWrite();
- std::string e_src_name = strings::StrCat("^", e->src_->name());
+ string e_src_name = strings::StrCat("^", e->src_->name());
auto* inputs = e->dst_->props_->node_def.mutable_input();
for (auto it = inputs->begin(); it != inputs->end(); ++it) {
if (*it == e_src_name) {
@@ -495,6 +500,15 @@ void Graph::RemoveControlEdge(const Edge* e) {
RemoveEdge(e);
}
+namespace {
+const Edge* FindEdge(const Node* dst, int index) {
+ for (const Edge* e : dst->in_edges()) {
+ if (e->dst_input() == index) return e;
+ }
+ return nullptr;
+}
+} // namespace
+
Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
int dst_index) {
TF_RETURN_IF_ERROR(IsValidOutputTensor(new_src, new_src_index));
@@ -512,17 +526,6 @@ Status Graph::UpdateEdge(Node* new_src, int new_src_index, Node* dst,
return Status::OK();
}
-const Edge* Graph::FindEdge(const Node* dst, int index) {
- for (const Edge* e : edges_) {
- // edges_ will contain null edges if RemoveEdge() was called.
- if (e == nullptr) continue;
- if (e->dst() == dst && e->dst_input() == index) {
- return e;
- }
- }
- return nullptr;
-}
-
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
// Need a new-enough consumer to support the functions we add to the graph.
if (fdef_lib.function_size() > 0 && versions_->min_consumer() < 12) {
@@ -645,7 +648,7 @@ Status Graph::IsValidNode(const Node* node) const {
Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
- if (idx >= node->num_outputs()) {
+ if (idx >= node->num_outputs() || idx < 0) {
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of outputs: ", node->num_outputs(),
@@ -656,7 +659,7 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
Status Graph::IsValidInputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
- if (idx >= node->num_inputs()) {
+ if (idx >= node->num_inputs() || idx < 0) {
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of inputs: ", node->num_inputs(),
@@ -721,7 +724,7 @@ Status Graph::AddWhileContext(StringPiece frame_name,
std::vector<OutputTensor> body_outputs,
WhileContext** result) {
auto pair = while_ctxs_.insert(std::pair<string, WhileContext>(
- std::string(frame_name),
+ string(frame_name),
WhileContext(frame_name, std::move(enter_nodes), std::move(exit_nodes),
cond_output, std::move(body_inputs),
std::move(body_outputs))));
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index a147c94689..2944951f82 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -72,6 +72,7 @@ class Node {
int id() const { return id_; }
int cost_id() const { return cost_id_; }
const string& name() const;
+ void set_name(string name);
const string& type_string() const;
// def() provides the NodeDef the user supplied, but the specifics
@@ -590,12 +591,12 @@ class Graph {
// Returns OK if `node` is non-null and belongs to this graph
Status IsValidNode(const Node* node) const;
- // Returns OK if IsValidNode(`node`) and `idx` is less than
- // node->num_outputs()
+ // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
+ // accept control outputs.
Status IsValidOutputTensor(const Node* node, int idx) const;
- // Returns OK if IsValidNode(`node`) and `idx` is less than
- // node->num_inputs()
+ // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
+ // control inputs.
Status IsValidInputTensor(const Node* node, int idx) const;
// Create and return a new WhileContext owned by this graph. This is called
@@ -680,10 +681,6 @@ class Graph {
// AddWhileContext() or Node::while_ctx(), but this manages the lifetime.
std::map<string, WhileContext> while_ctxs_;
- // Searches through edges_ for the Edge whose destination node and index
- // matches dst. An edge with destination `dst` must exist in the graph.
- const Edge* FindEdge(const Node* dst, int index);
-
TF_DISALLOW_COPY_AND_ASSIGN(Graph);
};
diff --git a/tensorflow/core/graph/graph_constructor.cc b/tensorflow/core/graph/graph_constructor.cc
index 8c73f8f712..eeb5c14eaa 100644
--- a/tensorflow/core/graph/graph_constructor.cc
+++ b/tensorflow/core/graph/graph_constructor.cc
@@ -513,7 +513,7 @@ Status GraphConstructor::InitFromEdges() {
num_control_edges++;
} else {
TensorId id(ParseTensorName(input_name));
- if (next_iteration_nodes_.find(std::string(id.first)) !=
+ if (next_iteration_nodes_.find(string(id.first)) !=
next_iteration_nodes_.end()) {
has_loop_back_edge = true;
}
@@ -835,7 +835,7 @@ void GraphConstructor::UniquifyNames(
// We require that UniquifyNames() is called on all NodeDefs in topological
// order. This guarantees that node_def's inputs will already be uniquified
// if necessary.
- auto iter = uniquified_names_.find(std::string(id.first));
+ auto iter = uniquified_names_.find(string(id.first));
if (iter == uniquified_names_.end()) continue;
id.first = iter->second;
node_def->set_input(i, id.ToString());
@@ -854,7 +854,7 @@ void GraphConstructor::UpdateUniquifiedColocationNames() {
for (int i = 0; i < coloc_values.size(); ++i) {
StringPiece val(coloc_values[i]);
if (str_util::ConsumePrefix(&val, kColocationGroupPrefix)) {
- const auto& name_pair = uniquified_names_.find(std::string(val));
+ const auto& name_pair = uniquified_names_.find(string(val));
if (name_pair == uniquified_names_.end()) continue;
updated = true;
coloc_values[i] =
@@ -880,7 +880,7 @@ bool GraphConstructor::NameExistsInGraphDef(StringPiece name) {
}
string GraphConstructor::FindUniqueName(StringPiece original_name) {
- string name = std::string(original_name);
+ string name(original_name);
int count = 0;
// Check that any generated names don't collide with imported NodeDefs (as
// well as nodes in g_).
@@ -997,7 +997,7 @@ Status GraphConstructor::Convert() {
src_node->num_outputs(), " outputs");
}
- inputs.emplace_back(std::string(id.first), src_node, src_index);
+ inputs.emplace_back(string(id.first), src_node, src_index);
}
if (has_data_back_edge && !IsMerge(*node_def)) {
@@ -1042,12 +1042,12 @@ Status GraphConstructor::Convert() {
}
if (processed < node_defs_.size()) {
- LOG(WARNING) << "IN " << __func__ << (node_defs_.size() - processed)
+ LOG(WARNING) << "IN " << __func__ << " " << (node_defs_.size() - processed)
<< " NODES IN A CYCLE";
for (int64 i = 0; i < node_defs_.size(); i++) {
if (pending_count_[i] != 0) {
LOG(WARNING) << "PENDING: " << SummarizeNodeDef(*node_defs_[i])
- << "WITH PENDING COUNT = " << pending_count_[i];
+ << " WITH PENDING COUNT = " << pending_count_[i];
}
}
return errors::InvalidArgument(node_defs_.size() - processed,
@@ -1162,7 +1162,9 @@ Status GraphConstructor::PopulateMissingUnusedInputMapKeys() {
const NodeDef* node_def = node_defs_[pair->second.gdef_index];
const OpDef* op_def;
TF_RETURN_IF_ERROR(g_->op_registry()->LookUpOpDef(node_def->op(), &op_def));
- if (key.second >= op_def->output_arg_size()) {
+ int num_outputs;
+ TF_RETURN_IF_ERROR(NumOutputsForNode(*node_def, *op_def, &num_outputs));
+ if (key.second >= num_outputs) {
// key's index out of bounds
missing_unused_input_map_keys_->push_back(key);
}
diff --git a/tensorflow/core/graph/graph_constructor.h b/tensorflow/core/graph/graph_constructor.h
index 889359a68a..f6e41faf9c 100644
--- a/tensorflow/core/graph/graph_constructor.h
+++ b/tensorflow/core/graph/graph_constructor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
-#define TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
@@ -186,4 +186,4 @@ extern void CopyGraph(const Graph& src, Graph* dest);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_CONSTRUCTOR_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_CONSTRUCTOR_H_
diff --git a/tensorflow/core/graph/graph_constructor_test.cc b/tensorflow/core/graph/graph_constructor_test.cc
index e338840eeb..3eef6bd2bd 100644
--- a/tensorflow/core/graph/graph_constructor_test.cc
+++ b/tensorflow/core/graph/graph_constructor_test.cc
@@ -156,9 +156,8 @@ class GraphConstructorTest : public ::testing::Test {
return "";
}
StringPiece loc(value[0]);
- return str_util::ConsumePrefix(&loc, kColocationGroupPrefix)
- ? std::string(loc)
- : "";
+ return str_util::ConsumePrefix(&loc, kColocationGroupPrefix) ? string(loc)
+ : "";
}
string GraphDebugString() const {
@@ -200,6 +199,10 @@ REGISTER_OP("TestOneInputOneOutput")
.Output("y: T")
.Attr("T: {float, int64}")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("TestVariadicOutput")
+ .Output("outputs: N * int32")
+ .Attr("N: int >= 0")
+ .SetShapeFn(shape_inference::UnknownShape);
REGISTER_OP("TestDefaultAttr")
.Attr("default_int: int=31415")
.SetShapeFn(shape_inference::NoOutputs);
@@ -1464,12 +1467,15 @@ TEST_F(GraphConstructorTest, ImportGraphDef_InputMapMissingUnusedKeys) {
opts.input_map[TensorId("DNE", 0)] = TensorId("input", 0);
// Unused but not missing
opts.input_map[TensorId("t1", 0)] = TensorId("W1", 0);
+ // Unused but not missing
+ opts.input_map[TensorId("variadic", 4)] = TensorId("input", 0);
ExpectOK(
R"EOF(
node { name: 'W2' op: 'TestParams' }
node { name: 'new_input' op: 'TestInput' input: [ '^W2' ] }
node { name: 't1' op: 'TestMul' input: [ 'new_input:0', 'new_input:1' ] }
- node { name: 't2' op: 'TestMul' input: [ 't1:0', 't1:0' ] }
+ node { name: 'variadic' op: 'TestVariadicOutput'
+ attr { key: "N" value { i: 5 } } }
)EOF",
opts, &refiner, &results);
diff --git a/tensorflow/core/graph/graph_def_builder.cc b/tensorflow/core/graph/graph_def_builder.cc
index dd84c4f7c7..6d5df7efba 100644
--- a/tensorflow/core/graph/graph_def_builder.cc
+++ b/tensorflow/core/graph/graph_def_builder.cc
@@ -44,12 +44,12 @@ GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputs(
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithNameImpl(
StringPiece name) {
- name_ = std::string(name);
+ name_ = string(name);
return *this;
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithDeviceImpl(
StringPiece device) {
- device_ = std::string(device);
+ device_ = string(device);
return *this;
}
GraphDefBuilder::Options GraphDefBuilder::Options::WithControlInputImpl(
diff --git a/tensorflow/core/graph/graph_def_builder.h b/tensorflow/core/graph/graph_def_builder.h
index 0d6aae4355..400d8b6c84 100644
--- a/tensorflow/core/graph/graph_def_builder.h
+++ b/tensorflow/core/graph/graph_def_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
-#define TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
#include <vector>
#include "tensorflow/core/framework/graph.pb.h"
@@ -128,7 +128,7 @@ class GraphDefBuilder {
Options WithControlInputsImpl(gtl::ArraySlice<Node*> control_inputs);
template <class T>
Options WithAttrImpl(StringPiece name, T&& value) {
- attrs_.emplace_back(std::string(name), AttrValue());
+ attrs_.emplace_back(string(name), AttrValue());
SetAttrValue(std::forward<T>(value), &attrs_.back().second);
return *this;
}
@@ -203,4 +203,4 @@ Node* BinaryOp(const string& op_name, NodeOut a, NodeOut b,
} // namespace ops
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_DEF_BUILDER_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_DEF_BUILDER_H_
diff --git a/tensorflow/core/graph/graph_partition.cc b/tensorflow/core/graph/graph_partition.cc
index ea0a814ab8..1dbcebab59 100644
--- a/tensorflow/core/graph/graph_partition.cc
+++ b/tensorflow/core/graph/graph_partition.cc
@@ -793,7 +793,7 @@ Status TopologicalSortNodesWithTimePriority(
for (int n = 0; n < gdef->node_size(); ++n) {
const NodeDef* ndef = &gdef->node(n);
for (int i = 0; i < ndef->input_size(); ++i) {
- node_to_output_nodes[std::string(ParseTensorName(ndef->input(i)).first)]
+ node_to_output_nodes[string(ParseTensorName(ndef->input(i)).first)]
.push_back(ndef);
}
int64 start_time;
diff --git a/tensorflow/core/graph/graph_partition.h b/tensorflow/core/graph/graph_partition.h
index 67fafddd51..8020c2d247 100644
--- a/tensorflow/core/graph/graph_partition.h
+++ b/tensorflow/core/graph/graph_partition.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
-#define TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#ifndef TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
+#define TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
#include <functional>
#include <string>
@@ -95,4 +95,4 @@ Status AddControlEdges(const PartitionOptions& opts,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_PARTITION_H_
+#endif // TENSORFLOW_CORE_GRAPH_GRAPH_PARTITION_H_
diff --git a/tensorflow/core/graph/mkl_graph_util.h b/tensorflow/core/graph/mkl_graph_util.h
index 333bf761b0..bab1df87a4 100644
--- a/tensorflow/core/graph/mkl_graph_util.h
+++ b/tensorflow/core/graph/mkl_graph_util.h
@@ -41,7 +41,7 @@ namespace tensorflow {
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
// NOTE: Currently, we use contiguous ordering. If you change this, then you
// would need to change Mkl op definitions in nn_ops.cc.
-static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
+static const MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
// Get index of MetaData tensor from index 'n' of Data tensor.
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
diff --git a/tensorflow/core/graph/mkl_layout_pass.cc b/tensorflow/core/graph/mkl_layout_pass.cc
index 5683944e46..7394b1cddf 100644
--- a/tensorflow/core/graph/mkl_layout_pass.cc
+++ b/tensorflow/core/graph/mkl_layout_pass.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
@@ -37,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/util.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_layout_pass.h"
@@ -334,6 +336,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
CopyAttrsConv2D, AlwaysRewrite, nullptr});
+
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite, nullptr});
@@ -546,14 +549,14 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// If Op has been specifically assigned to a non-CPU device, then No.
if (!n->assigned_device_name().empty() &&
- !str_util::StrContains(n->assigned_device_name(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->assigned_device_name(), kCPUDeviceSubStr)) {
result = false;
reason = "Op has been assigned a runtime device that is not CPU.";
}
// If user has specifically assigned this op to a non-CPU device, then No.
if (!n->def().device().empty() &&
- !str_util::StrContains(n->def().device(),kCPUDeviceSubStr)) {
+ !str_util::StrContains(n->def().device(), kCPUDeviceSubStr)) {
result = false;
reason = "User has assigned a device that is not CPU.";
}
@@ -975,7 +978,9 @@ std::vector<MklLayoutRewritePass::ContextInfo*> MklLayoutRewritePass::cinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -1042,6 +1047,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -1335,6 +1341,7 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -2408,6 +2415,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.addn = "AddN";
csinfo_.avg_pool = "AvgPool";
csinfo_.avg_pool_grad = "AvgPoolGrad";
+ csinfo_.avg_pool3d = "AvgPool3D";
+ csinfo_.avg_pool3d_grad = "AvgPool3DGrad";
csinfo_.bias_add = "BiasAdd";
csinfo_.bias_add_grad = "BiasAddGrad";
csinfo_.concat = "Concat";
@@ -2418,6 +2427,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.conv2d_grad_filter = "Conv2DBackpropFilter";
csinfo_.conv2d_grad_filter_with_bias =
"__MklDummyConv2DBackpropFilterWithBias";
+ csinfo_.conv3d = "Conv3D";
+ csinfo_.conv3d_grad_input = "Conv3DBackpropInputV2";
+ csinfo_.conv3d_grad_filter = "Conv3DBackpropFilterV2";
csinfo_.fused_batch_norm = "FusedBatchNorm";
csinfo_.fused_batch_norm_grad = "FusedBatchNormGrad";
csinfo_.identity = "Identity";
@@ -2426,6 +2438,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.matmul = "MatMul";
csinfo_.max_pool = "MaxPool";
csinfo_.max_pool_grad = "MaxPoolGrad";
+ csinfo_.max_pool3d = "MaxPool3D";
+ csinfo_.max_pool3d_grad = "MaxPool3DGrad";
csinfo_.mkl_conv2d = "_MklConv2D";
csinfo_.mkl_conv2d_grad_input = "_MklConv2DBackpropInput";
csinfo_.mkl_conv2d_grad_filter = "_MklConv2DBackpropFilter";
@@ -2437,6 +2451,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
csinfo_.tanh = "Tanh";
csinfo_.tanh_grad = "TanhGrad";
csinfo_.reshape = "Reshape";
+ csinfo_.slice = "Slice";
csinfo_.softmax = "Softmax";
csinfo_.split = "Split";
// Element-wise ops. Ensure you also add any new ops to IsOpElementWise
@@ -2460,6 +2475,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.avg_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.avg_pool_grad),
CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d),
+ CopyAttrsPooling, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.avg_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.avg_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.concat,
mkl_op_registry::GetMklOpName(csinfo_.concat),
CopyAttrsConcat, AlwaysRewrite});
@@ -2468,18 +2489,27 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
CopyAttrsConcatV2, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d,
mkl_op_registry::GetMklOpName(csinfo_.conv2d),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_with_bias, csinfo_.mkl_conv2d_with_bias,
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_filter),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_filter_with_bias,
- csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv2D,
+ csinfo_.mkl_conv2d_grad_filter_with_bias, CopyAttrsConv,
AlwaysRewrite});
rinfo_.push_back({csinfo_.conv2d_grad_input,
mkl_op_registry::GetMklOpName(csinfo_.conv2d_grad_input),
- CopyAttrsConv2D, AlwaysRewrite});
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_filter,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_filter),
+ CopyAttrsConv, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.conv3d_grad_input,
+ mkl_op_registry::GetMklOpName(csinfo_.conv3d_grad_input),
+ CopyAttrsConv, AlwaysRewrite});
rinfo_.push_back({csinfo_.fused_batch_norm,
mkl_op_registry::GetMklOpName(csinfo_.fused_batch_norm),
CopyAttrsFusedBatchNorm, AlwaysRewrite});
@@ -2501,7 +2531,12 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.max_pool_grad,
mkl_op_registry::GetMklOpName(csinfo_.max_pool_grad),
CopyAttrsPooling, MaxpoolGradRewrite});
-
+ rinfo_.push_back({csinfo_.max_pool3d,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d),
+ CopyAttrsPooling, NonDepthBatchWisePoolRewrite});
+ rinfo_.push_back({csinfo_.max_pool3d_grad,
+ mkl_op_registry::GetMklOpName(csinfo_.max_pool3d_grad),
+ CopyAttrsPooling, AlwaysRewrite});
rinfo_.push_back({csinfo_.maximum,
mkl_op_registry::GetMklOpName(csinfo_.maximum),
CopyAttrsDataType, AlwaysRewrite});
@@ -2524,6 +2559,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
rinfo_.push_back({csinfo_.reshape,
mkl_op_registry::GetMklOpName(csinfo_.reshape),
CopyAttrsReshape, AlwaysRewrite});
+ rinfo_.push_back({csinfo_.slice,
+ mkl_op_registry::GetMklOpName(csinfo_.slice),
+ CopyAttrsSlice, AlwaysRewrite});
rinfo_.push_back({csinfo_.softmax,
mkl_op_registry::GetMklOpName(csinfo_.softmax),
CopyAttrsDataType, AlwaysRewrite});
@@ -2538,6 +2576,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
// Add info about which ops to add workspace edge to and the slots.
wsinfo_.push_back({csinfo_.lrn, csinfo_.lrn_grad, 0, 2, 1, 3});
wsinfo_.push_back({csinfo_.max_pool, csinfo_.max_pool_grad, 0, 1, 1, 3});
+ wsinfo_.push_back
+ ({csinfo_.max_pool3d, csinfo_.max_pool3d_grad, 0, 1, 1, 3});
// Add a rule for merging nodes
minfo_.push_back({csinfo_.conv2d, csinfo_.bias_add,
@@ -2605,6 +2645,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string add;
string avg_pool;
string avg_pool_grad;
+ string avg_pool3d;
+ string avg_pool3d_grad;
string bias_add;
string bias_add_grad;
string concat;
@@ -2614,6 +2656,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string conv2d_grad_input;
string conv2d_grad_filter;
string conv2d_grad_filter_with_bias;
+ string conv3d;
+ string conv3d_grad_input;
+ string conv3d_grad_filter;
string fused_batch_norm;
string fused_batch_norm_grad;
string identity;
@@ -2622,6 +2667,8 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string matmul;
string max_pool;
string max_pool_grad;
+ string max_pool3d;
+ string max_pool3d_grad;
string maximum;
string mkl_conv2d;
string mkl_conv2d_grad_input;
@@ -2634,6 +2681,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
string tanh;
string tanh_grad;
string reshape;
+ string slice;
string softmax;
string split;
string squared_difference;
@@ -3086,12 +3134,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
static void CopyAttrsBiasAddGrad(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcat(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsConcatV2(const Node* orig_node, NodeBuilder* nb);
- static void CopyAttrsConv2D(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsConv(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsDataType(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsFusedBatchNorm(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsLRN(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsPooling(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsReshape(const Node* orig_node, NodeBuilder* nb);
+ static void CopyAttrsSlice(const Node* orig_node, NodeBuilder* nb);
static void CopyAttrsSplit(const Node* orig_node, NodeBuilder* nb);
// Generate a graph node in graph 'g' representing a dummy Mkl tensor node,
@@ -3110,7 +3159,9 @@ MklLayoutRewritePass::ConstStringsInfo MklLayoutRewritePass::csinfo_;
// nodes. Do not change the ordering of the Mkl passes.
const OptimizationPassRegistry::Grouping kMklLayoutRewritePassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklLayoutRewritePassGroup, 1, MklLayoutRewritePass);
+#endif // ENABLE_MKL
//////////////////////////////////////////////////////////////////////////
// Helper functions for creating new node
@@ -3177,6 +3228,7 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
// device of the original
// node.
.Finalize(&**g, out));
+ CHECK_NOTNULL(*out); // Make sure we got a valid object before using it
// If number of inputs to the original node is > 0, then we add
// control dependency between 1st input (index 0) of the original node and
@@ -3571,14 +3623,13 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(
// Op-specific functions to copy attributes from old node to new node
//////////////////////////////////////////////////////////////////////////
-void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
- NodeBuilder* nb) {
+void MklLayoutRewritePass::CopyAttrsConv(const Node* orig_node,
+ NodeBuilder* nb) {
DataType T;
string data_format;
string padding;
std::vector<int32> strides;
std::vector<int32> dilations;
- bool use_cudnn_on_gpu;
// Get all attributes from old node.
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
@@ -3586,8 +3637,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "dilations", &dilations));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "padding", &padding));
TF_CHECK_OK(GetNodeAttr(orig_node->def(), "data_format", &data_format));
- TF_CHECK_OK(
- GetNodeAttr(orig_node->def(), "use_cudnn_on_gpu", &use_cudnn_on_gpu));
// Add attributes to new node.
nb->Attr("T", T);
@@ -3595,7 +3644,6 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orig_node,
nb->Attr("dilations", dilations);
nb->Attr("padding", padding);
nb->Attr("data_format", data_format);
- nb->Attr("use_cudnn_on_gpu", use_cudnn_on_gpu);
}
void MklLayoutRewritePass::CopyAttrsAddN(const Node* orig_node,
@@ -3698,6 +3746,19 @@ void MklLayoutRewritePass::CopyAttrsReshape(const Node* orig_node,
nb->Attr("Tshape", Tshape);
}
+void MklLayoutRewritePass::CopyAttrsSlice(const Node* orig_node,
+ NodeBuilder* nb) {
+ DataType T;
+ DataType Index;
+
+ // Get all attributes from old node.
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "T", &T));
+ TF_CHECK_OK(GetNodeAttr(orig_node->def(), "Index", &Index));
+ // Add attributes to new node.
+ nb->Attr("T", T);
+ nb->Attr("Index", Index);
+}
+
void MklLayoutRewritePass::CopyAttrsSplit(const Node* orig_node,
NodeBuilder* nb) {
DataType T;
@@ -3896,7 +3957,7 @@ Status MklLayoutRewritePass::MergeConv2DWithBiasAdd(std::unique_ptr<Graph>* g,
nb.Input(succ_in[1].first, succ_in[1].second); // In2 of BiasAdd
// Copy attributes from Conv2D to Conv2DWithBias.
- CopyAttrsConv2D(const_cast<const Node*>(pred), &nb);
+ CopyAttrsConv(const_cast<const Node*>(pred), &nb);
// Copy the device assigned to old node to new node.
nb.Device(succ->def().device());
@@ -4007,7 +4068,7 @@ Status MklLayoutRewritePass::MergeConv2DBackpropFilterWithBiasAddGrad(
}
// Copy attributes from Conv2DBackpropFilter.
- CopyAttrsConv2D(const_cast<const Node*>(fltr), &nb);
+ CopyAttrsConv(const_cast<const Node*>(fltr), &nb);
// Copy the device assigned to old node to new node.
nb.Device(fltr->def().device());
@@ -4451,6 +4512,10 @@ Status MklLayoutRewritePass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
+ if (DisableMKL()) {
+ VLOG(2) << "TF-MKL: Disabling MKL";
+ return Status::OK();
+ }
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of a graph
diff --git a/tensorflow/core/graph/mkl_layout_pass.h b/tensorflow/core/graph/mkl_layout_pass.h
index ffe5c1ecfc..e7175149df 100644
--- a/tensorflow/core/graph/mkl_layout_pass.h
+++ b/tensorflow/core/graph/mkl_layout_pass.h
@@ -15,8 +15,8 @@ limitations under the License.
// A graph pass that rewrites graph for propagating MKL layout as a tensor
-#ifndef TENSORFLOW_GRAPH_MKL_LAYOUT_PASS_H_
-#define TENSORFLOW_GRAPH_MKL_LAYOUT_PASS_H_
+#ifndef TENSORFLOW_CORE_GRAPH_MKL_LAYOUT_PASS_H_
+#define TENSORFLOW_CORE_GRAPH_MKL_LAYOUT_PASS_H_
#ifdef INTEL_MKL
@@ -33,4 +33,4 @@ extern bool RunMklLayoutRewritePass(std::unique_ptr<Graph>* g);
#endif
-#endif // TENSORFLOW_GRAPH_MKL_LAYOUT_PASS_H_
+#endif // TENSORFLOW_CORE_GRAPH_MKL_LAYOUT_PASS_H_
diff --git a/tensorflow/core/graph/mkl_layout_pass_test.cc b/tensorflow/core/graph/mkl_layout_pass_test.cc
index e8bac847e5..77640e287c 100644
--- a/tensorflow/core/graph/mkl_layout_pass_test.cc
+++ b/tensorflow/core/graph/mkl_layout_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -3510,6 +3510,26 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_DeviceTest) {
"B->C:1;C->E;D->E:1;E->Z;M->C:2;N->C:3;Y->Z:1");
}
+TEST_F(MklLayoutPassTest, NodeRewrite_Slice_DeviceTest) {
+ InitGraph(
+ "node { name: 'A' op: 'Input'}"
+ "node { name: 'B' op: 'Int32Input'}"
+ "node { name: 'C' op: 'Int32Input'}"
+ "node { name: 'D' op: 'Slice'"
+ " attr { key: 'T' value { type: DT_FLOAT } }"
+ " attr { key: 'Index' value { type: DT_INT32 } }"
+ " input: ['A', 'B', 'C'] }"
+ "node { name: 'E' op: 'Zeta' attr { key: 'T' value { type: DT_FLOAT } }"
+ " input: ['A', 'D'] }");
+ EXPECT_EQ(DoMklLayoutOptimizationPass(),
+ "A(Input);B(Int32Input);C(Int32Input);"
+ "D(_MklSlice);DMT/_0(Const);DMT/_1(Const);DMT/"
+ "_2(Const);E(Zeta)|A->D;A->E;"
+ "A:control->DMT/_0:control;A:control->DMT/"
+ "_1:control;A:control->DMT/_2:control;"
+ "B->D:1;C->D:2;D->E:1;DMT/_0->D:3;DMT/_1->D:4;DMT/_2->D:5");
+}
+
/////////////////////////////////////////////////////////////////////
// Post-rewrite fixup pass test
@@ -3586,4 +3606,4 @@ BENCHMARK(BM_MklLayoutRewritePass)->Arg(1000)->Arg(10000);
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass.cc b/tensorflow/core/graph/mkl_tfconversion_pass.cc
index aa39af637f..6804ab84ce 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/util/util.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
@@ -133,7 +134,9 @@ class MklToTfConversionPass : public GraphOptimizationPass {
// complete picture of inputs and outputs of the nodes in the graphs.
const OptimizationPassRegistry::Grouping kMklTfConvPassGroup =
OptimizationPassRegistry::POST_PARTITIONING;
+#ifdef ENABLE_MKL
REGISTER_OPTIMIZATION(kMklTfConvPassGroup, 2, MklToTfConversionPass);
+#endif // ENABLE_MKL
Status MklToTfConversionPass::InsertConversionNodeOnEdge(
std::unique_ptr<Graph>* g, Edge* e) {
@@ -175,7 +178,11 @@ Status MklToTfConversionPass::InsertConversionNodeOnEdge(
.Finalize(&**g, &conversion_node));
CHECK_NOTNULL(conversion_node);
- if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK()) {
+ // TODO(Intel-tf) MklToTf accepts only NHWC or NCHW, but doesn't seem to be
+ // using data_format. This code might be redundant.
+ if (GetNodeAttr(src->def(), "data_format", &data_format) == Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
@@ -254,9 +261,13 @@ Status MklToTfConversionPass::InsertInputConversionNode(
}
}
+ // TODO(Intel-tf) MklInputConversion accepts only NHWC or NCHW, but doesn't
+ // seem to be using data_format. This code might be redundant.
string data_format;
if (GetNodeAttr(edges[0]->src()->def(), "data_format", &data_format) ==
- Status::OK()) {
+ Status::OK() &&
+ (data_format == ToString(FORMAT_NHWC) ||
+ data_format == ToString(FORMAT_NCHW))) {
conversion_node->AddAttr("data_format", data_format);
}
@@ -414,6 +425,10 @@ Status MklToTfConversionPass::Run(const GraphOptimizationPassOptions& options) {
if (options.graph == nullptr && options.partition_graphs == nullptr) {
return Status::OK();
}
+ if (DisableMKL()) {
+ VLOG(2) << "TF-MKL: Disabling MKL";
+ return Status::OK();
+ }
auto process_graph = [&](std::unique_ptr<Graph>* g) {
// Get the ownership of graph
diff --git a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
index ebcb6de551..319437a801 100644
--- a/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
+++ b/tensorflow/core/graph/mkl_tfconversion_pass_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
@@ -304,4 +304,4 @@ BENCHMARK(BM_RunMklToTfConversionPass)->Arg(1000)->Arg(10000);
} // namespace
} // namespace tensorflow
-#endif /* INTEL_MKL */
+#endif // INTEL_MKL && ENABLE_MKL
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index 03f3bbd663..d92874909f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -30,7 +30,7 @@ NodeBuilder::NodeOut::NodeOut(Node* n, int32 i) // NOLINT(runtime/explicit)
dt(SafeGetOutput(node, i, &error)) {}
NodeBuilder::NodeOut::NodeOut(StringPiece n, int32 i, DataType t)
- : node(nullptr), error(false), name(std::string(n)), index(i), dt(t) {}
+ : node(nullptr), error(false), name(n), index(i), dt(t) {}
NodeBuilder::NodeOut::NodeOut()
: node(nullptr), error(true), index(0), dt(DT_FLOAT) {}
@@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
return *this;
}
+NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
+ assigned_device_ = string(device);
+ return *this;
+}
+
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;
@@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
Node* node = graph->AddNode(node_def, &status);
if (!status.ok()) return status;
+ node->set_assigned_device_name(assigned_device_);
+
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs_[i].node != nullptr) { // Skip back edges.
graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index f6b7b5674b..d576985a23 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_NODE_BUILDER_H_
-#define TENSORFLOW_GRAPH_NODE_BUILDER_H_
+#ifndef TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
+#define TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
#include <vector>
#include "tensorflow/core/framework/node_def_builder.h"
@@ -100,6 +100,9 @@ class NodeBuilder {
// "assigned device" in the Node).
NodeBuilder& Device(StringPiece device_spec);
+ // Sets the device name in the "assigned device" field in tensorflow::Node.
+ NodeBuilder& AssignedDevice(StringPiece device);
+
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
@@ -141,6 +144,7 @@ class NodeBuilder {
std::vector<NodeOut> inputs_;
std::vector<Node*> control_inputs_;
std::vector<string> errors_;
+ string assigned_device_;
};
// IMPLEMENTATION -------------------------------------------------------------
@@ -160,4 +164,4 @@ NodeBuilder& NodeBuilder::Attr(StringPiece attr_name,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_NODE_BUILDER_H_
+#endif // TENSORFLOW_CORE_GRAPH_NODE_BUILDER_H_
diff --git a/tensorflow/core/graph/optimizer_cse.h b/tensorflow/core/graph/optimizer_cse.h
index b8f3230c70..ef466fb788 100644
--- a/tensorflow/core/graph/optimizer_cse.h
+++ b/tensorflow/core/graph/optimizer_cse.h
@@ -15,8 +15,8 @@ limitations under the License.
// An optimization pass that performs common subexpression elimination.
-#ifndef TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
-#define TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+#ifndef TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_
+#define TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_
#include <sys/types.h>
#include "tensorflow/core/graph/graph.h"
@@ -34,4 +34,4 @@ extern bool OptimizeCSE(Graph* g,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_OPTIMIZER_CSE_H_
+#endif // TENSORFLOW_CORE_GRAPH_OPTIMIZER_CSE_H_
diff --git a/tensorflow/core/graph/quantize_training.h b/tensorflow/core/graph/quantize_training.h
index 2bb4ee1cf0..dc3d7e3b1f 100644
--- a/tensorflow/core/graph/quantize_training.h
+++ b/tensorflow/core/graph/quantize_training.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
-#define TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
+#ifndef TENSORFLOW_CORE_GRAPH_QUANTIZE_TRAINING_H_
+#define TENSORFLOW_CORE_GRAPH_QUANTIZE_TRAINING_H_
#include "tensorflow/core/graph/graph.h"
@@ -53,4 +53,4 @@ Status DoQuantizeTrainingOnGraphDef(const GraphDef& input_graphdef,
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_QUANTIZE_TRAINING_H_
+#endif // TENSORFLOW_CORE_GRAPH_QUANTIZE_TRAINING_H_
diff --git a/tensorflow/core/graph/subgraph.h b/tensorflow/core/graph/subgraph.h
index ba35846d93..3e99ff0c8c 100644
--- a/tensorflow/core/graph/subgraph.h
+++ b/tensorflow/core/graph/subgraph.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_SUBGRAPH_H_
-#define TENSORFLOW_GRAPH_SUBGRAPH_H_
+#ifndef TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
+#define TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
#include <string>
@@ -162,4 +162,4 @@ class SendFetchRewrite : public PruneRewrite {
} // namespace subgraph
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_SUBGRAPH_H_
+#endif // TENSORFLOW_CORE_GRAPH_SUBGRAPH_H_
diff --git a/tensorflow/core/graph/tensor_id.cc b/tensorflow/core/graph/tensor_id.cc
index 80c76df255..5a5b85e727 100644
--- a/tensorflow/core/graph/tensor_id.cc
+++ b/tensorflow/core/graph/tensor_id.cc
@@ -25,7 +25,7 @@ namespace tensorflow {
TensorId::TensorId(const SafeTensorId& id) : TensorId(id.first, id.second) {}
SafeTensorId::SafeTensorId(const TensorId& id)
- : SafeTensorId(id.first.ToString(), id.second) {}
+ : SafeTensorId(string(id.first), id.second) {}
TensorId ParseTensorName(const string& name) {
return ParseTensorName(StringPiece(name.data(), name.size()));
diff --git a/tensorflow/core/graph/testlib.cc b/tensorflow/core/graph/testlib.cc
index 67b252cb6c..0a38aa1c91 100644
--- a/tensorflow/core/graph/testlib.cc
+++ b/tensorflow/core/graph/testlib.cc
@@ -21,39 +21,14 @@ limitations under the License.
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
-#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/kernels/constant_op.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
-
-// HostConst: forced to generate output on the host.
-// Only used by testlib; no op is registered for this kernel
-// externally (i.e., in array_ops.cc)
-REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), HostConstantOp);
-REGISTER_KERNEL_BUILDER(
- Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), HostConstantOp);
-#ifdef TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(
- Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"), HostConstantOp);
-#endif // TENSORFLOW_USE_SYCL
-
-// Register the HostConst Op
-// Returns a constant tensor on the host. Useful for writing C++ tests
-// and benchmarks which run on GPU but require arguments pinned to the host.
-// Used by test::graph::HostConstant.
-// value: Attr `value` is the tensor to return.
-REGISTER_OP("HostConst")
- .Output("output: dtype")
- .Attr("value: tensor")
- .Attr("dtype: type")
- .SetShapeFn(shape_inference::UnknownShape);
-
namespace test {
namespace graph {
@@ -510,6 +485,33 @@ Node* DiagPart(Graph* g, Node* in, DataType type) {
return ret;
}
+Node* CheckNumerics(Graph* g, Node* in, const string& message) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "CheckNumerics")
+ .Input(in)
+ .Attr("message", message)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Arg(Graph* g, int64 index, DataType type) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Arg")
+ .Attr("T", type)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
+Node* Retval(Graph* g, int64 index, Node* in) {
+ Node* ret;
+ TF_CHECK_OK(NodeBuilder(g->NewName("n"), "_Retval")
+ .Input(in)
+ .Attr("index", index)
+ .Finalize(g, &ret));
+ return ret;
+}
+
void ToGraphDef(Graph* g, GraphDef* gdef) { g->ToGraphDef(gdef); }
} // end namespace graph
diff --git a/tensorflow/core/graph/testlib.h b/tensorflow/core/graph/testlib.h
index eb9038d619..b00196f587 100644
--- a/tensorflow/core/graph/testlib.h
+++ b/tensorflow/core/graph/testlib.h
@@ -15,8 +15,8 @@ limitations under the License.
// DEPRECATED: Use the C++ API defined in tensorflow/cc instead.
-#ifndef TENSORFLOW_GRAPH_TESTLIB_H_
-#define TENSORFLOW_GRAPH_TESTLIB_H_
+#ifndef TENSORFLOW_CORE_GRAPH_TESTLIB_H_
+#define TENSORFLOW_CORE_GRAPH_TESTLIB_H_
#include <string>
#include <vector>
@@ -32,7 +32,7 @@ namespace test {
namespace graph {
// Converts "g" into its corresponding GraphDef "def".
-// DEPRECATED: call g->ToGraphDef(def) instead.
+ABSL_DEPRECATED("Call g->ToGraphDef(def) instead.")
void ToGraphDef(Graph* g, GraphDef* def);
// A few helpers to construct a graph.
@@ -209,8 +209,17 @@ Node* Diag(Graph* g, Node* in, DataType type);
// Add a DiagPart node in "g".
Node* DiagPart(Graph* g, Node* in, DataType type);
+// Add a CheckNumerics node in "g".
+Node* CheckNumerics(Graph* g, Node* in, const string& message);
+
+// Add an _Arg node in "g".
+Node* Arg(Graph* g, int64 index, DataType type);
+
+// Add a _Retval node in "g".
+Node* Retval(Graph* g, int64 index, Node* in);
+
} // end namespace graph
} // end namespace test
} // end namespace tensorflow
-#endif // TENSORFLOW_GRAPH_TESTLIB_H_
+#endif // TENSORFLOW_CORE_GRAPH_TESTLIB_H_
diff --git a/tensorflow/core/graph/types.h b/tensorflow/core/graph/types.h
index c707809927..ac5a7f8229 100644
--- a/tensorflow/core/graph/types.h
+++ b/tensorflow/core/graph/types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_TYPES_H_
-#define TENSORFLOW_GRAPH_TYPES_H_
+#ifndef TENSORFLOW_CORE_GRAPH_TYPES_H_
+#define TENSORFLOW_CORE_GRAPH_TYPES_H_
#include "tensorflow/core/lib/gtl/int_type.h"
#include "tensorflow/core/platform/types.h"
@@ -32,4 +32,4 @@ TF_LIB_GTL_DEFINE_INT_TYPE(Bytes, int64);
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_TYPES_H_
+#endif // TENSORFLOW_CORE_GRAPH_TYPES_H_
diff --git a/tensorflow/core/graph/while_context.cc b/tensorflow/core/graph/while_context.cc
index 1b38aac35d..8e89bc4c75 100644
--- a/tensorflow/core/graph/while_context.cc
+++ b/tensorflow/core/graph/while_context.cc
@@ -23,7 +23,7 @@ WhileContext::WhileContext(StringPiece frame_name,
OutputTensor cond_output,
std::vector<OutputTensor> body_inputs,
std::vector<OutputTensor> body_outputs)
- : frame_name_(std::string(frame_name)),
+ : frame_name_(frame_name),
enter_nodes_(std::move(enter_nodes)),
exit_nodes_(std::move(exit_nodes)),
cond_output_(cond_output),
diff --git a/tensorflow/core/graph/while_context.h b/tensorflow/core/graph/while_context.h
index 2a83eb7bd8..5405e62be2 100644
--- a/tensorflow/core/graph/while_context.h
+++ b/tensorflow/core/graph/while_context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
-#define TENSORFLOW_GRAPH_WHILE_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
+#define TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
#include "tensorflow/core/graph/graph.h"
@@ -73,4 +73,4 @@ class WhileContext {
} // namespace tensorflow
-#endif // TENSORFLOW_GRAPH_GRAPH_H_
+#endif // TENSORFLOW_CORE_GRAPH_WHILE_CONTEXT_H_
diff --git a/tensorflow/core/grappler/clusters/cluster.cc b/tensorflow/core/grappler/clusters/cluster.cc
index 6ca379323e..3b1d7d8347 100644
--- a/tensorflow/core/grappler/clusters/cluster.cc
+++ b/tensorflow/core/grappler/clusters/cluster.cc
@@ -81,6 +81,9 @@ void Cluster::DisableOptimizer(bool disable) {
rewriter_config->set_dependency_optimization(RewriterConfig::OFF);
rewriter_config->set_constant_folding(RewriterConfig::OFF);
rewriter_config->set_memory_optimization(RewriterConfig::NO_MEM_OPT);
+ rewriter_config->set_shape_optimization(RewriterConfig::OFF);
+ rewriter_config->set_remapping(RewriterConfig::OFF);
+ rewriter_config->set_pin_to_host_optimization(RewriterConfig::OFF);
rewriter_config->mutable_auto_parallel()->set_enable(false);
rewriter_config->clear_optimizers();
} else {
diff --git a/tensorflow/core/grappler/clusters/single_machine.cc b/tensorflow/core/grappler/clusters/single_machine.cc
index b97603c890..e4f6bf7c86 100644
--- a/tensorflow/core/grappler/clusters/single_machine.cc
+++ b/tensorflow/core/grappler/clusters/single_machine.cc
@@ -93,13 +93,13 @@ Status SingleMachine::Provision() {
strings::StrCat("Not able to parse GPU device name: ", dev.name()));
}
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
return errors::Unavailable("Unknown TF GPU device with id ",
tf_gpu_id.value(), ": ", s.ToString());
}
- attr = GetLocalGPUInfo(cuda_gpu_id);
+ attr = GetLocalGPUInfo(platform_gpu_id);
} else if (dev.device_type().find("XLA") == string::npos) {
// Filter out the fake XLA devices to avoid double counting the actual
// hardware resources that are available.
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index a7519725a5..567e7c075e 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -70,13 +70,14 @@ DeviceProperties GetLocalCPUInfo() {
return device;
}
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id) {
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id) {
DeviceProperties device;
device.set_type("GPU");
#if GOOGLE_CUDA
cudaDeviceProp properties;
- cudaError_t error = cudaGetDeviceProperties(&properties, cuda_gpu_id.value());
+ cudaError_t error =
+ cudaGetDeviceProperties(&properties, platform_gpu_id.value());
if (error != cudaSuccess) {
device.set_type("UNKNOWN");
LOG(ERROR) << "Failed to get device properties, error code: " << error;
@@ -122,15 +123,15 @@ DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device) {
} else if (device.type == "GPU") {
if (device.has_id) {
TfGpuId tf_gpu_id(device.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
LOG(ERROR) << s;
return unknown;
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else {
- return GetLocalGPUInfo(CudaGpuId(0));
+ return GetLocalGPUInfo(PlatformGpuId(0));
}
}
return unknown;
diff --git a/tensorflow/core/grappler/clusters/utils.h b/tensorflow/core/grappler/clusters/utils.h
index ca15c48006..f0a342b728 100644
--- a/tensorflow/core/grappler/clusters/utils.h
+++ b/tensorflow/core/grappler/clusters/utils.h
@@ -28,7 +28,7 @@ DeviceProperties GetLocalCPUInfo();
// Returns the DeviceProperties for the specified GPU attached to the server on
// which grappler is running.
-DeviceProperties GetLocalGPUInfo(CudaGpuId cuda_gpu_id);
+DeviceProperties GetLocalGPUInfo(PlatformGpuId platform_gpu_id);
// Returns the DeviceProperties of the specified device
DeviceProperties GetDeviceInfo(const DeviceNameUtils::ParsedName& device);
diff --git a/tensorflow/core/grappler/clusters/utils_test.cc b/tensorflow/core/grappler/clusters/utils_test.cc
index 74218adbac..3863d62980 100644
--- a/tensorflow/core/grappler/clusters/utils_test.cc
+++ b/tensorflow/core/grappler/clusters/utils_test.cc
@@ -31,22 +31,22 @@ TEST(UtilsTest, GetLocalGPUInfo) {
LOG(INFO) << "CUDA is enabled.";
DeviceProperties properties;
- // Invalid CUDA GPU ID.
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ // Invalid platform GPU ID.
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("UNKNOWN", properties.type());
- // Succeed when a valid CUDA GPU id was inserted.
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ // Succeed when a valid platform GPU id was inserted.
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
EXPECT_EQ("NVIDIA", properties.vendor());
#else
LOG(INFO) << "CUDA is not enabled.";
DeviceProperties properties;
- properties = GetLocalGPUInfo(CudaGpuId(0));
+ properties = GetLocalGPUInfo(PlatformGpuId(0));
EXPECT_EQ("GPU", properties.type());
- properties = GetLocalGPUInfo(CudaGpuId(100));
+ properties = GetLocalGPUInfo(PlatformGpuId(100));
EXPECT_EQ("GPU", properties.type());
#endif
}
@@ -74,20 +74,20 @@ TEST(UtilsTest, GetDeviceInfo) {
EXPECT_EQ("NVIDIA", properties.vendor());
#endif
- // TF to CUDA GPU id mapping entry doesn't exist.
+ // TF to platform GPU id mapping entry doesn't exist.
device.has_id = true;
device.id = 0;
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
#if GOOGLE_CUDA
- // Invalid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(0), CudaGpuId(100));
+ // Invalid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(0), PlatformGpuId(100));
properties = GetDeviceInfo(device);
EXPECT_EQ("UNKNOWN", properties.type());
- // Valid CUDA GPU id.
- GpuIdManager::InsertTfCudaGpuIdPair(TfGpuId(1), CudaGpuId(0));
+ // Valid platform GPU id.
+ GpuIdManager::InsertTfPlatformGpuIdPair(TfGpuId(1), PlatformGpuId(0));
device.id = 1;
properties = GetDeviceInfo(device);
EXPECT_EQ("GPU", properties.type());
diff --git a/tensorflow/core/grappler/clusters/virtual_cluster.cc b/tensorflow/core/grappler/clusters/virtual_cluster.cc
index 12e3e46f65..f543dca49e 100644
--- a/tensorflow/core/grappler/clusters/virtual_cluster.cc
+++ b/tensorflow/core/grappler/clusters/virtual_cluster.cc
@@ -45,6 +45,8 @@ VirtualCluster::VirtualCluster(const DeviceSet* device_set)
for (const auto& device : device_set_->devices()) {
DeviceProperties props = GetDeviceInfo(device->parsed_name());
if (props.type() == "UNKNOWN") continue;
+ auto attrs = device->attributes();
+ props.set_memory_size(attrs.memory_limit());
devices_[device->name()] = props;
}
}
diff --git a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
index a60e3c7a9f..0690640ffa 100644
--- a/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/analytical_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <limits>
#include <unordered_map>
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
diff --git a/tensorflow/core/grappler/costs/graph_memory.cc b/tensorflow/core/grappler/costs/graph_memory.cc
index a5736d40b1..b01aca610a 100644
--- a/tensorflow/core/grappler/costs/graph_memory.cc
+++ b/tensorflow/core/grappler/costs/graph_memory.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/step_stats.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index 231c7c63be..56c8339d57 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -259,13 +260,13 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
}
bool IsEnqueue(const NodeDef& n) {
- return (n.op().find("Enqueue") != std::string::npos &&
- n.op().find("EnqueueMany") == std::string::npos);
+ return (n.op().find("Enqueue") != string::npos &&
+ n.op().find("EnqueueMany") == string::npos);
}
bool IsDequeue(const NodeDef& n) {
- return (n.op().find("Dequeue") != std::string::npos &&
- n.op().find("DequeueMany") == std::string::npos);
+ return (n.op().find("Dequeue") != string::npos &&
+ n.op().find("DequeueMany") == string::npos);
}
bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
@@ -344,6 +345,56 @@ void VerboseLogUnknownDimensionSources(
}
}
+bool IsShapeFullyDefinedIntegerVectorOrScalar(
+ InferenceContext* ic, const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape, const DataType& dtype) {
+ if (!ic->FullyDefined(shape) || ic->Rank(shape) > 1 ||
+ !ic->FullyDefined(tensor_as_shape) ||
+ (dtype != DT_INT32 && dtype != DT_INT64)) {
+ return false;
+ }
+ return true;
+}
+
+// Returned tensor's shape is like `shape`, and its values and dtype are from
+// `tensor_as_shape` and `dtype`.
+TensorProto MakeTensorProtoFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ TensorProto tensor_proto;
+ tensor_proto.set_dtype(dtype);
+ auto* shape_proto = tensor_proto.mutable_tensor_shape();
+ if (ic->Rank(shape) == 1) {
+ shape_proto->add_dim()->set_size(ic->Rank(tensor_as_shape));
+ }
+ // For a scalar tensor, tensor_shape field will be left empty; no dim.
+ for (int i = 0; i < ic->Rank(tensor_as_shape); i++) {
+ int64 value = ic->Value(ic->Dim(tensor_as_shape, i));
+ if (dtype == DT_INT32) {
+ tensor_proto.add_int_val(value);
+ } else {
+ tensor_proto.add_int64_val(value);
+ }
+ }
+ return tensor_proto;
+}
+
+// Returns a Const NodeDef with shape = `shape`, values = `tensor_as_shape`,
+// and dtype = `dtype`.
+NodeDef MakeConstNodeDefFromShape(InferenceContext* ic,
+ const ShapeHandle& shape,
+ const ShapeHandle& tensor_as_shape,
+ const DataType& dtype) {
+ NodeDef const_node;
+ const_node.set_name("const_from_shape");
+ const_node.set_op("Const");
+ auto* attr = const_node.mutable_attr();
+ (*attr)["dtype"].set_type(dtype);
+ auto* tensor = (*attr)["value"].mutable_tensor();
+ *tensor = MakeTensorProtoFromShape(ic, shape, tensor_as_shape, dtype);
+ return const_node;
+}
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -428,18 +479,22 @@ class SymbolicShapeRefiner {
// perform shape inference on the function body.
//
// Propagate shape information of final function body node
- // to function node `node`.
+ // to function node `function_node`.
//
- // In the event of an error, UpdateNode will simply set `node`'s
+ // In the event of an error, UpdateNode will simply set `function_node`'s
// output shape to be Unknown.
- Status UpdateFunction(const NodeDef* node) {
- auto it = fun_to_grappler_function_item_.find(node->op());
+ Status UpdateFunction(const NodeDef* function_node) {
+ auto it = fun_to_grappler_function_item_.find(function_node->op());
if (it == fun_to_grappler_function_item_.end()) {
return errors::InvalidArgument(
- node->op(), " was not previously added to SymbolicShapeRefiner.");
+ function_node->op(),
+ " was not previously added to SymbolicShapeRefiner.");
}
- GrapplerFunctionItem& grappler_function_item = it->second;
+ // Copy (not reference) so that changes we make here (e.g., replacing
+ // Placeholder with Const) don't affect one in
+ // fun_to_grappler_function_item_.
+ GrapplerFunctionItem grappler_function_item = it->second;
GraphView gv(&grappler_function_item.graph);
// Forward shapes from function input nodes to argument nodes.
@@ -452,7 +507,7 @@ class SymbolicShapeRefiner {
"supported.");
}
NodeDef* fun_node = gv.GetNode(fun_input.input_name);
- const string& input = node->input(i);
+ const string& input = function_node->input(i);
const string& node_name = NodeName(input);
if (IsControlInput(input)) {
@@ -477,17 +532,48 @@ class SymbolicShapeRefiner {
TensorShapeProto proto;
const auto& handle = input_inference_context->output(output_port_num);
input_inference_context->ShapeHandleToProto(handle, &proto);
+ // There may be dim.size < -1 in SymbolicShapeRefiner. Change those to -1.
+ for (int i = 0; i < proto.dim_size(); i++) {
+ if (proto.dim(i).size() < -1) {
+ proto.mutable_dim(i)->set_size(-1);
+ }
+ }
*attr_output_shape.mutable_shape() = proto;
(*fun_node->mutable_attr())["shape"] = attr_output_shape;
}
+ // Replace input Placeholders with Consts, if values are known. Note that
+ // we don't check exceptions here as it's done in the above loop.
+ auto* ctx = GetNodeContext(function_node);
+ auto* ic = ctx->inference_context.get();
+ for (int i = grappler_function_item.inputs().size() - 1; i >= 0; --i) {
+ const string& input = function_node->input(i);
+ const string& node_name = NodeName(input);
+ NodeDef* input_node = graph_.GetNode(node_name);
+ if (IsConstant(*input_node)) {
+ TF_CHECK_OK(
+ ReplaceInputWithConst(*input_node, i, &grappler_function_item));
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ // We have fully defined input_tensors_as_shapes for this input; use it
+ // as a const input to the function node.
+ NodeDef const_input_node = MakeConstNodeDefFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
+ TF_CHECK_OK(ReplaceInputWithConst(const_input_node, i,
+ &grappler_function_item));
+ }
+ }
+
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true));
// Add return nodes for output shapes.
- auto ic = GetContext(node);
int output = 0;
+ ctx->output_tensors_as_shapes.resize(grappler_function_item.output_size());
for (auto const& out_arg : grappler_function_item.outputs()) {
if (out_arg.output_tensors.size() > 1) {
// TODO(jmdecker): Handle case of multiple output tensors
@@ -504,8 +590,9 @@ class SymbolicShapeRefiner {
const NodeDef* retnode = gv.GetNode(node_name);
if (retnode == nullptr) {
- return errors::FailedPrecondition("Unable to find return node ",
- node_name, " for ", node->name());
+ return errors::FailedPrecondition(
+ "Unable to find return function_node ", node_name, " for ",
+ function_node->name());
}
auto output_properties = gp.GetOutputProperties(retnode->name());
@@ -519,6 +606,14 @@ class SymbolicShapeRefiner {
ShapeHandle out;
TF_RETURN_IF_ERROR(ic->MakeShapeFromShapeProto(shape, &out));
ic->set_output(output, out);
+ if (outprop.has_value()) {
+ // Forward tensor value to output_tensors_as_shape.
+ Tensor tensor;
+ if (tensor.FromProto(outprop.value())) {
+ MaybeSetTensorValueToShape(ic, tensor,
+ &ctx->output_tensors_as_shapes[output]);
+ }
+ }
output++;
}
@@ -561,21 +656,9 @@ class SymbolicShapeRefiner {
if (const_values[dst_input].FromProto(
input->attr().at("value").tensor())) {
input_tensors[dst_input] = &const_values[dst_input];
- // Integer tensors of rank one can also be interpreted as a shape
- // provided all their values are >= -1.
- if (const_values[dst_input].dims() == 1 &&
- (const_values[dst_input].dtype() == DT_INT32 ||
- const_values[dst_input].dtype() == DT_INT64)) {
- ShapeHandle tensor_shape = inference_context->Vector(
- const_values[dst_input].NumElements());
- ShapeHandle shp;
- if (inference_context
- ->MakeShapeFromTensor(input_tensors[dst_input],
- tensor_shape, &shp)
- .ok()) {
- input_tensors_as_shapes[dst_input] = shp;
- }
- }
+ MaybeSetTensorValueToShape(inference_context,
+ const_values[dst_input],
+ &input_tensors_as_shapes[dst_input]);
}
} else if (IsRank(*input)) {
if (c->inference_context->RankKnown(c->inference_context->input(0))) {
@@ -670,11 +753,13 @@ class SymbolicShapeRefiner {
// true, as the updates to the call node will have changed, even if it's
// the same function being called twice with the same input shapes.
// Example: simple_function.pbtxt
- if (UpdateFunction(node).ok()) {
+ auto s = UpdateFunction(node);
+ if (s.ok()) {
return Status::OK();
} else {
VLOG(1) << "UpdateFunction failed for " << node->op()
- << ". Defaulting to ShapeUnknown.";
+ << ". Defaulting to ShapeUnknown.\n"
+ << s.ToString();
}
}
@@ -804,8 +889,9 @@ class SymbolicShapeRefiner {
CHECK_NOTNULL(function_library_.Find(function_node->op()));
GrapplerFunctionItem grappler_function_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
- *function_def, function_library_, &grappler_function_item));
+ TF_RETURN_IF_ERROR(
+ MakeGrapplerFunctionItem(*function_def, function_library_,
+ graph_def_version_, &grappler_function_item));
if (grappler_function_item.inputs().size() > function_node->input_size()) {
return errors::FailedPrecondition(
@@ -940,13 +1026,25 @@ class SymbolicShapeRefiner {
: t->scalar<int64>()();
dims.push_back(size < 0 ? ic->UnknownDim() : ic->MakeDim(size));
} else {
- dims.push_back(ic->UnknownDim());
+ // Don't have tensor value, but use input_tensors_as_shapes, if
+ // possible.
+ const ShapeHandle& shape_handle = ic->input_tensors_as_shapes()[i];
+ if (ic->RankKnown(shape_handle) && ic->Rank(shape_handle) >= 1 &&
+ ic->ValueKnown(ic->Dim(shape_handle, 0))) {
+ dims.push_back(ic->Dim(shape_handle, 0));
+ } else {
+ dims.push_back(ic->UnknownDim());
+ }
}
}
if (valid) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = ic->MakeShape(dims);
}
+ } else if (IsIdentity(node)) {
+ // Pass input_tensors_as_shapes to output_tensors_as_shapes.
+ c->output_tensors_as_shapes.resize(1);
+ c->output_tensors_as_shapes[0] = ic->input_tensors_as_shapes()[0];
} else if (IsSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
@@ -1051,6 +1149,46 @@ class SymbolicShapeRefiner {
}
private:
+ bool IsIntegerVector(const Tensor& tensor) {
+ if (tensor.dims() == 1 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64)) {
+ return true;
+ }
+ return false;
+ }
+
+ bool IsIntegerScalar(const Tensor& tensor) {
+ if (tensor.dims() == 0 &&
+ (tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64) &&
+ tensor.NumElements() == 1) {
+ return true;
+ }
+ return false;
+ }
+
+ void MaybeSetTensorValueToShape(InferenceContext* ic, const Tensor& tensor,
+ ShapeHandle* tensors_as_shapes) {
+ // Integer tensors of rank one can also be interpreted as a shape
+ // provided all their values are >= -1.
+ if (IsIntegerVector(tensor)) {
+ ShapeHandle tensor_shape = ic->Vector(tensor.NumElements());
+ ShapeHandle shp;
+ // Note that MakeShapeFromTensor filters out invalid values (e.g., < -1).
+ if (ic->MakeShapeFromTensor(&tensor, tensor_shape, &shp).ok()) {
+ *tensors_as_shapes = shp;
+ }
+ } else if (IsIntegerScalar(tensor)) {
+ // Scalar constant.
+ int64 value = tensor.dtype() == DT_INT32 ? tensor.flat<int32>()(0)
+ : tensor.flat<int64>()(0);
+ // Ideally, values can be < -1, but MakeDim() fails with a value < -1.
+ // It's a limitation as we use ShapeHandle as a means to pass values.
+ if (value >= -1) {
+ *tensors_as_shapes = ic->MakeShape({ic->MakeDim(value)});
+ }
+ }
+ }
+
const GraphView& graph_;
int graph_def_version_;
std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
@@ -1526,6 +1664,8 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
continue;
}
+ auto* ic = ctx->inference_context.get();
+
// Fill input properties.
{
auto& input_properties = input_properties_[node.name()];
@@ -1533,19 +1673,26 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(input_properties.size(), 0);
- input_properties.resize(ctx->inference_context->num_inputs());
+ input_properties.resize(ic->num_inputs());
GraphView::InputPort input(&node, -1);
- for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->input(i),
- ctx->input_types[i],
+ for (int i = 0; i < ic->num_inputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->input(i), ctx->input_types[i],
&input_properties[i]);
input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
- if (!IsConstant(*fanin.node)) {
- continue;
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to input_properties.value.
+ if (IsConstant(*fanin.node)) {
+ const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
+ *input_properties[i].mutable_value() = raw_val;
+ } else if (ic->input_tensors_as_shapes().size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i])) {
+ *input_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->input(i), ic->input_tensors_as_shapes()[i],
+ ctx->input_types[i]);
}
- const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
- *input_properties[i].mutable_value() = raw_val;
}
}
@@ -1556,11 +1703,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(output_properties.size(), 0);
- output_properties.resize(ctx->inference_context->num_outputs());
- for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->inference_context->output(i),
- ctx->output_types[i],
+ output_properties.resize(ic->num_outputs());
+ for (int i = 0; i < ic->num_outputs(); ++i) {
+ shape_manager.AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]);
+ // Export tensor value (either const tensor or input_tensors_as_shapes)
+ // to output_properties.value.
+ if (IsConstant(node)) {
+ const TensorProto& raw_val = node.attr().at("value").tensor();
+ *output_properties[i].mutable_value() = raw_val;
+ } else if (ctx->output_tensors_as_shapes.size() > i &&
+ IsShapeFullyDefinedIntegerVectorOrScalar(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i])) {
+ *output_properties[i].mutable_value() = MakeTensorProtoFromShape(
+ ic, ic->output(i), ctx->output_tensors_as_shapes[i],
+ ctx->output_types[i]);
+ }
}
}
}
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index f716cd72c9..28fd7565cc 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -74,6 +74,10 @@ class GraphProperties {
// shape information.
void ClearInputProperties(const string& node_name);
void ClearOutputProperties(const string& node_name);
+ // Returns true if we have *any* properties.
+ bool has_properties() const {
+ return input_properties_.size() > 0 || output_properties_.size() > 0;
+ }
private:
// Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 5acfb56b05..db10f586bc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -18,8 +18,10 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/clusters/single_machine.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -42,6 +44,30 @@ class GraphPropertiesTest : public ::testing::Test {
// Provision a single machine with 3 cpu cores
cluster_.reset(new SingleMachine(5 * 60, 3, 0));
TF_CHECK_OK(cluster_->Provision());
+
+ // This function is simply
+ // out = Fill(shape, value), but
+ // Fill requires values in the shape input, not just shape of it, to infer
+ // output shape.
+ auto f = FunctionDefHelper::Create(
+ // Name
+ "MyFillFunc",
+ // Inputs
+ {"shape: int32", "value: float"},
+ // Outputs
+ {"out: float"},
+ // Attrs
+ {},
+ // Nodes
+ {
+ {{"a"},
+ "Fill",
+ {"shape", "value"},
+ {{"T", DataType::DT_FLOAT}, {"index_type", DataType::DT_INT32}}},
+ },
+ // Returns
+ {{"out", "a:output:0"}});
+ function_lib_.add_function()->Swap(&f);
}
void TearDown() override {
@@ -67,7 +93,29 @@ class GraphPropertiesTest : public ::testing::Test {
return s;
}
+ // Compare values of integer (DT_INT32 or DT_INT64) tensor against expected
+ // ones.
+ void ExpectTensorValues(const std::vector<int64>& expected,
+ const TensorProto& tensor_proto_to_compare) {
+ Tensor tensor;
+ EXPECT_TRUE(tensor.FromProto(tensor_proto_to_compare));
+ EXPECT_EQ(expected.size(), tensor.NumElements());
+ // We're interested in only integer tensors as only shapes are exported as
+ // graph properties values.
+ CHECK(tensor.dtype() == DT_INT32 || tensor.dtype() == DT_INT64);
+ if (tensor.dtype() == DT_INT32) {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int32>()(i));
+ }
+ } else {
+ for (int i = 0; i < tensor.NumElements(); i++) {
+ EXPECT_EQ(expected[i], tensor.flat<int64>()(i));
+ }
+ }
+ }
+
std::unique_ptr<SingleMachine> cluster_;
+ FunctionDefLibrary function_lib_;
};
TEST_F(GraphPropertiesTest, StaticProperties) {
@@ -783,6 +831,259 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
EXPECT_EQ("float: [128,256]", PropToString(prop));
}
+TEST_F(GraphPropertiesTest, TensorAsShapesPropagation) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), {5, 7}, {2});
+ Output a1 = ops::Identity(s.WithOpName("a1"), a);
+ Output b = ops::Const(s.WithOpName("b"), 99, {});
+ Output b1 = ops::Identity(s.WithOpName("b1"), b);
+ Output c = ops::Const(s.WithOpName("c"), 1, {4, 4, 4});
+ Output c1 = ops::Identity(s.WithOpName("c1"), c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+
+ // Check output shapes.
+ EXPECT_EQ("int32: [2]", PropToString(properties.GetOutputProperties("a")[0]));
+ EXPECT_EQ("int32: [2]",
+ PropToString(properties.GetOutputProperties("a1")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b")[0]));
+ EXPECT_EQ("int32: []", PropToString(properties.GetOutputProperties("b1")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c")[0]));
+ EXPECT_EQ("int32: [4,4,4]",
+ PropToString(properties.GetOutputProperties("c1")[0]));
+
+ // Check has_value.
+ EXPECT_TRUE(properties.GetOutputProperties("a")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("a1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("b1")[0].has_value());
+ EXPECT_TRUE(properties.GetOutputProperties("c")[0].has_value());
+ EXPECT_TRUE(properties.GetInputProperties("c1")[0].has_value());
+ // Note that we propagate tensro value of only 1D vector and scalar.
+ EXPECT_FALSE(properties.GetOutputProperties("c1")[0].has_value());
+
+ // Check values.
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetInputProperties("a1")[0].value());
+ ExpectTensorValues({5, 7}, properties.GetOutputProperties("a1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b")[0].value());
+ ExpectTensorValues({99}, properties.GetInputProperties("b1")[0].value());
+ ExpectTensorValues({99}, properties.GetOutputProperties("b1")[0].value());
+ std::vector<int64> c_values;
+ for (int i = 0; i < 4 * 4 * 4; i++) {
+ c_values.push_back(1);
+ }
+ ExpectTensorValues({c_values},
+ properties.GetOutputProperties("c")[0].value());
+ ExpectTensorValues({c_values},
+ properties.GetInputProperties("c1")[0].value());
+ // No output value for c1, as it's neither 1D vector nor scalar.
+}
+
+TEST_F(GraphPropertiesTest, IdentityPassingShape) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 5, {2});
+ Output b = ops::Identity(s.WithOpName("b"), a);
+ Output c = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also the value of e to figure out output
+ // shape; hence, Identity op (b) should pass a's value as
+ // output_tensors_as_shape.
+ Output d = ops::Fill(s.WithOpName("fill"), b, c);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [5,5]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithConstInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {});
+ Output b = ops::Const(s.WithOpName("b"), 2, {});
+ Output c = ops::Const(s.WithOpName("c"), 3, {});
+ Output d = ops::Const(s.WithOpName("d"), 4, {});
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, PackWithIdentityInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // Same to PackWithConstInput test case, but a, b, c, and d are Identity ops
+ // from Const.
+ // If output_tensors_as_shape is not not set for those Shape ops or Pack op
+ // doesn't take input_tensors_as_shape, Fill op's input doesn't have value;
+ // hence, its output shape becomes unknown.
+ Output a0 = ops::Const(s.WithOpName("a0"), 1, {});
+ Output b0 = ops::Const(s.WithOpName("b0"), 2, {});
+ Output c0 = ops::Const(s.WithOpName("c0"), 3, {});
+ Output d0 = ops::Const(s.WithOpName("d0"), 4, {});
+ Output a = ops::Identity(s.WithOpName("a"), a0);
+ Output b = ops::Identity(s.WithOpName("b"), b0);
+ Output c = ops::Identity(s.WithOpName("c"), c0);
+ Output d = ops::Identity(s.WithOpName("d"), d0);
+ // Note ops::Stack instantiates Pack op.
+ Output e = ops::Stack(s.WithOpName("pack"), {a, b, c, d});
+ // e is rank 1 tensor: shape = {4}, and its value is {1, 2, 3, 4}
+ Output f = ops::Const(s.WithOpName("const"), 0.1f, {});
+ // Fill needs not only e's shape but also its value to figure out output
+ // shape.
+ Output g = ops::Fill(s.WithOpName("fill"), e, f);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("fill");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithConstInput) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+ Output shape = ops::Const(s.WithOpName("shape"), {1, 2, 3, 4});
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithIdentityOfConstInput) {
+ // Same to FunctionWithConstInput, but function inputs are Identity of Const,
+ // so tensor shapes, not tensor value, should be used as Const input to
+ // function.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib_));
+ Output shape_ = ops::Const(s.WithOpName("shape_"), {1, 2, 3, 4});
+ Output shape = ops::Identity(s.WithOpName("shape"), shape_);
+ Output value = ops::Const(s.WithOpName("value"), 0.1f, {});
+ auto builder = tensorflow::NodeBuilder("MyFillFunc", "MyFillFunc",
+ s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto _value = tensorflow::ops::AsNodeOut(s, value);
+ TF_CHECK_OK(
+ builder.Input(_shape).Input(_value).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(false));
+ const auto out_props = properties.GetOutputProperties("MyFillFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("float: [1,2,3,4]", PropToString(out_prop0));
+}
+
+TEST_F(GraphPropertiesTest, FunctionReturnTensorValue) {
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: int32"}, // Inputs
+ {"out: int32"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_INT32}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+
+ // MyFunc takes Const (shape) and passes it with Identity. Expect function
+ // output has the same shape as well as value (output_tensors_as_shape) as
+ // input Const tensor.
+ Output shape = ops::Const(s.WithOpName("shape"), {5, 7}, {2});
+ auto _shape = tensorflow::ops::AsNodeOut(s, shape);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_shape).Finalize(s.graph(), &func_op));
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(true));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ("int32: [2]", PropToString(out_prop0));
+ EXPECT_TRUE(out_prop0.has_value());
+ ExpectTensorValues({5, 7}, out_prop0.value());
+ ExpectTensorValues({5, 7},
+ properties.GetInputProperties("MyFunc")[0].value());
+}
+
+TEST_F(GraphPropertiesTest, FunctionWithScalarInput) {
+ // Create graph with a function that takes a scalar value so that we use
+ // Placeholder with scalar as for input to the function shape inference.
+ // Placeholder -> Identity -> MyFunc, where MyFunc simply takes Identity of
+ // the input; all tensors are scalars.
+ FunctionDefLibrary library;
+ *library.add_function() = FunctionDefHelper::Create(
+ "MyFunc", // Name
+ {"x: float"}, // Inputs
+ {"out: float"}, // Outputs
+ {}, // Attrs
+ {{{"a"}, "Identity", {"x"}, {{"T", DataType::DT_FLOAT}}}}, // Nodes
+ {{"out", "a:output:0"}}); // Returns
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ TF_CHECK_OK(s.graph()->AddFunctionLibrary(library));
+ Output placeholder =
+ ops::Placeholder(s.WithOpName("Placeholder"), DataType::DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({})));
+ Output identity = ops::Identity(s.WithOpName("Identity"), placeholder);
+ auto _identity = tensorflow::ops::AsNodeOut(s, identity);
+ auto builder =
+ tensorflow::NodeBuilder("MyFunc", "MyFunc", s.graph()->op_registry());
+ tensorflow::Node* func_op;
+ TF_CHECK_OK(builder.Input(_identity).Finalize(s.graph(), &func_op));
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ // Tensorflow version < 21 infers output shape of Placeholder with empty shape
+ // as unknown, instead of scalar.
+ EXPECT_GT(item.graph.versions().producer(), 21);
+
+ // MyFunc output shouldn't be unknown rank.
+ GraphProperties properties(item);
+ TF_CHECK_OK(properties.InferStatically(true));
+ const auto out_props = properties.GetOutputProperties("MyFunc");
+ const OpInfo::TensorProperties out_prop0 = out_props[0];
+ EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
+ EXPECT_FALSE(out_prop0.shape().unknown_rank());
+}
+
TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
// Test graph produced in python using:
/*
@@ -814,18 +1115,10 @@ TEST_F(GraphPropertiesTest, SimpleFunctionStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
@@ -840,51 +1133,25 @@ TEST_F(GraphPropertiesTest, LargeFunctionStaticShapeInference) {
EXPECT_EQ(2, out_props.size());
const OpInfo::TensorProperties& out_prop0 = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop0.dtype());
- EXPECT_EQ(4, out_prop0.shape().dim_size());
- EXPECT_EQ(128, out_prop0.shape().dim(0).size());
- EXPECT_EQ(112, out_prop0.shape().dim(1).size());
- EXPECT_EQ(112, out_prop0.shape().dim(2).size());
- EXPECT_EQ(64, out_prop0.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,64]", PropToString(out_prop0));
const OpInfo::TensorProperties& out_prop1 = out_props[1];
- EXPECT_EQ(DT_FLOAT, out_prop1.dtype());
- EXPECT_EQ(128, out_prop1.shape().dim(0).size());
- EXPECT_EQ(112, out_prop1.shape().dim(1).size());
- EXPECT_EQ(112, out_prop1.shape().dim(2).size());
- EXPECT_EQ(24, out_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [128,112,112,24]", PropToString(out_prop1));
const auto in_props = properties.GetInputProperties("y0");
EXPECT_EQ(4, in_props.size());
const OpInfo::TensorProperties& in_prop0 = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop0.dtype());
- EXPECT_EQ(1, in_prop0.shape().dim_size());
- EXPECT_EQ(64, in_prop0.shape().dim(0).size());
+ EXPECT_EQ("float: [64]", PropToString(in_prop0));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_EQ(4, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(1, in_prop1.shape().dim(1).size());
- EXPECT_EQ(24, in_prop1.shape().dim(2).size());
- EXPECT_EQ(64, in_prop1.shape().dim(3).size());
+ EXPECT_EQ("float: [1,1,24,64]", PropToString(in_prop1));
const OpInfo::TensorProperties& in_prop2 = in_props[2];
- EXPECT_EQ(DT_FLOAT, in_prop2.dtype());
- EXPECT_EQ(4, in_prop2.shape().dim_size());
- EXPECT_EQ(128, in_prop2.shape().dim(0).size());
- EXPECT_EQ(224, in_prop2.shape().dim(1).size());
- EXPECT_EQ(224, in_prop2.shape().dim(2).size());
- EXPECT_EQ(3, in_prop2.shape().dim(3).size());
+ EXPECT_EQ("float: [128,224,224,3]", PropToString(in_prop2));
const OpInfo::TensorProperties& in_prop3 = in_props[3];
- EXPECT_EQ(DT_FLOAT, in_prop3.dtype());
- EXPECT_EQ(4, in_prop3.shape().dim_size());
- EXPECT_EQ(7, in_prop3.shape().dim(0).size());
- EXPECT_EQ(7, in_prop3.shape().dim(1).size());
- EXPECT_EQ(3, in_prop3.shape().dim(2).size());
- EXPECT_EQ(8, in_prop3.shape().dim(3).size());
+ EXPECT_EQ("float: [7,7,3,8]", PropToString(in_prop3));
}
TEST_F(GraphPropertiesTest, LargeFunctionWithMultipleOutputs) {
@@ -944,18 +1211,10 @@ TEST_F(GraphPropertiesTest, FunctionWithErrorStaticShapeInference) {
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
@@ -980,27 +1239,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchStaticShapeInference) {
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
@@ -1024,28 +1272,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitch2StaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_MPaeanipb7o");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_MPaeanipb7o");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(2, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
@@ -1073,28 +1309,16 @@ TEST_F(GraphPropertiesTest, FunctionSwitchShapesStaticShapeInference) {
TF_CHECK_OK(properties.InferStatically(false));
const auto out_props = properties.GetOutputProperties("MyAdd_lEKAAnIwI5I");
const OpInfo::TensorProperties& out_prop = out_props[0];
- EXPECT_EQ(DT_FLOAT, out_prop.dtype());
- EXPECT_FALSE(out_prop.shape().unknown_rank());
- EXPECT_EQ(2, out_prop.shape().dim_size());
- EXPECT_EQ(1, out_prop.shape().dim(0).size());
- EXPECT_EQ(2, out_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(out_prop));
const auto in_props = properties.GetInputProperties("MyAdd_lEKAAnIwI5I");
EXPECT_EQ(2, in_props.size());
const OpInfo::TensorProperties& in_prop = in_props[0];
- EXPECT_EQ(DT_FLOAT, in_prop.dtype());
- EXPECT_FALSE(in_prop.shape().unknown_rank());
- EXPECT_EQ(2, in_prop.shape().dim_size());
- EXPECT_EQ(1, in_prop.shape().dim(0).size());
- EXPECT_EQ(2, in_prop.shape().dim(1).size());
+ EXPECT_EQ("float: [1,2]", PropToString(in_prop));
const OpInfo::TensorProperties& in_prop1 = in_props[1];
- EXPECT_EQ(DT_FLOAT, in_prop1.dtype());
- EXPECT_FALSE(in_prop1.shape().unknown_rank());
- EXPECT_EQ(2, in_prop1.shape().dim_size());
- EXPECT_EQ(1, in_prop1.shape().dim(0).size());
- EXPECT_EQ(3, in_prop1.shape().dim(1).size());
+ EXPECT_EQ("float: [1,3]", PropToString(in_prop1));
}
TEST_F(GraphPropertiesTest, SymbolicShapes) {
@@ -1116,6 +1340,8 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
Output zero = ops::Const(s.WithOpName("zero"), 0.0f, {});
Output g = ops::Shape(s.WithOpName("g"), c);
Output h = ops::Fill(s.WithOpName("h"), g, zero);
+ Output zero_idx = ops::Const(s.WithOpName("zero_idx"), {0}, {1});
+ Output j = ops::Sum(s.WithOpName("j"), a, zero_idx);
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
@@ -1158,6 +1384,10 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
ASSERT_EQ(2, shape_f.dim_size());
EXPECT_EQ(shape_h.dim(0).size(), shape_c.dim(0).size());
EXPECT_EQ(shape_h.dim(1).size(), shape_c.dim(1).size());
+
+ const auto shape_j = properties.GetOutputProperties("j").at(0).shape();
+ ASSERT_EQ(1, shape_j.dim_size());
+ EXPECT_EQ(shape_j.dim(0).size(), shape_a.dim(1).size());
}
TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
index c94ee2f227..0ec95dd684 100644
--- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
@@ -88,6 +88,13 @@ library {
}
}
}
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ }
+ }
+ }
}
ret {
key: "while"
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
index 6406a4bdbf..71f4d9fd05 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/clusters/utils.h"
@@ -175,14 +176,24 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
TensorShapeProto MaybeGetMinimumShape(const TensorShapeProto& original_shape,
int rank, bool* found_unknown_shapes) {
auto shape = original_shape;
- if (shape.unknown_rank() || shape.dim_size() < rank) {
+ bool is_scalar = !shape.unknown_rank() && shape.dim_size() == 0;
+
+ if (shape.unknown_rank() || (!is_scalar && shape.dim_size() < rank)) {
*found_unknown_shapes = true;
- TensorShapeProto::Dim dim;
VLOG(2) << "Use minimum shape because the rank is unknown.";
// The size of each dimension is at least 1, if unknown.
- dim.set_size(1);
+ for (int i = shape.dim_size(); i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (is_scalar) {
+ for (int i = 0; i < rank; i++) {
+ shape.add_dim()->set_size(1);
+ }
+ } else if (shape.dim_size() > rank) {
+ *found_unknown_shapes = true;
+ shape.clear_dim();
for (int i = 0; i < rank; i++) {
- *shape.add_dim() = dim;
+ shape.add_dim()->set_size(original_shape.dim(i).size());
}
} else {
for (int i = 0; i < shape.dim_size(); i++) {
diff --git a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
index 7271a29319..998bd59dce 100644
--- a/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
+++ b/tensorflow/core/grappler/costs/op_level_cost_estimator_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -1126,5 +1127,77 @@ TEST_F(OpLevelCostEstimatorTest, PredictFusedBatchNormGrad) {
EXPECT_EQ(0, costs.num_ops_with_unknown_shapes);
}
}
+
+TEST_F(OpLevelCostEstimatorTest, MaybeGetMinimumShape) {
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(true);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({1, 1, 1, 1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 1, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({1, 1}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_FALSE(unknown_shapes);
+ ExpectTensorShape({10, 20}, y);
+
+ unknown_shapes = false;
+ TensorShapeProto z = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ EXPECT_EQ(4, z.dim_size());
+ ExpectTensorShape({10, 20, 1, 1}, z);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ x.add_dim()->set_size(-1);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 4, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({10, 20, 1, 20}, y);
+ }
+
+ {
+ TensorShapeProto x;
+ x.set_unknown_rank(false);
+ x.add_dim()->set_size(10);
+ x.add_dim()->set_size(20);
+ x.add_dim()->set_size(30);
+ x.add_dim()->set_size(20);
+ bool unknown_shapes = false;
+ TensorShapeProto y = MaybeGetMinimumShape(x, 2, &unknown_shapes);
+ EXPECT_TRUE(unknown_shapes);
+ ExpectTensorShape({10, 20}, y);
+ }
+}
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index be54d98534..5415324b48 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -99,7 +99,7 @@ static void ExtractExtraProperties(
continue;
}
TensorId input_tensor_id = ParseTensorName(input_name);
- const string input_node_name = input_tensor_id.first.ToString();
+ const string input_node_name(input_tensor_id.first);
auto iter = name_to_node.find(input_node_name);
if (iter == name_to_node.end()) continue;
@@ -127,7 +127,7 @@ static void ExtractExtraProperties(
// For filename input, the file size can also be useful.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("filename") != std::string::npos) {
+ op_def->input_arg(i).name().find("filename") != string::npos) {
Tensor tensor;
if (!tensor.FromProto(t)) {
continue;
@@ -153,7 +153,7 @@ static void ExtractExtraProperties(
// When the input is a handle (e.g. look up table handle), the information
// in the op itself is not sufficient to predict the op memory.
if (op_def && i < op_def->input_arg_size() &&
- op_def->input_arg(i).name().find("handle") != std::string::npos) {
+ op_def->input_arg(i).name().find("handle") != string::npos) {
string new_key = strings::StrCat("parent_", i, "_op");
AttrValue attr;
attr.set_s(input_node->op());
@@ -172,7 +172,7 @@ std::vector<OpInfo::TensorProperties> FindInputFeatures(
for (const auto& input_name : node.input()) {
CHECK(!input_name.empty());
TensorId input_tensor_id = ParseTensorName(input_name);
- const string input_node_name = input_tensor_id.first.ToString();
+ const string input_node_name(input_tensor_id.first);
const int output_index = input_tensor_id.second;
// Skip control inputs.
@@ -209,13 +209,13 @@ DeviceProperties GetDeviceInfo(const string& device_str) {
if (DeviceNameUtils::ParseFullName(device_str, &parsed)) {
if (parsed.type == "GPU") {
TfGpuId tf_gpu_id(parsed.id);
- CudaGpuId cuda_gpu_id;
- Status s = GpuIdManager::TfToCudaGpuId(tf_gpu_id, &cuda_gpu_id);
+ PlatformGpuId platform_gpu_id;
+ Status s = GpuIdManager::TfToPlatformGpuId(tf_gpu_id, &platform_gpu_id);
if (!s.ok()) {
// We are probably running simulation without linking cuda libraries.
- cuda_gpu_id = CudaGpuId(parsed.id);
+ platform_gpu_id = PlatformGpuId(parsed.id);
}
- return GetLocalGPUInfo(cuda_gpu_id);
+ return GetLocalGPUInfo(platform_gpu_id);
} else if (parsed.type == "CPU") {
return GetLocalCPUInfo();
}
@@ -320,8 +320,8 @@ void TensorSizeHistogram::Merge(const TensorSizeHistogram& src) {
buckets_.begin(), std::plus<uint64>());
}
-std::string TensorSizeHistogram::ToString() const {
- std::string r;
+string TensorSizeHistogram::ToString() const {
+ string r;
char buf[200];
snprintf(buf, sizeof(buf), "Count: %lld, Average: ", num_elem_);
r.append(buf);
diff --git a/tensorflow/core/grappler/costs/utils.h b/tensorflow/core/grappler/costs/utils.h
index d2c7c67666..5fd6717712 100644
--- a/tensorflow/core/grappler/costs/utils.h
+++ b/tensorflow/core/grappler/costs/utils.h
@@ -80,7 +80,7 @@ class TensorSizeHistogram {
uint64 Max() const { return max_; }
uint64 NumElem() const { return num_elem_; }
uint64 SumElem() const { return sum_elem_; }
- std::string ToString() const;
+ string ToString() const;
protected:
const int Index(const uint64 value) const;
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 6e3ebdee12..037a823096 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -880,10 +880,15 @@ Costs VirtualScheduler::Summary() const {
// Print per device summary
VLOG(1) << "Devices:";
Costs critical_path_costs = Costs::ZeroCosts();
+ std::vector<string> device_names;
+ device_names.reserve(device_.size());
+ for (auto& it : device_) {
+ device_names.push_back(it.first);
+ }
+ std::sort(device_names.begin(), device_names.end());
- for (const auto& device : device_) {
- const auto& name = device.first;
- const auto& state = device.second;
+ for (const auto& name : device_names) {
+ const auto& state = device_.at(name);
std::map<string, int64> op_to_memory;
// First profile only persistent memory usage.
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
index b1373d8317..80889afc86 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_description.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
@@ -1998,13 +1999,13 @@ TEST_F(VirtualSchedulerTest, InterDeviceTransfer) {
// Helper lambda to extract port num from _Send and _Recv op name.
auto get_port_num = [](const string& name) -> int {
- if (name.find("bn_0") != std::string::npos) {
+ if (name.find("bn_0") != string::npos) {
return 0;
- } else if (name.find("bn_1") != std::string::npos) {
+ } else if (name.find("bn_1") != string::npos) {
return 1;
- } else if (name.find("bn_2") != std::string::npos) {
+ } else if (name.find("bn_2") != string::npos) {
return 2;
- } else if (name.find("bn_minus1") != std::string::npos) {
+ } else if (name.find("bn_minus1") != string::npos) {
return -1;
}
return -999;
diff --git a/tensorflow/core/grappler/graph_analyzer/BUILD b/tensorflow/core/grappler/graph_analyzer/BUILD
new file mode 100644
index 0000000000..d56a08d3c8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/BUILD
@@ -0,0 +1,139 @@
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+cc_library(
+ name = "graph_analyzer_lib",
+ srcs = [
+ "gen_node.cc",
+ "graph_analyzer.cc",
+ "sig_node.cc",
+ "subgraph.cc",
+ ],
+ hdrs = [
+ "gen_node.h",
+ "graph_analyzer.h",
+ "hash_tools.h",
+ "map_tools.h",
+ "sig_node.h",
+ "subgraph.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+cc_library(
+ name = "graph_analyzer_tool",
+ srcs = ["graph_analyzer_tool.cc"],
+ hdrs = ["graph_analyzer_tool.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_analyzer_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:grappler_item",
+ "@com_google_absl//absl/strings",
+ ],
+)
+
+cc_library(
+ name = "test_tools_lib",
+ testonly = 1,
+ srcs = [
+ "test_tools.cc",
+ ],
+ hdrs = [
+ "test_tools.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_analyzer_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core/grappler:op_types",
+ "@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
+ ],
+)
+
+tf_cc_test(
+ name = "hash_tools_test",
+ testonly = 1,
+ srcs = [
+ "hash_tools_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "gen_node_test",
+ testonly = 1,
+ srcs = [
+ "gen_node_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "sig_node_test",
+ testonly = 1,
+ srcs = [
+ "sig_node_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "//tensorflow/core/grappler:utils",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "graph_analyzer_test",
+ testonly = 1,
+ srcs = [
+ "graph_analyzer_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
+
+tf_cc_test(
+ name = "subgraph_test",
+ testonly = 1,
+ srcs = [
+ "subgraph_test.cc",
+ ],
+ deps = [
+ ":graph_analyzer_lib",
+ ":test_tools_lib",
+ "@com_google_absl//absl/memory",
+ "@com_google_absl//absl/strings:str_format",
+ "@com_google_googletest//:gtest_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.cc b/tensorflow/core/grappler/graph_analyzer/gen_node.cc
new file mode 100644
index 0000000000..f8c15fd50e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node.cc
@@ -0,0 +1,148 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GenNode::GenNode(const NodeDef* node) : node_(node), op_(nullptr) {}
+
+Status GenNode::BuildGraphInMap(const GraphDef& source, GenNodeMap* map) {
+ for (const auto& n : source.node()) {
+ const string& name = n.name();
+ if (map->find(name) != map->end()) {
+ // This error code looks more meaningful than ALREADY_EXISTS.
+ return Status(error::INVALID_ARGUMENT,
+ "Duplicate node name '" + name + "'.");
+ }
+ (*map)[name] = absl::make_unique<GenNode>(&n);
+ }
+ // Now parse the links.
+ for (const auto& mapit : *map) {
+ Status st = mapit.second->ParseInputs(map);
+ if (!st.ok()) {
+ return st;
+ }
+ }
+ return Status::OK();
+}
+
+Status GenNode::ParseInputs(const GenNodeMap* map) {
+ all_inputs_or_none_ = false;
+ Status st = OpRegistry::Global()->LookUpOpDef(opcode(), &op_);
+ if (!st.ok()) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat("Node '%s' contains an undefined operation '%s': %s",
+ name(), opcode(), st.error_message()));
+ }
+
+ int n_inputs = node_->input_size();
+
+ int n_named_inputs = op_->input_arg_size();
+
+ int n_multi_inputs = 0;
+ for (const auto& inarg : op_->input_arg()) {
+ if (!inarg.number_attr().empty() || !inarg.type_list_attr().empty()) {
+ ++n_multi_inputs;
+ }
+ }
+ bool is_commutative = grappler::IsCommutative(*node_);
+
+ if (n_multi_inputs > 1 || (n_multi_inputs > 0 && n_named_inputs > 1)) {
+ // Can't handle more than one multi-input at a time.
+ // And can't handle the commutativeness of only some arguments
+ // rather than all of them.
+ is_commutative = false;
+ }
+
+ if (is_commutative) {
+ // If truly commutative, can treat all the inputs as one multi-input.
+ // It's possible to just treat the commutative nodes as AllInputsOrNone
+ // but (1) this way is a bit more efficient and (2) I want to preserve this
+ // more efficient code path that does all-or-none by a single input and
+ // perhaps extend its use in the future.
+ n_named_inputs = 1;
+ all_inputs_or_none_ = false;
+ } else if (n_multi_inputs > 0) {
+ all_inputs_or_none_ = true;
+ }
+
+ for (int i = 0; i < n_inputs; ++i) {
+ int other_position;
+ string other_name = ParseNodeName(node_->input(i), &other_position);
+ auto other_it = map->find(other_name);
+ if (other_it == map->end()) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "Node '%s' input %d refers to a non-existing node '%s'.", name(),
+ i, other_name));
+ }
+ GenNode* other_node = other_it->second.get();
+
+ int this_position = other_position < 0 ? -1 : (is_commutative ? 0 : i);
+
+ if (this_position >= 0 && n_multi_inputs == 0 &&
+ this_position >= n_named_inputs) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "Node '%s' has a non-control input from '%s' at index %d but its "
+ "operation '%s' defines only %d inputs.",
+ name(), other_name, this_position, op_->name(), n_named_inputs));
+ }
+
+ Port this_port(/*inbound=*/true, this_position);
+ Port other_port(/*inbound=*/false, other_position);
+
+ links_[this_port].emplace_back(LinkTarget(other_node, other_port));
+ other_node->links_[other_port].emplace_back(LinkTarget(this, this_port));
+ }
+ return Status::OK();
+}
+
+bool GenNode::IsMultiInput(Port port) const {
+ if (!port.IsInbound()) {
+ return false;
+ }
+ auto it = links_.find(port);
+ if (it == links_.end()) {
+ return false; // Shouldn't happen.
+ }
+ return (it->second.size() > 1);
+}
+
+GenNode::Port::operator string() const {
+ string result = this->IsInbound() ? "i" : "o";
+ if (this->IsControl()) {
+ result.append("C");
+ } else {
+ result.append(absl::StrFormat("%d", this->Id()));
+ }
+ return result;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node.h b/tensorflow/core/grappler/graph_analyzer/gen_node.h
new file mode 100644
index 0000000000..faec9ecad8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node.h
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+class GenNode;
+
+// To find nodes by name.
+using GenNodeMap = std::unordered_map<string, std::unique_ptr<GenNode>>;
+
+// One node in the graph, in the form convenient for traversal and generation of
+// subgraphs. It refers to the original NodeDef protobuf for most information
+// and adds the extra enrichment.
+//
+// The graph building is 2-stage: first match a GenNode with each NodeDef and
+// collect them into a map that finds them by name, then process the map,
+// deep-parse the underlying NodeDefs and connect the GenNodes together.
+class GenNode {
+ public:
+ // Will keep the pointer, so the underlying object must not be deleted while
+ // GenNode is alive.
+ explicit GenNode(const NodeDef* node);
+
+ // Access wrappers.
+ const string& name() const { return node_->name(); }
+ const string& opcode() const { return node_->op(); }
+ const NodeDef* node_def() const { return node_; }
+
+ // Parse the inputs of this node and update the map accordingly, creating the
+ // links (i.e. edges, connections between nodes) in itself and in the nodes
+ // it's linked to (the map itself is unchanged, only the nodes in it are
+ // updated).
+ Status ParseInputs(const GenNodeMap* map);
+
+ // Does the full 2-stage build of the graph. The map should be initially
+ // empty. The map keeps pointers to the nodes in source, so the source must
+ // not be destroyed before the map.
+ static Status BuildGraphInMap(const GraphDef& source, GenNodeMap* map);
+
+ // The enrichment that constitutes the point of this class.
+
+ // Representation of a connection on a node.
+ class Port {
+ public:
+ // A port may be inbound or outbound.
+ // Negative ids (canonically -1) mean a control port.
+ Port(bool inbound, int32_t id) : value_(id << 1) {
+ if (inbound) {
+ value_ |= 1;
+ }
+ }
+ Port(const Port&) = default;
+ Port& operator=(const Port&) = default;
+
+ bool IsInbound() const { return (value_ & 0x1); }
+
+ bool IsControl() const { return (value_ < 0); }
+
+ int32_t Id() const {
+ // Arithmetic shift preserves the sign.
+ return (value_ >> 1);
+ }
+
+ // Integer type used to represent the encoded port value.
+ using IntPort = int32_t;
+
+ // Returns the encoded form of this port, so that it can be used
+ // as various map indexes.
+ IntPort Encoded() const { return value_; }
+
+ static Port Decode(IntPort encoded) { return Port(encoded); }
+
+ bool operator==(const Port& other) const { return value_ == other.value_; }
+ bool operator<(const Port& other) const { return value_ < other.value_; }
+
+ struct Hasher {
+ size_t operator()(const Port& port) const noexcept {
+ return hasher(port.Encoded());
+ }
+ std::hash<int32_t> hasher;
+ };
+
+ // Convenient for printing. I've really wanted it to be implicit but
+ // ClangTidy insists on making it explicit.
+ explicit operator string() const;
+
+ private:
+ explicit Port(IntPort value) : value_(value) {}
+
+ IntPort value_;
+ };
+
+ struct LinkTarget {
+ GenNode* node; // Node where this link points.
+ Port port; // Port on the remote side of this link.
+
+ LinkTarget(GenNode* a_node, Port a_port) : node(a_node), port(a_port) {}
+ };
+ // All the links that are connected to the same port of this node
+ // are collected in one vector. A link is an edge of the graph that connects
+ // 2 nodes. Each of the connected nodes has its own perspective on the link,
+ // seeing its local port, remote port and the remote node. The direction of
+ // the link is encoded in the ports, one port is always incoming and another
+ // one outgoing.
+ using LinkTargetVector = std::vector<LinkTarget>;
+ // Both inputs and outputs are stored in the same map.
+ using LinkMap = std::unordered_map<Port, LinkTargetVector, Port::Hasher>;
+
+ // Access to the link map.
+ const LinkMap& links() const { return links_; }
+
+ // Check whether the port is an input (including the controls) with multiple
+ // connections. Such inputs get handled in a special way when building the
+ // subgraphs, in an "all or nothing" fashion.
+ bool IsMultiInput(Port port) const;
+
+ // When building the subgraphs, must include either all non-control inputs of
+ // this node into the subgraph or none of them. This happens when at least one
+ // of the inputs is a multi-input (or if the opcode is commutative, thus
+ // treating all the inputs as one multi-input).
+ bool AllInputsOrNone() const { return all_inputs_or_none_; }
+
+ private:
+ const NodeDef* node_;
+ // Becomes valid only after ParseInputs().
+ const OpDef* op_;
+
+ // The opcode has a complicated structure of input args, with multi-input args
+ // that are not commutative. This means that to make sense, the subgraphs that
+ // include this node must also include either all its inputs or none of them.
+ bool all_inputs_or_none_ = false;
+
+ LinkMap links_;
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GEN_NODE_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc
new file mode 100644
index 0000000000..d77daf7849
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/gen_node_test.cc
@@ -0,0 +1,491 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(GenNodeTest, Port) {
+ {
+ GenNode::Port p(true, 100);
+ EXPECT_THAT(p.IsInbound(), Eq(true));
+ EXPECT_THAT(p.IsControl(), Eq(false));
+ EXPECT_THAT(p.Id(), Eq(100));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(true));
+ EXPECT_THAT(p2.IsControl(), Eq(false));
+ EXPECT_THAT(p2.Id(), Eq(100));
+ }
+ {
+ GenNode::Port p(false, 0);
+ EXPECT_THAT(p.IsInbound(), Eq(false));
+ EXPECT_THAT(p.IsControl(), Eq(false));
+ EXPECT_THAT(p.Id(), Eq(0));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(false));
+ EXPECT_THAT(p2.IsControl(), Eq(false));
+ EXPECT_THAT(p2.Id(), Eq(0));
+ }
+ {
+ GenNode::Port p(true, -100);
+ EXPECT_THAT(p.IsInbound(), Eq(true));
+ EXPECT_THAT(p.IsControl(), Eq(true));
+ EXPECT_THAT(p.Id(), Eq(-100));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(true));
+ EXPECT_THAT(p2.IsControl(), Eq(true));
+ EXPECT_THAT(p2.Id(), Eq(-100));
+ }
+ {
+ GenNode::Port p(false, -1);
+ EXPECT_THAT(p.IsInbound(), Eq(false));
+ EXPECT_THAT(p.IsControl(), Eq(true));
+ EXPECT_THAT(p.Id(), Eq(-1));
+ GenNode::Port p2 = GenNode::Port::Decode(p.Encoded());
+ EXPECT_THAT(p2.IsInbound(), Eq(false));
+ EXPECT_THAT(p2.IsControl(), Eq(true));
+ EXPECT_THAT(p2.Id(), Eq(-1));
+ }
+}
+
+TEST(GenNodeTest, ParseNodeNoInputs) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK()));
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre());
+}
+
+// A general operation, and a control link.
+TEST(GenNodeTest, ParseNodeWithControl) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeSub("node3", "node1", "node2");
+ node3.add_input("^node1"); // The control link.
+ node3.add_input("^node2"); // The control link.
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]",
+ "oC: node3[iC]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]",
+ "oC: node3[iC]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "iC: node1[oC], node2[oC]"
+ ));
+ // clang-format on
+
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+
+ // This is a multi-control-input.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, -1)), Eq(true));
+
+ EXPECT_FALSE(gn1->AllInputsOrNone());
+ EXPECT_FALSE(gn2->AllInputsOrNone());
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+// Commutative nodes are treated as having a single input,
+// because their inputs are equivalent.
+TEST(GenNodeTest, ParseNodeCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ // TODO(babkin): grappler::IsCommutative() should return true for Add but
+ // apparently doesn't. So use Mul in the meantime.
+ NodeDef node3 = MakeNodeMul("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0], node2[o0]"
+ ));
+ // clang-format on
+
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true));
+
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeAddN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0], node2[o0]"
+ ));
+ // clang-format on
+
+ // This is a multi-output.
+ EXPECT_THAT(gn2->IsMultiInput(GenNode::Port(false, 0)), Eq(false));
+ // This is a multi-input.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(true));
+
+ EXPECT_FALSE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputNotCommutative) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+ EXPECT_TRUE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiInputList) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeIdentityN("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ ASSERT_THAT(gn3->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn3->IsMultiInput(GenNode::Port(true, 0)), Eq(false));
+ EXPECT_TRUE(gn3->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiMultiInput) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeConst("node3");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ NodeDef node4 = MakeNodeConst("node4");
+ map["node4"] = absl::make_unique<GenNode>(&node4);
+
+ NodeDef node5 =
+ MakeNodeQuantizedConcat("node5", "node1", "node2", "node3", "node4");
+ map["node5"] = absl::make_unique<GenNode>(&node5);
+
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ auto gn3 = map["node3"].get();
+ auto gn4 = map["node4"].get();
+ auto gn5 = map["node5"].get();
+ ASSERT_THAT(gn5->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "o0: node5[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn2->links()), ElementsAre(
+ "o0: node5[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn3->links()), ElementsAre(
+ "o0: node5[i2]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre(
+ "o0: node5[i3]"
+ ));
+ EXPECT_THAT(DumpLinkMap(gn5->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "i2: node3[o0]",
+ "i3: node4[o0]"
+ ));
+ // clang-format on
+
+ // Non-commutative multi-input doesn't count.
+ EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 1)), Eq(false));
+ EXPECT_THAT(gn5->IsMultiInput(GenNode::Port(true, 2)), Eq(false));
+ EXPECT_TRUE(gn5->AllInputsOrNone());
+}
+
+TEST(GenNodeTest, ParseNodeMultiOutput) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+
+ NodeDef node4 = MakeNodeSub("node4", "node3:1", "node3:0");
+ map["node4"] = absl::make_unique<GenNode>(&node4);
+
+ auto gn4 = map["node4"].get();
+ ASSERT_THAT(gn4->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn4->links()), ElementsAre(
+ "i0: node3[o1]",
+ "i1: node3[o0]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, ParseNodeUndefinedOp) {
+ GenNodeMap map;
+ NodeDef node1;
+ node1.set_name("node1");
+ node1.set_op("Zzzx");
+
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+
+ const OpDef* opdef;
+ Status nested_error = OpRegistry::Global()->LookUpOpDef("Zzzx", &opdef);
+
+ auto gn = map["node1"].get();
+ ASSERT_THAT(
+ gn->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node1' contains an undefined operation 'Zzzx': " +
+ nested_error.error_message())));
+}
+
+TEST(GenNodeTest, ParseNodeUnexpectedInputs) {
+ GenNodeMap map;
+
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("node1");
+
+ auto gn1 = map["node1"].get();
+ EXPECT_THAT(gn1->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node1' has a non-control "
+ "input from 'node1' at index 0 but its operation "
+ "'Const' defines only 0 inputs.")));
+
+ NodeDef node2 = MakeNodeConst("node2");
+ map["node2"] = absl::make_unique<GenNode>(&node2);
+
+ NodeDef node3 = MakeNodeSub("node3", "node1", "node2");
+ map["node3"] = absl::make_unique<GenNode>(&node3);
+ node3.add_input("node1");
+
+ auto gn3 = map["node3"].get();
+ EXPECT_THAT(gn3->ParseInputs(&map),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "Node 'node3' has a non-control "
+ "input from 'node1' at index 2 but its operation "
+ "'Sub' defines only 2 inputs.")));
+}
+
+// Even if an opcode defines no inputs, the node may still accept the control
+// inputs.
+TEST(GenNodeTest, ParseNodeControlInputsAlwaysOk) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeConst("node1");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("^node1");
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(gn1->ParseInputs(&map), Eq(Status::OK()));
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(gn1->links()), ElementsAre(
+ "iC: node1[oC]",
+ "oC: node1[iC]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, ParseNodeInvalidInput) {
+ GenNodeMap map;
+ NodeDef node1 = MakeNodeAddN("node1", "node2", "node3");
+ map["node1"] = absl::make_unique<GenNode>(&node1);
+ node1.add_input("node1");
+ auto gn1 = map["node1"].get();
+ ASSERT_THAT(
+ gn1->ParseInputs(&map),
+ Eq(Status(
+ error::INVALID_ARGUMENT,
+ "Node 'node1' input 0 refers to a non-existing node 'node2'.")));
+}
+
+TEST(GenNodeTest, BuildGraphInMap) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ (*graph.add_node()) =
+ MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node1"), Ne(map.end()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ EXPECT_THAT(map["node1"]->name(), Eq("node1"));
+ EXPECT_THAT(map["node2"]->name(), Eq("node2"));
+ EXPECT_THAT(map["node3"]->name(), Eq("node3"));
+
+ // clang-format off
+ EXPECT_THAT(DumpLinkMap(map["node1"]->links()), ElementsAre(
+ "o0: node3[i0]"
+ ));
+ EXPECT_THAT(DumpLinkMap(map["node2"]->links()), ElementsAre(
+ "i0: node3[o1]",
+ "i1: node3[o0]",
+ "o0: node3[i1]"
+ ));
+ EXPECT_THAT(DumpLinkMap(map["node3"]->links()), ElementsAre(
+ "i0: node1[o0]",
+ "i1: node2[o0]",
+ "o0: node2[i1]",
+ "o1: node2[i0]"
+ ));
+ // clang-format on
+}
+
+TEST(GenNodeTest, BuildGraphInMapDuplicateNode) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node1");
+ GenNodeMap map;
+ ASSERT_THAT(
+ GenNode::BuildGraphInMap(graph, &map),
+ Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'.")));
+}
+
+TEST(GenNodeTest, BuildGraphInMapParseError) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+
+ GenNodeMap map;
+ ASSERT_THAT(
+ GenNode::BuildGraphInMap(graph, &map),
+ Eq(Status(
+ error::INVALID_ARGUMENT,
+ "Node 'node2' input 0 refers to a non-existing node 'node3'.")));
+}
+
+} // end namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
new file mode 100644
index 0000000000..f3796fcf86
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.cc
@@ -0,0 +1,341 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <deque>
+#include <iostream>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+GraphAnalyzer::GraphAnalyzer(const GraphDef& graph, int subgraph_size)
+ : graph_(graph), subgraph_size_(subgraph_size) {}
+
+GraphAnalyzer::~GraphAnalyzer() {}
+
+Status GraphAnalyzer::Run() {
+ // The signature computation code would detect this too, but better
+ // to report it up front than spend time computing all the graphs first.
+ if (subgraph_size_ > Signature::kMaxGraphSize) {
+ return Status(error::INVALID_ARGUMENT,
+ absl::StrFormat("Subgraphs of %d nodes are not supported, "
+ "the maximal supported node count is %d.",
+ subgraph_size_, Signature::kMaxGraphSize));
+ }
+
+ Status st = BuildMap();
+ if (!st.ok()) {
+ return st;
+ }
+
+ FindSubgraphs();
+ DropInvalidSubgraphs();
+ st = CollateResult();
+ if (!st.ok()) {
+ return st;
+ }
+
+ return Status::OK();
+}
+
+Status GraphAnalyzer::BuildMap() {
+ nodes_.clear();
+ return GenNode::BuildGraphInMap(graph_, &nodes_);
+}
+
+void GraphAnalyzer::FindSubgraphs() {
+ result_.clear();
+
+ if (subgraph_size_ < 1) {
+ return;
+ }
+
+ partial_.clear();
+ todo_.clear(); // Just in case.
+
+ // Start with all subgraphs of size 1.
+ const Subgraph::Identity empty_parent;
+ for (const auto& node : nodes_) {
+ if (subgraph_size_ == 1) {
+ result_.ExtendParent(empty_parent, node.second.get());
+ } else {
+ // At this point ExtendParent() is guaranteed to not return nullptr.
+ todo_.push_back(partial_.ExtendParent(empty_parent, node.second.get()));
+ }
+ }
+
+ // Then extend the subgraphs until no more extensions are possible.
+ while (!todo_.empty()) {
+ ExtendSubgraph(todo_.front());
+ todo_.pop_front();
+ }
+
+ partial_.clear();
+}
+
+void GraphAnalyzer::ExtendSubgraph(Subgraph* parent) {
+ bool will_complete = (parent->id().size() + 1 == subgraph_size_);
+ SubgraphPtrSet& sg_set = will_complete ? result_ : partial_;
+
+ const GenNode* last_all_or_none_node = nullptr;
+ for (SubgraphIterator sit(parent); !sit.AtEnd(); sit.Next()) {
+ const GenNode* node = sit.GetNode();
+ GenNode::Port port = sit.GetPort();
+ const GenNode::LinkTarget& neighbor = sit.GetNeighbor();
+
+ if (node->AllInputsOrNone() && port.IsInbound() && !port.IsControl()) {
+ if (node != last_all_or_none_node) {
+ ExtendSubgraphAllOrNone(parent, node);
+ last_all_or_none_node = node;
+ }
+ sit.SkipPort();
+ } else if (neighbor.node->AllInputsOrNone() && !port.IsInbound() &&
+ !port.IsControl()) {
+ if (parent->id().find(neighbor.node) == parent->id().end()) {
+ // Not added yet.
+ ExtendSubgraphAllOrNone(parent, neighbor.node);
+ }
+ } else if (node->IsMultiInput(port)) {
+ ExtendSubgraphPortAllOrNone(parent, node, port);
+ sit.SkipPort();
+ } else if (neighbor.node->IsMultiInput(neighbor.port)) {
+ // Would need to add all inputs of the neighbor node at this port at
+ // once.
+ if (parent->id().find(neighbor.node) != parent->id().end()) {
+ continue; // Already added.
+ }
+ ExtendSubgraphPortAllOrNone(parent, neighbor.node, neighbor.port);
+ } else {
+ Subgraph* sg = sg_set.ExtendParent(parent->id(), neighbor.node);
+ if (!will_complete && sg != nullptr) {
+ todo_.push_back(sg);
+ }
+ }
+ }
+}
+
+void GraphAnalyzer::ExtendSubgraphAllOrNone(Subgraph* parent,
+ const GenNode* node) {
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ auto range_end = node->links().end();
+
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::ExtendSubgraphPortAllOrNone(Subgraph* parent,
+ const GenNode* node,
+ GenNode::Port port) {
+ auto nbit = node->links().find(port);
+ if (nbit == node->links().end()) {
+ return; // Should never happen.
+ }
+
+ Subgraph::Identity id = parent->id();
+ id.insert(node);
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ id.insert(link.node);
+ if (id.size() > subgraph_size_) {
+ return; // Too big.
+ }
+ }
+
+ AddExtendedSubgraph(parent, id);
+}
+
+void GraphAnalyzer::AddExtendedSubgraph(Subgraph* parent,
+ const Subgraph::Identity& id) {
+ if (id.size() == parent->id().size()) {
+ return; // Nothing new was added.
+ }
+
+ auto sg = absl::make_unique<Subgraph>(id);
+ SubgraphPtrSet& spec_sg_set =
+ (id.size() == subgraph_size_) ? result_ : partial_;
+ if (spec_sg_set.find(sg) != spec_sg_set.end()) {
+ // This subgraph was already found by extending from a different path.
+ return;
+ }
+
+ if (id.size() != subgraph_size_) {
+ todo_.push_back(sg.get());
+ }
+ spec_sg_set.insert(std::move(sg));
+}
+
+void GraphAnalyzer::DropInvalidSubgraphs() {
+ auto resit = result_.begin();
+ while (resit != result_.end()) {
+ if (HasInvalidMultiInputs(resit->get())) {
+ auto delit = resit;
+ ++resit;
+ result_.erase(delit);
+ } else {
+ ++resit;
+ }
+ }
+}
+
+bool GraphAnalyzer::HasInvalidMultiInputs(Subgraph* sg) {
+ // Do the all-or-none-input nodes.
+ for (auto const& node : sg->id()) {
+ if (!node->AllInputsOrNone()) {
+ continue;
+ }
+
+ bool anyIn = false;
+ bool anyOut = false;
+
+ auto range_end = node->links().end();
+ for (auto nbit = node->links().begin(); nbit != range_end; ++nbit) {
+ auto port = nbit->first;
+ if (!port.IsInbound() || port.IsControl()) {
+ continue;
+ }
+
+ // Since there might be multiple links to the same nodes,
+ // have to add all links one-by-one to check whether the subgraph
+ // would grow too large. But if it does grow too large, there is no
+ // point in growing it more, can just skip over the rest of the links.
+ for (const auto& link : nbit->second) {
+ if (sg->id().find(link.node) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ }
+ }
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+
+ // Do the multi-input ports.
+ for (SubgraphIterator sit(sg); !sit.AtEnd(); sit.Next()) {
+ if (sit.GetNode()->IsMultiInput(sit.GetPort())) {
+ bool anyIn = false;
+ bool anyOut = false;
+ do {
+ GenNode* peer = sit.GetNeighbor().node;
+ if (sg->id().find(peer) == sg->id().end()) {
+ anyOut = true;
+ } else {
+ anyIn = true;
+ }
+ } while (sit.NextIfSamePort());
+
+ if (anyIn && anyOut) {
+ return true;
+ }
+ }
+ }
+ return false;
+}
+
+Status GraphAnalyzer::CollateResult() {
+ ordered_collation_.clear();
+ collation_map_.clear();
+
+ // Collate by the signatures of the graphs.
+ for (const auto& it : result_) {
+ auto sig = absl::make_unique<Signature>();
+ it->ExtractForSignature(&sig->map);
+ Status status = sig->Compute();
+ if (!status.ok()) {
+ return status;
+ }
+
+ auto& coll_entry = collation_map_[sig.get()];
+ if (coll_entry.sig == nullptr) {
+ coll_entry.sig = std::move(sig);
+ }
+ ++coll_entry.count;
+ }
+
+ // Then order them by the count.
+ for (auto& entry : collation_map_) {
+ ordered_collation_.insert(&entry.second);
+ }
+
+ result_.clear(); // Not needed after collation.
+
+ return Status::OK();
+}
+
+std::vector<string> GraphAnalyzer::DumpRawSubgraphs() {
+ std::vector<string> result;
+ for (const auto& it : result_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+}
+
+std::vector<string> GraphAnalyzer::DumpSubgraphs() {
+ std::vector<string> result;
+ for (auto ptr : ordered_collation_) {
+ result.emplace_back(
+ absl::StrFormat("%d %s", ptr->count, ptr->sig->ToString()));
+ }
+ return result;
+}
+
+Status GraphAnalyzer::OutputSubgraphs() {
+ size_t total = 0;
+ for (auto ptr : ordered_collation_) {
+ std::cout << ptr->count << ' ' << ptr->sig->ToString() << '\n';
+ total += ptr->count;
+ }
+ std::cout << "Total: " << total << '\n';
+ if (std::cout.fail()) {
+ return Status(error::DATA_LOSS, "Failed to write to stdout");
+ } else {
+ return Status::OK();
+ }
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
new file mode 100644
index 0000000000..97626346c7
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer.h
@@ -0,0 +1,154 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
+
+#include <deque>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+namespace test {
+class GraphAnalyzerTest;
+} // end namespace test
+
+// Finds all the subgraphs of a given size and groups them by equivalence.
+class GraphAnalyzer {
+ public:
+ // Makes a copy of the graph.
+ GraphAnalyzer(const GraphDef& graph, int subgraph_size);
+
+ virtual ~GraphAnalyzer();
+
+ // Performs the analysis and collects the subgraphs.
+ Status Run();
+
+ // Returns the subgraphs found in Run() printed to text.
+ std::vector<string> DumpSubgraphs();
+
+ // Prints the subgraphs found in Run() to stdout.
+ Status OutputSubgraphs();
+
+ // TODO(babkin): add a way to extract the subgraphs as direct data
+ // structures and as protobufs, and to write protobufs to a RecordIO.
+
+ private:
+ GraphAnalyzer() = delete;
+ GraphAnalyzer(const GraphAnalyzer&) = delete;
+ void operator=(const GraphAnalyzer&) = delete;
+
+ friend class tensorflow::grappler::graph_analyzer::test::GraphAnalyzerTest;
+
+ // Builds the map of nodes from the original graph definition.
+ Status BuildMap();
+
+ // Using nodes_, finds all the subgraphs of size subgraph_size_ and places
+ // them into result_.
+ void FindSubgraphs();
+
+ // Deletes from result_ the unacceptable subgraphs. Those include the
+ // subgraphs where not all the inputs at a multi-input port are included (this
+ // could happen if some of these inputs were reached and included through
+ // different paths).
+ void DropInvalidSubgraphs();
+
+ // Deletes from result_ duplicate entries of equivalent topology.
+ Status CollateResult();
+
+ // Returns the raw subgraphs found in FindSubgraphs() printed to text.
+ std::vector<string> DumpRawSubgraphs();
+
+ // Finds and adds appropriately to either partial_ or result_ all the
+ // subgraphs that can be created by extending the parent subgraph by one node.
+ // Ignores the duplicates.
+ void ExtendSubgraph(Subgraph* parent);
+
+ // Extends the parent subgraph by adding another node (if it wasn't already
+ // added) and all its non-control inputs in the link map range at once.
+ // If the subgraph would grow over subgraph_size_, it gets ignored.
+ void ExtendSubgraphAllOrNone(Subgraph* parent, const GenNode* node);
+ // Same but adds one specific inbound port (even control) all-or-none.
+ void ExtendSubgraphPortAllOrNone(Subgraph* parent, const GenNode* node,
+ GenNode::Port port);
+ // The common final step called by ExtendSubgraph*AllOrNone() methods.
+ void AddExtendedSubgraph(Subgraph* parent, const Subgraph::Identity& id);
+
+ // Returns true if this subgraph has any multi-inputs that aren't all-in or
+ // all-out.
+ bool HasInvalidMultiInputs(Subgraph* sg);
+
+ // Graph to run the analysis on.
+ GraphDef graph_;
+ int subgraph_size_;
+
+ // The enriched graph of parsed nodes and connections.
+ GenNodeMap nodes_;
+ // The resulting set of subgraphs.
+ SubgraphPtrSet result_;
+ // The subgraphs of partial size, stored while finding the result.
+ SubgraphPtrSet partial_;
+ // The subgraphs of partial size (stored in partial_) that are still waiting
+ // to be extended.
+ //
+ // TODO(babkin): This is rather simple-minded, each subgraph is examined from
+ // scratch, which means that all its internal links get iterated too. But it's
+ // OK for the small subgraphs. This can be improved by keeping not just
+ // subgraphs but iterators on the list, each of them having the list not-yet
+ // examined nodes (and the link position of the next link to be examined for
+ // the first node). This would add extra constant overhead, so the break-even
+ // subgraph size is not clear yet.
+ std::deque<Subgraph*> todo_;
+
+ // The collation map by signature is designed to allow the removal of entries
+ // and moving of the signature references from the keys of this map to the
+ // outside world. Must be careful at inserting and removal: make sure that
+ // when a new entry is inserted, its signature reference gets populated with
+ // the same data as the key of the map, and that if a reference is moved out,
+ // the map entry gets removed before that reference gets destroyed.
+ struct CollationEntry {
+ std::shared_ptr<Signature> sig;
+ size_t count = 0;
+ };
+ using CollationMap =
+ std::unordered_map<Signature*, CollationEntry, HashAtPtr<Signature*>,
+ EqAtPtr<Signature*> >;
+ CollationMap collation_map_;
+
+ // The entries are owned by collation_map_, so must be removed from
+ // ordered_collation_ before removing them from collation_map_.
+ struct ReverseLessByCount {
+ bool operator()(CollationEntry* left, CollationEntry* right) const {
+ return left->count > right->count; // Reverse order.
+ }
+ };
+ using CollationOrderByCount =
+ std::multiset<CollationEntry*, ReverseLessByCount>;
+ CollationOrderByCount ordered_collation_;
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc
new file mode 100644
index 0000000000..e94c472056
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_test.cc
@@ -0,0 +1,569 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+
+#include <algorithm>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+using ::testing::SizeIs;
+using ::testing::UnorderedElementsAre;
+
+class GraphAnalyzerTest : public ::testing::Test, protected TestGraphs {
+ protected:
+ Status BuildMap() { return gran_->BuildMap(); }
+
+ void FindSubgraphs() { gran_->FindSubgraphs(); }
+
+ void DropInvalidSubgraphs() { gran_->DropInvalidSubgraphs(); }
+
+ Status CollateResult() { return gran_->CollateResult(); }
+
+ void ExtendSubgraph(Subgraph* parent) { gran_->ExtendSubgraph(parent); }
+
+ void ExtendSubgraphPortAllOrNone(Subgraph* parent, GenNode* node,
+ GenNode::Port port) {
+ gran_->ExtendSubgraphPortAllOrNone(parent, node, port);
+ }
+
+ void ExtendSubgraphAllOrNone(Subgraph* parent, GenNode* node) {
+ gran_->ExtendSubgraphAllOrNone(parent, node);
+ }
+
+ std::vector<string> DumpRawSubgraphs() { return gran_->DumpRawSubgraphs(); }
+
+ std::vector<string> DumpPartials() {
+ std::vector<string> result;
+ for (const auto& it : gran_->partial_) {
+ result.emplace_back(it->Dump());
+ }
+ return result;
+ }
+
+ const GenNodeMap& GetNodes() { return gran_->nodes_; }
+
+ GenNode* GetNode(const string& name) { return gran_->nodes_.at(name).get(); }
+
+ SubgraphPtrSet& GetResult() { return gran_->result_; }
+ SubgraphPtrSet& GetPartial() { return gran_->partial_; }
+ std::deque<Subgraph*>& GetTodo() { return gran_->todo_; }
+
+ // Gets initialized by a particular test from a suitable GraphDef.
+ std::unique_ptr<GraphAnalyzer> gran_;
+};
+
+TEST_F(GraphAnalyzerTest, BuildMap) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ EXPECT_THAT(st, Eq(Status::OK()));
+
+ auto& map = GetNodes();
+ EXPECT_THAT(map.find("node1"), Ne(map.end()));
+ EXPECT_THAT(map.find("node2"), Ne(map.end()));
+ EXPECT_THAT(map.find("node3"), Ne(map.end()));
+}
+
+TEST_F(GraphAnalyzerTest, BuildMapError) {
+ // A duplicate node.
+ (*graph_3n_self_control_.add_node()) = MakeNodeConst("node1");
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ ASSERT_THAT(
+ st, Eq(Status(error::INVALID_ARGUMENT, "Duplicate node name 'node1'.")));
+}
+
+TEST_F(GraphAnalyzerTest, FindSubgraphs0) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 0);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ auto& subgraphs = GetResult();
+ EXPECT_THAT(subgraphs, SizeIs(0));
+ EXPECT_THAT(DumpRawSubgraphs(), ElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, FindSubgraphs1) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 1);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ auto& subgraphs = GetResult();
+ EXPECT_THAT(subgraphs, SizeIs(3));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: BroadcastGradientArgs(node3)",
+ "1: Const(node1)",
+ "1: Sub(node2)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// The required subgraphs are larger than the graph.
+TEST_F(GraphAnalyzerTest, FindSubgraphsTooLarge) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_3n_self_control_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ FindSubgraphs();
+ EXPECT_THAT(DumpRawSubgraphs(), ElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+//===
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node already in the graph.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseIn) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node not in the graph yet.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto parent = absl::make_unique<Subgraph>(Subgraph::Identity());
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(parent.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// where the target subgraph size is larger.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsIncomplete) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ // clang-format off
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Propagate backwards through a multi-input link, finding that the
+// resulting subgraph would be too large.
+TEST_F(GraphAnalyzerTest, MultiInputTooLargeBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Propagate backwards through a multi-input link, finding that nothing
+// would be added to the parent subgraph.
+TEST_F(GraphAnalyzerTest, MultiInputNothingAddedBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root = absl::make_unique<Subgraph>(
+ Subgraph::Identity({GetNode("add2"), GetNode("const2_1"),
+ GetNode("const2_2"), GetNode("const2_3")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate forwards through a multi-input link,
+// with the base (currently-extending) node not in the subgraph yet.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("add2"),
+ GenNode::Port(true, 0));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessBackwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("add2")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: AddN(add2), Sub(sub)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Successfully propagate forwards through a multi-input link.
+TEST_F(GraphAnalyzerTest, MultiInputSuccessForwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add2), Const(const2_1), Const(const2_2), Const(const2_3)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsMulti) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_multi_input_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ // A good one, multi-input is all-in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("const1_2"),
+ GetNode("add1"),
+ })));
+ // A good one, multi-input is all-out
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("add1"),
+ GetNode("add2"),
+ GetNode("sub"),
+ })));
+ // A bad one, multi-input is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("add1"),
+ GetNode("sub"),
+ })));
+ // A bad one, multi-input is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("add2"),
+ GetNode("const2_1"),
+ GetNode("const2_2"),
+ })));
+
+ DropInvalidSubgraphs();
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: AddN(add1), AddN(add2), Sub(sub)",
+ "1: AddN(add1), Const(const1_1), Const(const1_2)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+//===
+
+// Successfully propagate backwards through a multi-input link,
+// with the base (currently-extending) node already in the graph.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards through a multi-input link,
+// but no control links propagate. It also tests the situation
+// where the target subgraph size is larger.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsNoControl) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass1"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: Const(const1_1), Const(const1_2), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// The control links propagate separately as all-or-none, even on the nodes
+// that are all-or-none for the normal inputs.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSeparateControl) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 5);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass1")}));
+
+ ExtendSubgraphPortAllOrNone(root.get(), GetNode("pass1"),
+ GenNode::Port(true, -1));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Propagate backwards from all-or-none-input node, finding that the
+// resulting subgraph would be too large.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputTooLargeBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Propagate backwards from all-or-none-input node, finding that nothing
+// would be added to the parent subgraph.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputNothingAddedBackwards) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root = absl::make_unique<Subgraph>(
+ Subgraph::Identity({GetNode("pass2"), GetNode("const2_1"),
+ GetNode("const2_2"), GetNode("const2_3")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre());
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate forwards to all-or-none-input node,
+// with the base (currently-extending) node not in the subgraph yet.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsBaseOut) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraphAllOrNone(root.get(), GetNode("pass2"));
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+// Successfully propagate backwards from all-or-none-input node.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessBackwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("pass2")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre(
+ "1: IdentityN(pass2), Sub(sub)"
+ ));
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(1));
+}
+
+// Successfully propagate forwards to all-or-none-input node. This includes
+// both all-or-none-input for the normal inputs, and multi-input by the
+// control path.
+TEST_F(GraphAnalyzerTest, AllOrNoneInputSuccessForwardsFull) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 4);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ auto root =
+ absl::make_unique<Subgraph>(Subgraph::Identity({GetNode("const2_1")}));
+
+ ExtendSubgraph(root.get());
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass2)",
+ "1: Const(const2_1), Const(const2_2), Const(const2_3), IdentityN(pass1)"
+ ));
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ // clang-format on
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+TEST_F(GraphAnalyzerTest, DropInvalidSubgraphsAllOrNone) {
+ gran_ = absl::make_unique<GraphAnalyzer>(graph_all_or_none_, 3);
+ Status st = BuildMap();
+ ASSERT_THAT(st, Eq(Status::OK()));
+
+ // A good one, all-or-none is all-in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("const1_2"),
+ GetNode("pass1"),
+ })));
+ // A good one, all-or-none is all-out
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("pass1"),
+ GetNode("pass2"),
+ GetNode("sub"),
+ })));
+ // A bad one, all-or-none is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("const1_1"),
+ GetNode("pass1"),
+ GetNode("sub"),
+ })));
+ // A bad one, all-or-none is partially in.
+ GetResult().insert(absl::make_unique<Subgraph>(Subgraph::Identity({
+ GetNode("pass2"),
+ GetNode("const2_1"),
+ GetNode("const2_2"),
+ })));
+
+ DropInvalidSubgraphs();
+
+ // clang-format off
+ EXPECT_THAT(DumpRawSubgraphs(), UnorderedElementsAre(
+ "1: IdentityN(pass1), IdentityN(pass2), Sub(sub)",
+ "1: Const(const1_1), Const(const1_2), IdentityN(pass1)"
+ ));
+ // clang-format on
+ EXPECT_THAT(DumpPartials(), UnorderedElementsAre());
+ EXPECT_THAT(GetTodo(), SizeIs(0));
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc
new file mode 100644
index 0000000000..924ca11e61
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.cc
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/graph_analyzer.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/init_main.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Dies on failure.
+static void LoadModel(const string& filename,
+ tensorflow::MetaGraphDef* metagraph) {
+ LOG(INFO) << "Loading model from " << filename;
+ Status st;
+ st = ReadBinaryProto(Env::Default(), filename, metagraph);
+ if (!st.ok()) {
+ LOG(WARNING) << "Failed to read a binary metagraph: " << st;
+ st = ReadTextProto(Env::Default(), filename, metagraph);
+ if (!st.ok()) {
+ LOG(FATAL) << "Failed to read a text metagraph: " << st;
+ }
+ }
+}
+
+// Prune the graph to only keep the transitive fanin part with respect to a set
+// of train ops (if provided).
+void MaybePruneGraph(const tensorflow::MetaGraphDef& metagraph,
+ tensorflow::GraphDef* graph) {
+ std::vector<string> fetch_nodes;
+ for (const auto& fetch :
+ metagraph.collection_def().at("train_op").node_list().value()) {
+ LOG(INFO) << "Fetch node: " << fetch;
+ fetch_nodes.push_back(fetch);
+ }
+ if (fetch_nodes.empty()) {
+ *graph = metagraph.graph_def();
+ } else {
+ std::vector<const tensorflow::NodeDef*> fanin_nodes =
+ tensorflow::grappler::ComputeTransitiveFanin(metagraph.graph_def(),
+ fetch_nodes);
+ for (const tensorflow::NodeDef* node : fanin_nodes) {
+ *(graph->add_node()) = *node;
+ }
+ LOG(INFO) << "Pruned "
+ << metagraph.graph_def().node_size() - graph->node_size()
+ << " nodes. Original graph size: "
+ << metagraph.graph_def().node_size()
+ << ". New graph size: " << graph->node_size() << ".";
+ }
+}
+
+void GraphAnalyzerTool(const string& file_name, int n) {
+ if (n < 1) {
+ LOG(FATAL) << "Invalid subgraph size " << n << ", must be at least 1";
+ }
+
+ tensorflow::MetaGraphDef metagraph;
+ LoadModel(file_name, &metagraph);
+ tensorflow::GraphDef graph;
+ MaybePruneGraph(metagraph, &graph);
+ tensorflow::grappler::graph_analyzer::GraphAnalyzer analyzer(graph, n);
+ LOG(INFO) << "Running the analysis";
+ tensorflow::Status st = analyzer.Run();
+ if (!st.ok()) {
+ LOG(FATAL) << "Analysis failed: " << st;
+ }
+
+ LOG(INFO) << "Printing the result";
+ st = analyzer.OutputSubgraphs();
+ if (!st.ok()) {
+ LOG(FATAL) << "Failed to print the result: " << st;
+ }
+
+ LOG(INFO) << "Completed";
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h
new file mode 100644
index 0000000000..5a91fe7dc8
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/graph_analyzer_tool.h
@@ -0,0 +1,31 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
+
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+void GraphAnalyzerTool(const string& file_name, int n);
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_GRAPH_ANALYZER_TOOL_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/hash_tools.h b/tensorflow/core/grappler/graph_analyzer/hash_tools.h
new file mode 100644
index 0000000000..b0e79f9a68
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/hash_tools.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
+
+#include <cstddef>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Unfortunately, std::hash provides no way to combine hashes, so everyone
+// is copying boost::hash_combine. This is a version that follows Google's
+// guidelines on the arguments, and contains only the combination, without
+// hashing.
+inline void CombineHash(size_t from, size_t* to) {
+ *to ^= from + 0x9e3779b9 + (*to << 6) + (*to >> 2);
+}
+
+// Combine two hashes in such a way that the order of combination doesn't matter
+// (so it's really both commutative and associative). The result is not a very
+// high-quality hash but can be used in case if the order of sub-elements must
+// not matter in the following comparison. An alternative would be to sort the
+// hashes of the sub-elements and then combine them normally in the sorted
+// order.
+inline void CombineHashCommutative(size_t from, size_t* to) {
+ *to = *to + from + 0x9e3779b9;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_HASH_TOOLS_H_
diff --git a/tensorflow/core/util/status_util_test.cc b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc
index 1f06004db2..b5e9ce6b8e 100644
--- a/tensorflow/core/util/status_util_test.cc
+++ b/tensorflow/core/grappler/graph_analyzer/hash_tools_test.cc
@@ -13,24 +13,34 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/util/status_util.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
-#include "tensorflow/core/graph/graph_constructor.h"
-#include "tensorflow/core/graph/node_builder.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
namespace {
-TEST(TestStatusUtil, ErrorFormatTagForNode) {
- Graph graph(OpRegistry::Global());
- Node* node;
- TF_CHECK_OK(NodeBuilder("Foo", "NoOp").Finalize(&graph, &node));
- EXPECT_EQ(error_format_tag(*node, "${line}"), "^^node:Foo:${line}^^");
- EXPECT_EQ(error_format_tag(*node, "${file}:${line}"),
- "^^node:Foo:${file}:${line}^^");
+using ::testing::Eq;
+
+TEST(HashToolsTest, CombineHashCommutative) {
+ size_t a = 0;
+ size_t b = 999;
+
+ size_t c = a;
+ CombineHashCommutative(b, &c);
+
+ size_t d = b;
+ CombineHashCommutative(a, &d);
+
+ EXPECT_THAT(c, Eq(d));
}
} // namespace
-} // namespace tensorflow
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/map_tools.h b/tensorflow/core/grappler/graph_analyzer/map_tools.h
new file mode 100644
index 0000000000..584062c5f2
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/map_tools.h
@@ -0,0 +1,46 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
+
+#include <functional>
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// Helpers for building maps of pointers.
+
+template <typename Ptr>
+struct LessAtPtr : std::binary_function<Ptr, Ptr, bool> {
+ bool operator()(const Ptr& x, const Ptr& y) const { return *x < *y; }
+};
+
+template <typename Ptr>
+struct EqAtPtr : std::binary_function<Ptr, Ptr, bool> {
+ bool operator()(const Ptr& x, const Ptr& y) const { return *x == *y; }
+};
+
+template <typename Ptr>
+struct HashAtPtr : std::unary_function<Ptr, size_t> {
+ size_t operator()(const Ptr& x) const { return x->Hash(); }
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_MAP_TOOLS_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.cc b/tensorflow/core/grappler/graph_analyzer/sig_node.cc
new file mode 100644
index 0000000000..b5cca6a512
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.cc
@@ -0,0 +1,453 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+#include <algorithm>
+
+#include "absl/strings/str_format.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+static constexpr bool debug = false;
+
+//=== SigNode
+
+SigNode::SigNode(const NodeDef* node) : node_(node) {}
+
+void SigNode::CopyLinks(const GenNode& from, const TranslationMap& map) {
+ hash_to_link_.clear();
+ hashed_peers_.clear();
+
+ std::map<LinkTag, Link> link_map;
+ CopyLinksPass1(from, map, &link_map);
+ CopyLinksPass2(&link_map);
+}
+
+void SigNode::CopyLinksPass1(const GenNode& from, const TranslationMap& map,
+ std::map<LinkTag, Link>* link_map) {
+ LinkTag::Hasher link_hasher;
+
+ for (const auto& entry : from.links()) {
+ for (const auto& target : entry.second) {
+ auto nodeit = map.find(target.node);
+ if (nodeit == map.end()) {
+ // Node is not in the subgraph, ignore.
+ continue;
+ }
+
+ LinkTag tag(entry.first, target.port);
+ size_t hval = link_hasher(tag);
+
+ // This instantiates the entry if it was not present.
+ Link& map_entry = (*link_map)[tag];
+ if (map_entry.peers.empty()) {
+ map_entry.tag = tag;
+ map_entry.unique_hash = hval;
+ }
+ map_entry.peers.push_back(nodeit->second);
+ }
+ }
+}
+
+void SigNode::CopyLinksPass2(std::map<LinkTag, Link>* link_map) {
+ for (auto& entry : *link_map) {
+ Link* hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
+ // In case of a conflict, rehash. This should almost never happen.
+ // Because the order of iteration is predictable, the rehashed values
+ // will also be predictable.
+ while (!hl_entry_ptr->peers.empty()) {
+ CombineHash(1, &entry.second.unique_hash);
+ hl_entry_ptr = &hash_to_link_[entry.second.unique_hash];
+ }
+
+ for (const auto& peer : entry.second.peers) {
+ hashed_peers_.emplace_back(HashedPeer(entry.second.unique_hash, peer));
+ }
+
+ hl_entry_ptr->tag = entry.second.tag;
+ hl_entry_ptr->unique_hash = entry.second.unique_hash;
+ hl_entry_ptr->peers.swap(entry.second.peers);
+ }
+}
+
+void SigNode::ComputeTopoHash0() {
+ topo_hash_.clear();
+ last_hashed_nodes_ = next_hashed_nodes_ = node_mask_;
+
+ // TODO(babkin): include the attrbutes too, as an option.
+ size_t hval = std::hash<string>()(opcode());
+
+ // Getting the topology of the links in to the hash early should get more
+ // conflicts resolved early.
+ for (const auto& entry : hashed_peers_) {
+ CombineHash(entry.link_hash, &hval);
+ }
+
+ topo_hash_.push_back(hval);
+}
+
+void SigNode::ComputeTopoHash(int distance) {
+ // The new starting point.
+ next_hashed_nodes_ = last_hashed_nodes_;
+ if (debug) {
+ LOG(INFO) << "DEBUG node " << name() << " mask=" << std::hex
+ << next_hashed_nodes_;
+ }
+
+ if (hash_is_final_) {
+ return;
+ }
+
+ CHECK(topo_hash_.size() == distance);
+
+ int prev = distance - 1;
+
+ // Start with own's local topology hash. This value is stable, so
+ // if the hashes of the surrounding nodes don't change on the following
+ // distances, the hash of this node won't change either.
+ size_t hval = topo_hash_[0];
+
+ if (!hashed_peers_.empty()) {
+ size_t last_link_hash = hashed_peers_[0].link_hash;
+ size_t comm_hash = 0;
+
+ for (const auto& entry : hashed_peers_) {
+ if (entry.link_hash != last_link_hash) {
+ CombineHash(last_link_hash, &hval);
+ CombineHash(comm_hash, &hval);
+ comm_hash = 0;
+ last_link_hash = entry.link_hash;
+ }
+
+ // The links in the same vector are commutative, so combine their
+ // hashes in a commutative way.
+ CombineHashCommutative(entry.peer->GetTopoHash(prev), &comm_hash);
+ next_hashed_nodes_ |= entry.peer->last_hashed_nodes_;
+ if (debug) {
+ LOG(INFO) << "DEBUG node " << name() << " += " << entry.peer->name()
+ << " mask=" << std::hex << next_hashed_nodes_;
+ }
+ }
+
+ // The last commutative group.
+ CombineHash(last_link_hash, &hval);
+ CombineHash(comm_hash, &hval);
+ }
+
+ topo_hash_.push_back(hval);
+}
+
+size_t SigNode::GetTopoHash(int distance) const {
+ CHECK(!topo_hash_.empty());
+ if (distance >= topo_hash_.size()) {
+ CHECK(hash_is_final_);
+ return topo_hash_.back();
+ } else {
+ return topo_hash_[distance];
+ }
+}
+
+bool SigNode::operator==(const SigNode& other) const {
+ // TODO(babkin): add attributes too.
+ if (opcode() != other.opcode()) {
+ return false;
+ }
+
+ // Normally the caller is expected to compare the nodes
+ // at the same rank in different graphs, but just in case...
+ if (unique_rank_ != other.unique_rank_) {
+ return false;
+ }
+
+ if (hashed_peers_.size() != other.hashed_peers_.size()) {
+ return false;
+ }
+
+ for (auto it1 = hashed_peers_.begin(), it2 = other.hashed_peers_.begin();
+ it1 != hashed_peers_.end(); ++it1, ++it2) {
+ // TODO(babkin): might compare the actual values too
+ // but the hash is probably just as good.
+ if (it1->link_hash != it2->link_hash) {
+ return false;
+ }
+ if (it1->peer->unique_rank_ != it2->peer->unique_rank_) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+//=== Signature
+
+constexpr int Signature::kMaxGraphSize;
+
+string Signature::ToString() const {
+ string result;
+ for (size_t n = 0; n < nodes.size(); ++n) {
+ // TODO(babkin): add attributes too.
+ result += absl::StrFormat("%d:%s", n, nodes[n]->opcode());
+ for (const auto& entry : nodes[n]->hashed_peers_) {
+ const auto& link = nodes[n]->hash_to_link_[entry.link_hash];
+
+ // The link entries are already sorted, by tags and then by the
+ // node ranks.
+ if (link.tag.local.IsInbound()) {
+ result +=
+ absl::StrFormat("[%s:%s:%d]", string(link.tag.local),
+ string(link.tag.remote), entry.peer->unique_rank_);
+ }
+ }
+ result.push_back(',');
+ }
+ return result;
+}
+
+Status Signature::Compute() {
+ if (map.size() > kMaxGraphSize) {
+ return Status(
+ error::INVALID_ARGUMENT,
+ absl::StrFormat(
+ "A graph of %d nodes is too big for signature computation, "
+ "the maximal supported node count is %d.",
+ map.size(), kMaxGraphSize));
+ }
+
+ // The value that will be assigned next as the unique node id.
+ // This also means that all the entries in nodes at indexes less than this
+ // have been finalized and don't need to be touched any more.
+ size_t next_node_id = 0;
+
+ sig_short = 0;
+ sig_full.resize(0); // Keep the storage.
+
+ // The main signature generation.
+ PrepareNodes();
+ FindUniqueHashes(&next_node_id);
+ while (next_node_id < map.size()) {
+ ComputeOneRound(next_node_id);
+ FindUniqueHashes(&next_node_id);
+ }
+
+ OrderLinks();
+
+ return Status::OK();
+}
+
+void Signature::PrepareNodes() {
+ nodes.resize(0); // Keep the storage.
+
+ // Initialize the nodes.
+ int64_t mask = 1;
+ for (const auto& entry : map) {
+ SigNode* node = entry.second.get();
+ node->last_hashed_nodes_ = node->node_mask_ = mask;
+ mask <<= 1;
+ node->unique_rank_ = ~0;
+ node->hash_is_final_ = false;
+ node->ComputeTopoHash0();
+ if (node->GetHighTopoHash() <= map.size()) {
+ // Would conflict with one of the reserved values.
+ node->ReHighTopoHash();
+ }
+
+ // The initial order is random.
+ nodes.emplace_back(node);
+ }
+}
+
+void Signature::FindUniqueHashes(size_t* next_node_id_p) {
+ // Start by sorting by the hash value.
+ std::sort(nodes.begin() + *next_node_id_p, nodes.end(),
+ SigNode::NodeOrderLess());
+
+ // At each call, if no nodes have unique hashes, one node that has a
+ // non-unique (shared) hash can be made unique by assigning a unique id.
+ // This node gets picked predictably by taking the last node.
+ // TODO(babkin): Technically, more than one node can be unshared,
+ // as long as their last_hashed_nodes_ overlap only by the nodes that
+ // already had the assigned ids before the current round. But it's not clear
+ // yet, how often would this beneficial, because it looks like for many
+ // subgraphs unsharing one node should be enough to untangle them. This
+ // would need more measurement before implementing.
+ bool found_unique = false;
+ for (size_t n = *next_node_id_p; n < nodes.size(); ++n) {
+ size_t cur_hash = nodes[n]->GetHighTopoHash();
+ if (n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash) {
+ // A sequence of nodes sharing the same hash. Skip over it.
+ // TODO(babkin): check here for the arbitrary hash conflicts and resolve
+ // them.
+ for (++n;
+ n + 1 < nodes.size() && nodes[n + 1]->GetHighTopoHash() == cur_hash;
+ ++n) {
+ }
+ if (found_unique || n != nodes.size() - 1) {
+ // Either some unique nodes have already been found, or this is
+ // not the last chance, keep trying to find the unique nodes.
+ continue;
+ }
+ // Here we're at the last node and haven't found any unique ones.
+ // So fall through and make this last node unique.
+ }
+
+ found_unique = true;
+ size_t id = (*next_node_id_p)++;
+ nodes[n]->unique_rank_ = id;
+
+ size_t last_hash = nodes[n]->GetHighTopoHash();
+ CombineHash(last_hash, &sig_short);
+ sig_full.push_back(last_hash);
+
+ // Take the hash at 0 and mix the unique rank into it. After that it will
+ // stay fixed.
+ nodes[n]->topo_hash_.resize(1);
+ nodes[n]->topo_hash_[0] = id + 1; // Avoid the value of 0.
+
+ nodes[n]->hash_is_final_ = true;
+ nodes[n]->last_hashed_nodes_ = nodes[n]->node_mask_;
+ if (n != id) {
+ std::swap(nodes[id], nodes[n]);
+ }
+ }
+}
+
+void Signature::ComputeOneRound(size_t next_node_id) {
+ // Reset the state of the nodes.
+ int debug_i = 0;
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ // The hash at distance 0 never changes, so preserve it.
+ node->topo_hash_.resize(1);
+ node->last_hashed_nodes_ = node->node_mask_;
+ node->hash_is_final_ = false;
+ if (debug) {
+ LOG(INFO) << "DEBUG distance=" << 0 << " node " << debug_i++ << " "
+ << node->name() << " mask=" << std::hex
+ << node->last_hashed_nodes_;
+ }
+ }
+
+ bool stop = false;
+ // The distance can reach up to nodes.size()+1, to include not only all the
+ // nodes but also all the redundant paths.
+ for (int distance = 1; !stop; ++distance) {
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ if (node->hash_is_final_) {
+ continue;
+ }
+ node->ComputeTopoHash(distance);
+ if (node->GetHighTopoHash() <= nodes.size()) {
+ // Would conflict with one of the reserved values.
+ node->ReHighTopoHash();
+ }
+ }
+
+ // Will be looking for the indications to not stop.
+ stop = true;
+
+ debug_i = 0;
+ // The bitmasks get moved after all the hash computations are done.
+ for (auto it = nodes.begin() + next_node_id; it != nodes.end(); ++it) {
+ auto node = *it;
+ if (debug) {
+ LOG(INFO) << "DEBUG distance=" << distance << " node " << debug_i++
+ << " " << node->name() << " oldmask=" << std::hex
+ << node->last_hashed_nodes_ << " mask=" << std::hex
+ << node->next_hashed_nodes_;
+ }
+ if (node->last_hashed_nodes_ == node->next_hashed_nodes_) {
+ // Stopped growing, this part of the graph must be fully
+ // surrounded by nodes that already have the unique ids.
+ node->hash_is_final_ = true;
+ } else {
+ node->last_hashed_nodes_ = node->next_hashed_nodes_;
+ stop = false;
+ }
+ }
+ }
+}
+
+void Signature::OrderLinks() {
+ for (const auto& node : nodes) {
+ if (node->hashed_peers_.empty()) {
+ continue;
+ }
+
+ size_t cur_link_hash = node->hashed_peers_[0].link_hash + 1;
+ int first_idx = -1;
+
+ int idx;
+ for (idx = 0; idx < node->hashed_peers_.size(); ++idx) {
+ auto& entry = node->hashed_peers_[idx];
+ if (entry.link_hash == cur_link_hash) {
+ continue;
+ }
+ if (idx - first_idx > 1) {
+ // Need to sort.
+ std::sort(node->hashed_peers_.begin() + first_idx,
+ node->hashed_peers_.begin() + idx,
+ SigNode::HashedPeer::LessByRank());
+ }
+
+ cur_link_hash = entry.link_hash;
+ first_idx = idx;
+ }
+ if (idx - first_idx > 1) {
+ // Sort the last bunch.
+ std::sort(node->hashed_peers_.begin() + first_idx,
+ node->hashed_peers_.begin() + idx,
+ SigNode::HashedPeer::LessByRank());
+ }
+ }
+}
+
+bool Signature::operator==(const Signature& other) const {
+ // Tries to find the differences as early as possible by
+ // comparing the hashes first.
+
+ if (sig_short != other.sig_short) {
+ return false;
+ }
+ if (sig_full.size() != other.sig_full.size()) {
+ return false;
+ }
+
+ for (auto it1 = sig_full.begin(), it2 = other.sig_full.begin();
+ it1 != sig_full.end(); ++it1, ++it2) {
+ if (*it1 != *it2) {
+ return false;
+ }
+ }
+
+ if (nodes.size() != other.nodes.size()) {
+ return false;
+ }
+ for (auto it1 = nodes.begin(), it2 = other.nodes.begin(); it1 != nodes.end();
+ ++it1, ++it2) {
+ if (**it1 != **it2) {
+ return false;
+ }
+ }
+
+ return true;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node.h b/tensorflow/core/grappler/graph_analyzer/sig_node.h
new file mode 100644
index 0000000000..45c0ed3162
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node.h
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
+
+#include <map>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+namespace test {
+class SigBaseTest;
+} // end namespace test
+
+class SigNode;
+
+// To find nodes by name. Having the map ordered makes the tests easier,
+// and it isn't used in production code often enough to get any win from
+// using an unordered map.
+using SigNodeMap = std::map<string, std::unique_ptr<SigNode>>;
+
+// One node in the graph, in the form convenient for generation of the signature
+// of the graph, and comparison of two (sub)graphs for equivalence. It refers to
+// the original NodeDef protobuf for most information and adds the extra
+// enrichment.
+//
+// The graph building is 2-stage: first match a SigNode with each NodeDef and
+// collect them into a map that finds them by name, then process the map,
+// deep-parse the underlying NodeDefs and connect the SigNodes together.
+class SigNode {
+ public:
+ friend struct Signature;
+
+ // Will keep the pointer to the underlying NodeDef, so that
+ // underlying object must not be deleted while SigNode is alive.
+ explicit SigNode(const NodeDef* node);
+
+ // Access wrappers.
+ const string& name() const { return node_->name(); }
+ const string& opcode() const { return node_->op(); }
+ const NodeDef* node_def() const { return node_; }
+
+ // For extraction of subgraphs into a separate SigNodeMap, copies the links
+ // that point inside the subgraph from a full-graph SigNode to a subgraph
+ // SigNode. The translation map defines the subgraph and gives the mapping
+ // from the nodes in the full graph to the matching nodes in subgraph.
+ using TranslationMap =
+ std::unordered_map<const GenNode* /*full_graph*/, SigNode* /*subgraph*/>;
+ void CopyLinks(const GenNode& from, const TranslationMap& map);
+
+ // A link is an edge of the graph that connects 2 nodes. Each of the connected
+ // nodes has its own perspective on the link, seeing its local port, remote
+ // port and the remote node. The direction of the link is encoded in the
+ // ports, one port is always incoming and another one outgoing.
+ //
+ // The link tag here contains both ports of the link viewed from the
+ // perspective of this node; consisting of both the local port (i.e. at this
+ // node) and remote port (i.e. on the other node), the local one going first.
+ struct LinkTag {
+ struct Hasher {
+ size_t operator()(const LinkTag& tag) const noexcept {
+ size_t hval = port_hasher(tag.local);
+ CombineHash(port_hasher(tag.remote), &hval);
+ return hval;
+ }
+ GenNode::Port::Hasher port_hasher;
+ };
+
+ LinkTag(GenNode::Port a_local, GenNode::Port a_remote)
+ : local(a_local), remote(a_remote) {}
+
+ // The default constructor is used for the default values in maps.
+ // (false, 99) is an arbitrary value that makes the uninitialized
+ // links easy to tell when debugging (they should never happen).
+ LinkTag() : local(false, 99), remote(false, 99) {}
+
+ // Port of the link on the local node.
+ GenNode::Port local;
+ // Port of the link on the remote node.
+ GenNode::Port remote;
+
+ bool operator==(const LinkTag& other) const {
+ return local == other.local && remote == other.remote;
+ }
+ bool operator<(const LinkTag& other) const {
+ return local < other.local ||
+ (local == other.local && remote < other.remote);
+ }
+ };
+
+ // Since the signature logic doesn't differentiate between the links
+ // with the same tag (other than by the "peer" nodes on their other ends),
+ // all the links with the same tag are grouped into a single structure.
+ struct Link {
+ LinkTag tag;
+ size_t unique_hash; // Hash of the tag after conflict resolution.
+ // The remote node(s) on the other side on the link(s).
+ using PeerVector = std::vector<SigNode*>;
+ PeerVector peers;
+ };
+
+ // A way to look up the link description by its hash.
+ using LinkHashMap = std::map<size_t, Link>;
+ const LinkHashMap& hash_to_link() const { return hash_to_link_; }
+
+ // The enumeration of all the peer nodes in a predictable order.
+ // Before the signature generation, only the link values determine the
+ // order, after the signature generation the entries at the same
+ // links get further sorted by their peer node ranks.
+ struct HashedPeer {
+ HashedPeer(size_t l, SigNode* p) : link_hash(l), peer(p) {}
+
+ struct LessByRank {
+ bool operator()(const SigNode::HashedPeer& left,
+ const SigNode::HashedPeer& right) {
+ return left.peer->unique_rank_ < right.peer->unique_rank_;
+ }
+ };
+
+ size_t link_hash;
+ SigNode* peer;
+ };
+ using HashedPeerVector = std::vector<HashedPeer>;
+ const HashedPeerVector& hashed_peers() const { return hashed_peers_; }
+
+ // Compares two nodes in two different graphs for equivalence (two nodes in
+ // the same graph would never be equivalent). Expects that the signatures of
+ // the graphs have already been computed, so unique_rank_ is filled in and
+ // the hashed_peers_ properly ordered.
+ bool operator==(const SigNode& other) const;
+
+ bool operator!=(const SigNode& other) const { return !(*this == other); }
+
+ private:
+ friend class test::SigBaseTest;
+
+ // The CopyLinks code is split into 2 parts for testability.
+ // The first pass builds a map ordered by LinkTag for predictability.
+ void CopyLinksPass1(const GenNode& from, const TranslationMap& map,
+ std::map<LinkTag, Link>* link_map);
+ // The second pass converts to the map by hash value,
+ // resolves any hash conflicts, and builds the hashed peer vector.
+ void CopyLinksPass2(std::map<LinkTag, Link>* link_map);
+
+ // Computes the topological hash at distance 0. Resets the topo_hash_ vector
+ // and hashed_nodes_;
+ void ComputeTopoHash0();
+
+ // Compute the topological has at the given distance. The hashes for all the
+ // lower distances must be already computed for all the nodes in the graph.
+ // Also computes next_hashed_nodes_ from last_hashed_nodes_.
+ void ComputeTopoHash(int distance);
+
+ // Get the hash value for a particular distance. It must be previously
+ // computed.
+ size_t GetTopoHash(int distance) const;
+
+ // The the hash value for the highest computed distance. It must be previously
+ // computed.
+ size_t GetHighTopoHash() const {
+ CHECK(!topo_hash_.empty());
+ return topo_hash_.back();
+ }
+
+ // Rehash the topmost hash, to avoid conflicts.
+ void ReHighTopoHash() {
+ CHECK(!topo_hash_.empty());
+ CombineHash(1, &topo_hash_.back());
+ }
+
+ // Ordering by node order and highest available hash (it must be
+ // previously computed).
+ struct NodeOrderLess {
+ bool operator()(const SigNode* left, const SigNode* right) {
+ return left->topo_hash_.back() < right->topo_hash_.back();
+ }
+ };
+
+ private:
+ const NodeDef* node_;
+
+ // The bitmap mask with 1 bit set that represents this node in the set
+ // during the computation of the signature.
+ uint64_t node_mask_ = 0;
+
+ // The code that populates this map makes sure that there are no hash
+ // conflicts, rehashing if necessary.
+ LinkHashMap hash_to_link_;
+
+ // The enumeration of all the direct peers in the predictable order (which
+ // happens to be the order ot their link tags, but the order of the hashes
+ // would do too). It is used for the quick enumeration during the signature
+ // computation. After the signature building is completed, the entries that
+ // have the same link tag get further sorted in the order of the ranks of
+ // their nodes.
+ HashedPeerVector hashed_peers_;
+
+ // The unique rank represents the order in which the node will be included
+ // into the signature. It gets assigned in order either when the topo_hash_ of
+ // this node becomes unique in the graph, or when the nodes are completely
+ // equivalent, one of them is picked at random to assign the next rank, and
+ // then the rest of the nodes attempt to disambiguate based on that
+ // information.
+ size_t unique_rank_ = ~0;
+ // When hash_is_final_ is set, the topo_has_ vector stops growing, and the
+ // last value from it is used for all the further hashes.
+ bool hash_is_final_ = false;
+ // The hashes that include the topology of the nodes up to the distance N. The
+ // hash for distance 0 is produced from the attributes of this node itself and
+ // its general connectivity properties but no information about the
+ // neighboring nodes. The hash for distance D+1 is build from hashes at level
+ // D of this node and of all its immediate neighbors. The neighbors that are
+ // connected by equivalent links are included in a commutative way.
+ std::vector<size_t> topo_hash_;
+ // The set of nodes that got included into the computation of the
+ // last topo_hash_ entry.
+ uint64_t last_hashed_nodes_ = 0;
+ // The next set of nodes that gets used for the current topo_hash entry.
+ uint64_t next_hashed_nodes_ = 0;
+};
+
+// Signature of a graph. The computation is intertwined with the private methods
+// of SigNode, so keeping both in the same file looks more convenient.
+struct Signature {
+ friend class test::SigBaseTest;
+
+ // Maximal size of the graphs for which the signature can be computed.
+ // Changing this constant won't magically add the support for a larger size,
+ // the rest of implementation would have to be extended. The value of 64 is
+ // driven by the size of a bitset in an uint64_t, and should be enough for our
+ // purposes, while having a high efficiency of implementation.
+ static constexpr int kMaxGraphSize = 64;
+
+ // Using the map, computes the rest of the fields of a signature.
+ // Returns an error is the graph is too big.
+ Status Compute();
+
+ // Convert the computed signature to a string representation.
+ string ToString() const;
+
+ SigNodeMap map; // The nodes in the graph, accessible by name.
+ size_t sig_short = 0; // Hash of the signature, for the quick equality check.
+ // The full signature: hashes of the nodes in a predictable order.
+ std::vector<size_t> sig_full;
+ // The nodes in the same order as they go in the signature.
+ std::vector<SigNode*> nodes;
+
+ // For building the unordered maps.
+ size_t Hash() const { return sig_short; }
+
+ // Returns true if the graphs are equivalent. The signature must be already
+ // computed.
+ bool operator==(const Signature& other) const;
+
+ private:
+ // Populates the nodes vector from the map and initializes the state of the
+ // nodes for the signature computation.
+ void PrepareNodes();
+
+ // Finds the nodes with the hashes that are unique and assigns the unique ids
+ // to them. If there are nodes with non-unique hashes, exactly one node from
+ // the first such sequence (in the order of hash values) will be picked and
+ // assigned a unique id. Assumes that the nodes[0...(next_node_id-1)] have
+ // been already assigned the unique ids. Advances next_node_id by at least 1.
+ void FindUniqueHashes(size_t* next_node_id_p);
+
+ // One round of the signature computation. Assumes that the
+ // nodes[0...(next_node_id-1)] have been already assigned the fixed
+ // positions, and thus computes the hashes only for the remaining nodes.
+ void ComputeOneRound(size_t next_node_id);
+
+ // Additional ordering of the hashed_peers_ links in the nodes, so that they
+ // can be compared and printed in a predictable order.
+ void OrderLinks();
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SIG_NODE_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc
new file mode 100644
index 0000000000..4c6a9ba9e0
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/sig_node_test.cc
@@ -0,0 +1,1235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Gt;
+using ::testing::Ne;
+using ::testing::SizeIs;
+
+//===
+
+TEST(SigNodeLinkTag, Compare) {
+ SigNode::LinkTag a(GenNode::Port(false, 1), GenNode::Port(false, 2));
+ SigNode::LinkTag b(GenNode::Port(false, 1), GenNode::Port(false, 2));
+ SigNode::LinkTag c(GenNode::Port(false, 2), GenNode::Port(false, 1));
+ SigNode::LinkTag d(GenNode::Port(false, 1), GenNode::Port(false, 3));
+ SigNode::LinkTag e(GenNode::Port(false, 2), GenNode::Port(false, 2));
+
+ EXPECT_TRUE(a == b);
+ EXPECT_FALSE(a == c);
+ EXPECT_FALSE(a == e);
+
+ EXPECT_FALSE(a < b);
+ EXPECT_FALSE(b < a);
+
+ EXPECT_TRUE(a < c);
+ EXPECT_FALSE(c < a);
+
+ EXPECT_TRUE(a < d);
+ EXPECT_FALSE(d < a);
+}
+
+//===
+
+class SigBaseTest : public ::testing::Test, protected TestGraphs {
+ protected:
+ void BuildSigMap(const GraphDef& graph) {
+ gen_map_.clear();
+ sig_.map.clear();
+ CHECK(GenNode::BuildGraphInMap(graph, &gen_map_).ok());
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+ }
+
+ static void CopyLinksPass2(
+ std::map<SigNode::LinkTag, SigNode::Link>* link_map, SigNode* node) {
+ node->CopyLinksPass2(link_map);
+ }
+
+ static void ComputeTopoHash0(SigNode* node) { node->ComputeTopoHash0(); }
+
+ static void ComputeTopoHash(int distance, SigNode* node) {
+ node->ComputeTopoHash(distance);
+ }
+
+ static size_t GetTopoHash(int distance, SigNode* node) {
+ return node->GetTopoHash(distance);
+ }
+
+ static size_t GetHighTopoHash(SigNode* node) {
+ return node->GetHighTopoHash();
+ }
+
+ static void ReHighTopoHash(SigNode* node) { node->ReHighTopoHash(); }
+
+ static SigNode::HashedPeerVector& RefHashedPeers(SigNode* node) {
+ return node->hashed_peers_;
+ }
+ static size_t& RefUniqueRank(SigNode* node) { return node->unique_rank_; }
+ static bool& RefHashIsFinal(SigNode* node) { return node->hash_is_final_; }
+ static std::vector<size_t>& RefTopoHash(SigNode* node) {
+ return node->topo_hash_;
+ }
+ static uint64_t& RefNodeMask(SigNode* node) { return node->node_mask_; }
+ static uint64_t& RefLastHashedNodes(SigNode* node) {
+ return node->last_hashed_nodes_;
+ }
+ static uint64_t& RefNextHashedNodes(SigNode* node) {
+ return node->next_hashed_nodes_;
+ }
+
+ static void PrepareNodes(Signature* signature) { signature->PrepareNodes(); }
+
+ static void FindUniqueHashes(size_t* next_node_id_p, Signature* signature) {
+ signature->FindUniqueHashes(next_node_id_p);
+ }
+
+ static void ComputeOneRound(size_t next_node_id, Signature* signature) {
+ signature->ComputeOneRound(next_node_id);
+ }
+
+ static void OrderLinks(Signature* signature) { signature->OrderLinks(); }
+
+ // These get initialized in BuildSigMap().
+ GenNodeMap gen_map_;
+ Signature sig_;
+};
+
+//===
+
+class SigNodeTest : public SigBaseTest {};
+
+// Tests that the duplicate hashes get resolved by rehashing.
+TEST_F(SigNodeTest, DuplicateHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ NodeDef node2 = MakeNodeConst("node2");
+ NodeDef node3 = MakeNodeShapeN("node3", "node1", "node2");
+
+ SigNode sn1(&node1);
+ SigNode sn2(&node2);
+ SigNode sn3(&node3);
+
+ constexpr size_t kSameHash = 999;
+
+ SigNode::Link link1;
+ link1.tag = SigNode::LinkTag(GenNode::Port(true, 0), GenNode::Port(false, 0));
+ link1.unique_hash = kSameHash;
+ link1.peers.emplace_back(&sn1);
+
+ SigNode::Link link2;
+ link2.tag = SigNode::LinkTag(GenNode::Port(true, 1), GenNode::Port(false, 0));
+ link2.unique_hash = kSameHash;
+ link2.peers.emplace_back(&sn2);
+
+ SigNode::Link link3;
+ link3.tag = SigNode::LinkTag(GenNode::Port(true, 2), GenNode::Port(false, 0));
+ link3.unique_hash = kSameHash;
+ link3.peers.emplace_back(&sn3);
+
+ std::map<SigNode::LinkTag, SigNode::Link> link_map;
+ link_map[link1.tag] = link1;
+ link_map[link2.tag] = link2;
+ link_map[link3.tag] = link3;
+
+ CopyLinksPass2(&link_map, &sn3);
+ auto& hl = sn3.hash_to_link();
+ EXPECT_THAT(hl, SizeIs(3));
+
+ // Check that the hashes are self_consistent, and put the entries into
+ // another map with a known order.
+ std::map<SigNode::LinkTag, SigNode::Link> rehashed;
+ auto hlit = hl.begin();
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+ ++hlit;
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+ ++hlit;
+ ASSERT_THAT(hlit, Ne(hl.end()));
+ EXPECT_THAT(hlit->second.unique_hash, Eq(hlit->first));
+ rehashed[hlit->second.tag] = hlit->second;
+
+ // Just in case.
+ ASSERT_THAT(rehashed, SizeIs(3));
+
+ auto rhit = rehashed.begin();
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link1.tag);
+ EXPECT_THAT(rhit->second.unique_hash, Eq(kSameHash));
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn1));
+
+ ++rhit;
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link2.tag);
+ // This hash must be rehashed.
+ EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
+ size_t hash2 = rhit->second.unique_hash;
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn2));
+
+ ++rhit;
+ ASSERT_THAT(rhit, Ne(rehashed.end()));
+ EXPECT_TRUE(rhit->second.tag == link3.tag);
+ // This hash must be rehashed.
+ EXPECT_THAT(rhit->second.unique_hash, Ne(kSameHash));
+ EXPECT_THAT(rhit->second.unique_hash, Ne(hash2));
+ size_t hash3 = rhit->second.unique_hash;
+ EXPECT_THAT(rhit->second.peers, ElementsAre(&sn3));
+
+ auto& peers = sn3.hashed_peers();
+ EXPECT_THAT(peers, SizeIs(3));
+
+ auto peerit = peers.begin();
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(kSameHash));
+ EXPECT_THAT(peerit->peer, Eq(&sn1));
+
+ ++peerit;
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(hash2));
+ EXPECT_THAT(peerit->peer, Eq(&sn2));
+
+ ++peerit;
+ ASSERT_THAT(peerit, Ne(peers.end()));
+ EXPECT_THAT(peerit->link_hash, Eq(hash3));
+ EXPECT_THAT(peerit->peer, Eq(&sn3));
+}
+
+// The full CopyLinks() is tested in (SubgraphTest, ExtractForSignature).
+
+TEST_F(SigNodeTest, GetTopoHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake some hash values.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+
+ RefHashIsFinal(&sn1) = true;
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(456));
+
+ EXPECT_THAT(GetHighTopoHash(&sn1), Eq(456));
+}
+
+TEST_F(SigNodeTest, ReTopoHash) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake some hash values.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(456));
+
+ ReHighTopoHash(&sn1);
+
+ size_t expected_hash = 456;
+ CombineHash(1, &expected_hash);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(123));
+ EXPECT_THAT(GetTopoHash(1, &sn1), Eq(expected_hash));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHash0) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ // Fake a topology.
+ RefUniqueRank(&sn1) = 10;
+ RefNodeMask(&sn1) = 0x02;
+
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(456);
+
+ // Fake a state.
+ RefLastHashedNodes(&sn1) = 0xFF;
+ RefNextHashedNodes(&sn1) = 0xFF;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(1, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(2, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(3, nullptr));
+
+ // Run the test.
+ ComputeTopoHash0(&sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x02));
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x02));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(1));
+
+ size_t exp_hval = std::hash<string>()(sn1.opcode());
+ CombineHash(1, &exp_hval);
+ CombineHash(1, &exp_hval);
+ CombineHash(2, &exp_hval);
+ CombineHash(3, &exp_hval);
+ CombineHash(3, &exp_hval);
+
+ EXPECT_THAT(GetTopoHash(0, &sn1), Eq(exp_hval));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHashNotFinal) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+
+ // Fake a topology.
+ RefUniqueRank(&sn1) = 0;
+ RefNodeMask(&sn1) = 0x01;
+ RefUniqueRank(&sn2) = 0;
+ RefNodeMask(&sn2) = 0x02;
+ RefUniqueRank(&sn3) = 0;
+ RefNodeMask(&sn3) = 0x04;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
+
+ // Fake a state.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(321);
+
+ RefTopoHash(&sn2).emplace_back(456);
+ RefTopoHash(&sn2).emplace_back(654);
+
+ RefTopoHash(&sn3).emplace_back(789);
+ RefTopoHash(&sn3).emplace_back(987);
+
+ // These values are not realistic in the way that they don't include the bits
+ // from the mask of nodes themselves, but that's the point of this test: only
+ // the previous nodes' node sets are used in the computation, not their own
+ // masks directly.
+ RefLastHashedNodes(&sn1) = 0x8;
+ RefLastHashedNodes(&sn2) = 0x10;
+ RefLastHashedNodes(&sn3) = 0x20;
+
+ // A scratch value to get overwritten.
+ RefNextHashedNodes(&sn1) = 0x100;
+
+ ComputeTopoHash(2, &sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x38));
+
+ // This computes the hash form the explicit numbers above.
+ size_t exp_hash = 123; // The 0th hash is the starting point.
+ size_t comm_hash;
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+ CombineHashCommutative(987, &comm_hash);
+
+ CombineHash(10, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+
+ CombineHash(20, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ comm_hash = 0;
+ CombineHashCommutative(654, &comm_hash);
+ CombineHashCommutative(987, &comm_hash);
+
+ CombineHash(30, &exp_hash);
+ CombineHash(comm_hash, &exp_hash);
+
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(exp_hash));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(3));
+}
+
+TEST_F(SigNodeTest, ComputeTopoHashFinal) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+
+ // Fake a topology - same as for ComputeTopoHashNotFinal.
+ RefUniqueRank(&sn1) = 0;
+ RefNodeMask(&sn1) = 0x01;
+ RefUniqueRank(&sn2) = 0;
+ RefNodeMask(&sn2) = 0x02;
+ RefUniqueRank(&sn3) = 0;
+ RefNodeMask(&sn3) = 0x04;
+
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(10, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(20, &sn2));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn3));
+ RefHashedPeers(&sn1).emplace_back(SigNode::HashedPeer(30, &sn2));
+
+ // Fake a state - mostly same as for ComputeTopoHashNotFinal.
+ RefTopoHash(&sn1).emplace_back(123);
+ RefTopoHash(&sn1).emplace_back(321);
+
+ RefTopoHash(&sn2).emplace_back(456);
+ RefTopoHash(&sn2).emplace_back(654);
+
+ RefTopoHash(&sn3).emplace_back(789);
+ RefTopoHash(&sn3).emplace_back(987);
+
+ // These values are not realistic in the way that they don't include the bits
+ // from the mask of nodes themselves, but that's the point of this test: only
+ // the previous nodes' node sets are used in the computation, not their own
+ // masks directly.
+ RefLastHashedNodes(&sn1) = 0x8;
+ RefLastHashedNodes(&sn2) = 0x10;
+ RefLastHashedNodes(&sn3) = 0x20;
+
+ // A scratch value to get overwritten.
+ RefNextHashedNodes(&sn1) = 0x100;
+
+ // This is the difference in configuration.
+ RefHashIsFinal(&sn1) = true;
+
+ ComputeTopoHash(2, &sn1);
+
+ EXPECT_THAT(RefLastHashedNodes(&sn1), Eq(0x8)); // Unchanged.
+ EXPECT_THAT(RefNextHashedNodes(&sn1), Eq(0x8));
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(GetTopoHash(2, &sn1), Eq(321));
+}
+
+TEST_F(SigNodeTest, EqualsOpcode) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+
+ EXPECT_TRUE(sn1 == sn2);
+ EXPECT_FALSE(sn1 != sn2);
+
+ node2.set_op("Mul");
+
+ EXPECT_TRUE(sn1 != sn2);
+ EXPECT_FALSE(sn1 == sn2);
+}
+
+TEST_F(SigNodeTest, EqualsRank) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+
+ EXPECT_TRUE(sn1 == sn2);
+ EXPECT_FALSE(sn1 != sn2);
+
+ RefUniqueRank(&sn1) = 1;
+ RefUniqueRank(&sn2) = 2;
+
+ EXPECT_TRUE(sn1 != sn2);
+ EXPECT_FALSE(sn1 == sn2);
+}
+
+// Checks that if the nodes have a different number of links,
+// they will be considered unequal.
+TEST_F(SigNodeTest, EqualsLinkSize) {
+ GraphDef graph1;
+ (*graph1.add_node()) = MakeNodeConst("node1");
+ (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
+
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ SigNodeMap sig_map1;
+ sg1.ExtractForSignature(&sig_map1);
+
+ GraphDef graph2;
+ (*graph2.add_node()) = MakeNodeConst("node1");
+ // The difference between graph1 and graph2: one more input.
+ auto node22 = graph2.add_node();
+ *node22 = MakeNodeMul("node2", "node1", "node1");
+ node22->add_input("node2");
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph2, &gen_map2), Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ SigNodeMap sig_map2;
+ sg2.ExtractForSignature(&sig_map2);
+
+ EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+ EXPECT_FALSE(*sig_map2["node2"] == *sig_map1["node2"]);
+}
+
+TEST_F(SigNodeTest, EqualsLinks) {
+ // Start with 2 copies of the same graph.
+ GraphDef graph1;
+ (*graph1.add_node()) = MakeNodeConst("node1");
+ (*graph1.add_node()) = MakeNodeMul("node2", "node1", "node1");
+
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map1), Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ SigNodeMap sig_map1;
+ sg1.ExtractForSignature(&sig_map1);
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph1, &gen_map2), Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ SigNodeMap sig_map2;
+ sg2.ExtractForSignature(&sig_map2);
+
+ EXPECT_TRUE(*sig_map1["node1"] == *sig_map2["node1"]);
+ EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Alter the link hash of one of the nodes.
+ SigNode* sn2 = sig_map2["node2"].get();
+ ++RefHashedPeers(sn2)[0].link_hash;
+
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Restore back.
+ --RefHashedPeers(sn2)[0].link_hash;
+ EXPECT_TRUE(*sig_map1["node2"] == *sig_map2["node2"]);
+
+ // Alter the unique rank of a referenced node.
+ ++RefUniqueRank(sig_map2["node1"].get());
+
+ EXPECT_FALSE(*sig_map1["node2"] == *sig_map2["node2"]);
+}
+
+//===
+
+class SignatureTest : public SigBaseTest {
+ protected:
+ // Initializeds the state used to generate the permutations of a given size.
+ static void InitPermutation(size_t size,
+ std::vector<size_t>* plain_permutation,
+ std::vector<size_t>* countdown) {
+ plain_permutation->clear();
+ countdown->clear();
+ for (size_t i = 0; i < size; ++i) {
+ plain_permutation->emplace_back(i);
+ countdown->emplace_back(size - 1 - i);
+ }
+ }
+
+ // Builds a permutation guided by the count-down value.
+ static void BuildPermutation(const std::vector<size_t>& plain_permutation,
+ const std::vector<size_t>& countdown,
+ std::vector<size_t>* result) {
+ *result = plain_permutation;
+ for (int i = 0; i < result->size(); ++i) {
+ std::swap((*result)[i], (*result)[i + countdown[i]]);
+ }
+ }
+
+ // Returns false when the count-down is finished.
+ static bool CountDown(std::vector<size_t>* countdown) {
+ // The last position always contains 0, so skip it.
+ int pos;
+ for (pos = countdown->size() - 2; pos >= 0; --pos) {
+ if ((*countdown)[pos] > 0) {
+ --(*countdown)[pos];
+ break;
+ }
+ (*countdown)[pos] = (countdown->size() - 1 - pos);
+ }
+
+ return pos >= 0;
+ }
+
+ // Permutes the nodes every which way and checks that all the signatures
+ // produced are the same. This is reasonable for the graphs up to the
+ // size 5, maybe 6 at the stretch. After that the number of permutation grows
+ // huge and the test becomes very slow.
+ void TestGraphEveryWay(const GraphDef& graph) {
+ size_t graph_size = graph.node_size();
+
+ gen_map_.clear();
+ sig_.map.clear();
+ Status result = GenNode::BuildGraphInMap(graph, &gen_map_);
+ ASSERT_THAT(result, Eq(Status::OK()));
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ std::vector<size_t> plain_permutation;
+ std::vector<size_t> countdown;
+ InitPermutation(graph_size, &plain_permutation, &countdown);
+
+ std::set<string> signatures;
+ std::vector<size_t> permutation;
+ do {
+ BuildPermutation(plain_permutation, countdown, &permutation);
+
+ constexpr bool kDebugPermutation = false;
+ if (kDebugPermutation) {
+ string p;
+ for (int i = 0; i < permutation.size(); ++i) {
+ p.push_back('0' + permutation[i]);
+ }
+ LOG(INFO) << "Permutation: " << p;
+ }
+
+ std::vector<std::unique_ptr<SigNode>> hold(graph_size);
+ int idx;
+
+ // Permute the nodes.
+ sig_.nodes.clear();
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes before permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ hold[idx++] = std::move(entry.second);
+ }
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes after permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ entry.second = std::move(hold[permutation[idx++]]);
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ // This is used to order the links per permutation.
+ sig_.nodes.emplace_back(entry.second.get());
+ RefUniqueRank(entry.second.get()) = idx;
+ }
+ // Order the links with the same tags per permutation.
+ OrderLinks(&sig_);
+
+ // The test as such.
+ ASSERT_THAT(sig_.Compute(), Eq(Status::OK()));
+
+ signatures.insert(sig_.ToString());
+
+ EXPECT_THAT(sig_.sig_full, SizeIs(graph_size));
+ size_t hval = 0;
+ for (size_t ih : sig_.sig_full) {
+ // The space 1..graph_size is reserved.
+ EXPECT_THAT(ih, Gt(graph_size));
+ CombineHash(ih, &hval);
+ }
+ EXPECT_THAT(sig_.sig_short, Eq(hval));
+
+ // Un-permute the nodes for the next iteration.
+ idx = 0;
+ for (auto& entry : sig_.map) {
+ hold[permutation[idx++]] = std::move(entry.second);
+ }
+ idx = 0;
+ if (kDebugPermutation) {
+ LOG(INFO) << " nodes after un-permutation:";
+ }
+ for (auto& entry : sig_.map) {
+ entry.second = std::move(hold[idx++]);
+ if (kDebugPermutation) {
+ LOG(INFO) << " " << entry.second.get();
+ }
+ }
+ } while (CountDown(&countdown));
+
+ for (const auto& s : signatures) {
+ LOG(INFO) << "Signature: " << s;
+ }
+
+ // All the permutations should produce the same signature.
+ EXPECT_THAT(signatures, SizeIs(1));
+ }
+};
+
+TEST_F(SignatureTest, PrepareNodes) {
+ NodeDef node1 = MakeNodeConst("node1");
+ sig_.map["node1"] = absl::make_unique<SigNode>(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ sig_.map["node2"] = absl::make_unique<SigNode>(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ sig_.map["node3"] = absl::make_unique<SigNode>(&node3);
+
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(3));
+
+ int idx = 0;
+ for (const auto& entry : sig_.map) {
+ EXPECT_THAT(RefNodeMask(entry.second.get()), Eq(1 << idx))
+ << " at index " << idx;
+ EXPECT_THAT(RefUniqueRank(entry.second.get()), Eq(static_cast<size_t>(~0)))
+ << " at index " << idx;
+ EXPECT_THAT(RefHashIsFinal(entry.second.get()), false)
+ << " at index " << idx;
+ EXPECT_THAT(RefTopoHash(entry.second.get()), SizeIs(1))
+ << " at index " << idx;
+ ++idx;
+ }
+}
+
+TEST_F(SignatureTest, FindUniqueHashesAllDifferent) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+
+ // The last values in the arrays values go in the backwards order.
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(900);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(800);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(600);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+
+ size_t next = 1; // Skips over sn1.
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(4));
+
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn1));
+ // The nodes after first one get sorted by the high hash.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn2));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ // Nodes that get finalized are marked as such.
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(true));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ ASSERT_THAT(RefTopoHash(&sn2), SizeIs(1));
+ ASSERT_THAT(RefTopoHash(&sn3), SizeIs(1));
+ ASSERT_THAT(RefTopoHash(&sn4), SizeIs(1));
+
+ EXPECT_THAT(RefTopoHash(&sn2)[0], Eq(4));
+ EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(3));
+ EXPECT_THAT(RefTopoHash(&sn4)[0], Eq(2));
+
+ EXPECT_THAT(sig_.sig_full, ElementsAre(600, 700, 800));
+
+ size_t exp_short_hash = 0;
+ CombineHash(600, &exp_short_hash);
+ CombineHash(700, &exp_short_hash);
+ CombineHash(800, &exp_short_hash);
+ EXPECT_THAT(sig_.sig_short, Eq(exp_short_hash));
+}
+
+TEST_F(SignatureTest, FindUniqueHashesDuplicatesExceptOne) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+ NodeDef node5 = MakeNodeConst("node5");
+ SigNode sn5(&node5);
+
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(600);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(600);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(800);
+
+ RefTopoHash(&sn5).emplace_back(500);
+ RefTopoHash(&sn5).emplace_back(800);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+ sig_.nodes.emplace_back(&sn5);
+
+ size_t next = 0;
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(1));
+
+ // The unique node goes first.
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn3));
+
+ // The rest of the nodes are assumed to be sorted in a stable order.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
+ // Node 1 gets swapped with node 3.
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn1));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[4], Eq(&sn5));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(true));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn5), Eq(false));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn3), SizeIs(1));
+ EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn5), SizeIs(2));
+
+ EXPECT_THAT(RefTopoHash(&sn3)[0], Eq(1));
+}
+
+TEST_F(SignatureTest, FindUniqueHashesDuplicates) {
+ NodeDef node1 = MakeNodeConst("node1");
+ SigNode sn1(&node1);
+ NodeDef node2 = MakeNodeConst("node2");
+ SigNode sn2(&node2);
+ NodeDef node3 = MakeNodeConst("node3");
+ SigNode sn3(&node3);
+ NodeDef node4 = MakeNodeConst("node4");
+ SigNode sn4(&node4);
+ NodeDef node5 = MakeNodeConst("node5");
+ SigNode sn5(&node5);
+
+ RefTopoHash(&sn1).emplace_back(100);
+ RefTopoHash(&sn1).emplace_back(600);
+
+ RefTopoHash(&sn2).emplace_back(200);
+ RefTopoHash(&sn2).emplace_back(600);
+
+ RefTopoHash(&sn3).emplace_back(300);
+ RefTopoHash(&sn3).emplace_back(700);
+
+ RefTopoHash(&sn4).emplace_back(400);
+ RefTopoHash(&sn4).emplace_back(700);
+
+ RefTopoHash(&sn5).emplace_back(500);
+ RefTopoHash(&sn5).emplace_back(700);
+
+ sig_.nodes.emplace_back(&sn1);
+ sig_.nodes.emplace_back(&sn2);
+ sig_.nodes.emplace_back(&sn3);
+ sig_.nodes.emplace_back(&sn4);
+ sig_.nodes.emplace_back(&sn5);
+
+ size_t next = 0;
+
+ FindUniqueHashes(&next, &sig_);
+ EXPECT_THAT(next, Eq(1));
+
+ // The last copy of the last duplicate wins.
+ EXPECT_THAT(sig_.nodes[0], Eq(&sn5));
+
+ // The rest of the nodes are assumed to be sorted in a stable order.
+ // Node 1 gets swapped.
+ EXPECT_THAT(sig_.nodes[1], Eq(&sn2));
+ EXPECT_THAT(sig_.nodes[2], Eq(&sn3));
+ EXPECT_THAT(sig_.nodes[3], Eq(&sn4));
+ EXPECT_THAT(sig_.nodes[4], Eq(&sn1));
+
+ EXPECT_THAT(RefHashIsFinal(&sn1), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn2), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn3), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn4), Eq(false));
+ EXPECT_THAT(RefHashIsFinal(&sn5), Eq(true));
+
+ EXPECT_THAT(RefTopoHash(&sn1), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn2), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn3), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn4), SizeIs(2));
+ EXPECT_THAT(RefTopoHash(&sn5), SizeIs(1));
+
+ EXPECT_THAT(RefTopoHash(&sn5)[0], Eq(1));
+}
+
+// On a circular topology.
+TEST_F(SignatureTest, ComputeOneRoundCircular) {
+ BuildSigMap(graph_circular_onedir_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This skips FindUniqueHashes() which would pick one node, so that
+ // all the nodes are equivalent for ComputeOneRound().
+
+ ComputeOneRound(0, &sig_);
+
+ // All the nodes are the same, so the computed hashes will also be the same.
+ size_t hval = GetHighTopoHash(sig_.nodes[0]);
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(GetHighTopoHash(sig_.nodes[i]), Eq(hval)) << " at index " << i;
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ // The sets of hashed nodes go like this:
+ // Step 0: self.
+ // Step 1: self, previous (-1) and next (+1) node.
+ // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
+ // Step 3: still all 5 nodes in the graph
+ EXPECT_THAT(RefTopoHash(sig_.nodes[i]), SizeIs(4)) << " at index " << i;
+ }
+}
+
+// On a linear topology.
+TEST_F(SignatureTest, ComputeOneRoundLinear) {
+ BuildSigMap(graph_linear_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This skips FindUniqueHashes() which would pick one node, so that
+ // all the nodes are equivalent for ComputeOneRound().
+
+ ComputeOneRound(0, &sig_);
+
+ std::vector<size_t> hash_size;
+ for (int i = 0; i < 5; ++i) {
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[i]), Eq(0x1F))
+ << " at index " << i;
+ hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
+ }
+
+ // The sets of hashed nodes for the central node go like this:
+ // Step 0: self.
+ // Step 1: self, previous (-1) and next (+1) node.
+ // Step 2: self, (-1), (-2), (+1), (+2): all 5 nodes in the graph
+ // Step 3: still all 5 nodes in the graph
+ //
+ // The nodes one step closer to the ends require one more step. The end nodes
+ // require one more step yet.
+ std::sort(hash_size.begin(), hash_size.end());
+ EXPECT_THAT(hash_size, ElementsAre(4, 5, 5, 6, 6));
+}
+
+// On a linear topology where the cental node has been already marked as unique
+// (yeah, not a very realistic case but tests the situations when the
+// disconnected subgraphs get created).
+TEST_F(SignatureTest, ComputeOneRoundSplitLinear) {
+ BuildSigMap(graph_linear_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // This test relies on the order of SigNodeMap imposed on sig_.nodes.
+
+ // The middle node gets separated by moving it to the front.
+ std::swap(sig_.nodes[0], sig_.nodes[2]);
+ ASSERT_THAT(RefNodeMask(sig_.nodes[0]), Eq(0x04));
+ ASSERT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
+ ASSERT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
+ RefHashIsFinal(sig_.nodes[0]) = true;
+
+ ComputeOneRound(1, &sig_);
+
+ // These should stay unchanged.
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[0]), Eq(0x04));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[0]), Eq(0x04));
+
+ std::vector<size_t> hash_size;
+ for (int i = 1; i < 5; ++i) {
+ EXPECT_THAT(RefHashIsFinal(sig_.nodes[i]), Eq(true)) << " at index " << i;
+ hash_size.emplace_back(RefTopoHash(sig_.nodes[i]).size());
+ }
+
+ std::sort(hash_size.begin(), hash_size.end());
+ // The end nodes take 4 steps, closer to the center 3 steps.
+ EXPECT_THAT(hash_size, ElementsAre(3, 3, 4, 4));
+
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[1]), Eq(0x07));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[1]), Eq(0x07));
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[2]), Eq(0x07));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[2]), Eq(0x07));
+
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[3]), Eq(0x1C));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[3]), Eq(0x1C));
+ EXPECT_THAT(RefLastHashedNodes(sig_.nodes[4]), Eq(0x1C));
+ EXPECT_THAT(RefNextHashedNodes(sig_.nodes[4]), Eq(0x1C));
+}
+
+TEST_F(SignatureTest, OrderLinks) {
+ gen_map_.clear();
+ sig_.map.clear();
+ Status result = GenNode::BuildGraphInMap(graph_for_link_order_, &gen_map_);
+ ASSERT_THAT(result, Eq(Status::OK()));
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ // Populate the fake signature and assign the ranks in the backwards order.
+ for (auto it = sig_.map.rbegin(); it != sig_.map.rend(); ++it) {
+ auto& entry = *it;
+ RefUniqueRank(entry.second.get()) = sig_.nodes.size();
+ sig_.nodes.emplace_back(entry.second.get());
+ }
+
+ // How it was ordered in the original graph.
+ string before = sig_.ToString();
+ // clang-format off
+ EXPECT_THAT(before, Eq(
+ "0:Mul[i0:o0:5][i0:o0:4][i0:o1:4][i0:o2:3][i0:o2:2][i0:o3:2],"
+ "1:Mul[i0:o0:5][i0:o0:4][i0:o0:3][i0:o0:2],"
+ "2:Const,"
+ "3:Const,"
+ "4:Const,"
+ "5:Const,"
+ ));
+ // clang-format on
+
+ OrderLinks(&sig_);
+
+ string after = sig_.ToString();
+ // clang-format off
+ EXPECT_THAT(after, Eq(
+ "0:Mul[i0:o0:4][i0:o0:5][i0:o1:4][i0:o2:2][i0:o2:3][i0:o3:2],"
+ "1:Mul[i0:o0:2][i0:o0:3][i0:o0:4][i0:o0:5],"
+ "2:Const,"
+ "3:Const,"
+ "4:Const,"
+ "5:Const,"
+ ));
+ // clang-format on
+}
+
+TEST_F(SignatureTest, GraphTooBig) {
+ GraphDef graph;
+ for (int i = 0; i <= Signature::kMaxGraphSize; ++i) {
+ (*graph.add_node()) = MakeNodeConst(absl::StrFormat("node%d", i));
+ }
+
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &gen_map_), Eq(Status::OK()));
+
+ Subgraph::Identity id;
+ for (const auto& entry : gen_map_) {
+ id.insert(entry.second.get());
+ }
+ Subgraph sg(id);
+ sg.ExtractForSignature(&sig_.map);
+
+ ASSERT_THAT(sig_.Compute(),
+ Eq(Status(error::INVALID_ARGUMENT,
+ "A graph of 65 nodes is too big for signature "
+ "computation, the maximal supported node count is "
+ "64.")));
+}
+
+TEST_F(SignatureTest, ToString) {
+ BuildSigMap(graph_circular_onedir_);
+ PrepareNodes(&sig_);
+
+ ASSERT_THAT(sig_.nodes, SizeIs(5));
+
+ // Fake the works by assigning unique ranks as they go in the initial order.
+ for (int i = 0; i < 5; ++i) {
+ RefUniqueRank(sig_.nodes[i]) = i;
+ RefHashIsFinal(sig_.nodes[i]) = true;
+ }
+
+ string result = sig_.ToString();
+
+ // clang-format off
+ ASSERT_THAT(result, Eq(
+ "0:Mul[i0:o0:4][i0:o0:4],"
+ "1:Mul[i0:o0:0][i0:o0:0],"
+ "2:Mul[i0:o0:1][i0:o0:1],"
+ "3:Mul[i0:o0:2][i0:o0:2],"
+ "4:Mul[i0:o0:3][i0:o0:3],"
+ ));
+ // clang-format on
+}
+
+// This is a test of the permutation logic itself.
+TEST_F(SignatureTest, Permutation) {
+ std::vector<size_t> plain_permutation;
+ std::vector<size_t> countdown;
+ InitPermutation(5, &plain_permutation, &countdown);
+
+ std::set<string> results;
+
+ std::vector<size_t> permutation;
+ do {
+ BuildPermutation(plain_permutation, countdown, &permutation);
+ EXPECT_THAT(permutation, SizeIs(5));
+
+ string p;
+ for (int i = 0; i < permutation.size(); ++i) {
+ p.push_back('0' + permutation[i]);
+ }
+ LOG(INFO) << "Permutation: " << p;
+ results.insert(p);
+ } while (CountDown(&countdown));
+
+ EXPECT_THAT(results, SizeIs(5 * 4 * 3 * 2 * 1));
+}
+
+TEST_F(SignatureTest, ComputeCircularOneDir) {
+ TestGraphEveryWay(graph_circular_onedir_);
+}
+
+TEST_F(SignatureTest, ComputeCircularBiDir) {
+ TestGraphEveryWay(graph_circular_bidir_);
+}
+
+TEST_F(SignatureTest, ComputeLinear) { TestGraphEveryWay(graph_linear_); }
+
+TEST_F(SignatureTest, ComputeMultiInput) {
+ TestGraphEveryWay(graph_multi_input_);
+}
+
+TEST_F(SignatureTest, ComputeAllOrNone) {
+ TestGraphEveryWay(graph_all_or_none_);
+}
+
+TEST_F(SignatureTest, ComputeCross) { TestGraphEveryWay(graph_small_cross_); }
+
+TEST_F(SignatureTest, Equals) {
+ // Start with 2 copies of the same graph.
+ GenNodeMap gen_map1;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map1),
+ Eq(Status::OK()));
+
+ Subgraph::Identity id1;
+ id1.insert(gen_map1["node1"].get());
+ id1.insert(gen_map1["node2"].get());
+ Subgraph sg1(id1);
+
+ Signature sig1;
+ sg1.ExtractForSignature(&sig1.map);
+ ASSERT_THAT(sig1.Compute(), Eq(Status::OK()));
+
+ GenNodeMap gen_map2;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph_circular_bidir_, &gen_map2),
+ Eq(Status::OK()));
+
+ Subgraph::Identity id2;
+ id2.insert(gen_map2["node1"].get());
+ id2.insert(gen_map2["node2"].get());
+ Subgraph sg2(id2);
+
+ Signature sig2;
+ sg2.ExtractForSignature(&sig2.map);
+ ASSERT_THAT(sig2.Compute(), Eq(Status::OK()));
+
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Change the short hash.
+ ++sig2.sig_short;
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ --sig2.sig_short;
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Change the full hash.
+ ++sig2.sig_full[0];
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ --sig2.sig_full[0];
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Make the nodes different.
+ std::swap(sig2.nodes[0], sig2.nodes[1]);
+ EXPECT_FALSE(sig1 == sig2);
+
+ // Restore back.
+ std::swap(sig2.nodes[0], sig2.nodes[1]);
+ EXPECT_TRUE(sig1 == sig2);
+
+ // Different number of nodes.
+ sig2.nodes.emplace_back(sig2.nodes[0]);
+ EXPECT_FALSE(sig1 == sig2);
+ EXPECT_FALSE(sig2 == sig1);
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.cc b/tensorflow/core/grappler/graph_analyzer/subgraph.cc
new file mode 100644
index 0000000000..28a91e0f84
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph.cc
@@ -0,0 +1,235 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+
+#include <functional>
+
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+#include "tensorflow/core/grappler/graph_analyzer/hash_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+//=== Subgraph::Identity
+
+Subgraph::Identity::Identity(InitializerList init) {
+ for (auto element : init) {
+ insert(element);
+ }
+}
+
+bool Subgraph::Identity::operator<(const Identity& other) const {
+ // Shorter sets go first.
+ if (this->size() < other.size()) {
+ return true;
+ }
+ if (this->size() > other.size()) {
+ return false;
+ }
+ for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
+ ++lit, ++rit) {
+ if (*lit < *rit) {
+ return true;
+ }
+ if (*lit > *rit) {
+ return false;
+ }
+ }
+ return false; // Equal.
+}
+
+bool Subgraph::Identity::operator==(const Identity& other) const {
+ if (this->size() != other.size()) {
+ return false;
+ }
+ for (auto lit = this->begin(), rit = other.begin(); lit != this->end();
+ ++lit, ++rit) {
+ if (*lit != *rit) {
+ return false;
+ }
+ }
+ return true; // Equal.
+}
+
+size_t Subgraph::Identity::Hash() const {
+ std::hash<const GenNode*> hasher;
+ size_t result = 0;
+ for (auto ptr : *this) {
+ CombineHash(hasher(ptr), &result);
+ }
+ return result;
+}
+
+string Subgraph::Dump() {
+ // TODO(babkin): this is simplified for now.
+ std::vector<string> nodes;
+ for (const auto& n : id_) {
+ if (specific_) {
+ nodes.emplace_back(absl::StrFormat("%s(%s)", n->opcode(), n->name()));
+ } else {
+ nodes.emplace_back(n->opcode());
+ }
+ }
+ std::sort(nodes.begin(), nodes.end());
+
+ return absl::StrFormat("%d: ", collation_count_) + absl::StrJoin(nodes, ", ");
+}
+
+void Subgraph::ExtractForSignature(SigNodeMap* result) {
+ // Mapping of nodes from the original graph to the new one.
+ SigNode::TranslationMap full_to_new;
+
+ for (auto node : id_) {
+ auto newnode_ref = absl::make_unique<SigNode>(node->node_def());
+ auto newnode = newnode_ref.get();
+ (*result)[node->name()] = std::move(newnode_ref);
+ full_to_new[node] = newnode;
+ }
+
+ for (const auto& mapping : full_to_new) {
+ mapping.second->CopyLinks(*mapping.first, full_to_new);
+ }
+}
+
+//=== Subgraph
+
+Subgraph::Subgraph(const Identity& parent_id, GenNode* add_node)
+ : id_(parent_id) {
+ id_.insert(add_node);
+ hash_ = id_.Hash();
+}
+
+//=== SubgraphIterator
+
+SubgraphIterator::SubgraphIterator(const Subgraph::Identity* id)
+ : id_(id), id_it_(id_->begin()) {
+ if (!id_->empty()) {
+ link_map_it_ = (*id_it_)->links().begin();
+ // In case if the node has no links.
+ while (link_map_it_ == (*id_it_)->links().end()) {
+ if (++id_it_ == id_->end()) {
+ return;
+ }
+ link_map_it_ = (*id_it_)->links().begin();
+ }
+ link_idx_ = 0;
+ // The LinkTargetVector should never be empty but just in case safeguard
+ // against that too.
+ PropagateNext();
+ }
+}
+
+bool SubgraphIterator::Next() {
+ if (AtEnd()) {
+ return false;
+ }
+ ++link_idx_;
+ return PropagateNext();
+}
+
+bool SubgraphIterator::NextIfSamePort() {
+ if (AtEnd()) {
+ return false;
+ }
+ if (link_idx_ + 1 < link_map_it_->second.size()) {
+ ++link_idx_;
+ return true;
+ } else {
+ return false;
+ }
+}
+
+void SubgraphIterator::SkipPort() {
+ if (AtEnd()) {
+ return;
+ }
+ link_idx_ = link_map_it_->second.size() - 1;
+}
+
+void SubgraphIterator::SkipNode() {
+ if (AtEnd()) {
+ return;
+ }
+ for (auto next = link_map_it_; next != (*id_it_)->links().end(); ++next) {
+ link_map_it_ = next;
+ }
+ link_idx_ = link_map_it_->second.size() - 1;
+}
+
+bool SubgraphIterator::PropagateNext() {
+ // Loops are used to skip over the empty entries.
+ while (link_idx_ >= link_map_it_->second.size()) {
+ ++link_map_it_;
+ while (link_map_it_ == (*id_it_)->links().end()) {
+ if (++id_it_ == id_->end()) {
+ return false;
+ }
+ link_map_it_ = (*id_it_)->links().begin();
+ }
+ link_idx_ = 0;
+ }
+ return true;
+}
+
+bool SubgraphIterator::operator==(const SubgraphIterator& other) const {
+ if (id_ != other.id_) {
+ return false;
+ }
+ if (id_it_ != other.id_it_) {
+ return false;
+ }
+ // When AtEnd(), the rest of the fields are not valid.
+ if (AtEnd()) {
+ return true;
+ }
+ if (link_map_it_ != other.link_map_it_) {
+ return false;
+ }
+ if (link_idx_ != other.link_idx_) {
+ return false;
+ }
+ return true;
+}
+
+//=== SubgraphPtrSet
+
+Subgraph* SubgraphPtrSet::ExtendParent(const Subgraph::Identity& parent_id,
+ GenNode* node) {
+ if (parent_id.find(node) != parent_id.end()) {
+ // This was another link to the node that is already in the parent.
+ return nullptr;
+ }
+
+ // Constructing an object just to check that an equivalent one is already
+ // present is kind of ugly but storing the references rather than the objects
+ // in the set avoids the need to make the object copyable.
+ auto sg = absl::make_unique<Subgraph>(parent_id, node);
+ if (find(sg) != end()) {
+ // This subgraph was already found by extending from a different path.
+ return nullptr;
+ }
+
+ Subgraph* ptr = sg.get();
+ insert(std::move(sg));
+ return ptr;
+}
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph.h b/tensorflow/core/grappler/graph_analyzer/subgraph.h
new file mode 100644
index 0000000000..4de31d5dfa
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph.h
@@ -0,0 +1,189 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
+
+#include <initializer_list>
+#include <set>
+
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/map_tools.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+
+// The description of a single subgraph for processing.
+class Subgraph {
+ public:
+ // Identity of a single subgraph as a set of nodes.
+ class Identity : public gtl::FlatSet<const GenNode*> {
+ public:
+ using InitializerList = std::initializer_list<GenNode*>;
+
+ Identity() = default;
+ Identity(InitializerList init);
+ bool operator<(const Identity& other) const;
+ bool operator==(const Identity& other) const;
+
+ // Compute the hash.
+ size_t Hash() const;
+ };
+
+ explicit Subgraph(Identity id) : id_(std::move(id)), hash_(id_.Hash()) {}
+
+ // Construct by extending the parent identity with an extra node.
+ Subgraph(const Identity& parent_id, GenNode* add_node);
+
+ Subgraph() = delete;
+ Subgraph(const Subgraph& other) = delete;
+ void operator=(const Subgraph& other) = delete;
+
+ // Order for building sets of subgraphs.
+ bool operator<(const Subgraph& other) const { return this->id_ < other.id_; }
+ // Support for hashed sets.
+ bool operator==(const Subgraph& other) const {
+ return this->id_ == other.id_;
+ }
+ size_t Hash() const { return hash_; }
+
+ // Dump the subgraph information to a string.
+ string Dump();
+
+ // Extract this subgraph into a separate graph representation for signature
+ // building, that includes only the links between the nodes in the subgraph
+ // and drops all the external links. The result map should be clear before the
+ // call.
+ void ExtractForSignature(SigNodeMap* result);
+
+ const Identity& id() const { return id_; }
+ bool specific() const { return specific_; }
+ void SetSpecific(bool value) { specific_ = value; }
+ int32_t collation_count() const { return collation_count_; }
+ void AddCollation(int32_t n = 1) { collation_count_ += n; }
+ void ResetCollation() { collation_count_ = 1; }
+ void MergeCollation(const Subgraph& other) {
+ collation_count_ += other.collation_count_;
+ }
+
+ private:
+ // Identity also serves as the list of nodes. It never changes throughout the
+ // life of subgraph.
+ Identity id_;
+ size_t hash_; // Cached from the identity.
+ // Whether the dump should include the specific names of the nodes. The
+ // non-specific (i.e. generic) subgraphs represent a collation of multiple
+ // subgraphs.
+ bool specific_ = true;
+ // How many collated subgraphs are represented by this subgraph.
+ int32_t collation_count_ = 1;
+};
+
+// Iteration of all links in a subgraph. This is more like Java iterators than
+// the normal C++ iterators. It's simpler this way and there seems to be no
+// major reason to make it a proper C++ iterator.
+class SubgraphIterator {
+ public:
+ // Obviously an iterator is valid only until the original object
+ // gets destroyed.
+ explicit SubgraphIterator(const Subgraph::Identity* id);
+ explicit SubgraphIterator(const Subgraph* sg) : SubgraphIterator(&sg->id()) {}
+
+ // Check whether the built-in iterator is at the end.
+ bool AtEnd() const { return id_it_ == id_->end(); }
+
+ // Get the neighbor at the current iterator.
+ // MUST NOT be called when AtEnd();
+ const GenNode::LinkTarget& GetNeighbor() const {
+ return link_map_it_->second[link_idx_];
+ }
+
+ // Get the node at the current iterator.
+ // MUST NOT be called when AtEnd();
+ const GenNode* GetNode() const { return *id_it_; }
+
+ // Get the port leading to the neighbor at the current iterator.
+ // MUST NOT be called when AtEnd();
+ GenNode::Port GetPort() const { return link_map_it_->first; }
+
+ // Increases the iterator.
+ // Returns true if NOT AtEnd() after increasing the iterator.
+ // Safe to call if already AtEnd().
+ bool Next();
+
+ // If there are more links at the same port, increases the iterator and
+ // returns true. Otherwise leaves the iterator unchanged and returns false.
+ bool NextIfSamePort();
+
+ // Increases the iterator directly to the last position on the current port
+ // (or if already there then doesn't increase). Equivalent to calling
+ // NextIfSamePort() while it returns true, but faster.
+ // Safe to call if already AtEnd().
+ void SkipPort();
+
+ // Increases the iterator directly to the last position on the current node.
+ // Safe to call if already AtEnd().
+ void SkipNode();
+
+ // Returns true if the iterators are exactly the same.
+ bool operator==(const SubgraphIterator& other) const;
+ bool operator!=(const SubgraphIterator& other) const {
+ return !(*this == other);
+ }
+
+ private:
+ // After link_idx_ has been increased, make sure that it points to the
+ // next valid element (or end) by increasing the higher levels of iteration if
+ // needed.
+ // Returns true if NOT AtEnd() after increasing the iterator.
+ // NOT safe to call if already AtEnd().
+ bool PropagateNext();
+
+ // Identity of the subgraph being iterated over.
+ const Subgraph::Identity* id_;
+
+ // The current position, allowing to iterate through the links (see the
+ // reasoning for it in the public section).
+ //
+ // (1) Iterator of the nodes in the subgraph.
+ Subgraph::Identity::const_iterator id_it_;
+ // (2) Iterator in the link map of the node.
+ GenNode::LinkMap::const_iterator link_map_it_;
+ // (3) Index in the vector of the links.
+ int32_t link_idx_;
+};
+
+// A convenient way to store subgraphs: in a set of unique_ptrs. This way the
+// addresses of subgraph objects will stay stable, and the objects themselves
+// won't be copied.
+class SubgraphPtrSet
+ : public std::unordered_set<std::unique_ptr<Subgraph>,
+ HashAtPtr<std::unique_ptr<Subgraph>>,
+ EqAtPtr<std::unique_ptr<Subgraph>>> {
+ public:
+ // Attempts to extend the set by adding a new subgraph that gets created by
+ // adding one node to the parent subgraph. If such a subgraph already exists,
+ // returns nullptr, otherwise returns the pointer to the new subgraph.
+ Subgraph* ExtendParent(const Subgraph::Identity& parent_id, GenNode* node);
+};
+
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_SUBGRAPH_H_
diff --git a/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc
new file mode 100644
index 0000000000..0f90dc8f0d
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/subgraph_test.cc
@@ -0,0 +1,348 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/subgraph.h"
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include "absl/memory/memory.h"
+#include "absl/strings/str_format.h"
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+namespace {
+
+using ::testing::ElementsAre;
+using ::testing::Eq;
+using ::testing::Ne;
+
+TEST(SubgraphTest, Comparison) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ auto gn1 = map["node1"].get();
+ auto gn2 = map["node2"].get();
+ ASSERT_THAT(gn1, Ne(nullptr));
+ ASSERT_THAT(gn2, Ne(nullptr));
+
+ Subgraph::Identity id1;
+ Subgraph::Identity id2;
+
+ id1.insert(gn1);
+ id2.insert(gn2);
+
+ Subgraph sg1(id1);
+ Subgraph sg2(id2);
+
+ EXPECT_TRUE(id1 == sg1.id());
+ EXPECT_TRUE(id2 == sg2.id());
+
+ EXPECT_THAT(sg1 < sg2, Eq(id1 < id2));
+}
+
+TEST(SubgraphTest, EmptyIteration) {
+ NodeDef node1 = MakeNodeConst("node1");
+ auto gn1 = absl::make_unique<GenNode>(&node1);
+ Subgraph::Identity id1;
+ id1.insert(gn1.get());
+ Subgraph sg1(id1);
+ SubgraphIterator sit(&sg1);
+
+ EXPECT_TRUE(sit.AtEnd());
+ EXPECT_FALSE(sit.Next());
+ EXPECT_TRUE(sit.AtEnd());
+
+ SubgraphIterator sit2(&sg1);
+ EXPECT_TRUE(sit == sit2);
+}
+
+TEST(SubgraphTest, Iteration) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ // node3 has 2 incoming data links, 2 outgoing data , 1 control incoming, 1
+ // control outgoing = total of 6
+ {
+ SubgraphIterator sit(&sg);
+ EXPECT_FALSE(sit.AtEnd()); // 1
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 2
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 3
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 4
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 5
+ EXPECT_TRUE(sit.Next());
+ EXPECT_FALSE(sit.AtEnd()); // 6
+ EXPECT_FALSE(sit.Next());
+ EXPECT_TRUE(sit.AtEnd());
+ }
+
+ // Now get the values out. And more equality testing along the way.
+ {
+ SubgraphIterator sit(&sg);
+ SubgraphIterator sit2(&sg);
+ std::vector<string> links;
+ for (; !sit.AtEnd(); sit.Next()) {
+ EXPECT_TRUE(sit == sit2);
+ sit2.Next();
+ EXPECT_FALSE(sit == sit2);
+
+ links.push_back(absl::StrFormat("[%s,%s,%s]", string(sit.GetPort()),
+ sit.GetNeighbor().node->name(),
+ string(sit.GetNeighbor().port)));
+ }
+ EXPECT_TRUE(sit == sit2);
+
+ std::sort(links.begin(), links.end());
+ // clang-format off
+ EXPECT_THAT(links, ElementsAre(
+ "[i0,node1,o0]",
+ "[i1,node2,o0]",
+ "[iC,node3,oC]",
+ "[o0,node2,i1]",
+ "[o1,node2,i0]",
+ "[oC,node3,iC]"
+ ));
+ // clang-format on
+ }
+}
+
+TEST(SubgraphTest, IterationSamePort) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3");
+ (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ int total_links = 0;
+ for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) {
+ ++total_links;
+ }
+
+ // Initialize the port as control, which doesn't occur in this graph.
+ GenNode::Port last_port(false, -1);
+ int steps_total_same_port = 0;
+ int steps_with_same_port = 0;
+ for (SubgraphIterator sit(&sg); !sit.AtEnd(); sit.Next()) {
+ GenNode::Port new_port = sit.GetPort();
+ EXPECT_THAT(last_port.Encoded(), Ne(new_port.Encoded()))
+ << "At step " << steps_total_same_port;
+ last_port = new_port;
+
+ ++steps_total_same_port;
+
+ SubgraphIterator sit2(sit);
+ sit2.SkipPort();
+
+ while (sit.NextIfSamePort()) {
+ new_port = sit.GetPort();
+ EXPECT_THAT(last_port.Encoded(), Eq(new_port.Encoded()))
+ << "At step " << steps_total_same_port;
+ ++steps_total_same_port;
+ ++steps_with_same_port;
+ }
+
+ EXPECT_TRUE(sit == sit2);
+ }
+
+ EXPECT_THAT(steps_total_same_port, Eq(total_links));
+ // There is one 2-way input and one 2-way output.
+ EXPECT_THAT(steps_with_same_port, Eq(2));
+}
+
+TEST(SubgraphTest, IterationSameNode) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3", "node3");
+ (*graph.add_node()) = MakeNodeAddN("node3", "node1", "node2");
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node3"].get());
+ Subgraph sg(id);
+
+ const GenNode* last_node = nullptr;
+ SubgraphIterator sit(&sg);
+ while (!sit.AtEnd()) {
+ const GenNode* new_node = sit.GetNode();
+
+ EXPECT_THAT(new_node, Ne(last_node)) << "At node " << new_node->name();
+
+ SubgraphIterator sit2(sit);
+ sit2.SkipNode();
+
+ ASSERT_FALSE(sit2.AtEnd());
+ EXPECT_THAT(sit2.GetNode(), Eq(new_node))
+ << "At expected node " << new_node->name() << ", got "
+ << sit2.GetNode()->name();
+
+ while (sit != sit2 && !sit.AtEnd()) {
+ sit.Next();
+ }
+
+ ASSERT_FALSE(sit.AtEnd());
+ EXPECT_THAT(sit.GetNode(), Eq(new_node))
+ << "At expected node " << new_node->name() << ", got "
+ << sit2.GetNode()->name();
+
+ sit.Next();
+
+ last_node = new_node;
+ }
+
+ // Check that it doesn't fail if already at end.
+ sit.SkipNode();
+ EXPECT_TRUE(sit.AtEnd());
+}
+
+TEST(SubgraphTest, ExtendSet) {
+ GraphDef graph;
+ // A topology with a loop.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id_empty;
+
+ Subgraph::Identity id3;
+ id3.insert(map["node3"].get());
+
+ Subgraph::Identity id23 = id3;
+ id23.insert(map["node2"].get());
+
+ Subgraph* sg;
+ SubgraphPtrSet set;
+
+ // Extend an empty identity.
+ sg = set.ExtendParent(id_empty, map["node3"].get());
+ EXPECT_THAT(set.size(), Eq(1));
+ ASSERT_THAT(sg, Ne(nullptr));
+ EXPECT_TRUE(sg->id() == id3);
+
+ // Extend with a node that is already in the parent.
+ sg = set.ExtendParent(id3, map["node3"].get());
+ EXPECT_THAT(set.size(), Eq(1));
+ EXPECT_THAT(sg, Eq(nullptr));
+
+ // Extend to a 2-node subgraph.
+ sg = set.ExtendParent(id3, map["node2"].get());
+ EXPECT_THAT(set.size(), Eq(2));
+ ASSERT_THAT(sg, Ne(nullptr));
+ EXPECT_TRUE(sg->id() == id23);
+
+ // The second insert of the same node gets ignored.
+ sg = set.ExtendParent(id3, map["node2"].get());
+ EXPECT_THAT(set.size(), Eq(2));
+ EXPECT_THAT(sg, Eq(nullptr));
+}
+
+TEST(SubgraphTest, ExtractForSignature) {
+ GraphDef graph;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node1");
+ node3->add_input("^node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+
+ GenNodeMap map;
+ ASSERT_THAT(GenNode::BuildGraphInMap(graph, &map), Eq(Status::OK()));
+ ASSERT_THAT(map.find("node1"), Ne(map.end()));
+ ASSERT_THAT(map.find("node2"), Ne(map.end()));
+ ASSERT_THAT(map.find("node3"), Ne(map.end()));
+
+ Subgraph::Identity id;
+ id.insert(map["node1"].get());
+ id.insert(map["node3"].get());
+
+ Subgraph sg(id);
+
+ SigNodeMap map2;
+ sg.ExtractForSignature(&map2);
+ ASSERT_THAT(map2.find("node1"), Ne(map2.end()));
+ ASSERT_THAT(map2.find("node2"), Eq(map2.end()));
+ ASSERT_THAT(map2.find("node3"), Ne(map2.end()));
+
+ // clang-format off
+ EXPECT_THAT(DumpLinkHashMap(map2["node1"]->hash_to_link()), ElementsAre(
+ "oC:iC: node3",
+ "o0:i0: node3"
+ ));
+ EXPECT_THAT(DumpHashedPeerVector(map2["node1"]->hashed_peers()), ElementsAre(
+ "node3",
+ "node3"
+ ));
+ EXPECT_THAT(DumpLinkHashMap(map2["node3"]->hash_to_link()), ElementsAre(
+ "oC:iC: node3",
+ "iC:oC: node1, node3",
+ "i0:o0: node1"
+ ));
+ EXPECT_THAT(DumpHashedPeerVector(map2["node3"]->hashed_peers()), ElementsAre(
+ "node3",
+ "node1",
+ "node3",
+ "node1"
+ ));
+ // clang-format on
+}
+
+} // end namespace
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.cc b/tensorflow/core/grappler/graph_analyzer/test_tools.cc
new file mode 100644
index 0000000000..fc9495bc7d
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/test_tools.cc
@@ -0,0 +1,296 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/graph_analyzer/test_tools.h"
+
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+//=== Helper methods to construct the nodes.
+
+NodeDef MakeNodeConst(const string& name) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op("Const");
+ return n;
+}
+
+NodeDef MakeNode2Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(opcode);
+ n.add_input(arg1);
+ n.add_input(arg2);
+ return n;
+}
+
+NodeDef MakeNode4Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2, const string& arg3,
+ const string& arg4) {
+ NodeDef n;
+ n.set_name(name);
+ n.set_op(opcode);
+ n.add_input(arg1);
+ n.add_input(arg2);
+ n.add_input(arg3);
+ n.add_input(arg4);
+ return n;
+}
+
+// Not really a 2-argument but convenient to construct.
+NodeDef MakeNodeShapeN(const string& name, const string& arg1,
+ const string& arg2) {
+ // This opcode is multi-input but not commutative.
+ return MakeNode2Arg(name, "ShapeN", arg1, arg2);
+}
+
+// Not really a 2-argument but convenient to construct.
+NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
+ const string& arg2) {
+ // The argument is of a list type.
+ return MakeNode2Arg(name, "IdentityN", arg1, arg2);
+}
+
+NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
+ const string& arg2, const string& arg3,
+ const string& arg4) {
+ // This opcode has multiple multi-inputs.
+ return MakeNode4Arg(name, "QuantizedConcat", arg1, arg2, arg3, arg4);
+}
+
+//=== Helper methods for analysing the structures.
+
+std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map) {
+ // This will order the entries first.
+ std::map<string, string> ordered;
+ for (const auto& link : link_map) {
+ string key = string(link.first);
+
+ // Order the other sides too. They may be repeating, so store them
+ // in a multiset.
+ std::multiset<string> others;
+ for (const auto& other : link.second) {
+ others.emplace(
+ absl::StrFormat("%s[%s]", other.node->name(), string(other.port)));
+ }
+ ordered[key] = absl::StrJoin(others, ", ");
+ }
+ // Now dump the result in a predictable order.
+ std::vector<string> result;
+ result.reserve(ordered.size());
+ for (const auto& link : ordered) {
+ result.emplace_back(link.first + ": " + link.second);
+ }
+ return result;
+}
+
+std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map) {
+ // The entries in this map are ordered by hash value which might change
+ // at any point. Re-order them by the link tag.
+ std::map<SigNode::LinkTag, size_t> tags;
+ for (const auto& entry : link_hash_map) {
+ tags[entry.second.tag] = entry.first;
+ }
+
+ std::vector<string> result;
+ for (const auto& id : tags) {
+ // For predictability, the nodes need to be sorted.
+ std::vector<string> nodes;
+ for (const auto& peer : link_hash_map.at(id.second).peers) {
+ nodes.emplace_back(peer->name());
+ }
+ std::sort(nodes.begin(), nodes.end());
+ result.emplace_back(string(id.first.local) + ":" + string(id.first.remote) +
+ ": " + absl::StrJoin(nodes, ", "));
+ }
+ return result;
+}
+
+std::vector<string> DumpHashedPeerVector(
+ const SigNode::HashedPeerVector& hashed_peers) {
+ std::vector<string> result;
+
+ // Each subset of nodes with the same hash has to be sorted by name.
+ // Other than that, the vector is already ordered by full tags.
+ size_t last_hash = 0;
+ // Index, since iterators may get invalidated on append.
+ size_t subset_start = 0;
+
+ for (const auto& entry : hashed_peers) {
+ if (entry.link_hash != last_hash) {
+ std::sort(result.begin() + subset_start, result.end());
+ subset_start = result.size();
+ }
+ result.emplace_back(entry.peer->name());
+ }
+ std::sort(result.begin() + subset_start, result.end());
+
+ return result;
+}
+
+TestGraphs::TestGraphs() {
+ {
+ GraphDef& graph = graph_3n_self_control_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeSub("node2", "node3:1", "node3:0");
+ auto node3 = graph.add_node();
+ *node3 = MakeNodeBroadcastGradientArgs("node3", "node1", "node2");
+ node3->add_input("^node3"); // The control link goes back to self.
+ }
+ {
+ GraphDef& graph = graph_multi_input_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("const1_1");
+ (*graph.add_node()) = MakeNodeConst("const1_2");
+ (*graph.add_node()) = MakeNodeAddN("add1", "const1_1", "const1_2");
+
+ (*graph.add_node()) = MakeNodeConst("const2_1");
+ (*graph.add_node()) = MakeNodeConst("const2_2");
+ (*graph.add_node()) = MakeNodeConst("const2_3");
+
+ auto add2 = graph.add_node();
+ *add2 = MakeNodeAddN("add2", "const2_1", "const2_2");
+ // The 3rd node is connected twice, to 4 links total.
+ add2->add_input("const2_3");
+ add2->add_input("const2_3");
+
+ (*graph.add_node()) = MakeNodeSub("sub", "add1", "add2");
+ }
+ {
+ GraphDef& graph = graph_all_or_none_;
+ // The topology includes a loop and a link to self.
+ (*graph.add_node()) = MakeNodeConst("const1_1");
+ (*graph.add_node()) = MakeNodeConst("const1_2");
+ auto pass1 = graph.add_node();
+ *pass1 = MakeNodeIdentityN("pass1", "const1_1", "const1_2");
+
+ (*graph.add_node()) = MakeNodeConst("const2_1");
+ (*graph.add_node()) = MakeNodeConst("const2_2");
+ (*graph.add_node()) = MakeNodeConst("const2_3");
+
+ auto pass2 = graph.add_node();
+ *pass2 = MakeNodeIdentityN("pass2", "const2_1", "const2_2");
+ // The 3rd node is connected twice, to 4 links total.
+ pass2->add_input("const2_3");
+ pass2->add_input("const2_3");
+
+ // Add the control links, they get handled separately than the normal
+ // links.
+ pass1->add_input("^const2_1");
+ pass1->add_input("^const2_2");
+ pass1->add_input("^const2_3");
+
+ (*graph.add_node()) = MakeNodeSub("sub", "pass1", "pass2");
+ }
+ {
+ GraphDef& graph = graph_circular_onedir_;
+ (*graph.add_node()) = MakeNodeMul("node1", "node5", "node5");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
+ }
+ {
+ GraphDef& graph = graph_circular_bidir_;
+ // The left and right links are intentionally mixed up.
+ (*graph.add_node()) = MakeNodeMul("node1", "node5", "node2");
+ (*graph.add_node()) = MakeNodeMul("node2", "node3", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node4");
+ (*graph.add_node()) = MakeNodeMul("node4", "node5", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node1");
+ }
+ {
+ GraphDef& graph = graph_linear_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeMul("node3", "node2", "node2");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeMul("node5", "node4", "node4");
+ }
+ {
+ GraphDef& graph = graph_cross_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeMul("node2", "node1", "node1");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeMul("node4", "node3", "node3");
+ (*graph.add_node()) = MakeNodeConst("node5");
+ (*graph.add_node()) = MakeNodeMul("node6", "node5", "node5");
+ (*graph.add_node()) = MakeNodeConst("node7");
+ (*graph.add_node()) = MakeNodeMul("node8", "node7", "node7");
+
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node9", "node2", "node4");
+ center->add_input("node6");
+ center->add_input("node8");
+ }
+ {
+ GraphDef& graph = graph_small_cross_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node5", "node1", "node2");
+ center->add_input("node3");
+ center->add_input("node4");
+ }
+ {
+ GraphDef& graph = graph_for_link_order_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+
+ // One group of equivalent links.
+ auto center = graph.add_node();
+ *center = MakeNodeMul("node5", "node1", "node2");
+ center->add_input("node3");
+ center->add_input("node4");
+
+ // Multiple groups, separated by unique links.
+ auto center2 = graph.add_node();
+ *center2 = MakeNodeMul("node6", "node1", "node2");
+ center2->add_input("node2:1");
+ center2->add_input("node3:2");
+ center2->add_input("node4:2");
+ center2->add_input("node4:3");
+ }
+ {
+ GraphDef& graph = graph_sun_;
+ (*graph.add_node()) = MakeNodeConst("node1");
+ (*graph.add_node()) = MakeNodeConst("node2");
+ (*graph.add_node()) = MakeNodeConst("node3");
+ (*graph.add_node()) = MakeNodeConst("node4");
+ (*graph.add_node()) = MakeNodeConst("node5");
+ (*graph.add_node()) = MakeNodeSub("node6", "node1", "node10");
+ (*graph.add_node()) = MakeNodeSub("node7", "node2", "node6");
+ (*graph.add_node()) = MakeNodeSub("node8", "node3", "node7");
+ (*graph.add_node()) = MakeNodeSub("node9", "node4", "node8");
+ (*graph.add_node()) = MakeNodeSub("node10", "node5", "node9");
+ }
+}
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_analyzer/test_tools.h b/tensorflow/core/grappler/graph_analyzer/test_tools.h
new file mode 100644
index 0000000000..98e269d57e
--- /dev/null
+++ b/tensorflow/core/grappler/graph_analyzer/test_tools.h
@@ -0,0 +1,120 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
+#define TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
+
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/graph_analyzer/gen_node.h"
+#include "tensorflow/core/grappler/graph_analyzer/sig_node.h"
+#include "tensorflow/core/grappler/op_types.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_analyzer {
+namespace test {
+
+//=== Helper methods to construct the nodes.
+
+NodeDef MakeNodeConst(const string& name);
+
+NodeDef MakeNode2Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2);
+
+NodeDef MakeNode4Arg(const string& name, const string& opcode,
+ const string& arg1, const string& arg2, const string& arg3,
+ const string& arg4);
+
+inline NodeDef MakeNodeMul(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "Mul", arg1, arg2);
+}
+
+// Not really a 2-argument but convenient to construct.
+inline NodeDef MakeNodeAddN(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "AddN", arg1, arg2);
+}
+
+inline NodeDef MakeNodeSub(const string& name, const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "Sub", arg1, arg2);
+}
+
+// Has 2 honest outputs.
+inline NodeDef MakeNodeBroadcastGradientArgs(const string& name,
+ const string& arg1,
+ const string& arg2) {
+ return MakeNode2Arg(name, "BroadcastGradientArgs", arg1, arg2);
+}
+
+NodeDef MakeNodeShapeN(const string& name, const string& arg1,
+ const string& arg2);
+
+NodeDef MakeNodeIdentityN(const string& name, const string& arg1,
+ const string& arg2);
+
+NodeDef MakeNodeQuantizedConcat(const string& name, const string& arg1,
+ const string& arg2, const string& arg3,
+ const string& arg4);
+
+//=== A container of pre-constructed graphs.
+
+class TestGraphs {
+ public:
+ TestGraphs();
+
+ // Graph with 3 nodes and a control link to self (which is not valid in
+ // reality but adds excitement to the tests).
+ GraphDef graph_3n_self_control_;
+ // Graph that has the multi-input links.
+ GraphDef graph_multi_input_;
+ // Graph that has the all-or-none nodes.
+ GraphDef graph_all_or_none_;
+ // All the nodes are connected in a circle that goes in one direction.
+ GraphDef graph_circular_onedir_;
+ // All the nodes are connected in a circle that goes in both directions.
+ GraphDef graph_circular_bidir_;
+ // The nodes are connected in a line.
+ GraphDef graph_linear_;
+ // The nodes are connected in a cross shape.
+ GraphDef graph_cross_;
+ GraphDef graph_small_cross_;
+ // For testing the ordering of links at the end of signature generation,
+ // a variation of a cross.
+ GraphDef graph_for_link_order_;
+ // Sun-shaped, a ring with "rays".
+ GraphDef graph_sun_;
+};
+
+//=== Helper methods for analysing the structures.
+
+std::vector<string> DumpLinkMap(const GenNode::LinkMap& link_map);
+
+// Also checks for the consistency of hash values.
+std::vector<string> DumpLinkHashMap(const SigNode::LinkHashMap& link_hash_map);
+
+std::vector<string> DumpHashedPeerVector(
+ const SigNode::HashedPeerVector& hashed_peers);
+
+} // end namespace test
+} // end namespace graph_analyzer
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_GRAPH_ANALYZER_TEST_TOOLS_H_
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index a6b6b6f8b2..de0a63fc4e 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -14,11 +14,55 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
+namespace {
+int OpPortIdToArgId(const NodeDef& node,
+ const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
+ int port_id) {
+ for (int arg_id = 0; arg_id < args.size(); ++arg_id) {
+ if (port_id < 0) {
+ return -1;
+ } else if (port_id == 0) {
+ return arg_id;
+ }
+
+ // Default is 1 port per arg.
+ int n = 1;
+
+ const auto& arg = args.Get(arg_id);
+ if (!arg.number_attr().empty()) {
+ n = node.attr().at(arg.number_attr()).i();
+ } else if (!arg.type_list_attr().empty()) {
+ n = node.attr().at(arg.type_list_attr()).list().type_size();
+ }
+
+ if (n < 0) {
+ // This should never happen.
+ DCHECK_GE(n, 0);
+ return -1;
+ } else if (port_id < n) {
+ return arg_id;
+ }
+ port_id -= n;
+ }
+
+ return -1;
+}
+} // end namespace
+
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.output_arg(), port_id);
+}
+
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id) {
+ return OpPortIdToArgId(node, op.input_arg(), port_id);
+}
+
GraphView::GraphView(GraphDef* graph) : graph_(graph) {
for (int i = 0; i < graph_->node_size(); i++) {
auto node = graph_->mutable_node(i);
@@ -39,7 +83,7 @@ void GraphView::AddUniqueNodeOrDie(NodeDef* node) {
void GraphView::AddFanouts(NodeDef* node) {
for (int i = 0; i < node->input_size(); ++i) {
OutputPort fanin;
- string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
+ const string fanin_name = ParseNodeName(node->input(i), &fanin.port_id);
fanin.node = nodes_[fanin_name];
InputPort input;
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index ac260f85a0..09c36a1368 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -20,11 +20,22 @@ limitations under the License.
#include <unordered_set>
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace grappler {
+// Map a node/op's input/output port_id to arg_id.
+//
+// The port_id refers to the n-th tensor of the node, while the arg_id refers to
+// the n-th arg of the op. These two can be different if an op's arg is a list
+// of tensors.
+//
+// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
+int OpOutputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+int OpInputPortIdToArgId(const NodeDef& node, const OpDef& op, int port_id);
+
// A utility class to simplify the traversal of a GraphDef.
class GraphView {
public:
diff --git a/tensorflow/core/grappler/graph_view_test.cc b/tensorflow/core/grappler/graph_view_test.cc
index 958eb921fb..f90e2c8cfc 100644
--- a/tensorflow/core/grappler/graph_view_test.cc
+++ b/tensorflow/core/grappler/graph_view_test.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/cc/ops/parsing_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
@@ -25,6 +26,102 @@ namespace {
class GraphViewTest : public ::testing::Test {};
+TEST_F(GraphViewTest, OpPortIdToArgIdShapeN) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& a_node_def = *graph_view.GetNode("a");
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+
+ const OpDef* a_op_def = nullptr;
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(a_node_def.op(), &a_op_def).ok());
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ // Const has 0 inputs, 1 output.
+ EXPECT_EQ(-1, OpInputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(a_node_def, *a_op_def, 0));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(a_node_def, *a_op_def, 1));
+
+ // ShapeN has N=3 inputs and outputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(b_node_def, *b_op_def, 4));
+}
+
+TEST_F(GraphViewTest, OpPortIdToArgIdSparseSplit) {
+ for (int num_splits : {1, 2}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<int64>(s.WithOpName("a"), 1, {10, 10});
+ ops::SparseSplit b(s.WithOpName("b"), a, a, a, a, num_splits);
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& b_node_def = *graph_view.GetNode("b");
+ const OpDef* b_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(b_node_def.op(), &b_op_def).ok());
+
+ // We have 4 inputs.
+ EXPECT_EQ(0, OpInputPortIdToArgId(b_node_def, *b_op_def, 0));
+ EXPECT_EQ(1, OpInputPortIdToArgId(b_node_def, *b_op_def, 1));
+ EXPECT_EQ(2, OpInputPortIdToArgId(b_node_def, *b_op_def, 2));
+ EXPECT_EQ(3, OpInputPortIdToArgId(b_node_def, *b_op_def, 3));
+ EXPECT_EQ(-1, OpInputPortIdToArgId(b_node_def, *b_op_def, 4));
+
+ for (int port_id = 0; port_id <= num_splits * 3; ++port_id) {
+ int arg_id = -1;
+ if (port_id < num_splits * 3) {
+ arg_id = port_id / num_splits;
+ }
+ EXPECT_EQ(arg_id, OpOutputPortIdToArgId(b_node_def, *b_op_def, port_id));
+ }
+ }
+}
+
+TEST_F(GraphViewTest, ParseSingleExample) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const<string>(s.WithOpName("a"), "", {});
+ Output b = ops::Const<int64>(s.WithOpName("b"), 1, {1, 1});
+ ops::ParseSingleExample c(s.WithOpName("c"), a, {b, b}, 2, {"w", "x"},
+ {"y", "z"}, {DT_INT64, DT_INT64}, {{1}, {1}});
+
+ GraphDef graph_def;
+ TF_CHECK_OK(s.ToGraphDef(&graph_def));
+ GraphView graph_view(&graph_def);
+
+ const NodeDef& c_node_def = *graph_view.GetNode("c");
+
+ const OpDef* c_op_def = nullptr;
+ EXPECT_TRUE(
+ OpRegistry::Global()->LookUpOpDef(c_node_def.op(), &c_op_def).ok());
+
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 0));
+ EXPECT_EQ(0, OpOutputPortIdToArgId(c_node_def, *c_op_def, 1));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 2));
+ EXPECT_EQ(1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 3));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 4));
+ EXPECT_EQ(2, OpOutputPortIdToArgId(c_node_def, *c_op_def, 5));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 6));
+ EXPECT_EQ(3, OpOutputPortIdToArgId(c_node_def, *c_op_def, 7));
+ EXPECT_EQ(-1, OpOutputPortIdToArgId(c_node_def, *c_op_def, 8));
+}
+
TEST_F(GraphViewTest, BasicGraph) {
TrivialTestGraphInputYielder fake_input(4, 2, 2, false, {"/CPU:0", "/GPU:0"});
GrapplerItem item;
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index bbc0fedd22..2c490f3966 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
+ allowed_optimizations = other.allowed_optimizations;
graph.Swap(graph_def);
}
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 939e5fa046..a0748abfe6 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -77,6 +77,15 @@ struct GrapplerItem {
// Return a set of node names that must be preserved. This includes feed and
// fetch nodes, keep_ops, init_ops.
std::unordered_set<string> NodesToPreserve() const;
+
+ // Restrict types of optimizations that are allowed for this GrapplerItem.
+ struct AllowedOptimizations {
+ // Is it allowed to add nodes to the graph that do not have registered
+ // gradient function.
+ bool non_differentiable_rewrites = true;
+ };
+
+ AllowedOptimizations allowed_optimizations;
};
// Return the transitive fanin of a set of terminal nodes.
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 288587ce9b..369046666d 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/framework/variable.pb.h"
@@ -191,9 +192,13 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
const string feed_name = NodeName(feed_node);
new_item->feed.emplace_back(feed_name, Tensor());
}
+ for (const auto& fetch_node : cfg.fetch_nodes) {
+ new_item->fetch.emplace_back(NodeName(fetch_node));
+ }
- // Attempt to detect the fetch node(s).
- if (meta_graph.collection_def().count("train_op") > 0) {
+ // Attempt to detect the fetch node(s) if they were not set explicitly.
+ if (new_item->fetch.empty() &&
+ meta_graph.collection_def().count("train_op") > 0) {
const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
if (nodes.has_node_list()) {
for (const auto& node : nodes.node_list().value()) {
diff --git a/tensorflow/core/grappler/grappler_item_builder.h b/tensorflow/core/grappler/grappler_item_builder.h
index aafd2fdcda..1698587f8c 100644
--- a/tensorflow/core/grappler/grappler_item_builder.h
+++ b/tensorflow/core/grappler/grappler_item_builder.h
@@ -49,6 +49,8 @@ struct ItemConfig {
bool prune_graph = false;
// Override feed nodes list.
std::set<string> feed_nodes;
+ // Override fetch nodes list.
+ std::set<string> fetch_nodes;
};
// Factory method for creating a GrapplerItem from a MetaGraphDef.
diff --git a/tensorflow/core/grappler/grappler_item_builder_test.cc b/tensorflow/core/grappler/grappler_item_builder_test.cc
index 4b90bf3038..d00981f174 100644
--- a/tensorflow/core/grappler/grappler_item_builder_test.cc
+++ b/tensorflow/core/grappler/grappler_item_builder_test.cc
@@ -313,6 +313,29 @@ TEST_F(GrapplerItemBuilderTest, FromGraphWithUnknownDimInSignatureInput) {
EXPECT_EQ(item2->feed[0].second.NumElements(), 1);
}
+TEST_F(GrapplerItemBuilderTest, ExplicitFeedAndFetch) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), 0);
+ auto y = ops::Const(s.WithOpName("y"), 1);
+ auto z = ops::Add(s.WithOpName("z"), x, y);
+
+ MetaGraphDef meta_graph;
+ TF_CHECK_OK(s.ToGraphDef(meta_graph.mutable_graph_def()));
+
+ ItemConfig config;
+ config.feed_nodes.insert("x");
+ config.fetch_nodes.insert("z");
+
+ std::unique_ptr<GrapplerItem> item =
+ GrapplerItemFromMetaGraphDef("0", meta_graph, config);
+ ASSERT_TRUE(item != nullptr);
+
+ EXPECT_EQ(item->feed.size(), 1);
+ EXPECT_EQ(item->fetch.size(), 1);
+ EXPECT_EQ(item->feed[0].first, "x");
+ EXPECT_EQ(item->fetch[0], "z");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/inputs/utils.cc b/tensorflow/core/grappler/inputs/utils.cc
index 5029dff877..def9198a69 100644
--- a/tensorflow/core/grappler/inputs/utils.cc
+++ b/tensorflow/core/grappler/inputs/utils.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/inputs/utils.h"
-#include "tensorflow/core/platform/env.h"
#include <vector>
+#include "tensorflow/core/platform/env.h"
+
namespace tensorflow {
namespace grappler {
@@ -29,12 +30,12 @@ bool FilesExist(const std::set<string>& files) {
return FilesExist(std::vector<string>(files.begin(), files.end()), nullptr);
}
-bool FileExists(const std::string& file, Status* status) {
+bool FileExists(const string& file, Status* status) {
*status = Env::Default()->FileExists(file);
return status->ok();
}
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result) {
Status status;
if (FileExists(graph_def_pbtxt_path, &status)) {
diff --git a/tensorflow/core/grappler/inputs/utils.h b/tensorflow/core/grappler/inputs/utils.h
index 627dd5359f..4b9cb0a9ad 100644
--- a/tensorflow/core/grappler/inputs/utils.h
+++ b/tensorflow/core/grappler/inputs/utils.h
@@ -29,9 +29,9 @@ bool FilesExist(const std::vector<string>& files,
std::vector<Status>* status = nullptr);
bool FilesExist(const std::set<string>& files);
-bool FileExists(const std::string& file, Status* status);
+bool FileExists(const string& file, Status* status);
-Status ReadGraphDefFromFile(const std::string& graph_def_pbtxt_path,
+Status ReadGraphDefFromFile(const string& graph_def_pbtxt_path,
GraphDef* result);
} // end namespace grappler
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 653b088b1d..1b5a215987 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -13,14 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <unordered_set>
-
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
@@ -102,6 +101,18 @@ bool IsConjugateTranspose(const NodeDef& node) {
return node.op() == "ConjugateTranspose";
}
+bool IsControlFlow(const NodeDef& node) {
+ // clang-format off
+ return node.op() == "ControlTrigger" ||
+ node.op() == "Enter" ||
+ node.op() == "Exit" ||
+ node.op() == "LoopCond" ||
+ node.op() == "Merge" ||
+ node.op() == "NextIteration" ||
+ node.op() == "Switch";
+ // clang-format on
+}
+
bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
bool IsConv2DBackpropFilter(const NodeDef& node) {
@@ -135,16 +146,37 @@ bool IsDequeueOp(const NodeDef& node) {
bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
-bool IsElementWiseMonotonic(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_monotonic_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
- "Relu",
- "Relu6",
- "Sigmoid",
- "Sqrt",
- "Tanh",
+// Returns true if node represents a unary elementwise function that is
+// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
+// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
+// e.g. inv.
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
+ static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "Asinh", "Atanh", "Ceil", "Elu", "Erf", "Exp", "Expm1",
+ "Floor", "Log", "Log1p", "Relu", "Relu", "Relu6", "Rint",
+ "Selu", "Sigmoid", "Sign", "Sinh", "Sqrt", "Tanh",
+ }));
+ static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "Inv",
+ "Reciprocal",
+ "Erfc",
+ "Rsqrt",
+ "Neg",
}));
- return element_wise_monotonic_ops->count(node.op()) > 0;
+ if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = true;
+ }
+ return true;
+ } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
+ if (is_non_decreasing) {
+ *is_non_decreasing = false;
+ }
+ return true;
+ }
+ return false;
}
bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
@@ -404,8 +436,44 @@ bool IsSwitch(const NodeDef& node) {
return op == "Switch" || op == "RefSwitch";
}
+bool IsSymbolicGradient(const NodeDef& node) {
+ return node.op() == "SymbolicGradient";
+}
+
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
+bool IsTensorArray(const NodeDef& node) {
+ static const gtl::FlatSet<string>* const kTensorArrayOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
+ "TensorArray",
+ "TensorArrayV2",
+ "TensorArrayV3",
+ "TensorArrayGrad",
+ "TensorArrayGradV2",
+ "TensorArrayGradV3",
+ "TensorArrayGradWithShape",
+ "TensorArrayWrite",
+ "TensorArrayWriteV2",
+ "TensorArrayWriteV3",
+ "TensorArrayRead",
+ "TensorArrayReadV2",
+ "TensorArrayReadV3",
+ "TensorArrayConcat",
+ "TensorArrayConcatV2",
+ "TensorArrayConcatV3",
+ "TensorArraySplit",
+ "TensorArraySplitV2",
+ "TensorArraySplitV3",
+ "TensorArraySize",
+ "TensorArraySizeV2",
+ "TensorArraySizeV3",
+ "TensorArrayClose",
+ "TensorArrayCloseV2",
+ "TensorArrayCloseV3",
+ }));
+ return kTensorArrayOps->count(node.op()) > 0;
+}
+
bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
@@ -470,7 +538,7 @@ bool IsFreeOfSideEffect(const NodeDef& node) {
}
}
// Queue ops modify the queue which is a side effect.
- if (node.op().find("Queue") != std::string::npos) {
+ if (node.op().find("Queue") != string::npos) {
return false;
}
return !ModifiesInputsInPlace(node);
@@ -517,30 +585,29 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
OPDEF_PROPERTY_HELPER(Commutative, commutative)
bool IsInvolution(const NodeDef& node) {
- static const std::unordered_set<string>* involution_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
- "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"}));
- return involution_ops->count(node.op()) > 0;
+ static const gtl::FlatSet<string>* const kInvolutionOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
+ "Neg", "LogicalNot"}));
+ return kInvolutionOps->count(node.op()) > 0;
}
bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>*
- value_and_order_and_shape_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
- "PreventGradient",
- "Print",
- "Snapshot",
- "StopGradient",
- }));
- return value_and_order_and_shape_preserving_ops->count(node.op()) > 0 ||
+ static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "PreventGradient",
+ "Print",
+ "Snapshot",
+ "StopGradient",
+ }));
+ return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
IsIdentity(node);
}
@@ -548,31 +615,31 @@ bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- static const std::unordered_set<string>* value_and_order_preserving_ops =
- CHECK_NOTNULL((new const std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
+ CHECK_NOTNULL((new const gtl::FlatSet<string>{
"ExpandDims",
"Reshape",
"Squeeze",
}));
- return value_and_order_preserving_ops->count(node.op()) > 0 ||
+ return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
bool IsValuePreserving(const NodeDef& node) {
- static const std::unordered_set<string>* value_preserving_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kValuePreservingOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"InvertPermutation",
"Reverse",
"Roll",
"Transpose",
}));
return IsValueAndOrderPreserving(node) ||
- value_preserving_ops->count(node.op()) > 0;
+ kValuePreservingOps->count(node.op()) > 0;
}
bool IsUnaryElementWise(const NodeDef& node) {
- static const std::unordered_set<string>* element_wise_ops =
- CHECK_NOTNULL((new std::unordered_set<string>{
+ static const gtl::FlatSet<string>* const kElementWiseOps =
+ CHECK_NOTNULL((new gtl::FlatSet<string>{
"Abs",
"Acos",
"Acosh",
@@ -621,7 +688,7 @@ bool IsUnaryElementWise(const NodeDef& node) {
"Tan"
"Tanh",
}));
- return element_wise_ops->count(node.op()) > 0 ||
+ return kElementWiseOps->count(node.op()) > 0 ||
IsValueAndOrderAndShapePreserving(node);
}
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 94439265c9..d4e0159e81 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -46,6 +46,7 @@ bool IsConjugateTranspose(const NodeDef& node);
bool IsConcat(const NodeDef& node);
bool IsConcatOffset(const NodeDef& node);
bool IsConstant(const NodeDef& node);
+bool IsControlFlow(const NodeDef& node);
bool IsConv2D(const NodeDef& node);
bool IsConv2DBackpropFilter(const NodeDef& node);
bool IsConv2DBackpropInput(const NodeDef& node);
@@ -55,7 +56,7 @@ bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node);
bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node);
bool IsDequeueOp(const NodeDef& node);
bool IsDiv(const NodeDef& node);
-bool IsElementWiseMonotonic(const NodeDef& node);
+bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing);
bool IsEluGrad(const NodeDef& node);
bool IsEnter(const NodeDef& node);
bool IsEqual(const NodeDef& node);
@@ -149,7 +150,9 @@ bool IsStridedSliceGrad(const NodeDef& node);
bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
+bool IsSymbolicGradient(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
+bool IsTensorArray(const NodeDef& node);
bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
bool IsTruncateDiv(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index caaa5ac8db..c708f84948 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -8,10 +8,6 @@ load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
# Platform specific build config
load(
- "//tensorflow/core:platform/default/build_config.bzl",
- "tf_protos_grappler",
-)
-load(
"//tensorflow/core:platform/default/build_config_root.bzl",
"if_static",
)
@@ -97,7 +93,6 @@ cc_library(
deps = [
":evaluation_utils",
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -107,6 +102,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -116,6 +112,7 @@ tf_cc_test(
shard_count = 5,
deps = [
":constant_folding",
+ ":dependency_optimizer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/core:all_kernels",
@@ -260,7 +257,6 @@ cc_library(
":constant_folding",
":graph_optimizer",
":graph_optimizer_stage",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -269,6 +265,7 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -514,18 +511,21 @@ cc_library(
":custom_graph_optimizer_registry",
":debug_stripper",
":dependency_optimizer",
+ ":experimental_implementation_selector",
":function_optimizer",
":graph_optimizer",
":layout_optimizer",
":loop_optimizer",
":memory_optimizer",
":model_pruner",
+ ":pin_to_host_optimizer",
":remapper",
":scoped_allocator_optimizer",
":shape_optimizer",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
@@ -542,6 +542,7 @@ tf_cuda_cc_test(
":custom_graph_optimizer_registry",
":meta_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
@@ -646,7 +647,6 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_optimizer",
- ":symbolic_shapes",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -656,6 +656,7 @@ cc_library(
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/costs:graph_properties",
"//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
],
)
@@ -713,31 +714,6 @@ tf_cuda_cc_test(
)
cc_library(
- name = "symbolic_shapes",
- srcs = ["symbolic_shapes.cc"],
- hdrs = ["symbolic_shapes.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ] + tf_protos_grappler(),
-)
-
-tf_cc_test(
- name = "symbolic_shapes_test",
- srcs = ["symbolic_shapes_test.cc"],
- deps = [
- ":symbolic_shapes",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- ],
-)
-
-cc_library(
name = "debug_stripper",
srcs = ["debug_stripper.cc"],
hdrs = [
@@ -827,11 +803,6 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/grappler:grappler_item",
- "//tensorflow/core/grappler:op_types",
- "//tensorflow/core/grappler:utils",
- "//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/grappler/costs:graph_properties",
],
)
@@ -850,3 +821,106 @@ tf_cc_test(
"//third_party/eigen3",
],
)
+
+cc_library(
+ name = "function_api_info",
+ srcs = ["function_api_info.cc"],
+ hdrs = ["function_api_info.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ ],
+)
+
+tf_cc_test(
+ name = "function_api_info_test",
+ size = "small",
+ srcs = ["function_api_info_test.cc"],
+ deps = [
+ ":function_api_info",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+cc_library(
+ name = "experimental_implementation_selector",
+ srcs = ["experimental_implementation_selector.cc"],
+ hdrs = ["experimental_implementation_selector.h"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ ],
+)
+
+tf_cc_test(
+ name = "experimental_implementation_selector_test",
+ size = "small",
+ srcs = ["experimental_implementation_selector_test.cc"],
+ deps = [
+ ":custom_graph_optimizer",
+ ":custom_graph_optimizer_registry",
+ ":experimental_implementation_selector",
+ ":function_api_info",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
+
+cc_library(
+ name = "pin_to_host_optimizer",
+ srcs = ["pin_to_host_optimizer.cc"],
+ hdrs = [
+ "pin_to_host_optimizer.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_optimizer",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/costs:graph_properties",
+ "//tensorflow/core/grappler/utils:frame",
+ "//tensorflow/core/grappler/utils:symbolic_shapes",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ ],
+)
+
+tf_cuda_cc_test(
+ name = "pin_to_host_optimizer_test",
+ srcs = ["pin_to_host_optimizer_test.cc"],
+ deps = [
+ ":pin_to_host_optimizer",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/utils:grappler_test",
+ ],
+)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 889445bbd6..7d5014ee0a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -34,8 +35,8 @@ limitations under the License.
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
#include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -275,7 +276,7 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
for (const NodeDef* output : ctx().node_map->GetOutputs(node.name())) {
for (int i = 0; i < output->input_size(); ++i) {
auto input = output->input(i);
- string name = ParseNodeName(input, &position);
+ StringPiece name = ParseNodeNameAsStringPiece(input, &position);
if (name == node.name() && /*control input*/ position < 0) {
return true;
}
@@ -1120,11 +1121,8 @@ class RemoveIdentityTranspose : public ArithmeticOptimizerStage {
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
NodeDef* tail = node;
- // TODO(rmlarsen): Enable after debugging breakage in Bayesflow.
- if (ctx().opt_level == RewriterConfig::AGGRESSIVE) {
- tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
- *ctx().nodes_to_preserve);
- }
+ tail = GetTailOfIdempotentChain(*tail, *ctx().node_map,
+ *ctx().nodes_to_preserve);
NodeDef* first_transpose;
TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &first_transpose));
@@ -1327,38 +1325,26 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const string node_name = node->name();
NodeDef* x;
NodeDef* y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
bool updated = false;
- if (IsAdd(*node)) {
- if (IsNeg(*x)) {
- // (-a) + b = b - a
- node->set_op("Sub");
- node->mutable_input()->SwapElements(0, 1);
- node->set_input(1, x->input(0));
- node->add_input(AsControlDependency(x->name()));
- ctx().node_map->AddOutput(NodeName(x->input(0)), node_name);
- updated = true;
- } else if (IsNeg(*y)) {
- // a + (-b) = a - b
- node->set_op("Sub");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
- } else if (IsSub(*node)) {
- if (IsNeg(*y)) {
- // a - (-b) = a + b
- node->set_op("Add");
- node->set_input(1, y->input(0));
- node->add_input(AsControlDependency(y->name()));
- ctx().node_map->AddOutput(NodeName(y->input(0)), node_name);
- updated = true;
- }
+ if (IsNeg(*y)) {
+ // a - (-b) = a + b or a + (-b) = a - b
+ ForwardControlDependencies(node, {y});
+ ctx().node_map->UpdateInput(node->name(), node->input(1), y->input(0));
+ node->set_op(IsAdd(*node) ? "Sub" : "Add");
+ node->set_input(1, y->input(0));
+ updated = true;
+ } else if (IsAdd(*node) && IsNeg(*x)) {
+ // (-a) + b = b - a
+ ForwardControlDependencies(node, {x});
+ ctx().node_map->UpdateInput(node->name(), node->input(0), x->input(0));
+ node->set_op("Sub");
+ node->mutable_input()->SwapElements(0, 1);
+ node->set_input(1, x->input(0));
+ updated = true;
}
if (updated) {
AddToOptimizationQueue(node);
@@ -1582,7 +1568,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
for (NodeDef* output : outputs) {
if (IsControlInput(output->input(0))) continue;
int port;
- const string node_name = ParseNodeName(output->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(output->input(0), &port);
if (node_name == node.name()) {
tails->insert(ChainLink(output, port));
} else {
@@ -1632,7 +1619,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
} else {
for (NodeDef* new_tail : ctx().node_map->GetOutputs(tail->name())) {
int port;
- const string node_name = ParseNodeName(new_tail->input(0), &port);
+ const StringPiece node_name =
+ ParseNodeNameAsStringPiece(new_tail->input(0), &port);
if (node_name != tail->name()) {
return Status::OK();
}
@@ -2381,26 +2369,24 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
- const auto& p = ctx().graph_properties->GetInputProperties(node->name())[1];
- for (int i = 0; i < p.shape().dim_size(); ++i) {
- if (p.shape().dim(i).size() < 0) {
+ const auto& pow_props =
+ ctx().graph_properties->GetInputProperties(node->name())[1];
+ for (int i = 0; i < pow_props.shape().dim_size(); ++i) {
+ if (pow_props.shape().dim(i).size() < 0) {
// skip if p is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(p.shape()) && p.has_value()) {
- Tensor pow(p.dtype(), p.shape());
- if (!pow.FromProto(p.value())) {
+ if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
+ Tensor pow(pow_props.dtype(), pow_props.shape());
+ if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- p.value().DebugString());
+ pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
- if (!GetElementUnexhaustive(pow, i,
- {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_COMPLEX128},
- &curr)) {
+ if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
@@ -2413,12 +2399,19 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
+ const auto& value_props =
+ ctx().graph_properties->GetInputProperties(node->name())[0];
+ const TensorShapeProto& output_shape =
+ ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(1, 0)) {
+ } else if (curr == complex128(1, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ // Pow could be used to broadcast, so make sure the shapes of the two
+ // arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
@@ -2428,20 +2421,20 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
- } else if (curr == complex128(0, 0)) {
- const auto& b =
- ctx().graph_properties->GetInputProperties(node->name())[0];
- for (int i = 0; i < b.shape().dim_size(); ++i) {
- if (b.shape().dim(i).size() < 0) {
+ } else if (curr == complex128(0, 0) &&
+ ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
+ for (int i = 0; i < value_props.shape().dim_size(); ++i) {
+ if (value_props.shape().dim(i).size() < 0) {
// skip if b is is not fully defined.
return Status::OK();
}
}
- if (TensorShape::IsValid(b.shape()) && b.has_value()) {
- Tensor base(b.dtype(), b.shape());
- if (!base.FromProto(b.value())) {
+ if (TensorShape::IsValid(value_props.shape()) &&
+ value_props.has_value()) {
+ Tensor base(value_props.dtype(), value_props.shape());
+ if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
- b.value().DebugString());
+ value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
@@ -2599,12 +2592,10 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
~ConvertExpm1Stage() override = default;
bool IsSupported(const NodeDef* node) const override {
- if (!IsSub(*node))
- return false;
+ if (!IsSub(*node)) return false;
NodeDef* input;
- if (!GetInputNode(node->input(0), &input).ok())
- return false;
+ if (!GetInputNode(node->input(0), &input).ok()) return false;
return IsExp(*input);
}
@@ -2624,10 +2615,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
return Status::OK();
}
- const auto& t =
- ctx().graph_properties->GetInputProperties(exp->name())[0];
- const auto& c =
- ctx().graph_properties->GetInputProperties(node->name())[1];
+ const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
+ const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
@@ -2702,22 +2691,37 @@ class OptimizeMaxOrMinOfMonotonicStage : public ArithmeticOptimizerStage {
NodeDef* inner_function;
TF_RETURN_IF_ERROR(GetInputNode(reduction_node->input(0), &inner_function));
// Optimize only if:
+ // 0. inner_function is not in the preserve set,
// 1. inner_function's Op is element-wise monotonic
// 2. inner_function's output is not being consumed elsewhere.
- if (IsElementWiseMonotonic(*inner_function) &&
- (NumNonControlOutputs(*inner_function, *ctx().node_map) == 1)) {
+ bool is_non_decreasing = false;
+ if (!IsInPreserveSet(*inner_function) &&
+ IsElementWiseMonotonic(*inner_function, &is_non_decreasing) &&
+ ctx().node_map->GetOutputs(inner_function->name()).size() == 1) {
// Swap the first inputs of the inner function Op & the reduction Op.
NodeDef* inner_input;
TF_RETURN_IF_ERROR(GetInputNode(inner_function->input(0), &inner_input));
- inner_function->set_input(0, reduction_node->name());
- UpdateConsumersAvoidingLoop(inner_function, reduction_node->name());
reduction_node->set_input(0, inner_input->name());
- UpdateConsumersAvoidingLoop(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(reduction_node->name(),
+ inner_function->name(), inner_input->name());
+ inner_function->set_input(0, reduction_node->name());
+ UpdateConsumers(reduction_node, inner_function->name());
+ ctx().node_map->UpdateInput(inner_function->name(), inner_input->name(),
+ reduction_node->name());
+ if (!is_non_decreasing) {
+ // Flip Min<->Max if the function is non-increasing, e.g.
+ // Max(Neg(x)) = Neg(Min(x)).
+ const string opposite = IsMax(*reduction_node) ? "Min" : "Max";
+ reduction_node->set_op(opposite);
+ }
+ AddToOptimizationQueue(reduction_node);
+ AddToOptimizationQueue(inner_function);
+ AddToOptimizationQueue(inner_input);
}
return Status::OK();
}
- void UpdateConsumersAvoidingLoop(NodeDef* node, const string& new_input) {
+ void UpdateConsumers(NodeDef* node, const string& new_input) {
const string& node_name = node->name();
const std::set<NodeDef*> consumers = ctx().node_map->GetOutputs(node_name);
for (NodeDef* consumer : consumers) {
@@ -2927,8 +2931,8 @@ uint64 UniqueNodes::ComputeSignature(const NodeDef& node) const {
for (const auto& input : node.input()) {
int pos;
- string node_name = ParseNodeName(input, &pos);
- h = Hash64CombineUnordered(Hash64(node_name), h);
+ const StringPiece node_name = ParseNodeNameAsStringPiece(input, &pos);
+ h = Hash64CombineUnordered(Hash64(node_name.data(), node_name.size()), h);
h = Hash64CombineUnordered(std::hash<int>()(pos), h);
}
for (const auto& attr : node.attr()) {
@@ -3040,6 +3044,13 @@ void ArithmeticOptimizer::DedupComputations() {
return;
}
std::set<int> duplicates;
+ // Populate feed_inplace_op;
+ std::unordered_set<NodeDef*> feeds_inplace_op;
+ for (int i = 0; i < optimized_graph_->node_size(); ++i) {
+ if (FeedsInPlaceOp(graph_view, optimized_graph_->node(i))) {
+ feeds_inplace_op.insert(optimized_graph_->mutable_node(i));
+ }
+ }
do {
stop = true;
UniqueNodes nodes;
@@ -3048,19 +3059,19 @@ void ArithmeticOptimizer::DedupComputations() {
continue;
}
NodeDef* node = optimized_graph_->mutable_node(i);
- if (!CanDedup(*node)) {
+ if (!CanDedup(*node) ||
+ feeds_inplace_op.find(node) != feeds_inplace_op.end()) {
continue;
}
NodeDef* rep = nodes.FindOrAddRepresentative(node);
if (rep == node) {
continue;
}
- // If either node feeds an inplace op, deduping them may cause data races.
- // For example: If we dedup nodes initializing two independent inplace
- // accumulations, they will write to the same buffer, clobbering each
- // other's results.
- if (FeedsInPlaceOp(graph_view, *rep) ||
- FeedsInPlaceOp(graph_view, *node)) {
+ // If either node or rep feeds an inplace op, deduping them may cause data
+ // races. For example: If we dedup nodes initializing two independent
+ // inplace accumulations, they will write to the same buffer, clobbering
+ // each other's results.
+ if (feeds_inplace_op.find(rep) != feeds_inplace_op.end()) {
continue;
}
VLOG(3) << "Remove duplicated node: node=" << node->name()
@@ -3068,20 +3079,20 @@ void ArithmeticOptimizer::DedupComputations() {
const std::set<NodeDef*>& fanouts = node_map_->GetOutputs(node->name());
for (NodeDef* fanout : fanouts) {
for (int i = 0; i < fanout->input_size(); ++i) {
- string* name = fanout->mutable_input(i);
- int position;
- const string nodename = ParseNodeName(*name, &position);
- if (nodename == node->name()) {
- // Update name in-place.
- if (position > 0) {
- *name = StrCat(rep->name(), ":", position);
- } else if (position == 0) {
- *name = rep->name();
- } else {
- *name = StrCat("^", rep->name());
- }
- node_map_->AddOutput(rep->name(), fanout->name());
+ string* fanout_input = fanout->mutable_input(i);
+ const int position =
+ NodePositionIfSameNode(*fanout_input, node->name());
+ // Update name in-place.
+ if (position < -1) {
+ continue;
+ } else if (position > 0) {
+ *fanout_input = StrCat(rep->name(), ":", position);
+ } else if (position == 0) {
+ *fanout_input = rep->name();
+ } else {
+ *fanout_input = StrCat("^", rep->name());
}
+ node_map_->AddOutput(rep->name(), fanout->name());
}
}
duplicates.insert(i);
@@ -3238,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ // Disable restricted graph rewrites.
+ options_.unary_ops_composition &=
+ item.allowed_optimizations.non_differentiable_rewrites;
+
if (options_.dedup_computations) {
DedupComputations();
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index 551c3652bf..d457eb6d21 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -61,7 +61,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
bool fold_multiply_into_conv = true;
bool fold_transpose_into_matmul = true;
bool hoist_common_factor_out_of_aggregation = true;
- bool hoist_cwise_unary_chains = false;
+ bool hoist_cwise_unary_chains = true;
bool minimize_broadcasts = true;
bool optimize_max_or_min_of_monotonic = true;
bool remove_idempotent = true;
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index 685b5379af..77f3c64c65 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -581,7 +581,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimple) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -625,7 +625,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsSimpleWithControlDep) {
const NodeDef* new_const = node_map.GetNode(optimized_const_name);
ASSERT_NE(new_const, nullptr);
EXPECT_EQ("^x", new_const->input(0));
- EXPECT_EQ(std::string("\0\0\0@", 4),
+ EXPECT_EQ(string("\0\0\0@", 4),
new_const->attr().at("value").tensor().tensor_content());
const NodeDef* new_mul = node_map.GetNode(optimized_mul_name);
@@ -2353,9 +2353,14 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
Output sub_negx_y = ops::Sub(s.WithOpName("Sub_negx_y"), neg_x, y);
Output sub_x_negy = ops::Sub(s.WithOpName("Sub_x_negy"), x, neg_y);
Output sub_negx_negy = ops::Sub(s.WithOpName("Sub_negx_negy"), neg_x, neg_y);
- auto add_all = ops::AddN(s.WithOpName("add_all"),
- {add_x_y, add_negx_y, add_x_negy, add_negx_negy,
- sub_x_y, sub_negx_y, sub_x_negy, sub_negx_negy});
+ Output neg_x_with_dep = ops::Neg(
+ s.WithOpName("Neg_x_with_dep").WithControlDependencies({add_x_y}), x);
+ Output add_negx_with_dep_y =
+ ops::Add(s.WithOpName("Add_negx_with_dep_y"), neg_x_with_dep, y);
+ auto add_all =
+ ops::AddN(s.WithOpName("add_all"),
+ {add_x_y, add_negx_y, add_x_negy, add_negx_negy, sub_x_y,
+ sub_negx_y, sub_x_negy, sub_negx_negy, add_negx_with_dep_y});
GrapplerItem item;
item.fetch = {"add_all"};
@@ -2370,7 +2375,7 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
GraphDef output;
ArithmeticOptimizer optimizer;
EnableOnlyRemoveNegation(&optimizer);
- OptimizeAndPrune(&optimizer, &item, &output);
+ OptimizeTwice(&optimizer, &item, &output);
EXPECT_EQ(item.graph.node_size(), output.node_size());
int found = 0;
@@ -2379,42 +2384,43 @@ TEST_F(ArithmeticOptimizerTest, RemoveNegation) {
if (node.name() == "Add_negx_y") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
} else if (node.name() == "Add_x_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Add_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(3, node.input_size());
- EXPECT_EQ("Neg_y", node.input(0));
- EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_x", node.input(2));
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("Neg_x", node.input(0));
+ EXPECT_EQ("y", node.input(1));
} else if (node.name() == "Sub_x_negy") {
++found;
EXPECT_EQ("Add", node.op());
- EXPECT_EQ(3, node.input_size());
+ EXPECT_EQ(2, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("y", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
} else if (node.name() == "Sub_negx_negy") {
++found;
EXPECT_EQ("Sub", node.op());
- EXPECT_EQ(4, node.input_size());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("x", node.input(1));
+ } else if (node.name() == "Add_negx_with_dep_y") {
+ ++found;
+ EXPECT_EQ("Sub", node.op());
+ EXPECT_EQ(3, node.input_size());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("x", node.input(1));
- EXPECT_EQ("^Neg_y", node.input(2));
- EXPECT_EQ("^Neg_x", node.input(3));
+ EXPECT_EQ("^Add_x_y", node.input(2));
}
}
- EXPECT_EQ(5, found);
+ EXPECT_EQ(6, found);
auto tensors = EvaluateNodes(output, item.fetch, feed);
EXPECT_EQ(1, tensors.size());
@@ -2468,6 +2474,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
auto y_Point5 = ops::Const(s.WithOpName("y_.5"), {-0.5f, -0.5f}, {1, 2});
auto y_1 = ops::Const(s.WithOpName("y_1"), {-1.0f, -1.0f}, {1, 2});
auto y = ops::Const(s.WithOpName("y"), {3.0f, 4.0f}, {1, 2});
+ auto z = ops::Const(s.WithOpName("z"), {42.0f}, {});
+ auto ones = ops::Const(s.WithOpName("ones"), {1.0f, 1.0f, 1.0f}, {1, 3});
+ auto zeros = ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f, 0.0f}, {1, 3});
Output out2 = ops::Pow(s.WithOpName("out2"), x, y2);
Output out1 = ops::Pow(s.WithOpName("out1"), x, y1);
Output outPoint5 = ops::Pow(s.WithOpName("out.5"), x, yPoint5);
@@ -2475,21 +2484,24 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
Output out_Point5 = ops::Pow(s.WithOpName("out_.5"), x, y_Point5);
Output out_1 = ops::Pow(s.WithOpName("out_1"), x, y_1);
Output out = ops::Pow(s.WithOpName("out"), x, y);
+ Output out_bcast1 = ops::Pow(s.WithOpName("out_bcast1"), z, ones);
+ Output out_bcast2 = ops::Pow(s.WithOpName("out_bcast2"), z, zeros);
GrapplerItem item;
- item.fetch = {"out2", "out1", "out.5", "out0", "out_.5", "out_1", "out"};
+ item.fetch = {"out2", "out1", "out.5", "out0", "out_.5",
+ "out_1", "out", "out_bcast1", "out_bcast2"};
TF_CHECK_OK(s.ToGraphDef(&item.graph));
auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
- EXPECT_EQ(7, tensors_expected.size());
+ EXPECT_EQ(9, tensors_expected.size());
GraphDef got;
ArithmeticOptimizer optimizer;
EnableOnlyConvertPow(&optimizer);
OptimizeAndPrune(&optimizer, &item, &got);
auto tensors = EvaluateNodes(got, item.fetch);
- EXPECT_EQ(7, tensors.size());
+ EXPECT_EQ(9, tensors.size());
- for (int i = 0; i < 7; ++i) {
+ for (int i = 0; i < tensors.size(); ++i) {
EXPECT_EQ(tensors[i].NumElements(), tensors_expected[i].NumElements());
test::ExpectTensorNear<float>(tensors[i], tensors_expected[i], 1e-6);
}
@@ -2503,6 +2515,9 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("y_.5", "Const", {}, {}, &want);
AddNode("y_1", "Const", {}, {}, &want);
AddNode("y", "Const", {}, {}, &want);
+ AddNode("z", "Const", {}, {}, &want);
+ AddNode("ones", "Const", {}, {}, &want);
+ AddNode("zeros", "Const", {}, {}, &want);
AddNode("out2", "Square", {"x", AsControlDependency("y2")}, {}, &want);
AddNode("out1", "Identity", {"x", AsControlDependency("y1")}, {}, &want);
AddNode("out.5", "Sqrt", {"x", AsControlDependency("y.5")}, {}, &want);
@@ -2511,6 +2526,8 @@ TEST_F(ArithmeticOptimizerTest, ConvertPow) {
AddNode("out_.5", "Rsqrt", {"x", AsControlDependency("y_.5")}, {}, &want);
AddNode("out_1", "Reciprocal", {"x", AsControlDependency("y_1")}, {}, &want);
AddNode("out", "Pow", {"x", "y"}, {}, &want);
+ AddNode("out_bcast1", "Pow", {"z", "ones"}, {}, &want);
+ AddNode("out_bcast2", "Pow", {"z", "zeros"}, {}, &want);
CompareGraphs(want, got);
}
@@ -3224,6 +3241,72 @@ TEST_F(ArithmeticOptimizerTest, OptimizeMaxOrMinOfMonotonicElementWise) {
EXPECT_EQ(2, required_node_count);
}
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWise_DoNotChangeFetchNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output sqrt = ops::Sqrt(s.WithOpName("sqrt"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), sqrt, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"sqrt", "final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(2, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeTwice(&optimizer, &item, &output);
+
+ // Should be a NoOp since we are not allowed to change the output of fetch
+ // nodes.
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+}
+
+TEST_F(ArithmeticOptimizerTest,
+ OptimizeMaxOrMinOfMonotonicElementWiseNonIncreasing) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output neg = ops::Neg(s.WithOpName("neg"), x);
+ Output reduce_max = ops::Max(s.WithOpName("reduce_max"), neg, {0});
+ Output final_out = ops::Identity(s.WithOpName("final_out"), reduce_max);
+
+ GrapplerItem item;
+ item.fetch = {"final_out"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyOptimizeMaxOrMinOfMonotonic(&optimizer);
+ OptimizeAndPrune(&optimizer, &item, &output);
+ auto tensors = EvaluateNodes(output, item.fetch);
+ EXPECT_EQ(1, tensors.size());
+
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-6);
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+ // Check if the inputs are switched
+ int required_node_count = 0;
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ if (node.name() == "neg") {
+ EXPECT_EQ("Neg", node.op());
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("reduce_max", node.input(0));
+ ++required_node_count;
+ } else if (node.name() == "reduce_max") {
+ EXPECT_EQ("Min", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ ++required_node_count;
+ }
+ }
+ EXPECT_EQ(2, required_node_count);
+}
+
TEST_F(ArithmeticOptimizerTest, UnaryOpsComposition) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index f2ac3a44c0..3d0d95bba7 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -32,8 +32,8 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
@@ -136,6 +136,27 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input;
}
+bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
+ int* axis) {
+ if (node->op() != "ConcatV2" ||
+ properties.GetInputProperties(node->name()).empty()) {
+ return false;
+ }
+ const auto& axis_input = properties.GetInputProperties(node->name()).back();
+ if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
+ return false;
+ }
+
+ Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
+ if (!axis_tensor.FromProto(axis_input.value())) {
+ return false;
+ }
+ *axis = axis_input.dtype() == DT_INT64
+ ? static_cast<int>(axis_tensor.scalar<int64>()())
+ : axis_tensor.scalar<int32>()();
+ return true;
+}
+
} // namespace
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
@@ -416,25 +437,6 @@ Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
}
namespace {
-bool ShapesEqual(const TensorShapeProto& shape1,
- const TensorShapeProto& shape2) {
- if (shape1.unknown_rank() || shape2.unknown_rank()) {
- return false;
- }
- if (shape1.dim_size() != shape2.dim_size()) {
- return false;
- }
- for (int i = 0; i < shape1.dim_size(); ++i) {
- if (shape1.dim(i).size() != shape2.dim(i).size()) {
- return false;
- }
- if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
- return false;
- }
- }
- return true;
-}
-
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -614,28 +616,37 @@ Status ConstantFolding::MaterializeReductionIndices(
// We can't do anything if we don't know the rank of the input.
return Status::OK();
}
- const int rank = input_prop.shape().dim_size();
- if (rank == 0) {
+ const int input_rank = input_prop.shape().dim_size();
+ if (input_rank < 1) {
// Unexpected graph, don't try to change it.
return Status::OK();
}
+ const OpInfo::TensorProperties& reduction_indices_prop = input_props[1];
+ DataType dtype = reduction_indices_prop.dtype();
+ if (dtype != DT_INT32 && dtype != DT_INT64) {
+ return Status::OK();
+ }
+ PartialTensorShape reduction_indices_shape(reduction_indices_prop.shape());
+ const int num_reduction_indices = reduction_indices_shape.num_elements();
+
const std::vector<OpInfo::TensorProperties>& output_props =
properties.GetOutputProperties(node->name());
if (output_props.size() != 1) {
return Status::OK();
}
- const bool keep_dims =
- node->attr().count("keep_dims") && node->attr().at("keep_dims").b();
const OpInfo::TensorProperties& output_prop = output_props[0];
- PartialTensorShape output_shape(output_prop.shape());
- if (output_shape.num_elements() != 1) {
- bool full_reduction = false;
+ const int output_rank =
+ output_prop.shape().unknown_rank() ? -1 : output_prop.shape().dim_size();
+
+ bool full_reduction = output_rank == 0 || num_reduction_indices == input_rank;
+ if (!full_reduction) {
+ // A full reduction will generate a tensor of one of the shapes
+ // [], [1], [1, 1], [1, 1, ...]. Even if we do not know the number of
+ // elements in the output of the reduction, we may deduce it from reshape
+ // nodes following it.
for (const NodeDef* fanout : node_map_->GetOutputs(node->name())) {
- if (!IsReshape(*fanout) && !keep_dims) {
- // Depending on how it's setup, a full reduction will generate a tensor
- // of shape [], [1], [1, 1], [1, 1, ...]. If keep_dims isn't true, we
- // rely on the existence of a reshape node following the reduction to
- // ensure that the fanout is fed a scalar of the right shape.
+ full_reduction = false;
+ if (!IsReshape(*fanout)) {
return Status::OK();
}
const std::vector<OpInfo::TensorProperties>& reshape_props =
@@ -656,20 +667,15 @@ Status ConstantFolding::MaterializeReductionIndices(
}
}
- const OpInfo::TensorProperties& reduction_prop = input_props[1];
- DataType dtype = reduction_prop.dtype();
- if (dtype != DT_INT32 && dtype != DT_INT64) {
- return Status::OK();
- }
- // We know it's a full reduction. We can generate the set of indices to
- // reduce.
+ // We know it's a full reduction. We can generate the full set of indices to
+ // reduce as a constant node.
string const_name = OptimizedNodeName(*node, "-reduction_indices");
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
NodeDef* reduction_indices = graph_->add_node();
- Tensor value(dtype, TensorShape({rank}));
- for (int i = 0; i < rank; ++i) {
+ Tensor value(dtype, TensorShape({input_rank}));
+ for (int i = 0; i < input_rank; ++i) {
if (dtype == DT_INT32) {
value.vec<int32>()(i) = i;
} else {
@@ -678,6 +684,7 @@ Status ConstantFolding::MaterializeReductionIndices(
}
TF_RETURN_IF_ERROR(
CreateNodeDef(const_name, TensorValue(&value), reduction_indices));
+
reduction_indices->set_device(node->device());
string ctrl_dep =
AddControlDependency(node->input(1), graph_, node_map_.get());
@@ -1719,6 +1726,11 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
return Status::OK();
}
+ if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
+ graph_modified_ = true;
+ return Status::OK();
+ }
+
return Status::OK();
}
@@ -2099,7 +2111,8 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node();
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
- const int axis = node->attr().at("axis").i();
+ const int axis =
+ node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
.ok()) {
@@ -2322,7 +2335,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
- const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ const bool y_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
// 1 * y = y or 0 + y = y.
@@ -2352,7 +2366,8 @@ Status ConstantFolding::SimplifyArithmeticOperations(
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
- const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ const bool x_matches_output_shape =
+ ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
@@ -2910,6 +2925,55 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
return false;
}
+bool ConstantFolding::MergeConcat(const GraphProperties& properties,
+ bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node) {
+ // We only optimize for ConcatV2.
+ int axis;
+ if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
+ nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
+ node_map_->GetOutputs(node->name()).size() != 1) {
+ return false;
+ }
+
+ NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
+ int parent_axis;
+ if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
+ return false;
+ }
+
+ const int index = NumNonControlInputs(*node) - 1;
+ auto inputs = parent->input();
+ parent->clear_input();
+ for (int i = 0; i < inputs.size(); ++i) {
+ if (IsSameInput(inputs.Get(i), node->name())) {
+ for (int j = 0; j < node->input_size(); ++j) {
+ if (j < index) {
+ // Input tensors (non axis), add to input list of parent.
+ parent->add_input(node->input(j));
+ node_map_->RemoveOutput(node->input(j), node->name());
+ node_map_->AddOutput(node->input(j), parent->name());
+ }
+ // Skip j == index, which means axis tensor.
+ if (j > index) {
+ // Control Dependencies, push back to inputs so they can be forwarded
+ // to parent.
+ *inputs.Add() = node->input(j);
+ }
+ }
+ } else {
+ parent->add_input(inputs.Get(i));
+ }
+ }
+ node->clear_input();
+ node->set_op("NoOp");
+ node->clear_attr();
+ node_map_->RemoveNode(node->name());
+ (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
+
+ return true;
+}
+
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* optimized_graph) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index b42d5f201e..8593b3e0b8 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -209,6 +209,10 @@ class ConstantFolding : public GraphOptimizer {
// Removes Split or SplitV node if possible.
bool RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
+
+ bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
+ GraphDef* optimized_graph, NodeDef* node);
+
// Points to an externally provided device or to owned_device_;
RewriterConfig::Toggle opt_level_;
DeviceBase* cpu_device_;
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b9765b9292..fab01edfed 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -2030,6 +2030,130 @@ TEST_F(ConstantFoldingTest, TileWithMultipliesBeingOne) {
CompareGraphs(want, got);
}
+TEST_F(ConstantFoldingTest, MergeConcat) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_SameInput) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {4, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {4, 6}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3, Output(c1)}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "in1", "in2", "axis"}, {},
+ &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_ConcatWithConst) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 6}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis = ops::Const(scope.WithOpName("axis"), 0, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis", "Const", {}, {}, &want);
+ AddNode("c2", "ConcatV2", {"in1", "in2", "in3", "axis"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
+TEST_F(ConstantFoldingTest, MergeConcat_AxisMismatch) {
+ tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
+
+ Output in1 = ops::Variable(scope.WithOpName("in1"), {2, 5}, DT_FLOAT);
+ Output in2 = ops::Variable(scope.WithOpName("in2"), {}, DT_FLOAT);
+ Output in3 = ops::Variable(scope.WithOpName("in3"), {4, 6}, DT_FLOAT);
+ Output axis1 = ops::Const(scope.WithOpName("axis1"), 0, {});
+ Output axis2 = ops::Const(scope.WithOpName("axis2"), 1, {});
+
+ ops::Concat c1(scope.WithOpName("c1"), {in1, in2}, axis2);
+ ops::Concat c2(scope.WithOpName("c2"), {Output(c1), in3}, axis1);
+
+ GrapplerItem item;
+ item.fetch = {"c2"};
+ TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef got;
+ Status status = optimizer.Optimize(nullptr, item, &got);
+ TF_EXPECT_OK(status);
+
+ GraphDef want;
+ AddNode("in1", "VariableV2", {}, {}, &want);
+ AddNode("in2", "VariableV2", {}, {}, &want);
+ AddNode("in3", "VariableV2", {}, {}, &want);
+ AddNode("axis1", "Const", {}, {}, &want);
+ AddNode("axis2", "Const", {}, {}, &want);
+ AddNode("c1", "ConcatV2", {"in1", "in2", "axis2"}, {}, &want);
+ AddNode("c2", "ConcatV2", {"c1", "in3", "axis1"}, {}, &want);
+
+ CompareGraphs(want, got);
+}
+
TEST_F(ConstantFoldingTest, PaddingWithZeroSize) {
tensorflow::Scope scope = tensorflow::Scope::NewRootScope();
@@ -2467,58 +2591,100 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output input =
- ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
- ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
- Output indices = ops::Placeholder(s.WithOpName("indices"), DT_INT32);
- Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
- Output size = ops::Const(s.WithOpName("size"), 1, {1});
- Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ for (bool use_reshape : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(PartialTensorShape({-1, -1})));
+ // If use_reshape is false, we need to now the number of indices to apply
+ // the rewrite.
+ Output indices = ops::Placeholder(
+ s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(PartialTensorShape({use_reshape ? -1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+ if (use_reshape) {
+ Output size = ops::Const(s.WithOpName("size"), 1, {1});
+ Output reshape = ops::Reshape(s.WithOpName("reshape"), sum, size);
+ }
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
- item.fetch.push_back("reshape");
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back(use_reshape ? "reshape" : "sum");
- auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
- Tensor indices_t(DT_INT32, TensorShape({2}));
- indices_t.flat<int>()(0) = 0;
- indices_t.flat<int>()(1) = 1;
- auto tensors_expected = EvaluateNodes(
- item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors_expected.size());
+ auto input_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ Tensor indices_t(DT_INT32, TensorShape({2}));
+ indices_t.flat<int>()(0) = 0;
+ indices_t.flat<int>()(1) = 1;
+ auto tensors_expected = EvaluateNodes(
+ item.graph, item.fetch, {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors_expected.size());
- ConstantFolding optimizer(nullptr /* cpu_device */);
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- // Run a second time to make sure the optimization is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
+ // Run a second time to make sure the optimization is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
- int found = 0;
- for (const auto& node : output.node()) {
- if (node.name() == "ConstantFolding/sum-reduction_indices") {
- ++found;
- EXPECT_EQ("Const", node.op());
- EXPECT_EQ("^indices", node.input(0));
- EXPECT_EQ(2, TensorShape(node.attr().at("value").tensor().tensor_shape())
- .num_elements());
- } else if (node.name() == "sum") {
- ++found;
- EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
- } else if (node.name() == "indices") {
- ++found;
+ int found = 0;
+ for (const auto& node : output.node()) {
+ if (node.name() == "ConstantFolding/sum-reduction_indices") {
+ ++found;
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^indices", node.input(0));
+ EXPECT_EQ(2,
+ TensorShape(node.attr().at("value").tensor().tensor_shape())
+ .num_elements());
+ } else if (node.name() == "sum") {
+ ++found;
+ EXPECT_EQ("ConstantFolding/sum-reduction_indices", node.input(1));
+ } else if (node.name() == "indices") {
+ ++found;
+ }
}
+ EXPECT_EQ(3, found);
+
+ auto tensors = EvaluateNodes(output, item.fetch,
+ {{"input", input_t}, {"indices", indices_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
}
- EXPECT_EQ(3, found);
+}
- auto tensors = EvaluateNodes(output, item.fetch,
- {{"input", input_t}, {"indices", indices_t}});
- EXPECT_EQ(1, tensors.size());
- test::ExpectTensorNear<float>(tensors_expected[0], tensors[0], 1e-5);
+TEST_F(ConstantFoldingTest, MaterializeReductionIndices_NotFullReduction) {
+ for (bool input_rank_known : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input =
+ (input_rank_known ? ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape(
+ PartialTensorShape({-1, -1})))
+ : ops::Placeholder(s.WithOpName("input"), DT_FLOAT));
+ Output indices =
+ ops::Placeholder(s.WithOpName("indices"), DT_INT32,
+ ops::Placeholder::Shape(
+ PartialTensorShape({input_rank_known ? 1 : 2})));
+ Output sum = ops::Sum(s.WithOpName("sum"), input, indices);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ item.fetch.push_back("sum");
+
+ // Use aggressive mode to force the shape inference to propagate placeholder
+ // shapes.
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ CompareGraphs(item.graph, output);
+ }
}
TEST_F(ConstantFoldingTest, LargeConstant) {
@@ -2891,37 +3057,48 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
auto stack =
ops::Stack(scope.WithOpName("stack").WithControlDependencies({y}), {x},
ops::Stack::Axis(1));
+ auto stack_no_axis = ops::Stack(scope.WithOpName("stack_no_axis"), {x});
GrapplerItem item;
TF_CHECK_OK(scope.ToGraphDef(&item.graph));
- item.fetch.push_back("stack");
+ item.fetch = {"stack", "stack_no_axis"};
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
- EXPECT_EQ(5, output.node_size());
+ EXPECT_EQ(7, output.node_size());
+ int found = 0;
for (const auto& node : output.node()) {
if (node.name() == "stack") {
- EXPECT_EQ("stack", node.name());
EXPECT_EQ("ExpandDims", node.op());
EXPECT_EQ(3, node.input_size());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("ConstantFolding/stack_const_axis", node.input(1));
EXPECT_EQ("^y", node.input(2));
+ ++found;
+ } else if (node.name() == "stack_no_axis") {
+ EXPECT_EQ("ExpandDims", node.op());
+ EXPECT_EQ(2, node.input_size());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("ConstantFolding/stack_no_axis_const_axis", node.input(1));
+ ++found;
} else if (node.name() == "ConstantFolding/stack_const_axis") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(1, node.input_size());
EXPECT_EQ("^x", node.input(0));
+ ++found;
}
}
+ EXPECT_EQ(found, 3);
- std::vector<string> fetch = {"stack"};
+ std::vector<string> fetch = {"stack", "stack_no_axis"};
auto tensors_expected = EvaluateNodes(item.graph, fetch);
auto tensors = EvaluateNodes(output, fetch);
- EXPECT_EQ(1, tensors_expected.size());
- EXPECT_EQ(1, tensors.size());
+ EXPECT_EQ(2, tensors_expected.size());
+ EXPECT_EQ(2, tensors.size());
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
+ EXPECT_EQ(tensors_expected[1].shape(), tensors[1].shape());
}
// The test does not evalute the optimized and original graphs to check if their
@@ -3047,6 +3224,39 @@ TEST_F(ConstantFoldingTest, TensorArraySize) {
test::ExpectTensorEqual<int32>(tensors_expected[1], tensors_actual[1]);
}
+TEST_F(ConstantFoldingTest, FoldingPreservesDenormalFlushing) {
+ // Multiplying min() with 0.1 gives a denormal without FTZ and zero with FTZ.
+ // Make sure constant folding behaves the same way as TensorFlow.
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ Output a =
+ ops::Const(s.WithOpName("a"), std::numeric_limits<float>::min(), {1});
+ Output b = ops::Const(s.WithOpName("b"), 0.1f, {1});
+ Output c = ops::Mul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch.push_back("c");
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(1, output.node_size());
+
+ const NodeDef& node_d = output.node(0);
+ EXPECT_EQ("c", node_d.name());
+ EXPECT_EQ("Const", node_d.op());
+
+ std::vector<string> fetch = {"c"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch);
+ auto tensors = EvaluateNodes(output, fetch);
+ EXPECT_EQ(1, tensors_expected.size());
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index b8e69787e3..ee7c14e3ab 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -4,36 +4,42 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
cc_library(
- name = "function_rename",
- srcs = ["function_rename.cc"],
+ name = "filter_fusion",
+ srcs = ["filter_fusion.cc"],
hdrs = [
- "function_rename.h",
+ "filter_fusion.h",
],
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":fusion_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
"//tensorflow/core:lib",
- "//tensorflow/core/grappler:graph_view",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
tf_cc_test(
- name = "function_rename_test",
- srcs = ["function_rename_test.cc"],
+ name = "filter_fusion_test",
+ srcs = ["filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
- ":function_rename",
+ ":filter_fusion",
+ ":graph_test_utils",
+ ":graph_utils",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
- ] + tf_protos_all(),
+ ],
)
cc_library(
@@ -45,12 +51,15 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ ":function_utils",
"//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
- "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:functional_ops",
+ "//tensorflow/core/kernels:control_flow_ops",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
"//tensorflow/core:lib_internal",
] + tf_protos_all(),
@@ -61,6 +70,7 @@ tf_cc_test(
srcs = ["fusion_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":function_utils",
":fusion_utils",
":graph_utils",
"//tensorflow/core:framework",
@@ -72,6 +82,41 @@ tf_cc_test(
)
cc_library(
+ name = "function_utils",
+ srcs = ["function_utils.cc"],
+ hdrs = [
+ "function_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "function_utils_test",
+ srcs = ["function_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ],
+)
+
+cc_library(
name = "graph_utils",
srcs = ["graph_utils.cc"],
hdrs = [
@@ -79,11 +124,12 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -93,6 +139,7 @@ tf_cc_test(
visibility = ["//visibility:public"],
deps = [
":graph_utils",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
@@ -100,11 +147,66 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core/kernels:cast_op",
],
)
cc_library(
+ name = "graph_test_utils",
+ testonly = 1,
+ srcs = ["graph_test_utils.cc"],
+ hdrs = [
+ "graph_test_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "hoist_random_uniform",
+ srcs = ["hoist_random_uniform.cc"],
+ hdrs = [
+ "hoist_random_uniform.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "hoist_random_uniform_test",
+ srcs = ["hoist_random_uniform_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":hoist_random_uniform",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ] + tf_protos_all(),
+)
+
+cc_library(
name = "latency_all_edges",
srcs = ["latency_all_edges.cc"],
hdrs = [
@@ -125,6 +227,44 @@ cc_library(
)
cc_library(
+ name = "map_vectorization",
+ srcs = ["map_vectorization.cc"],
+ hdrs = [
+ "map_vectorization.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_vectorization_test",
+ srcs = ["map_vectorization_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":map_vectorization",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ ],
+)
+
+cc_library(
name = "map_and_batch_fusion",
srcs = ["map_and_batch_fusion.cc"],
hdrs = [
@@ -177,7 +317,7 @@ cc_library(
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
- "//tensorflow/core:ptr_util",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -186,6 +326,7 @@ tf_cc_test(
srcs = ["map_and_filter_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_and_filter_fusion",
"//tensorflow/core:framework",
@@ -212,10 +353,10 @@ cc_library(
"//tensorflow/core/grappler:op_types",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
- "//tensorflow/core/kernels:cast_op",
"//tensorflow/core/grappler/utils:topological_sort",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ "//tensorflow/core:lib_internal",
] + tf_protos_all(),
)
@@ -224,6 +365,7 @@ tf_cc_test(
srcs = ["map_fusion_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_test_utils",
":graph_utils",
":map_fusion",
"//tensorflow/core:framework",
@@ -231,6 +373,44 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/kernels:control_flow_ops",
+ ],
+)
+
+cc_library(
+ name = "map_parallelization",
+ srcs = ["map_parallelization.cc"],
+ hdrs = [
+ "map_parallelization.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:op_types",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/clusters:cluster",
+ "//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "map_parallelization_test",
+ srcs = ["map_parallelization_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_test_utils",
+ ":graph_utils",
+ ":map_parallelization",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
],
)
@@ -306,11 +486,14 @@ cc_library(
name = "data",
visibility = ["//visibility:public"],
deps = [
- ":function_rename",
+ ":filter_fusion",
+ ":hoist_random_uniform",
":latency_all_edges",
":map_and_batch_fusion",
":map_and_filter_fusion",
":map_fusion",
+ ":map_parallelization",
+ ":map_vectorization",
":noop_elimination",
":shuffle_and_repeat_fusion",
],
@@ -330,3 +513,49 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
],
)
+
+cc_library(
+ name = "vectorization_utils",
+ srcs = ["vectorization_utils.cc"],
+ hdrs = [
+ "vectorization_utils.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":function_utils",
+ ":graph_utils",
+ "//tensorflow/cc:ops",
+ "@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/grappler:mutable_graph_view",
+ "//tensorflow/core/grappler:utils",
+ "//tensorflow/core/grappler/optimizers/data/vectorization",
+ "//tensorflow/core/grappler/utils:functions",
+ ] + tf_protos_all(),
+)
+
+tf_cc_test(
+ name = "vectorization_utils_test",
+ srcs = ["vectorization_utils_test.cc"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":graph_utils",
+ ":function_utils",
+ ":vectorization_utils",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
+ "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
+ "//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
+ "//tensorflow/tools/graph_transforms:transform_utils",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
new file mode 100644
index 0000000000..1ad495bbad
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.cc
@@ -0,0 +1,136 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeFusedFilterNode(const NodeDef& first_filter_node,
+ const NodeDef& second_filter_node,
+ const FunctionDef& fused_function,
+ MutableGraphView* graph) {
+ NodeDef fused_node;
+ graph_utils::SetUniqueGraphNodeName("fused_filter", graph->GetGraph(),
+ &fused_node);
+
+ fused_node.set_op("FilterDataset");
+ fused_node.add_input(first_filter_node.input(0));
+
+ auto attr = first_filter_node.attr().at("predicate");
+ *attr.mutable_func()->mutable_name() = fused_function.signature().name();
+ (*fused_node.mutable_attr())["predicate"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", first_filter_node, &fused_node);
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, second_filter_node, &fused_node);
+
+ return fused_node;
+}
+
+} // namespace
+
+Status FilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ GraphDef sorted_old_graph = item.graph;
+ TF_RETURN_IF_ERROR(TopologicalSort(&sorted_old_graph));
+ *output = sorted_old_graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ output->library());
+
+ auto get_filter_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "FilterDataset") return &node;
+ return nullptr;
+ };
+
+ auto get_fused_predicate =
+ [&](const NodeDef* first_filter_node,
+ const NodeDef* second_filter_node) -> FunctionDef* {
+ const auto& parent_fun = first_filter_node->attr().at("predicate");
+ const FunctionDef* first_func =
+ function_library.Find(parent_fun.func().name());
+ const auto& fun = second_filter_node->attr().at("predicate");
+ const FunctionDef* second_func = function_library.Find(fun.func().name());
+
+ if (!fusion_utils::HasSameSignature(first_func->signature(),
+ second_func->signature())) {
+ VLOG(1) << "Can't fuse Filters because they have different signature\n";
+ return nullptr;
+ }
+
+ return fusion_utils::FuseFunctions(
+ *first_func, *second_func, "fused_predicate",
+ fusion_utils::SameSignature, fusion_utils::SameInput,
+ fusion_utils::LazyConjunctionOutput, fusion_utils::LazyConjunctionNodes,
+ output->mutable_library());
+ };
+
+ for (const NodeDef& node : sorted_old_graph.node()) {
+ const NodeDef* second_filter_node = get_filter_node(node);
+ if (!second_filter_node) continue;
+
+ const NodeDef* first_filter_node =
+ get_filter_node(*graph_utils::GetInputNode(*second_filter_node, graph));
+ if (!first_filter_node) continue;
+
+ const auto* fused_predicate =
+ get_fused_predicate(first_filter_node, second_filter_node);
+ if (!fused_predicate) continue;
+ const auto* fused_filter_node = graph.AddNode(MakeFusedFilterNode(
+ *first_filter_node, *second_filter_node, *fused_predicate, &graph));
+
+ graph.ReplaceInput(*second_filter_node, *fused_filter_node);
+
+ // TODO(prazek): we should run some optimizations on the fused filter
+ // functions, or make sure that optimization passes run after filter
+ // fusion.
+ TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_predicate));
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(first_filter_node->name());
+ nodes_to_delete.insert(second_filter_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void FilterFusion::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(FilterFusion, "filter_fusion");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion.h b/tensorflow/core/grappler/optimizers/data/filter_fusion.h
new file mode 100644
index 0000000000..91a0364a46
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization fuses filter transformations.
+class FilterFusion : public CustomGraphOptimizer {
+ public:
+ FilterFusion() = default;
+ ~FilterFusion() override = default;
+
+ string name() const override { return "filter_fusion"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FILTER_FUSION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
new file mode 100644
index 0000000000..c8becc5cc0
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/filter_fusion_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/filter_fusion.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+using graph_tests_utils::MakeFilterNode;
+
+TEST(FilterFusionTest, FuseTwoFilterIntoOne) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeFilterNode("filter1", "range"),
+ MakeFilterNode("filter2", "filter1")},
+ // FunctionLib
+ {
+ test::function::IsZero(),
+ });
+
+ FilterFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output));
+}
+
+TEST(FilterFusionTest, FuseThreeNodesIntoOne) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeFilterNode("filter1", "range"), MakeFilterNode("filter2", "filter1"),
+ MakeFilterNode("filter3", "filter2"),
+ NDef("cache", "CacheDataset", {"filter3", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::IsZero(),
+ });
+
+ FilterFusion optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("FilterDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter2", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("filter3", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.cc b/tensorflow/core/grappler/optimizers/data/function_rename.cc
deleted file mode 100644
index 8cf044d1bd..0000000000
--- a/tensorflow/core/grappler/optimizers/data/function_rename.cc
+++ /dev/null
@@ -1,51 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
-
-#include "tensorflow/core/grappler/clusters/cluster.h"
-#include "tensorflow/core/grappler/graph_view.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/op_types.h"
-#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
-#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
-#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/platform/protobuf.h"
-
-namespace tensorflow {
-namespace grappler {
-
-Status FunctionRename::Optimize(Cluster* cluster, const GrapplerItem& item,
- GraphDef* output) {
- *output = item.graph;
- GraphView graph(output);
- int n = output->mutable_library()->function_size();
- for (int i = 0; i < n; ++i) {
- FunctionDef* fn = output->mutable_library()->mutable_function(i);
- fn->mutable_signature()->set_name(fn->signature().name() + "world");
- }
-
- return Status::OK();
-}
-
-void FunctionRename::Feedback(Cluster* cluster, const GrapplerItem& item,
- const GraphDef& optimize_output, double result) {
- // no-op
-}
-
-REGISTER_GRAPH_OPTIMIZER_AS(FunctionRename, "_test_only_function_rename");
-
-} // end namespace grappler
-} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc b/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
deleted file mode 100644
index 56b8a960a7..0000000000
--- a/tensorflow/core/grappler/optimizers/data/function_rename_test.cc
+++ /dev/null
@@ -1,42 +0,0 @@
-/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/grappler/optimizers/data/function_rename.h"
-
-#include "tensorflow/core/framework/function.pb.h"
-#include "tensorflow/core/framework/op_def.pb.h"
-#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/platform/test.h"
-
-namespace tensorflow {
-namespace grappler {
-namespace {
-
-TEST(FunctionRenameTest, RenameFunction) {
- GrapplerItem item;
- GraphDef *graph = &item.graph;
- FunctionDef *fn = graph->mutable_library()->add_function();
- fn->mutable_signature()->set_name("hello");
-
- FunctionRename optimizer;
- GraphDef output;
- TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
- EXPECT_EQ(output.library().function(0).signature().name(), "helloworld");
-}
-
-} // namespace
-} // namespace grappler
-} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.cc b/tensorflow/core/grappler/optimizers/data/function_utils.cc
new file mode 100644
index 0000000000..311df15bc2
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.cc
@@ -0,0 +1,176 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/strings/scanner.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& node_name,
+ const string& output, int position)
+ : node_name(node_name), node_output(output), position(position) {
+ full_str = strings::StrCat(node_name, ":", node_output, ":", position);
+}
+
+FunctionDefTensorDesc::FunctionDefTensorDesc(const string& input) {
+ // Parses node_name:node_output:position string into its components.
+ full_str = input;
+ StringPiece capture;
+ StringPiece remaining;
+
+ // Parse "node_name"
+ if (strings::Scanner(input)
+ .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
+ .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_name = string(capture.data(), capture.size());
+ }
+
+ // Parse "node_output" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .One(strings::Scanner::LETTER)
+ .Any(strings::Scanner::LETTER_DIGIT_UNDERSCORE)
+ .GetResult(&remaining, &capture)) {
+ node_output = string(capture.data(), capture.size());
+ }
+
+ // Parse "position" if it exists
+ if (strings::Scanner(remaining)
+ .OneLiteral(":")
+ .RestartCapture()
+ .Many(strings::Scanner::DIGIT)
+ .GetResult(nullptr, &capture)) {
+ CHECK(strings::safe_strto32(capture, &position));
+ }
+}
+
+// TODO(rachelim): Create a utility class similar to MutableGraphView for
+// FunctionDefs, and use that to manipulate functions. It'll be more
+// performant if we kept mappings of nodes->inputs/outputs, so that we don't
+// have to search over all nodes each time.
+// Note that we're not using GrapplerFunctionItem because it doesn't cover
+// some of our desired uses (eg changing the outputs of a function), and the
+// FunctionDef -> GraphDef conversion isn't really necessary in this case.
+void ReplaceReferences(const string& from, const string& to,
+ FunctionDef* func) {
+ for (NodeDef& n : *func->mutable_node_def()) {
+ std::replace(n.mutable_input()->begin(), n.mutable_input()->end(), from,
+ to);
+ }
+
+ for (auto& p : *func->mutable_ret()) {
+ if (p.second == from) {
+ p.second = to;
+ }
+ }
+}
+
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt) {
+ string name = string(prefix);
+ int id = function->signature().output_arg_size();
+ while (ContainsFunctionOutputWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ auto* output = function->mutable_signature()->mutable_output_arg()->Add();
+ output->set_name(name);
+ output->set_type(dt);
+
+ (*function->mutable_ret())[name] = string(output_tensor_name);
+}
+
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd) {
+ NodeDef* node = fd->add_node_def();
+ if (!name.empty()) {
+ node->set_name(string(name));
+ } else {
+ SetUniqueFunctionNodeName(op, fd, node);
+ }
+ node->set_op(string(op));
+ for (const string& input : inputs) {
+ node->add_input(input);
+ }
+ for (auto attr : attributes) {
+ (*node->mutable_attr())[attr.first] = attr.second;
+ }
+ return node;
+}
+
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionNodeWithName(name, function) != -1;
+}
+
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return FindFunctionNodeWithOp(op, function) != -1;
+}
+
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function) {
+ return FindFunctionOutputWithName(name, function) != -1;
+}
+
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().input_arg());
+}
+
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&name](const OpDef_ArgDef& arg) { return arg.name() == name; },
+ function.signature().output_arg());
+}
+
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&name](const NodeDef& node) { return node.name() == name; },
+ function.node_def());
+}
+
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
+ return graph_utils::GetFirstElementIndexWithPredicate(
+ [&op](const NodeDef& node) { return node.op() == op; },
+ function.node_def());
+}
+
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node) {
+ string name = string(prefix);
+ int id = function->node_def_size();
+ while (ContainsFunctionNodeWithName(name, *function)) {
+ name = strings::StrCat(prefix, "/_", id);
+ ++id;
+ }
+ node->set_name(std::move(name));
+}
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils.h b/tensorflow/core/grappler/optimizers/data/function_utils.h
new file mode 100644
index 0000000000..d4ce824652
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils.h
@@ -0,0 +1,108 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+// This namespace contains utility functions for querying and modifying
+// FunctionDefs.
+
+// Describes a FunctionDef input tensor. In FunctionDefs, input tensor strings
+// have the format node_name:node_output:position (if they derive from nodes),
+// or input_name (if they derive from an argument).
+struct FunctionDefTensorDesc {
+ FunctionDefTensorDesc() = default;
+
+ FunctionDefTensorDesc(const string& node_name, const string& output,
+ int position);
+
+ // Parses node_name:node_output:position string into its components.
+ explicit FunctionDefTensorDesc(const string& input);
+
+ // TODO(rachelim): Add provisions to deal with special formats, like how
+ // GrapplerFunctionItem expands node output range if position is not defined
+ string full_str;
+ string node_name;
+ string node_output;
+ int position = -1;
+};
+
+// Replaces all references to `from` tensor in func's nodes' inputs and retvals
+// to `to` tensor. This is similar to `MutableGraphView::ReplaceInputs`.
+void ReplaceReferences(const string& from, const string& to, FunctionDef* func);
+
+// Adds a function output to the function def, ensuring that the output key
+// is unique, and maps to output_tensor_name in the ret dict.
+void AddFunctionOutputWithUniqueName(StringPiece prefix,
+ StringPiece output_tensor_name,
+ FunctionDef* function, DataType dt);
+
+// Adds a node to a FunctionDef.
+NodeDef* AddNode(StringPiece name, StringPiece op,
+ const std::vector<string>& inputs,
+ const std::vector<std::pair<string, AttrValue>>& attributes,
+ FunctionDef* fd);
+
+// Checks whether the function contains a node with the given name.
+bool ContainsFunctionNodeWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Checks whether the function contains a node with the given op.
+bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Checks whether the function contains an output with the given name.
+bool ContainsFunctionOutputWithName(StringPiece name,
+ const FunctionDef& function);
+
+// Returns the index of the function input with the given name or -1 if the
+// function node does not exist.
+int FindFunctionInputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function output with the given name or -1 if the
+// function node does not exist.
+int FindFunctionOutputWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given name or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
+
+// Returns the index of the function node with the given op or -1 if the
+// function node does not exist.
+int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
+
+// Sets the function node name using the `prefix` as a prefix while guaranteeing
+// the name is unique across the functions nodes.
+void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
+ NodeDef* node);
+
+} // end namespace function_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/function_utils_test.cc b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
new file mode 100644
index 0000000000..3739e20eb1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/function_utils_test.cc
@@ -0,0 +1,164 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace function_utils {
+namespace {
+
+TEST(FunctionDefTensorDesc, Parsing) {
+ FunctionDefTensorDesc f("Cast:y:0");
+ EXPECT_EQ(f.full_str, "Cast:y:0");
+ EXPECT_EQ(f.node_name, "Cast");
+ EXPECT_EQ(f.node_output, "y");
+ EXPECT_EQ(f.position, 0);
+
+ FunctionDefTensorDesc f2("Arg0");
+ EXPECT_EQ(f2.full_str, "Arg0");
+ EXPECT_EQ(f2.node_name, "Arg0");
+ EXPECT_EQ(f2.node_output, "");
+ EXPECT_EQ(f2.position, -1);
+}
+
+TEST(ReplaceReferencesTest, ReplaceReferencesTest) {
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer", {"arg0: int32"}, {"out: int32", "out2: int64"}, {}, {},
+ {{"out", "MapDefun:output:0"}, {"out2", "Cast:y:0"}});
+ NodeDef* derive_node =
+ AddNode("X", "Some_Op", {"MapDefun:output:0"}, {}, &outer);
+ // Check that both the input to "X" and retval of "outer" are replaced.
+ ReplaceReferences("MapDefun:output:0", "arg0", &outer);
+ EXPECT_EQ(outer.ret().at("out"), "arg0");
+ EXPECT_EQ(derive_node->input(0), "arg0");
+}
+
+TEST(FunctionUtilsTest, AddFunctionOutputWithUniqueName) {
+ FunctionDef function = test::function::XTimesTwo();
+ AddFunctionOutputWithUniqueName("y", "two", &function, DT_INT64);
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y/_1", function));
+ EXPECT_EQ(function.ret().at("y/_1"), "two");
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithName(
+ "weird_name_that_should_not_be_there", function));
+ EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
+ function));
+ EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
+}
+
+TEST(FunctionUtilsTest, ContainsFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_TRUE(ContainsFunctionOutputWithName("y", function));
+ EXPECT_FALSE(ContainsFunctionOutputWithName("Add:z:0", function));
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionNodeWithOp) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(
+ FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
+ -1);
+ EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionInputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionInputWithName("x", function), 0);
+ EXPECT_EQ(FindFunctionInputWithName("not_a_name", function), -1);
+}
+
+TEST(FunctionUtilsTest, FindFunctionOutputWithName) {
+ FunctionDef function = test::function::XTimesTwo();
+ EXPECT_EQ(FindFunctionOutputWithName("y", function), 0);
+ EXPECT_EQ(FindFunctionOutputWithName("Add:z:0", function), -1);
+}
+
+TEST(FunctionUtilsTest, SetUniqueFunctionNodeName) {
+ FunctionDef function = test::function::XTimesTwo();
+ NodeDef node;
+ SetUniqueFunctionNodeName("abc", &function, &node);
+ for (const NodeDef& function_node : function.node_def()) {
+ EXPECT_NE(node.name(), function_node.name());
+ }
+ auto* new_node = function.add_node_def();
+ *new_node = node;
+
+ NodeDef other;
+ SetUniqueFunctionNodeName("abc", &function, &other);
+ EXPECT_NE(other.name(), new_node->name());
+}
+
+TEST(FunctionUtilsTest, AddNodeToFunctionDef) {
+ FunctionDef func;
+ const char* op_name = "xxx";
+ AddNode(op_name, op_name, {}, {}, &func);
+
+ const NodeDef& node1 = func.node_def(FindFunctionNodeWithName("xxx", func));
+ EXPECT_EQ(node1.op(), op_name);
+ EXPECT_EQ(node1.input_size(), 0);
+ EXPECT_EQ(node1.attr_size(), 0);
+
+ const std::vector<string> inputs({"input1", "input2"});
+ AddNode("", op_name, inputs, {}, &func);
+ const NodeDef& node2 =
+ func.node_def(FindFunctionNodeWithName("xxx/_2", func));
+ EXPECT_EQ(node2.op(), op_name);
+ EXPECT_EQ(node2.attr_size(), 0);
+ EXPECT_EQ(node2.input_size(), inputs.size());
+ for (size_t i = 0; i < inputs.size(); ++i) {
+ EXPECT_EQ(node2.input(i), inputs[i]);
+ }
+
+ AttrValue a1, a2;
+ a1.set_type(DT_INT32);
+ a2.set_type(DT_INT64);
+ const std::vector<std::pair<string, AttrValue>> attrs(
+ {{"attr1", a1}, {"attr2", a2}});
+ AddNode("", op_name, {}, attrs, &func);
+ const NodeDef& node3 =
+ func.node_def(FindFunctionNodeWithName("xxx/_3", func));
+ EXPECT_EQ(node3.op(), op_name);
+ EXPECT_EQ(node3.input_size(), 0);
+ EXPECT_EQ(node3.attr_size(), attrs.size());
+ for (size_t i = 0; i < attrs.size(); ++i) {
+ EXPECT_EQ(attrs[i].second.type(), node3.attr().at(attrs[i].first).type());
+ }
+}
+
+} // namespace
+} // namespace function_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
index f84f109af6..b3bfee138f 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.cc
@@ -16,12 +16,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/fusion_utils.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/op_def.pb.h"
-
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
@@ -52,6 +53,12 @@ string GetOutputNode(const FunctionDef& function, int output_idx) {
return function.ret().at(ret_output_name);
}
+string& GetMutableOutputNode(FunctionDef* function, int output_idx) {
+ const auto& ret_output_name =
+ function->signature().output_arg(output_idx).name();
+ return function->mutable_ret()->at(ret_output_name);
+}
+
template <typename Iterable>
StringCollection GetNames(const Iterable& iterable, int allocate_size) {
StringCollection names;
@@ -106,7 +113,6 @@ gtl::FlatMap<string, string> GetUniqueNames(const Iterable& first_iterable,
// Nodes that will be added to the function can have the same name as the nodes
// from parent function.
void RenameFunctionNodes(const FunctionDef& first_function,
- FunctionDef* fused_function,
protobuf::RepeatedPtrField<NodeDef>* nodes_to_fuse,
protobuf::Map<string, string>* rets_to_fuse) {
const gtl::FlatMap<string, string> changed_node_names =
@@ -149,6 +155,7 @@ OpDef GetUniqueSignature(const OpDef& first_signature,
const gtl::FlatMap<string, string> changed_input_names =
GetUniqueNames(first_signature.input_arg(), second_signature.input_arg());
OpDef signature;
+ signature.set_name(second_signature.name());
for (const auto& input_arg : second_signature.input_arg()) {
auto& input = *signature.add_input_arg();
@@ -221,12 +228,13 @@ void FuseFunctionNodes(const StringCollection& first_inputs,
}
// This function looks for direct edges from input to return and rewrites
-// them to the coresponding input of the return of `first_function`.
+// them to the corresponding input of the return of `first_function`.
void FuseReturns(const StringCollection& first_inputs,
const StringCollection& second_inputs,
const StringCollection& first_outputs,
- const SetInputFn& set_input, FunctionDef* fused_function) {
- for (auto& ret : *fused_function->mutable_ret()) {
+ const SetInputFn& set_input,
+ protobuf::Map<string, string>* fused_ret) {
+ for (auto& ret : *fused_ret) {
auto return_input = ParseNodeConnection(ret.second);
auto input_it =
std::find(second_inputs.begin(), second_inputs.end(), return_input);
@@ -249,6 +257,33 @@ StringCollection GetFunctionOutputs(const FunctionDef& function) {
return outputs;
}
+FunctionDef* CreateFalsePredicate(
+ const protobuf::RepeatedPtrField<OpDef_ArgDef>& fake_args,
+ FunctionDefLibrary* library) {
+ GraphDef graph;
+ MutableGraphView graph_view(&graph);
+ auto* node = graph_utils::AddScalarConstNode(false, &graph_view);
+ auto* false_predicate = library->add_function();
+ graph_utils::SetUniqueGraphFunctionName("false_predicate", library,
+ false_predicate);
+
+ int num = 0;
+ for (const auto& fake_arg : fake_args) {
+ auto* arg = false_predicate->mutable_signature()->add_input_arg();
+ arg->set_type(fake_arg.type());
+ arg->set_name(strings::StrCat("fake_arg", num));
+ num++;
+ }
+
+ auto* output = false_predicate->mutable_signature()->add_output_arg();
+ output->set_name("false_out");
+ output->set_type(DT_BOOL);
+
+ (*false_predicate->mutable_ret())["false_out"] = node->name() + ":output:0";
+ *false_predicate->mutable_node_def() = std::move(*graph.mutable_node());
+ return false_predicate;
+}
+
void CheckIfCanCompose(const OpDef& first_signature,
const OpDef& second_signature) {
CHECK(CanCompose(first_signature, second_signature))
@@ -259,6 +294,15 @@ void CheckIfCanCompose(const OpDef& first_signature,
} // namespace
+void MergeNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function, FunctionDef* fused_function,
+ FunctionDefLibrary* library) {
+ // Copy all nodes from first_function.
+ fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+ // Copy transformed nodes from the second function.
+ fused_function->mutable_node_def()->MergeFrom(second_function.node_def());
+}
+
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature) {
// TODO(prazek): Functions can have additional inputs being placeholders
// for a values used in function. We should be able to also fuse these
@@ -285,8 +329,8 @@ void ComposeSignature(const OpDef& first_signature,
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function) {
- *fused_function->mutable_ret() = second_ret;
+ protobuf::Map<string, string>* fused_ret) {
+ *fused_ret = second_ret;
}
void CombineSignature(const OpDef& first_signature,
@@ -302,41 +346,110 @@ void CombineSignature(const OpDef& first_signature,
void CombineOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function) {
- *fused_function->mutable_ret() = first_ret;
- fused_function->mutable_ret()->insert(second_ret.begin(), second_ret.end());
+ protobuf::Map<string, string>* fused_ret) {
+ *fused_ret = first_ret;
+ fused_ret->insert(second_ret.begin(), second_ret.end());
+}
+
+string SameInput(const StringCollection& first_inputs,
+ const StringCollection& second_inputs,
+ const StringCollection& first_outputs, int arg_num) {
+ return first_inputs.at(arg_num);
+}
+
+bool HasSameSignature(const OpDef& first_signature,
+ const OpDef& second_signature) {
+ return first_signature.input_arg_size() ==
+ second_signature.input_arg_size() &&
+ first_signature.output_arg_size() ==
+ second_signature.output_arg_size();
+}
+
+void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
+ OpDef* fused_signature) {
+ CHECK(HasSameSignature(first_signature, second_signature))
+ << "Functions do not have the same signature";
+ // Copy signature from first function.
+ *fused_signature = first_signature;
+}
+
+void LazyConjunctionNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function,
+ FunctionDef* fused_function,
+ FunctionDefLibrary* library) {
+ fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+
+ NodeDefBuilder if_builder("", "If");
+ if_builder.Input(GetOutputNode(first_function, 0), 0, DT_BOOL);
+ DataTypeVector in_arg_types;
+ std::vector<NodeDefBuilder::NodeOut> inputs;
+ for (const auto& input_arg : first_function.signature().input_arg()) {
+ inputs.push_back({input_arg.name(), 0, input_arg.type()});
+ in_arg_types.push_back(input_arg.type());
+ }
+ if_builder.Attr("Tin", in_arg_types);
+
+ if_builder.Attr("Tcond", DT_BOOL);
+ if_builder.Attr("Tout", DataTypeVector{DT_BOOL});
+ if_builder.Attr("_lower_using_switch_merge", true);
+
+ NameAttrList then_branch;
+ then_branch.set_name(second_function.signature().name());
+ if_builder.Attr("then_branch", then_branch);
+
+ auto* false_predicate =
+ CreateFalsePredicate(first_function.signature().input_arg(), library);
+
+ NameAttrList else_branch;
+ else_branch.set_name(false_predicate->signature().name());
+ if_builder.Attr("else_branch", else_branch);
+ if_builder.Input(inputs);
+
+ auto* if_node = fused_function->add_node_def();
+ // This is guaranteed to succeed.
+ TF_CHECK_OK(if_builder.Finalize(if_node));
+ function_utils::SetUniqueFunctionNodeName("cond", fused_function, if_node);
+
+ GetMutableOutputNode(fused_function, 0) = if_node->name() + ":output:0";
+}
+
+void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
+ const protobuf::Map<string, string>& second_ret,
+ protobuf::Map<string, string>* fused_ret) {
+ CHECK_EQ(first_ret.size(), 1);
+ CHECK_EQ(second_ret.size(), 1);
+ // Temporarily copy returns from first_ret. We are going to change the
+ // output node after creating it.
+ *fused_ret = first_ret;
}
-FunctionDef* FuseFunctions(const FunctionDef& first_function,
- const FunctionDef& function,
- StringPiece fused_name_prefix,
- const SetFunctionSignatureFn& set_signature,
- const SetInputFn& set_input,
- const SetOutputFn& set_output,
- FunctionDefLibrary* library) {
- if (first_function.attr_size() != 0 || function.attr_size() != 0)
+FunctionDef* FuseFunctions(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
+ const SetInputFn& set_input, const SetOutputFn& set_output,
+ const SetNodesFn& set_nodes, FunctionDefLibrary* library) {
+ if (first_function.attr_size() != 0 || second_function.attr_size() != 0)
return nullptr; // Functions with attributes are currently not supported
// This function will be used as a clone of second function, having unique
// names.
- FunctionDef setup_function = function;
+ FunctionDef setup_function = second_function;
*setup_function.mutable_signature() = GetUniqueSignature(
first_function.signature(), setup_function.signature(),
setup_function.mutable_ret(), setup_function.mutable_node_def());
FunctionDef* fused_function = library->add_function();
- // Copy all nodes from first_function.
- fused_function->mutable_node_def()->CopyFrom(first_function.node_def());
+
set_signature(first_function.signature(), setup_function.signature(),
fused_function->mutable_signature());
graph_utils::SetUniqueGraphFunctionName(fused_name_prefix, library,
fused_function);
- RenameFunctionNodes(first_function, fused_function,
- setup_function.mutable_node_def(),
+ RenameFunctionNodes(first_function, setup_function.mutable_node_def(),
setup_function.mutable_ret());
- set_output(first_function.ret(), setup_function.ret(), fused_function);
+ set_output(first_function.ret(), setup_function.ret(),
+ fused_function->mutable_ret());
CHECK(fused_function->signature().output_arg_size() ==
fused_function->ret_size())
@@ -351,10 +464,10 @@ FunctionDef* FuseFunctions(const FunctionDef& first_function,
FuseFunctionNodes(first_inputs, second_inputs, first_outputs, set_input,
setup_function.mutable_node_def());
FuseReturns(first_inputs, second_inputs, first_outputs, set_input,
- fused_function);
+ fused_function->mutable_ret());
+
+ set_nodes(first_function, setup_function, fused_function, library);
- // Copy transformed nodes from the second function.
- fused_function->mutable_node_def()->MergeFrom(setup_function.node_def());
return fused_function;
}
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils.h b/tensorflow/core/grappler/optimizers/data/fusion_utils.h
index 41f13f6cb8..19b7002dcd 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils.h
@@ -48,14 +48,20 @@ using SetInputFn =
const StringCollection& second_function_inputs,
const StringCollection& parent_outputs, int arg_num)>;
-// This function is invoked with first function ret. It is used to set up
-// returns of fused function. If you need to combine outputs
-// of first and second function, then this is a right place to create a new
-// nodes.
+// This function is invoked with first and second function ret. It is used to
+// set up returns of fused function.
using SetOutputFn =
std::function<void(const protobuf::Map<string, string>& parent_ret,
const protobuf::Map<string, string>& second_function_ret,
- FunctionDef* fused_function)>;
+ protobuf::Map<string, string>* fused_ret)>;
+
+using SetNodesFn = std::function<void(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ FunctionDef* fused_function, FunctionDefLibrary* library)>;
+
+void MergeNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function, FunctionDef* fused_function,
+ FunctionDefLibrary* library);
// Returns true if functions can be composed.
bool CanCompose(const OpDef& first_signature, const OpDef& second_signature);
@@ -71,7 +77,7 @@ string ComposeInput(const StringCollection& first_inputs,
// second_function(first_function(args...)).
void ComposeOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function);
+ protobuf::Map<string, string>* fused_ret);
// Set input signature to `first_function_signature` and output signature
// to `first_function_signature` + `second_function_signature`
@@ -83,7 +89,32 @@ void CombineSignature(const OpDef& first_signature,
// return *first_function(...), *second_function(...)
void CombineOutput(const protobuf::Map<string, string>& first_ret,
const protobuf::Map<string, string>& second_ret,
- FunctionDef* fused_function);
+ protobuf::Map<string, string>* fused_ret);
+
+// Returns true if both signatures have the same number of input and output
+// args.
+bool HasSameSignature(const OpDef& first_signature,
+ const OpDef& second_signature);
+
+// Check if both signatures are same and copy it from `first_signature`.
+void SameSignature(const OpDef& first_signature, const OpDef& second_signature,
+ OpDef* fused_signature);
+
+// Take the same input as first function.
+string SameInput(const StringCollection& first_inputs,
+ const StringCollection& second_inputs,
+ const StringCollection& first_outputs, int arg_num);
+
+// Create a fused function that computes the short-circuit logical AND of the
+// result of the first function and the result of the second function.
+void LazyConjunctionOutput(const protobuf::Map<string, string>& first_ret,
+ const protobuf::Map<string, string>& second_ret,
+ protobuf::Map<string, string>* fused_ret);
+
+void LazyConjunctionNodes(const FunctionDef& first_function,
+ const FunctionDef& second_function,
+ FunctionDef* fused_function,
+ FunctionDefLibrary* library);
// Fuse `first_function` with `second_function`, setting `fused_name_prefix` as
// a name prefix. The nodes from `first_function` are copied unmodified. All
@@ -91,13 +122,11 @@ void CombineOutput(const protobuf::Map<string, string>& first_ret,
// that are not conflicting with first function. This means that copied nodes
// from second function can end up having different names. For explanation of
// set up functions see the documentation of the functions types.
-FunctionDef* FuseFunctions(const FunctionDef& first_function,
- const FunctionDef& second_function,
- StringPiece fused_name_prefix,
- const SetFunctionSignatureFn& set_signature,
- const SetInputFn& set_input,
- const SetOutputFn& set_output,
- FunctionDefLibrary* library);
+FunctionDef* FuseFunctions(
+ const FunctionDef& first_function, const FunctionDef& second_function,
+ StringPiece fused_name_prefix, const SetFunctionSignatureFn& set_signature,
+ const SetInputFn& set_input, const SetOutputFn& set_output,
+ const SetNodesFn& set_nodes, FunctionDefLibrary* library);
} // namespace fusion_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
index 7ad5d63bf6..e667affeea 100644
--- a/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/fusion_utils_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -57,10 +58,10 @@ TEST(FusionUtilsTest, FuseFunctionsByComposition) {
auto *function = graph.mutable_library()->add_function();
*function = test::function::XTimesTwo();
- auto *fused_function =
- FuseFunctions(*parent_function, *function, "fused_maps",
- fusion_utils::ComposeSignature, fusion_utils::ComposeInput,
- fusion_utils::ComposeOutput, graph.mutable_library());
+ auto *fused_function = FuseFunctions(
+ *parent_function, *function, "fused_maps", fusion_utils::ComposeSignature,
+ fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
+ fusion_utils::MergeNodes, graph.mutable_library());
EXPECT_EQ(fused_function->signature().name(), "fused_maps");
EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
@@ -98,7 +99,8 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
auto *fused_function =
FuseFunctions(*xtimes_two, *is_zero, "fused_map_and_filter_function",
fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, graph.mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ graph.mutable_library());
EXPECT_EQ(fused_function->signature().name(),
"fused_map_and_filter_function");
@@ -109,9 +111,9 @@ TEST(FusionUtilsTest, FuseFunctionWithPredicate) {
CheckUniqueNames(*fused_function);
ASSERT_TRUE(
- graph_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::ContainsFunctionNodeWithOp("Equal", *fused_function));
const auto &equal_node = fused_function->node_def(
- graph_utils::FindFunctionNodeWithOp("Equal", *fused_function));
+ function_utils::FindFunctionNodeWithOp("Equal", *fused_function));
EXPECT_EQ(xtimes_two->signature().output_arg(0).name(),
fused_function->signature().output_arg(0).name());
@@ -134,10 +136,10 @@ TEST(FusionUtilsTest, FuseSameFunctionWithExtraOutput) {
auto *function = graph.mutable_library()->add_function();
*function = test::function::XTimesTwo();
- auto *fused_function =
- FuseFunctions(*parent_function, *function, "fused_maps",
- fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, graph.mutable_library());
+ auto *fused_function = FuseFunctions(
+ *parent_function, *function, "fused_maps", fusion_utils::CombineSignature,
+ fusion_utils::ComposeInput, fusion_utils::CombineOutput,
+ fusion_utils::MergeNodes, graph.mutable_library());
EXPECT_EQ(fused_function->signature().input_arg_size(), 1);
EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
@@ -169,7 +171,8 @@ TEST(FusionUtilsTest, ZipFusion) {
auto *fused_function =
FuseFunctions(*function, *function, "zip_maps", zip_signature, zip_input,
- fusion_utils::CombineOutput, graph.mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ graph.mutable_library());
EXPECT_EQ(fused_function->signature().input_arg_size(), 2);
EXPECT_EQ(fused_function->signature().output_arg_size(), 2);
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
new file mode 100644
index 0000000000..b2eec7220e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "MapDataset", {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}});
+}
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name) {
+ return test::function::NDef(
+ name, "FilterDataset", {string(input_node_name)},
+ {{"predicate", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<TensorShape>{}}});
+}
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_test_utils.h b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
new file mode 100644
index 0000000000..ca0fde997d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/graph_test_utils.h
@@ -0,0 +1,36 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace graph_tests_utils {
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "XTimesTwo");
+
+NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name = "IsZero");
+
+} // end namespace graph_tests_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_GRAPH_TEST_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.cc b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
index 0eceaf4017..b863a25dc5 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -88,17 +89,27 @@ NodeDef* AddScalarConstNodeHelper(
} // namespace
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph) {
+ NodeDef node;
+ node.set_op("Placeholder");
+ SetUniqueGraphNodeName(node.op(), graph->GetGraph(), &node);
+ (*node.mutable_attr())["dtype"].set_type(dtype);
+ TensorShapeProto* shape = (*node.mutable_attr())["shape"].mutable_shape();
+ shape->set_unknown_rank(false);
+ return graph->AddNode(std::move(node));
+}
+
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph) {
NodeDef node;
if (!name.empty()) {
- node.set_name(name.ToString());
+ node.set_name(string(name));
} else {
SetUniqueGraphNodeName(op, graph->GetGraph(), &node);
}
- node.set_op(op.ToString());
+ node.set_op(string(op));
for (const string& input : inputs) {
node.add_input(input);
}
@@ -176,39 +187,37 @@ bool Compare(const GraphDef& g1, const GraphDef& g2) {
return true;
}
-bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- return FindGraphNodeWithName(name, graph) != -1;
-}
-
-bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
- return FindNodeWithOp(op, graph) != -1;
-}
-
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library) {
return FindGraphFunctionWithName(name, library) != -1;
}
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function) {
- return FindFunctionNodeWithName(name, function) != -1;
+bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph) {
+ return FindGraphNodeWithName(name, graph) != -1;
+}
+
+bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph) {
+ return FindGraphNodeWithOp(op, graph) != -1;
}
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- return FindFunctionNodeWithOp(op, function) != -1;
+int FindGraphFunctionWithName(StringPiece name,
+ const FunctionDefLibrary& library) {
+ return GetFirstElementIndexWithPredicate(
+ [&name](const FunctionDef& function) {
+ return function.signature().name() == name;
+ },
+ library.function());
}
int FindGraphNodeWithName(StringPiece name, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+ return GetFirstElementIndexWithPredicate(
[&name](const NodeDef& node) { return node.name() == name; },
graph.node());
- return indices.empty() ? -1 : indices.front();
}
-int FindNodeWithOp(StringPiece op, const GraphDef& graph) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
+int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph) {
+ return GetFirstElementIndexWithPredicate(
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
- return indices.empty() ? -1 : indices.front();
}
std::vector<int> FindAllGraphNodesWithOp(const string& op,
@@ -217,37 +226,18 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
[&op](const NodeDef& node) { return node.op() == op; }, graph.node());
}
-int FindGraphFunctionWithName(StringPiece name,
- const FunctionDefLibrary& library) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const FunctionDef& function) {
- return function.signature().name() == name;
- },
- library.function());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&name](const NodeDef& node) { return node.name() == name; },
- function.node_def());
- return indices.empty() ? -1 : indices.front();
-}
-
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function) {
- std::vector<int> indices = GetElementIndicesWithPredicate(
- [&op](const NodeDef& node) { return node.op() == op; },
- function.node_def());
-
- return indices.empty() ? -1 : indices.front();
+NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph) {
+ if (node.input_size() == 0) return nullptr;
+ GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
+ return graph.GetRegularFanin(input_port).node;
}
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
NodeDef* node) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = graph->node_size();
while (ContainsGraphNodeWithName(name, *graph)) {
- if (name.rfind("_generated") != std::string::npos &&
+ if (name.rfind("_generated") != string::npos &&
(name.rfind("_generated") == (name.size() - strlen("_generated")))) {
name.insert(name.rfind("_generated"), strings::StrCat("/_", id));
} else {
@@ -258,20 +248,9 @@ void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph,
node->set_name(std::move(name));
}
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node) {
- string name = prefix.ToString();
- int id = function->node_def_size();
- while (ContainsFunctionNodeWithName(name, *function)) {
- name = strings::StrCat(prefix, "/_", id);
- ++id;
- }
- node->set_name(std::move(name));
-}
-
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function) {
- string name = prefix.ToString();
+ string name = string(prefix);
int id = library->function_size();
while (ContainsGraphFunctionWithName(name, *library)) {
name = strings::StrCat(prefix, "/_", id);
@@ -280,6 +259,40 @@ void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
function->mutable_signature()->set_name(std::move(name));
}
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node) {
+ (*to_node->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
+}
+
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node) {
+ CopyAttribute(attribute_name, first, to_node);
+ (*to_node->mutable_attr())
+ .at(attribute_name)
+ .mutable_list()
+ ->MergeFrom(second.attr().at(attribute_name).list());
+}
+
+Status EnsureNodeNamesUnique(Graph* g) {
+ // Modeled after Scope::Impl::GetUniqueName
+ std::unordered_map<string, int> name_map;
+
+ for (auto node : g->op_nodes()) {
+ const string& prefix = node->name();
+ if (auto entry = gtl::FindOrNull(name_map, prefix)) {
+ string unique_name;
+ do {
+ unique_name = strings::StrCat(prefix, "_", ++(*entry));
+ } while (name_map.find(unique_name) != name_map.end());
+ name_map.insert({unique_name, 0});
+ node->set_name(std::move(unique_name));
+ } else {
+ name_map.insert({node->name(), 0});
+ }
+ }
+
+ return Status::OK();
+}
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 28a1aff877..d130fee204 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/mutable_graph_view.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -31,12 +32,30 @@ namespace tensorflow {
namespace grappler {
namespace graph_utils {
+// Returns the index of the first element in collection that fulfills predicate.
+// If no such element exists, returns -1.
+template <typename Predicate, typename Collection>
+int GetFirstElementIndexWithPredicate(const Predicate& predicate,
+ const Collection& collection) {
+ unsigned idx = 0;
+ for (auto&& element : collection) {
+ if (predicate(element)) {
+ return idx;
+ }
+ idx++;
+ }
+ return -1;
+}
+
// Adds a node to the graph.
NodeDef* AddNode(StringPiece name, StringPiece op,
const std::vector<string>& inputs,
const std::vector<std::pair<string, AttrValue>>& attributes,
MutableGraphView* graph);
+// Adds Placeholder node for given type.
+NodeDef* AddScalarPlaceholder(DataType dtype, MutableGraphView* graph);
+
// Adds a Const node with the given value to the graph.
template <typename T>
NodeDef* AddScalarConstNode(T v, MutableGraphView* graph) {
@@ -70,13 +89,6 @@ bool ContainsGraphNodeWithName(StringPiece name, const GraphDef& graph);
bool ContainsGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Checks whether the function contains a node with the given name.
-bool ContainsFunctionNodeWithName(StringPiece name,
- const FunctionDef& function);
-
-// Checks whether the function contains a node with the given op.
-bool ContainsFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Checks whether the graph contains a node with the given op.
bool ContainsNodeWithOp(StringPiece op, const GraphDef& graph);
@@ -89,17 +101,12 @@ int FindGraphNodeWithName(StringPiece name, const GraphDef& graph);
int FindGraphFunctionWithName(StringPiece name,
const FunctionDefLibrary& library);
-// Returns the index of the function node with the given name or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithName(StringPiece name, const FunctionDef& function);
-
-// Returns the index of the function node with the given op or -1 if the
-// function node does not exist.
-int FindFunctionNodeWithOp(StringPiece op, const FunctionDef& function);
-
// Returns the index of the first node with the given op or -1 if no such node
// exists.
-int FindNodeWithOp(StringPiece op, const GraphDef& graph);
+int FindGraphNodeWithOp(StringPiece op, const GraphDef& graph);
+
+// Gets the 0th input to a node in the graph.
+NodeDef* GetInputNode(const NodeDef& node, const MutableGraphView& graph);
// Returns the list of indices of all nodes with the given op or empty list if
// no such node exists.
@@ -110,16 +117,29 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the function node name using the `prefix` as a prefix while guaranteeing
-// the name is unique across the functions nodes.
-void SetUniqueFunctionNodeName(StringPiece prefix, FunctionDef* function,
- NodeDef* node);
-
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
+// Copies attribute having name `attribute_name` from node `from` to node
+// `to_node`.
+void CopyAttribute(const string& attribute_name, const NodeDef& from,
+ NodeDef* to_node);
+
+// Concatenates list attribute having name `attribute_name` from `first` and
+// `second` node, setting it to `to_node`.
+void ConcatAttributeList(const string& attribute_name, const NodeDef& first,
+ const NodeDef& second, NodeDef* to_node);
+
+// Checks that all nodes in the graphs have unique names, and sets their names
+// to be unique if they are not already. This is necessary as Graph does not
+// have the provisions to deduplicate names, and name deduplication elsewhere
+// in tensorflow happens in other layers (for example, in the Scope class of the
+// C++ API). Note that the nodes in the graph are identified by their id,
+// and renaming nodes does not mutate any edges.
+Status EnsureNodeNamesUnique(Graph* g);
+
} // end namespace graph_utils
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
index 0a3af1a914..4ab6d71532 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -24,6 +25,18 @@ namespace grappler {
namespace graph_utils {
namespace {
+TEST(GraphUtilsTest, GetFirstElementIndexWithPredicate) {
+ std::vector<int> vec({1, 2, 3, 4, 5, 6});
+ auto result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 3 == 0; }, vec);
+
+ EXPECT_EQ(result, 2);
+
+ result = GetFirstElementIndexWithPredicate(
+ [](int elem) { return elem % 7 == 0; }, vec);
+ EXPECT_EQ(result, -1);
+}
+
TEST(GraphUtilsTest, AddScalarConstNodeBool) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -112,20 +125,6 @@ TEST(GraphUtilsTest, ContainsGraphFunctionWithName) {
ContainsGraphFunctionWithName(new_function->signature().name(), library));
}
-TEST(GraphUtilsTest, ContainsFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithName(
- "weird_name_that_should_not_be_there", function));
- EXPECT_TRUE(ContainsFunctionNodeWithName("two", function));
-}
-
-TEST(GraphUtilsTest, ContainsFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_FALSE(ContainsFunctionNodeWithOp("weird_op_that_should_not_be_there",
- function));
- EXPECT_TRUE(ContainsFunctionNodeWithOp("Mul", function));
-}
-
TEST(GraphUtilsTest, ContainsNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
@@ -150,22 +149,6 @@ TEST(GraphUtilsTest, FindGraphNodeWithName) {
EXPECT_EQ(FindGraphNodeWithName("A", *graph.GetGraph()), -1);
}
-TEST(GraphUtilsTest, FindFunctionNodeWithName) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithName("weird_name_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithName("two", function), -1);
-}
-
-TEST(GraphUtilsTest, FindFunctionNodeWithOp) {
- FunctionDef function = test::function::XTimesTwo();
- EXPECT_EQ(
- FindFunctionNodeWithOp("weird_op_that_should_not_be_there", function),
- -1);
- EXPECT_NE(FindFunctionNodeWithOp("Mul", function), -1);
-}
-
TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FunctionDefLibrary library;
EXPECT_EQ(FindGraphFunctionWithName("new_function", library), -1);
@@ -176,25 +159,25 @@ TEST(GraphUtilsTest, FindGraphFunctionWithName) {
FindGraphFunctionWithName(new_function->signature().name(), library), -1);
}
-TEST(GraphUtilsTest, FindNodeWithOp) {
+TEST(GraphUtilsTest, FindGraphNodeWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
- EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
AddNode("A2", "OpA", {"B"}, {}, &graph);
- EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), 0);
+ EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), 0);
graph.DeleteNodes({"B"});
- EXPECT_EQ(FindNodeWithOp("OpB", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithOp("OpB", *graph.GetGraph()), -1);
EXPECT_EQ(FindGraphNodeWithName("A2", *graph.GetGraph()), 1);
}
TEST(GraphUtilsTest, FindAllGraphNodesWithOp) {
GraphDef graph_def;
MutableGraphView graph(&graph_def);
- EXPECT_EQ(FindNodeWithOp("OpA", *graph.GetGraph()), -1);
+ EXPECT_EQ(FindGraphNodeWithOp("OpA", *graph.GetGraph()), -1);
AddNode("A", "OpA", {}, {}, &graph);
AddNode("B", "OpB", {"A"}, {}, &graph);
@@ -225,21 +208,6 @@ TEST(GraphUtilsTest, SetUniqueGraphNodeName) {
EXPECT_NE(node2->name(), node3->name());
}
-TEST(GraphUtilsTest, SetUniqueFunctionNodeName) {
- FunctionDef function = test::function::XTimesTwo();
- NodeDef node;
- SetUniqueFunctionNodeName("abc", &function, &node);
- for (const NodeDef& function_node : function.node_def()) {
- EXPECT_NE(node.name(), function_node.name());
- }
- auto* new_node = function.add_node_def();
- *new_node = node;
-
- NodeDef other;
- SetUniqueFunctionNodeName("abc", &function, &other);
- EXPECT_NE(other.name(), new_node->name());
-}
-
TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
FunctionDefLibrary library;
FunctionDef* new_function = library.add_function();
@@ -251,6 +219,44 @@ TEST(GraphUtilsTest, SetUniqueGraphFunctionName) {
other_function->signature().name());
}
+TEST(GraphUtilsTest, GetInputNode) {
+ GraphDef graph_def;
+ MutableGraphView graph(&graph_def);
+
+ NodeDef* node1 = AddNode("", "A", {}, {}, &graph);
+ NodeDef* node2 = AddNode("", "A", {node1->name()}, {}, &graph);
+
+ EXPECT_EQ(GetInputNode(*node2, graph), node1);
+ EXPECT_EQ(GetInputNode(*node1, graph), nullptr);
+}
+
+TEST(GraphUtilsTest, EnsureNodeNamesUnique) {
+ Graph g(OpRegistry::Global());
+
+ Node *const_0, *const_1, *const_2;
+
+ // Arbitrary const
+ Tensor tensor(DT_INT32, {});
+ tensor.scalar<int32>()() = 5;
+
+ for (auto node : {&const_0, &const_1}) {
+ TF_EXPECT_OK(NodeBuilder("Const", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, node));
+ }
+ // Make sure generated name doesn't clash with existing name either
+ TF_EXPECT_OK(NodeBuilder("Const_1", "Const")
+ .Attr("value", tensor)
+ .Attr("dtype", DT_INT32)
+ .Finalize(&g, &const_2));
+
+ TF_EXPECT_OK(EnsureNodeNamesUnique(&g));
+ EXPECT_NE(const_0->name(), const_1->name());
+ EXPECT_NE(const_1->name(), const_2->name());
+ EXPECT_NE(const_0->name(), const_2->name());
+}
+
} // namespace
} // namespace graph_utils
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
new file mode 100644
index 0000000000..ce0b2db039
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.cc
@@ -0,0 +1,289 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+NodeDef MakeStatelessMap(const NodeDef& map_node, const NodeDef& zip_node,
+ const FunctionDef& stateless_function,
+ MutableGraphView* graph) {
+ NodeDef stateless_map;
+ graph_utils::SetUniqueGraphNodeName("stateless_map", graph->GetGraph(),
+ &stateless_map);
+
+ stateless_map.set_op("MapDataset");
+ stateless_map.add_input(zip_node.name());
+ // Add placeholders.
+ for (int i = 1; i < map_node.input_size(); i++)
+ stateless_map.add_input(map_node.input(i));
+
+ auto attr = map_node.attr().at("f");
+ *attr.mutable_func()->mutable_name() = stateless_function.signature().name();
+ *attr.mutable_func()->mutable_attr() = stateless_function.attr();
+ (*stateless_map.mutable_attr())["f"] = std::move(attr);
+
+ graph_utils::CopyAttribute("Targuments", map_node, &stateless_map);
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::CopyAttribute(key, map_node, &stateless_map);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*stateless_map.mutable_attr())["use_inter_op_parallelism"] = *attr;
+
+ return stateless_map;
+}
+
+NodeDef MakeRandomDataset(const NodeDef& random_uniform_node,
+ MutableGraphView* graph) {
+ NodeDef random_dataset;
+ random_dataset.set_op("RandomDataset");
+ graph_utils::SetUniqueGraphNodeName("RandomDataset", graph->GetGraph(),
+ &random_dataset);
+
+ const auto* seed = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed").i(), graph);
+ const auto* seed2 = graph_utils::AddScalarConstNode<int64>(
+ random_uniform_node.attr().at("seed2").i(), graph);
+
+ random_dataset.add_input(seed->name());
+ random_dataset.add_input(seed2->name());
+
+ (*random_dataset.mutable_attr())["output_shapes"].mutable_list()->add_shape();
+ (*random_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return random_dataset;
+}
+
+NodeDef MakeBatchTwo(const NodeDef& random_dataset, MutableGraphView* graph) {
+ NodeDef batch_dataset;
+ batch_dataset.set_op("BatchDatasetV2");
+ graph_utils::SetUniqueGraphNodeName("pair_of_random", graph->GetGraph(),
+ &batch_dataset);
+ const auto* batch_size = graph_utils::AddScalarConstNode<int64>(2, graph);
+ const auto* drop_reminder = graph_utils::AddScalarConstNode(false, graph);
+ batch_dataset.add_input(random_dataset.name());
+ batch_dataset.add_input(batch_size->name());
+ batch_dataset.add_input(drop_reminder->name());
+
+ (*batch_dataset.mutable_attr())["output_shapes"]
+ .mutable_list()
+ ->add_shape()
+ ->mutable_dim()
+ ->Add()
+ ->set_size(-1);
+ (*batch_dataset.mutable_attr())["output_types"].mutable_list()->add_type(
+ DT_INT64);
+
+ return batch_dataset;
+}
+
+NodeDef MakeZipNode(const NodeDef& first_node, const NodeDef& second_node,
+ MutableGraphView* graph) {
+ NodeDef zip_node;
+ graph_utils::SetUniqueGraphNodeName("zip_with_random", graph->GetGraph(),
+ &zip_node);
+
+ zip_node.set_op("ZipDataset");
+ zip_node.add_input(first_node.name());
+ zip_node.add_input(second_node.name());
+
+ for (auto key : {"output_shapes", "output_types"})
+ graph_utils::ConcatAttributeList(key, first_node, second_node, &zip_node);
+
+ (*zip_node.mutable_attr())["N"].set_i(2);
+
+ return zip_node;
+}
+
+// We need to insert our argument before the placeholders, which are the last
+// arguments.
+OpDef_ArgDef* InsertSeedArgument(OpDef* signature, int num_placeholders) {
+ int new_argument_idx = signature->input_arg_size() - num_placeholders;
+ signature->add_input_arg();
+ for (int i = signature->input_arg_size() - 1; i > new_argument_idx; i--) {
+ signature->mutable_input_arg()->SwapElements(i - 1, i);
+ }
+ auto* seed_arg = signature->mutable_input_arg(new_argument_idx);
+ seed_arg->set_name(strings::StrCat("seed_arg", new_argument_idx));
+ seed_arg->set_type(DT_INT64);
+
+ return seed_arg;
+}
+
+// Make function that uses `StatelessRandomUniform` instead of `RandomUniform`
+// to make it less statefull. The function can still be stateful, but in when
+// other stateful ops are e.g. `Assert`, then it will be parallelizable.
+const FunctionDef* MakeLessStatefulFunction(const FunctionDef& map_function,
+ bool is_stateful,
+ int num_placeholders,
+ FunctionDefLibrary* library) {
+ FunctionDef* stateless_function = library->add_function();
+ *stateless_function = map_function;
+ if (is_stateful)
+ stateless_function->mutable_signature()->set_is_stateful(is_stateful);
+ graph_utils::SetUniqueGraphFunctionName("stateless_function", library,
+ stateless_function);
+
+ auto* seed_arg = InsertSeedArgument(stateless_function->mutable_signature(),
+ num_placeholders);
+
+ auto* const random_uniform = stateless_function->mutable_node_def(
+ function_utils::FindFunctionNodeWithOp("RandomUniform",
+ *stateless_function));
+
+ // Replace RandomUniform node with StatelessRandomUniform.
+ random_uniform->set_op("StatelessRandomUniform");
+ random_uniform->add_input(seed_arg->name());
+ (*random_uniform->mutable_attr())["Tseed"].set_type(DT_INT64);
+ random_uniform->mutable_attr()->erase("seed");
+ random_uniform->mutable_attr()->erase("seed2");
+
+ return stateless_function;
+}
+// This function returns true if function is stateful and has single
+// RandomUniform op and no other stateful ops except Assert.
+// `is_stateful_after_hoisting` is set to true if RandomUniform is the only
+// stateful op and hoisting can be performed.
+bool CanHoistRandomUniform(const FunctionDef& map_function,
+ const FunctionLibraryDefinition& library,
+ bool* is_stateful_after_hoisting,
+ const NodeDef** random_uniform_op) {
+ if (!map_function.signature().is_stateful()) return false;
+ *is_stateful_after_hoisting = true;
+
+ bool have_other_stateful_ops = false;
+
+ for (const auto& node : map_function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Skip stateless nodes and assert, as it does not actually have a state.
+ if (!op_def->is_stateful()) continue;
+
+ if (op_def->name() == "Assert") {
+ have_other_stateful_ops = true;
+ continue;
+ }
+
+ // TODO(prazek): For now we only handle RandomUniform, we should handle
+ // RandomUniformInt as well.
+ if (op_def->name() != "RandomUniform") return false;
+
+ // TODO(prazek): For now we can only hoist single RandomUniform.
+ if (*random_uniform_op != nullptr) return false;
+
+ *random_uniform_op = &node;
+ }
+
+ if (!have_other_stateful_ops) *is_stateful_after_hoisting = false;
+
+ // Have we found single RandomUniform?
+ return *random_uniform_op != nullptr;
+}
+
+int NumberOfPlaceholders(const NodeDef& map_node) {
+ // First input of MapDataset is the argument to the function. Rest of the
+ // inputs are placeholders.
+ return map_node.input_size() - 1;
+}
+
+} // namespace
+
+Status HoistRandomUniform::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ // TODO(prazek): we could also handle ParallelMapDataset and
+ // MapAndBatchDataset.
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ const auto& fun = map_node->attr().at("f");
+ const FunctionDef* func = function_library.Find(fun.func().name());
+
+ const NodeDef* random_uniform_op = nullptr;
+ bool is_stateful_after_hoisting = true;
+ if (!CanHoistRandomUniform(*func, function_library,
+ &is_stateful_after_hoisting, &random_uniform_op))
+ continue;
+ const auto* random_seed_dataset =
+ graph.AddNode(MakeRandomDataset(*random_uniform_op, &graph));
+
+ const auto* batch_dataset =
+ graph.AddNode(MakeBatchTwo(*random_seed_dataset, &graph));
+
+ const NodeDef& parent_node = *graph_utils::GetInputNode(*map_node, graph);
+
+ const auto* zip_node =
+ graph.AddNode(MakeZipNode(parent_node, *batch_dataset, &graph));
+
+ const auto* stateless_func = MakeLessStatefulFunction(
+ *func, is_stateful_after_hoisting, NumberOfPlaceholders(*map_node),
+ output->mutable_library());
+
+ const auto* stateless_map = graph.AddNode(
+ MakeStatelessMap(*map_node, *zip_node, *stateless_func, &graph));
+
+ graph.ReplaceInput(*map_node, *stateless_map);
+
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void HoistRandomUniform::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(HoistRandomUniform, "hoist_random_uniform");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
new file mode 100644
index 0000000000..d1bcf6782d
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h
@@ -0,0 +1,55 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization hoists instances of `random_uniform` out of a function
+// with the aim of making it stateless. It creates a new function that takes a
+// random seed as an extra argument and uses `stateless_random_uniform` instead
+// of `random_uniform` to make it stateless.
+// It also creates RandomDataset(seed).batch(2), which is zipped with old input
+// to the map. The batching in RandomDataset is because we need 2 seeds for
+// `stateless_random_uniform`.
+// TODO(prazek): for now only `RandomUniform` is handled, but we could handle
+// `RandomUniformInt` similarly.
+class HoistRandomUniform : public CustomGraphOptimizer {
+ public:
+ HoistRandomUniform() = default;
+ ~HoistRandomUniform() override = default;
+
+ string name() const override { return "hoist_random_uniform"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_HOIST_RANDOM_UNIFORM_H_
diff --git a/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
new file mode 100644
index 0000000000..455459e3f6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/hoist_random_uniform_test.cc
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/hoist_random_uniform.h"
+
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+TEST(HoistRandomUniform, SimpleHoisting) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>{}},
+ {"output_types", gtl::ArraySlice<DataType>{}}}),
+ graph_tests_utils::MakeMapNode("map1", "range", "RandomUniform"),
+ NDef("cache", "CacheDataset", {"map1", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::RandomUniform(),
+ });
+
+ HoistRandomUniform optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ const int new_map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
+ const int zip_dataset_id =
+ graph_utils::FindGraphNodeWithOp("ZipDataset", output);
+ const int random_dataset_id =
+ graph_utils::FindGraphNodeWithOp("RandomDataset", output);
+ const int batch_random_id =
+ graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output);
+ ASSERT_NE(random_dataset_id, -1);
+ ASSERT_NE(zip_dataset_id, -1);
+ ASSERT_NE(new_map_id, -1);
+ ASSERT_NE(batch_random_id, -1);
+
+ const auto& new_map = output.node(new_map_id);
+ const auto& zip = output.node(zip_dataset_id);
+ const auto& random = output.node(random_dataset_id);
+ const auto& batch = output.node(batch_random_id);
+
+ ASSERT_EQ(new_map.input_size(), 1);
+ EXPECT_EQ(new_map.input(0), zip.name());
+
+ ASSERT_EQ(zip.input_size(), 2);
+ EXPECT_EQ(zip.input(0), "range");
+ EXPECT_EQ(zip.input(1), batch.name());
+
+ ASSERT_EQ(batch.input_size(), 3);
+ EXPECT_EQ(batch.input(0), random.name());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
index 0b25b1ea9d..9e382aeef9 100644
--- a/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
+++ b/tensorflow/core/grappler/optimizers/data/latency_all_edges.cc
@@ -33,7 +33,7 @@ namespace {
constexpr char kInsertOpName[] = "LatencyStatsDataset";
-NodeDef make_latency_node(const NodeDef& node, MutableGraphView* graph) {
+NodeDef MakeLatencyNode(const NodeDef& node, MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kInsertOpName);
graph_utils::SetUniqueGraphNodeName(
@@ -96,7 +96,7 @@ Status LatencyAllEdges::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- graph.InsertNode(node, make_latency_node(node, &graph));
+ graph.InsertNode(node, MakeLatencyNode(node, &graph));
}
return Status::OK();
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
index 3ce238a30a..e66766eb23 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion.cc
@@ -32,9 +32,8 @@ namespace {
constexpr char kFusedOpName[] = "MapAndBatchDatasetV2";
-NodeDef make_map_and_batch_node(const NodeDef& map_node,
- const NodeDef& batch_node,
- MutableGraphView* graph) {
+NodeDef MakeMapAndBatchNode(const NodeDef& map_node, const NodeDef& batch_node,
+ MutableGraphView* graph) {
NodeDef new_node;
new_node.set_op(kFusedOpName);
graph_utils::SetUniqueGraphNodeName(kFusedOpName, graph->GetGraph(),
@@ -81,11 +80,12 @@ NodeDef make_map_and_batch_node(const NodeDef& map_node,
// Set `f` and `Targuments` attributes.
for (auto key : {"f", "Targuments"}) {
- (*new_node.mutable_attr())[key] = map_node.attr().at(key);
+ graph_utils::CopyAttribute(key, map_node, &new_node);
}
+
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = batch_node.attr().at(key);
+ graph_utils::CopyAttribute(key, batch_node, &new_node);
}
return new_node;
}
@@ -104,8 +104,8 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& batch_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(batch_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
+
if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
continue;
}
@@ -113,7 +113,7 @@ Status MapAndBatchFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
NodeDef* map_node = node2;
auto* new_node =
- graph.AddNode(make_map_and_batch_node(*map_node, batch_node, &graph));
+ graph.AddNode(MakeMapAndBatchNode(*map_node, batch_node, &graph));
graph.ReplaceInput(batch_node, *new_node);
// Mark the `Map` and `Batch` nodes for removal.
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
index a46c504ac4..b676246b31 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_batch_fusion_test.cc
@@ -85,8 +85,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchNodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
- NodeDef map_and_batch_node =
- output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node = output.node(
+ graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
@@ -170,8 +170,8 @@ TEST(MapAndBatchFusionTest, FuseMapAndBatchV2NodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
- NodeDef map_and_batch_node =
- output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node = output.node(
+ graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
@@ -253,8 +253,8 @@ TEST(MapAndBatchFusionTest, FuseParallelMapAndBatchNodesIntoOne) {
EXPECT_FALSE(
graph_utils::ContainsGraphNodeWithName(batch_node->name(), output));
EXPECT_TRUE(graph_utils::ContainsNodeWithOp("MapAndBatchDatasetV2", output));
- NodeDef map_and_batch_node =
- output.node(graph_utils::FindNodeWithOp("MapAndBatchDatasetV2", output));
+ NodeDef map_and_batch_node = output.node(
+ graph_utils::FindGraphNodeWithOp("MapAndBatchDatasetV2", output));
EXPECT_EQ(map_and_batch_node.input_size(), 5);
EXPECT_EQ(map_and_batch_node.input(0), map_node->input(0));
EXPECT_EQ(map_and_batch_node.input(1), map_node->input(1));
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
index 5e76c9f819..c4868eacbb 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -41,19 +42,18 @@ NodeDef MakeFusedNode(const NodeDef& map_node,
fused_node.set_op("MapDataset");
fused_node.add_input(map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = map_node.attr().at("f");
attr.mutable_func()->set_name(fused_function.signature().name());
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", map_node, &fused_node);
+ graph_utils::CopyAttribute("Targuments", map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+
+ if (const auto* attr =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism"))
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"] = *attr;
// Add the predicate output attributes.
(*fused_node.mutable_attr())["output_types"]
@@ -116,22 +116,25 @@ Status MapAndFilterFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto& fun = filter_node->attr().at("predicate");
const FunctionDef* filter_func = function_library.Find(fun.func().name());
if (!fusion_utils::CanCompose(map_func->signature(),
- filter_func->signature()))
+ filter_func->signature())) {
+ VLOG(1) << "Can't fuse map and filter because the output signature of "
+ "the map function does not match the input signature of the "
+ "filter function\n";
return nullptr;
+ }
return fusion_utils::FuseFunctions(
*map_func, *filter_func, "fused_map_and_filter_function",
fusion_utils::CombineSignature, fusion_utils::ComposeInput,
- fusion_utils::CombineOutput, output->mutable_library());
+ fusion_utils::CombineOutput, fusion_utils::MergeNodes,
+ output->mutable_library());
};
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* filter_node = get_filter_node(node);
if (!filter_node) continue;
- GraphView::InputPort input_port =
- graph.GetInputPort(filter_node->name(), 0);
const NodeDef* map_node =
- get_map_node(*graph.GetRegularFanin(input_port).node);
+ get_map_node(*graph_utils::GetInputNode(*filter_node, graph));
if (!map_node) continue;
const auto* fused_function = make_fused_function(map_node, filter_node);
diff --git a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
index 027e0c1590..6e6da37d7c 100644
--- a/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_and_filter_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -27,24 +28,8 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
namespace {
-
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {input_node_name.ToString()},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
-
-NodeDef MakeFilterNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "FilterDataset", {input_node_name.ToString()},
- {{"predicate", FunctionDefHelper::FunctionRef("IsZero")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeFilterNode;
+using graph_tests_utils::MakeMapNode;
TEST(MapAndFilterFusionTest, FuseMapAndFilter) {
using test::function::NDef;
@@ -101,18 +86,18 @@ TEST(MapAndFilterFusionTest, FuseMapAndFilterWithExtraChild) {
graph_utils::ContainsNodeWithOp("FilterByLastComponentDataset", output));
ASSERT_TRUE(graph_utils::ContainsNodeWithOp("CacheDataset", output));
- int map_id = graph_utils::FindNodeWithOp("MapDataset", output);
+ int map_id = graph_utils::FindGraphNodeWithOp("MapDataset", output);
auto& map_node = output.node(map_id);
ASSERT_EQ(map_node.input_size(), 1);
EXPECT_EQ(map_node.input(0), "range");
int filter_by_component_id =
- graph_utils::FindNodeWithOp("FilterByLastComponentDataset", output);
+ graph_utils::FindGraphNodeWithOp("FilterByLastComponentDataset", output);
auto& filter_by_component = output.node(filter_by_component_id);
ASSERT_EQ(filter_by_component.input_size(), 1);
EXPECT_EQ(filter_by_component.input(0), map_node.name());
- int cache_id = graph_utils::FindNodeWithOp("CacheDataset", output);
+ int cache_id = graph_utils::FindGraphNodeWithOp("CacheDataset", output);
auto& cache_node = output.node(cache_id);
ASSERT_EQ(cache_node.input_size(), 2);
EXPECT_EQ(cache_node.input(0), filter_by_component.name());
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion.cc b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
index feb370eb9d..bd943342e8 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/protobuf.h"
namespace tensorflow {
@@ -40,24 +41,31 @@ NodeDef MakeFusedNode(const NodeDef& parent_map_node, const NodeDef& map_node,
NodeDef fused_node;
graph_utils::SetUniqueGraphNodeName("fused_map", graph->GetGraph(),
&fused_node);
-
fused_node.set_op("MapDataset");
fused_node.add_input(parent_map_node.input(0));
- auto copy_attribute = [](const string& attribute_name, const NodeDef& from,
- NodeDef* to) {
- (*to->mutable_attr())[attribute_name] = from.attr().at(attribute_name);
- };
-
auto attr = parent_map_node.attr().at("f");
*attr.mutable_func()->mutable_name() = fused_function.signature().name();
(*fused_node.mutable_attr())["f"] = std::move(attr);
- copy_attribute("Targuments", parent_map_node, &fused_node);
-
+ graph_utils::CopyAttribute("Targuments", parent_map_node, &fused_node);
for (auto key : {"output_shapes", "output_types"})
- copy_attribute(key, map_node, &fused_node);
+ graph_utils::CopyAttribute(key, map_node, &fused_node);
+ auto value_or_false = [](const AttrValue* attr) {
+ if (!attr) return false;
+ return attr->b();
+ };
+
+ const auto* first_parallelism =
+ gtl::FindOrNull(parent_map_node.attr(), "use_inter_op_parallelism");
+ const auto* second_parallelism =
+ gtl::FindOrNull(map_node.attr(), "use_inter_op_parallelism");
+ // Some graphs cannot execute with use_inter_op_parallelism=False, so we need
+ // to set it to true if one of the ops have it set to true.
+ if (value_or_false(first_parallelism) || value_or_false(second_parallelism)) {
+ (*fused_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+ }
return fused_node;
}
@@ -90,21 +98,25 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
const auto& fun = map_node->attr().at("f");
const FunctionDef* func = function_library.Find(fun.func().name());
- if (!fusion_utils::CanCompose(parent_func->signature(), func->signature()))
+ if (!fusion_utils::CanCompose(parent_func->signature(),
+ func->signature())) {
+ VLOG(1) << "Can't fuse two maps because the output signature of the "
+ "first map function does not match the input signature of the "
+ "second function\n";
return nullptr;
+ }
return fusion_utils::FuseFunctions(
*parent_func, *func, "fused_map", fusion_utils::ComposeSignature,
fusion_utils::ComposeInput, fusion_utils::ComposeOutput,
- output->mutable_library());
+ fusion_utils::MergeNodes, output->mutable_library());
};
for (const NodeDef& node : sorted_old_graph.node()) {
const NodeDef* map_node = get_map_node(node);
if (!map_node) continue;
- GraphView::InputPort input_port = graph.GetInputPort(map_node->name(), 0);
const NodeDef* parent_map_node =
- get_map_node(*graph.GetRegularFanin(input_port).node);
+ get_map_node(*graph_utils::GetInputNode(*map_node, graph));
if (!parent_map_node) continue;
const auto* fused_function = get_fused_function(parent_map_node, map_node);
@@ -119,8 +131,8 @@ Status MapFusion::Optimize(Cluster* cluster, const GrapplerItem& item,
// fusion.
TF_RETURN_IF_ERROR(function_library.AddFunctionDef(*fused_function));
- // TODO(prazek): we could also remove map functions from library if they
- // are not used anymore.
+ // TODO(b/116285210): we could also remove map functions from library if
+ // they are not used anymore.
nodes_to_delete.insert(parent_map_node->name());
nodes_to_delete.insert(map_node->name());
}
diff --git a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
index df6c19dc7c..8889f9dddd 100644
--- a/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_fusion_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/framework/function_testlib.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -28,14 +29,7 @@ namespace tensorflow {
namespace grappler {
namespace {
-NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name) {
- return test::function::NDef(
- name, "MapDataset", {input_node_name.ToString()},
- {{"f", FunctionDefHelper::FunctionRef("XTimesTwo")},
- {"Targuments", {}},
- {"output_shapes", {}},
- {"output_types", {}}});
-}
+using graph_tests_utils::MakeMapNode;
TEST(MapFusionTest, FuseTwoMapNodesIntoOne) {
using test::function::NDef;
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
new file mode 100644
index 0000000000..782c9f48b7
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.cc
@@ -0,0 +1,103 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+bool CanParallelize(const FunctionDef& function,
+ const FunctionLibraryDefinition& library) {
+ if (!function.signature().is_stateful()) return true;
+
+ for (const auto& node : function.node_def()) {
+ const OpDef* op_def;
+ TF_CHECK_OK(library.LookUpOpDef(node.op(), &op_def));
+ // Assert is marked as stateful, but it does not have any state (except
+ // changing io). Similarly to CUDA, we do not give guarantee that the
+ // assert operation that would fail would be the first one, so that we can
+ // parallelize it.
+ if (op_def->is_stateful() && op_def->name() != "Assert") return false;
+ }
+
+ return true;
+}
+
+NodeDef MakeParallelMap(const NodeDef& map_node, MutableGraphView* graph) {
+ NodeDef parallel_map = map_node;
+ graph_utils::SetUniqueGraphNodeName("parallel_map", graph->GetGraph(),
+ &parallel_map);
+ parallel_map.set_op("ParallelMapDataset");
+ // TODO(b/114475558): We want to set `num_parallel_calls` to a special value,
+ // so that dynamic tunning will pick the optimal value at runtime. Because
+ // this feature is not yet implemented, we set it to 2, which is the smallest
+ // value that introduces parallelism.
+ auto* num_parallel_calls = graph_utils::AddScalarConstNode(2, graph);
+ parallel_map.add_input(num_parallel_calls->name());
+
+ return parallel_map;
+}
+
+} // namespace
+
+Status MapParallelization::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+ FunctionLibraryDefinition function_library(OpRegistry::Global(),
+ item.graph.library());
+ auto get_map_node = [](const NodeDef& node) -> const NodeDef* {
+ if (node.op() == "MapDataset") return &node;
+ return nullptr;
+ };
+
+ for (const NodeDef& node : item.graph.node()) {
+ const NodeDef* map_node = get_map_node(node);
+ if (!map_node) continue;
+
+ auto* function =
+ function_library.Find(map_node->attr().at("f").func().name());
+ if (!CanParallelize(*function, function_library)) continue;
+
+ auto* parallel_map = graph.AddNode(MakeParallelMap(*map_node, &graph));
+ graph.ReplaceInput(*map_node, *parallel_map);
+ nodes_to_delete.insert(map_node->name());
+ }
+
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void MapParallelization::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapParallelization, "map_parallelization");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization.h b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
new file mode 100644
index 0000000000..ac9cf7e12a
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization.h
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
+
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// This optimization parallelizes MapDataset when function is stateless.
+class MapParallelization : public CustomGraphOptimizer {
+ public:
+ MapParallelization() = default;
+ ~MapParallelization() override = default;
+
+ string name() const override { return "map_parallelization"; };
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output, double result) override;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_PARALLELIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
new file mode 100644
index 0000000000..9fdfe8af30
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_parallelization_test.cc
@@ -0,0 +1,85 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_parallelization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_test_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+using graph_tests_utils::MakeMapNode;
+const char stateless_fun_name[] = "XTimesTwo";
+const char stateful_fun_name[] = "RandomUniform";
+
+TEST(MapParallelizationTest, ParallelizeSimpleMap) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateless_fun_name)},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+TEST(MapParallelization, ParallelizeAssert) {
+ using test::function::NDef;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("filename", "Const", {}, {{"value", ""}, {"dtype", DT_STRING}}),
+ NDef("range", "RangeDataset", {"start", "stop", "step"}, {}),
+ MakeMapNode("map1", "range", stateful_fun_name),
+ MakeMapNode("map2", "map1", stateless_fun_name),
+ NDef("cache", "CacheDataset", {"map2", "filename"}, {})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ test::function::RandomUniform(),
+ });
+
+ MapParallelization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ EXPECT_TRUE(graph_utils::ContainsNodeWithOp("ParallelMapDataset", output));
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName("map1", output));
+ EXPECT_FALSE(graph_utils::ContainsGraphNodeWithName("map2", output));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
new file mode 100644
index 0000000000..a9254ed58b
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -0,0 +1,282 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/grappler/clusters/cluster.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+// Returns a FunctionDef containing a MapDefun op that wraps the original
+// function.
+FunctionDef* CreateMapDefunWrapper(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ FunctionDef* vectorized_func = library->add_function();
+ // Function inputs and outputs are the same as original, just
+ // with different shapes.
+ *vectorized_func->mutable_signature() = orig_func.signature();
+ graph_utils::SetUniqueGraphFunctionName("naively_vectorized_fn", library,
+ vectorized_func);
+
+ // Add MapDefun node
+ NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Add();
+ map_defun_node->set_op("MapDefun");
+ function_utils::SetUniqueFunctionNodeName(map_defun_node->op(),
+ vectorized_func, map_defun_node);
+
+ // Set attrs and inputs
+ for (const string& k : {"f", "output_types", "output_shapes"}) {
+ // Function, output types and (unbatched) shapes are the same as the
+ // original map node.
+ graph_utils::CopyAttribute(k, map_node, map_defun_node);
+ }
+
+ // Get types of input arguments from original map function
+ AttrValue t_args;
+ for (const auto& input : vectorized_func->signature().input_arg()) {
+ t_args.mutable_list()->add_type(input.type());
+ map_defun_node->add_input(input.name());
+ }
+ (*map_defun_node->mutable_attr())["Targuments"] = t_args;
+ AddNodeAttr("Tcaptured", DataTypeVector(), map_defun_node);
+
+ // Set return values to match output names
+ string output_prefix = strings::StrCat(map_defun_node->name(), ":output:");
+ for (size_t i = 0; i < vectorized_func->signature().output_arg_size(); ++i) {
+ const auto& output_arg = vectorized_func->signature().output_arg(i);
+ (*vectorized_func->mutable_ret())[output_arg.name()] =
+ strings::StrCat(output_prefix, i);
+ }
+
+ return vectorized_func;
+}
+
+FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
+ const FunctionDef& orig_func,
+ FunctionDefLibrary* library) {
+ // Vectorizes orig_func naively by wrapping in a MapDefun op, then performing
+ // efficient vectorization with VectorizeMapDefun.
+ FunctionDef* vectorized_func =
+ CreateMapDefunWrapper(map_node, orig_func, library);
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
+}
+
+bool IsOutputShapesFullyDefined(const NodeDef& node) {
+ auto* shapes_attr = gtl::FindOrNull(node.attr(), "output_shapes");
+ if (shapes_attr == nullptr) return false;
+ const auto& shapes = shapes_attr->list().shape();
+
+ for (const TensorShapeProto& shape : shapes) {
+ for (const auto& dim : shape.dim()) {
+ if (dim.size() == -1) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+bool IsStatefulFn(const FunctionLibraryDefinition& library,
+ const FunctionDef& function_def) {
+ for (const NodeDef& node_def : function_def.node_def()) {
+ const OpDef* op_def;
+ Status s = library.LookUpOpDef(node_def.op(), &op_def);
+ if (!s.ok() || op_def->is_stateful()) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HasCapturedInputs(const NodeDef& map_node) {
+ return map_node.attr().at("Targuments").list().type_size() > 0;
+}
+
+NodeDef MakeNewBatchNode(const NodeDef& old_batch_node,
+ const NodeDef& input_node,
+ const FunctionDef& vectorized_func,
+ MutableGraphView* graph) {
+ NodeDef batch_node;
+ batch_node.set_op(old_batch_node.op());
+ graph_utils::SetUniqueGraphNodeName(batch_node.op(), graph->GetGraph(),
+ &batch_node);
+
+ // Set the `input_dataset` input argument
+ batch_node.add_input(input_node.name());
+ // Set the `batch_size` input_argument
+ batch_node.add_input(old_batch_node.input(1));
+ if (batch_node.op() == "BatchDatasetV2") {
+ // Set the `drop_remainder` input argument
+ batch_node.add_input(old_batch_node.input(2));
+ }
+
+ // Set attrs
+ AttrValue output_types;
+ for (const auto& input : vectorized_func.signature().input_arg()) {
+ output_types.mutable_list()->add_type(input.type());
+ }
+ (*batch_node.mutable_attr())["output_types"] = output_types;
+
+ auto& output_shapes_attr = (*batch_node.mutable_attr())["output_shapes"];
+ const auto& input_shapes =
+ input_node.attr().at("output_shapes").list().shape();
+ int64 batch_size =
+ old_batch_node.attr().at("output_shapes").list().shape()[0].dim(0).size();
+ for (size_t i = 0; i < input_shapes.size(); ++i) {
+ TensorShapeProto* shape = output_shapes_attr.mutable_list()->add_shape();
+ TensorShapeProto_Dim* dim = shape->add_dim();
+ dim->set_size(batch_size);
+ shape->MergeFrom(input_shapes.Get(i));
+ }
+ return batch_node;
+}
+
+NodeDef MakeNewMapNode(const NodeDef& old_map_node,
+ const NodeDef& old_batch_node,
+ const NodeDef& new_batch_node,
+ const FunctionDef& vectorized_func,
+ MutableGraphView* graph) {
+ NodeDef map_node;
+ map_node.set_op(old_map_node.op());
+ graph_utils::SetUniqueGraphNodeName(map_node.op(), graph->GetGraph(),
+ &map_node);
+
+ // Set the `input_dataset` input argument
+ map_node.add_input(new_batch_node.name());
+ for (int i = 1; i < old_map_node.input_size(); i++) {
+ // Set the `other_arguments` and `num_parallel_calls` input arguments
+ map_node.add_input(old_map_node.input(i));
+ }
+
+ // Set attrs
+ graph_utils::CopyAttribute("Targuments", old_map_node, &map_node);
+ auto& func_attr = (*map_node.mutable_attr())["f"];
+ func_attr.mutable_func()->set_name(vectorized_func.signature().name());
+
+ for (auto key : {"output_shapes", "output_types"}) {
+ graph_utils::CopyAttribute(key, old_batch_node, &map_node);
+ }
+
+ (*map_node.mutable_attr())["use_inter_op_parallelism"].set_b(true);
+
+ return map_node;
+}
+
+} // namespace
+
+Status MapVectorization::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* output) {
+ *output = item.graph;
+ MutableGraphView graph(output);
+ std::set<string> nodes_to_delete;
+
+ for (const NodeDef& node : item.graph.node()) {
+ // Find Map->Batch nodes.
+ // TODO(rachelim): Optimize MapAndBatchDataset[V2] as well.
+ if (node.op() != "BatchDataset" && node.op() != "BatchDatasetV2") {
+ continue;
+ }
+
+ const NodeDef& batch_node(node);
+ NodeDef* node2 = graph_utils::GetInputNode(batch_node, graph);
+ if (node2->op() != "MapDataset" && node2->op() != "ParallelMapDataset") {
+ continue;
+ }
+
+ // Use a more descriptive variable name now that we know the node type.
+ NodeDef* map_node = node2;
+ // Input to the map node
+ NodeDef* input_node = graph_utils::GetInputNode(*map_node, graph);
+ CHECK_NOTNULL(input_node);
+
+ FunctionDefLibrary* library = output->mutable_library();
+
+ FunctionLibraryDefinition function_library(OpRegistry::Global(), *library);
+ const FunctionDef* orig_func =
+ function_library.Find(map_node->attr().at("f").func().name());
+
+ // Check that this is a valid optimization.
+ if (!IsOutputShapesFullyDefined(*input_node) ||
+ !IsOutputShapesFullyDefined(*map_node) ||
+ IsStatefulFn(function_library, *orig_func) ||
+ HasCapturedInputs(*map_node)) {
+ // 1. If any of the inputs have an unknown shape, don't optimize, since
+ // inputs might not be batchable.
+ // 2. If any of the map func outputs have an unknown shape, don't
+ // optimize, so that batching errors surface as before.
+ // 3. If the function is stateful, don't vectorize it.
+ // 4. TODO(rachelim): Make this work for MapDataset with captured inputs
+ // by tiling inputs or modifying the signature of MapDefun.
+ continue;
+ }
+
+ FunctionDef* vectorized_func =
+ AddVectorizedFunction(*map_node, *orig_func, library);
+ CHECK_NOTNULL(vectorized_func);
+
+ auto* new_batch_node = graph.AddNode(
+ MakeNewBatchNode(batch_node, *input_node, *vectorized_func, &graph));
+
+ auto* new_map_node = graph.AddNode(MakeNewMapNode(
+ *map_node, batch_node, *new_batch_node, *vectorized_func, &graph));
+ graph.ReplaceInput(batch_node, *new_map_node);
+
+ // Mark the `Map` and `Batch` nodes for removal.
+ nodes_to_delete.insert(map_node->name());
+ nodes_to_delete.insert(batch_node.name());
+ }
+ graph.DeleteNodes(nodes_to_delete);
+ return Status::OK();
+}
+
+void MapVectorization::Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimize_output,
+ double result) {
+ // no-op
+}
+
+REGISTER_GRAPH_OPTIMIZER_AS(MapVectorization, "map_vectorization");
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/function_rename.h b/tensorflow/core/grappler/optimizers/data/map_vectorization.h
index 23ad9470ff..cc56a8ee5e 100644
--- a/tensorflow/core/grappler/optimizers/data/function_rename.h
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.h
@@ -13,20 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
namespace tensorflow {
namespace grappler {
-class FunctionRename : public CustomGraphOptimizer {
+class MapVectorization : public CustomGraphOptimizer {
public:
- FunctionRename() = default;
- ~FunctionRename() override = default;
+ MapVectorization() = default;
+ ~MapVectorization() override = default;
- string name() const override { return "_test_only_function_rename"; };
+ string name() const override { return "map_vectorization"; };
Status Init(
const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
@@ -43,4 +43,4 @@ class FunctionRename : public CustomGraphOptimizer {
} // end namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_FUNCTION_RENAME_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_MAP_VECTORIZATION_H_
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
new file mode 100644
index 0000000000..f4faf41549
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -0,0 +1,211 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/map_vectorization.h"
+
+#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+using test::function::GDef;
+using test::function::NDef;
+
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return test::function::NDef(
+ name, map_op_name, {string(input_node_name)},
+ {{"f", FunctionDefHelper::FunctionRef(string(function_name))},
+ {"Targuments", {}},
+ {"output_shapes", output_shapes},
+ {"output_types", output_types}});
+}
+
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
+ output_shapes, output_types);
+}
+
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
+}
+
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
+}
+
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
+ return NDef(name, "RangeDataset", inputs,
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
+ {"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
+}
+
+TEST(MapVectorizationTest, VectorizeMapWithBatch) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
+TEST(MapVectorizationTest, VectorizeMapWithBatchV2) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("drop_remainder", "Const", {},
+ {{"value", false}, {"dtype", DT_BOOL}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchV2Node("batch", "map", "batch_size", "drop_remainder", {{-1}},
+ {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(
+ graph_utils::FindAllGraphNodesWithOp("BatchDatasetV2", output).size(), 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDatasetV2", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
+TEST(MapVectorizationTest, VectorizeWithUndefinedOutputShape) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("input", "InputDataset", {},
+ {{"output_types", gtl::ArraySlice<DataType>({DT_INT32})}}),
+ MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+}
+
+TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("input", "InputDataset", {},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
+ MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {
+ test::function::XTimesTwo(),
+ });
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+}
+
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
index 55d57b3b97..cf5a19bab1 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination.cc
@@ -33,25 +33,27 @@ namespace {
bool IsTakeAll(const NodeDef& take_node, const GraphView& graph) {
if (take_node.op() != "TakeDataset") return false;
- const NodeDef& count_node = *graph.GetNode(take_node.input(1));
+ const auto& count_node = *graph.GetNode(take_node.input(1));
+ if (count_node.op() != "Const") return false;
// We are looking only for 'take' with negative count.
return count_node.attr().at("value").tensor().int64_val(0) < 0;
}
+bool IsConstNodeWithValue(const NodeDef& node, int value) {
+ if (node.op() != "Const") return false;
+ return node.attr().at("value").tensor().int64_val(0) == value;
+}
+
bool IsSkipNone(const NodeDef& skip_node, const GraphView& graph) {
if (skip_node.op() != "SkipDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(skip_node.input(1));
// We are looking only for skip(0) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 0;
+ return IsConstNodeWithValue(*graph.GetNode(skip_node.input(1)), 0);
}
bool IsRepeatOne(const NodeDef& repeat_node, const GraphView& graph) {
if (repeat_node.op() != "RepeatDataset") return false;
-
- const NodeDef& count_node = *graph.GetNode(repeat_node.input(1));
// We are looking only for repeat(1) nodes.
- return count_node.attr().at("value").tensor().int64_val(0) == 1;
+ return IsConstNodeWithValue(*graph.GetNode(repeat_node.input(1)), 1);
}
bool IsNoOp(const NodeDef& node, const GraphView& graph) {
@@ -69,8 +71,7 @@ Status NoOpElimination::Optimize(Cluster* cluster, const GrapplerItem& item,
for (const NodeDef& node : item.graph.node()) {
if (!IsNoOp(node, graph)) continue;
- GraphView::InputPort input_port = graph.GetInputPort(node.name(), 0);
- NodeDef* const parent = graph.GetRegularFanin(input_port).node;
+ NodeDef* const parent = graph_utils::GetInputNode(node, graph);
graph.ReplaceInput(node, *parent);
nodes_to_delete.insert(node.name());
diff --git a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
index f445e75aa7..be1a66df75 100644
--- a/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/noop_elimination_test.cc
@@ -43,6 +43,14 @@ NodeDef *MakeUnaryNode(StringPiece node_type, int count, string input_node,
GetCommonAttributes(), graph);
}
+NodeDef *MakeUnaryNonConstNode(StringPiece node_type, string input_node,
+ MutableGraphView *graph) {
+ NodeDef *node_count = graph_utils::AddScalarPlaceholder(DT_INT32, graph);
+ return graph_utils::AddNode("", node_type,
+ {std::move(input_node), node_count->name()},
+ GetCommonAttributes(), graph);
+}
+
NodeDef *MakeCacheNode(string input_node, MutableGraphView *graph) {
NodeDef *node_filename =
graph_utils::AddScalarConstNode<StringPiece>("", graph);
@@ -205,6 +213,41 @@ INSTANTIATE_TEST_CASE_P(
::testing::Values(*kTakeNode, *kSkipNode,
*kRepeatNode)));
+struct NoOpPlaceholdersTest
+ : ::testing::TestWithParam<std::tuple<string, string>> {};
+
+TEST_P(NoOpPlaceholdersTest, NonConstNoOpNode) {
+ GrapplerItem item;
+ MutableGraphView graph(&item.graph);
+
+ static_assert(std::tuple_size<NodesTypes>::value == 2,
+ "Make sure to include everything in the test");
+ const std::vector<string> noop_nodes = {std::get<0>(GetParam()),
+ std::get<1>(GetParam())};
+ NodeDef *range_node = MakeRangeNode(&graph);
+ std::vector<string> nodes_to_keep;
+ nodes_to_keep.reserve(noop_nodes.size());
+ NodeDef *previous = range_node;
+
+ for (const auto &noop_node : noop_nodes) {
+ NodeDef *node = MakeUnaryNonConstNode(noop_node, previous->name(), &graph);
+ nodes_to_keep.push_back(node->name());
+ previous = node;
+ }
+
+ NoOpElimination optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+ for (const auto &noop_node_name : nodes_to_keep)
+ EXPECT_TRUE(graph_utils::ContainsGraphNodeWithName(noop_node_name, output));
+}
+
+INSTANTIATE_TEST_CASE_P(
+ DoNotRemovePlaceholders, NoOpPlaceholdersTest,
+ ::testing::Combine(
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset"),
+ ::testing::Values("TakeDataset", "SkipDataset", "RepeatDataset")));
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
index 7c7161c5b2..99c4afa634 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion.cc
@@ -64,7 +64,7 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Set `output_types` and `output_shapes` attributes.
for (auto key : {"output_shapes", "output_types"}) {
- (*new_node.mutable_attr())[key] = repeat_node.attr().at(key);
+ graph_utils::CopyAttribute(key, repeat_node, &new_node);
}
return new_node;
};
@@ -76,8 +76,8 @@ Status ShuffleAndRepeatFusion::Optimize(Cluster* cluster,
// Use a more descriptive variable name now that we know the node type.
const NodeDef& repeat_node = node;
- GraphView::InputPort input_port = graph.GetInputPort(repeat_node.name(), 0);
- NodeDef* node2 = graph.GetRegularFanin(input_port).node;
+ NodeDef* node2 = graph_utils::GetInputNode(repeat_node, graph);
+
if (node2->op() != "ShuffleDataset") {
continue;
}
diff --git a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
index a2e470e511..f0696eb76d 100644
--- a/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/shuffle_and_repeat_fusion_test.cc
@@ -78,7 +78,7 @@ TEST(ShuffleAndRepeatFusionTest, FuseShuffleAndRepeatNodesIntoOne) {
EXPECT_TRUE(
graph_utils::ContainsNodeWithOp("ShuffleAndRepeatDataset", output));
NodeDef shuffle_and_repeat_node = output.node(
- graph_utils::FindNodeWithOp("ShuffleAndRepeatDataset", output));
+ graph_utils::FindGraphNodeWithOp("ShuffleAndRepeatDataset", output));
EXPECT_EQ(shuffle_and_repeat_node.input_size(), 5);
EXPECT_EQ(shuffle_and_repeat_node.input(0), shuffle_node->input(0));
EXPECT_EQ(shuffle_and_repeat_node.input(1), shuffle_node->input(1));
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
new file mode 100644
index 0000000000..37aa24b947
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -0,0 +1,70 @@
+package(
+ default_visibility = ["//visibility:private"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
+
+VECTORIZER_DEPS = [
+ ":vectorizer_registry",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
+] + tf_protos_all()
+
+cc_library(
+ name = "vectorizer",
+ hdrs = ["vectorizer.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:lib",
+ ] + tf_protos_all(),
+)
+
+cc_library(
+ name = "vectorizer_registry",
+ srcs = ["vectorizer_registry.cc"],
+ hdrs = ["vectorizer_registry.h"],
+ deps = [
+ ":vectorizer",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
+ name = "cast_vectorizer",
+ srcs = ["cast_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "unpack_vectorizer",
+ srcs = ["unpack_vectorizer.cc"],
+ deps = VECTORIZER_DEPS,
+ alwayslink = 1,
+)
+
+cc_library(
+ name = "vectorization",
+ hdrs = ["vectorizer_registry.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ ":cast_vectorizer",
+ ":unpack_vectorizer",
+ ":vectorizer",
+ ":vectorizer_registry",
+ ],
+)
+
+tf_cc_test(
+ name = "vectorizer_registry_test",
+ srcs = ["vectorizer_registry_test.cc"],
+ deps = [
+ ":vectorizer_registry",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ] + tf_protos_all(),
+)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
new file mode 100644
index 0000000000..3af6bab409
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -0,0 +1,49 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class CastVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
+ return errors::Internal("Cast op should only have one input.");
+ }
+
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
+
+ // Add input and output mappings
+ input_ports->push_back({new_cast_node, 0});
+ output_ports->push_back({new_cast_node, 0});
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Cast", CastVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
new file mode 100644
index 0000000000..74ce520ce1
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class UnpackVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
+ return errors::Internal("Unpack op should only have one input.");
+ }
+
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
+
+ // Increment "axis" attr by 1:
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
+
+ // Add the input mappings
+ input_ports->push_back({new_unpack_node, 0});
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
+ for (int i = 0; i < num; ++i) {
+ output_ports->push_back({new_unpack_node, i});
+ }
+
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("Unpack", UnpackVectorizer);
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
new file mode 100644
index 0000000000..56eb88c95e
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> Port;
+
+// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
+// for an example.
+class Vectorizer {
+ public:
+ virtual ~Vectorizer() {}
+
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
+ // that produce the same vector output(s) as executing `node`'s op
+ // on elements of the vector inputs. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds mappings for the new nodes' input and output ports to `inputs` and
+ // `outputs` respectively, where the i'th Port in inputs/outputs
+ // corresponds to the i'th input/output port of the node to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) = 0;
+};
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
new file mode 100644
index 0000000000..a6551e36ac
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.cc
@@ -0,0 +1,47 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+VectorizerRegistry* VectorizerRegistry::Global() {
+ static VectorizerRegistry* registry = new VectorizerRegistry;
+ return registry;
+}
+
+Vectorizer* VectorizerRegistry::Get(const string& op_type) {
+ auto found = vectorizers_.find(op_type);
+ if (found == vectorizers_.end()) {
+ return nullptr;
+ }
+ return found->second.get();
+}
+
+void VectorizerRegistry::Register(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ auto existing = Get(op_type);
+ CHECK_EQ(existing, nullptr)
+ << "Vectorizer for op type: " << op_type << " already registered";
+ vectorizers_.insert(std::pair<const string&, std::unique_ptr<Vectorizer>>(
+ op_type, std::move(vectorizer)));
+}
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
new file mode 100644
index 0000000000..16159d47ca
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h
@@ -0,0 +1,75 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
+
+#include <functional>
+#include <map>
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// A global VectorizerRegistry is used to hold all the vectorizers.
+class VectorizerRegistry {
+ public:
+ // Returns a pointer to a global VectorizerRegistry object.
+ static VectorizerRegistry* Global();
+
+ // Returns a pointer to a vectorizer that can vectorize an op for the op type.
+ Vectorizer* Get(const string& op_type);
+
+ // Registers a vectorizer that can vectorize an op for the given op type.
+ void Register(const string& op_type, std::unique_ptr<Vectorizer> vectorizer);
+
+ private:
+ std::map<string, std::unique_ptr<Vectorizer>> vectorizers_;
+};
+
+namespace vectorizer_registration {
+
+class VectorizerRegistration {
+ public:
+ VectorizerRegistration(const string& op_type,
+ std::unique_ptr<Vectorizer> vectorizer) {
+ VectorizerRegistry::Global()->Register(op_type, std::move(vectorizer));
+ }
+};
+
+} // namespace vectorizer_registration
+
+#define REGISTER_VECTORIZER(op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ_HELPER(__COUNTER__, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ_HELPER(ctr, op_type, vectorizer) \
+ REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer)
+
+#define REGISTER_VECTORIZER_UNIQ(ctr, op_type, vectorizer) \
+ static ::tensorflow::grappler::vectorization_utils:: \
+ vectorizer_registration::VectorizerRegistration \
+ vectorizer_registration_##ctr( \
+ op_type, \
+ ::std::unique_ptr< \
+ ::tensorflow::grappler::vectorization_utils::Vectorizer>( \
+ new vectorizer()))
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_REGISTRY_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
new file mode 100644
index 0000000000..663ceba027
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+class TestVectorizer : public Vectorizer {
+ public:
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* inputs,
+ std::vector<Port>* outputs) override {
+ return Status::OK();
+ }
+};
+
+REGISTER_VECTORIZER("test_op", TestVectorizer);
+
+TEST(TestVectorizer, TestTestVectorizer) {
+ EXPECT_EQ(VectorizerRegistry::Global()->Get("nonexistent"), nullptr);
+
+ auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
+ EXPECT_NE(vectorizer, nullptr);
+
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<Port> inputs, outputs;
+ EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
+}
+
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
new file mode 100644
index 0000000000..344c420902
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -0,0 +1,612 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
+
+#include "absl/strings/str_join.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/grappler/mutable_graph_view.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/functions.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+namespace {
+
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
+
+// Equivalent to python Pfor's WrappedTensor struct
+struct WrappedTensor {
+ TensorDesc tensor;
+
+ // Whether the tensor is stacked, i.e. represents the results of applying
+ // the operation on all slices of the input, where each row i of the
+ // tensor corresponds to the op's output on slice i of the input. False
+ // if the tensor is not stacked, i.e. represents the result of the op on
+ // a single slice of the input, where the result does not vary between
+ // slices.
+ bool stacked;
+
+ WrappedTensor(TensorDesc&& tensor, bool stacked)
+ : tensor(std::move(tensor)), stacked(stacked) {}
+};
+
+const char* const kRetValOp = "_Retval";
+
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
+}
+
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
+
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
+
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
+
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
+}
+
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
+
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
+ }
+}
+
+// Helper class that vectorizes the body of a MapDefun node, adding new
+// operations to the graph that collectively compute the same value as what
+// running the MapDefun function on slices of the input would produce.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
+class Vectorization {
+ public:
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
+
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
+
+ private:
+ // Converts FunctionDefs to Graphs and adds mappings from
+ // arg nodes and unstacked nodes to the corresponding nodes in outer_scope_.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
+
+ // Adds mappings from node's outputs tensors to converted output tensors,
+ // creating the necessary new node(s). Generally, the steps to convert an op
+ // are:
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
+ // These operations collectively compute the same value as what running
+ // the original operation on slices of the input tensors would produce.
+ // For example, a Cast op in MapDefun translates to a Cast op in
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Given a tensor t in `unstacked`, stacks it by doing the equivalent of
+ // tf.tile(tf.expand_dims(t, 0), [n, 1, 1, ...]) where n is dimension 0 of
+ // inputs to `map_defun_node_`. This stacked tensor will be compatible with
+ // the expected output shape of `map_defun_node_`.
+ // This is equivalent to the _stack function in python Pfor.
+ Status StackTensor(WrappedTensor* unstacked, TensorDesc* result);
+
+ // Recursively looks for unstacked nodes in the `map_defun_fn_` graph by
+ // doing a depth-first search from the ret nodes. Lifts nodes that are
+ // unstacked (i.e. don't derive from arg nodes) into `outer_scope_` directly
+ // and add mappings to `conversion_map_`.
+ Status AddUnstackedNodeMappings();
+
+ // Recursive helper for `AddUnstackedNodeMappings`, returns true if tensor
+ // is unstacked.
+ bool AddUnstackedNodeMappingsHelper(TensorDesc&& tensor, Status* status);
+
+ // Add mappings from `map_defun_fn_` arg nodes to `map_defun_node_` input
+ // nodes to `conversion_map_`.
+ Status AddArgNodeMappings();
+
+ // Maps a tensor to the corresponding WrappedTensor. For example,
+ // {"Cast" Node*, 0} -> WrappedTensor({"Vectorize/Cast" Node*, 0}, true)
+ std::map<TensorDesc, WrappedTensor> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+
+ // Caches the loop_len_node_ needed for tiling unstacked output. This
+ // corresponds to a vector with one element.
+ Node* loop_len_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
+};
+
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
+ return errors::InvalidArgument(
+ "Vectorizing outputs with control inputs is currently not "
+ "supported.");
+ }
+ }
+
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
+ if (vectorizer == nullptr) {
+ return errors::Unimplemented("No vectorizer registered for op: ",
+ op_node->type_string());
+ }
+ std::vector<Port> input_ports, output_ports;
+ input_ports.reserve(op_node->num_inputs());
+ output_ports.reserve(op_node->num_outputs());
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ &input_ports, &output_ports));
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ if (op_node->num_outputs() != output_ports.size() ||
+ op_node->num_inputs() != input_ports.size() ||
+ input_edges.size() != input_ports.size()) {
+ return errors::Internal("Vectorizer inputs/outputs don't match.");
+ }
+
+ // Promote the inputs of the op to MapDefun outputs and connect the edges
+ // accordingly.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
+ input_ports[i].first, input_ports[i].second);
+ }
+
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, {std::move(output_ports[i]), true}});
+ }
+
+ return Status::OK();
+}
+
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+ if (auto found = gtl::FindOrNull(conversion_map_, output)) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ if (found->stacked) {
+ converted_output = found->tensor;
+ } else {
+ // Some outputs may be unstacked if they don't derive from arg nodes
+ // (for example, if a function returns a constant). For these, we
+ // have to add extra nodes to tile it in the 0th dimension.
+ TF_RETURN_IF_ERROR(StackTensor(found, &converted_output));
+ }
+ } else {
+ // Note: All unstacked nodes are converted ahead of time in `Initialize`,
+ // and here we assume that all op vectorizers create only stacked outputs.
+ // This may not hold in the future, as more vectorizers are added that
+ // may actually create unstacked outputs. For example, see the `Shape`
+ // converter in third_party/tensorflow/python/ops/parallel_for/pfor.py
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ converted_output = conversion_map_.at(output).tensor;
+ }
+
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
+ return Status::OK();
+}
+
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
+ }
+
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
+ } else {
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
+ }
+}
+
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
+
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
+ }
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ TF_RETURN_IF_ERROR(AddArgNodeMappings());
+
+ TF_RETURN_IF_ERROR(AddUnstackedNodeMappings());
+ loop_len_node_ = nullptr;
+
+ return Status::OK();
+}
+
+// TODO(rachelim): It might be profitable to use the C++ API for this instead of
+// NodeBuilder
+Status Vectorization::StackTensor(WrappedTensor* unstacked,
+ TensorDesc* result) {
+ // Note that all these nodes are necessary as the size of the batch may not be
+ // constant.
+ if (unstacked->stacked) {
+ return errors::Internal("Can only stack unstacked tensor.");
+ }
+
+ Graph* g = outer_scope_.get();
+ auto node_builder = [](StringPiece op) {
+ return NodeBuilder(strings::StrCat("vectorized/stack/", op), op);
+ };
+
+ auto make_const = [&node_builder](const Input::Initializer& val, Graph* graph,
+ Node** result) {
+ TF_RETURN_IF_ERROR(val.status);
+ return node_builder("Const")
+ .Attr("value", val.tensor)
+ .Attr("dtype", val.tensor.dtype())
+ .Finalize(graph, result);
+ };
+
+ // If loop_len_node_ hasn't been created yet, add the node and cache it.
+ if (loop_len_node_ == nullptr) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(0, &input_node));
+
+ Node* shape_node;
+ TF_RETURN_IF_ERROR(
+ node_builder("Shape").Input(input_node).Finalize(g, &shape_node));
+
+ Node* const_vec_0;
+ TF_RETURN_IF_ERROR(make_const({0}, g, &const_vec_0));
+ Node* const_vec_1;
+ TF_RETURN_IF_ERROR(make_const({1}, g, &const_vec_1));
+
+ Node* strided_slice_node;
+ TF_RETURN_IF_ERROR(node_builder("StridedSlice")
+ .Input(shape_node) // input
+ .Input(const_vec_0) // begin
+ .Input(const_vec_1) // end
+ .Input(const_vec_1) // strides
+ .Finalize(g, &strided_slice_node));
+
+ // Produces a vector of length 1
+ TF_RETURN_IF_ERROR(node_builder("Reshape")
+ .Input(strided_slice_node) // tensor
+ .Input(const_vec_1) // shape
+ .Finalize(g, &loop_len_node_));
+ }
+
+ Node* ones_shape;
+ TF_RETURN_IF_ERROR(node_builder("Shape")
+ .Input(unstacked->tensor.first) // input
+ .Finalize(g, &ones_shape));
+
+ Node* ones;
+ TF_RETURN_IF_ERROR(
+ node_builder("OnesLike").Input(ones_shape).Finalize(g, &ones));
+
+ Node* const_0;
+ TF_RETURN_IF_ERROR(make_const(0, g, &const_0));
+
+ Node* multiples;
+ TF_RETURN_IF_ERROR(node_builder("Concat")
+ .Input(const_0) // concat_dim
+ .Input({{loop_len_node_, 0}, {ones, 0}}) // values
+ .Finalize(g, &multiples));
+
+ Node* expand_dims;
+ TF_RETURN_IF_ERROR(node_builder("ExpandDims")
+ .Input(unstacked->tensor.first) // input
+ .Input(const_0) // dim
+ .Finalize(g, &expand_dims));
+
+ TF_RETURN_IF_ERROR(node_builder("Tile")
+ .Input(expand_dims) // input
+ .Input(multiples) // multiples
+ .Finalize(g, &result->first));
+ result->second = 0;
+ return Status::OK();
+}
+
+Status Vectorization::AddArgNodeMappings() {
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
+
+ conversion_map_.insert({{arg_node, 0}, {{input_node, 0}, true}});
+
+ // Control inputs
+ conversion_map_.insert({{arg_node, Graph::kControlSlot},
+ {{input_node, Graph::kControlSlot}, true}});
+ }
+ return Status::OK();
+}
+
+bool Vectorization::AddUnstackedNodeMappingsHelper(TensorDesc&& tensor,
+ Status* status) {
+ if (auto found = gtl::FindOrNull(conversion_map_, tensor)) {
+ return !found->stacked;
+ }
+
+ if (tensor.first->op_def().is_stateful()) {
+ // We don't lift stateful nodes directly out of the MapDefun, since they may
+ // have to be executed N times.
+ return false;
+ }
+
+ bool is_unstacked = true;
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ // A node is unstacked if all of its inputs are unstacked
+ is_unstacked &= AddUnstackedNodeMappingsHelper(
+ {edge->src(), edge->src_output()}, status);
+ }
+
+ if (!is_unstacked) {
+ return false;
+ }
+
+ // If the node is unstacked, we copy it into outer_scope_ and
+ // add it to the map. Note that we don't clean up the nodes that are copied
+ // in map_defun_fn_, and rely on them being pruned out later.
+ Node* node = outer_scope_->AddNode(tensor.first->def(), status);
+ if (!status->ok()) return true;
+
+ // Add input edges to nodes that should already have been lifted.
+ for (auto edge : tensor.first->in_edges()) {
+ // Ignore Source nodes. Note that these are also ignored in the
+ // GraphToFunctionDef conversion.
+ if (edge->src()->IsSource()) continue;
+
+ if (auto found = gtl::FindOrNull(conversion_map_,
+ {edge->src(), edge->src_output()})) {
+ outer_scope_->AddEdge(found->tensor.first, found->tensor.second, node,
+ edge->dst_input());
+ } else {
+ status->Update(errors::Internal(
+ "Could not find input conversion even though we did depth first "
+ "conversion."));
+ }
+ }
+
+ // Add output mappings
+ for (int i = 0; i < tensor.first->num_outputs(); ++i) {
+ conversion_map_.insert(
+ {{tensor.first, i}, WrappedTensor({node, i}, false)});
+ }
+ conversion_map_.insert({{tensor.first, Graph::kControlSlot},
+ WrappedTensor({node, Graph::kControlSlot}, false)});
+
+ return true;
+}
+
+Status Vectorization::AddUnstackedNodeMappings() {
+ SetVector<Node*> unstacked_nodes;
+ Status s;
+ for (const auto& ret_node : map_defun_fn_->ret_nodes) {
+ const Edge* in_edge = nullptr;
+ TF_RETURN_IF_ERROR(ret_node->input_edge(0, &in_edge));
+ AddUnstackedNodeMappingsHelper({in_edge->src(), in_edge->src_output()}, &s);
+ TF_RETURN_IF_ERROR(s);
+ }
+ return Status::OK();
+}
+
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(outer_scope_.get()));
+ TF_RETURN_IF_ERROR(graph_utils::EnsureNodeNamesUnique(map_defun_fn_->graph));
+
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
+
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
+ }
+
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
+}
+
+} // namespace
+
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
+}
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
new file mode 100644
index 0000000000..bd7d390900
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -0,0 +1,97 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
+//
+// Example:
+// If the input to the `VectorizeMapDefun` function is a MapDefun
+// whose `map_defun_fn` performs the Cast operation, the vectorization will
+// eliminate the MapDefun. This is because the Cast operation supports
+// any tensor shape and can thus be lifted to `result`.
+//
+// Before:
+//
+//
+// outer_scope +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | map_defun_fn +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// result +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
+
+} // end namespace vectorization_utils
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_UTILS_H_
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
new file mode 100644
index 0000000000..a6020e36bb
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -0,0 +1,939 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/tools/graph_transforms/transform_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace vectorization_utils {
+namespace {
+
+NodeDef* AddCastNode(const string& name, const std::vector<string>& inputs,
+ DataType src, DataType dst, bool truncate,
+ FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Cast", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("SrcT", src, node);
+ graph_transforms::SetNodeAttr("DstT", dst, node);
+ graph_transforms::SetNodeAttr("Truncate", truncate, node);
+ return node;
+}
+
+NodeDef* AddUnstackNode(const string& name, const std::vector<string>& inputs,
+ DataType t, int axis, int num, FunctionDef* fn) {
+ NodeDef* node = function_utils::AddNode(name, "Unpack", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("T", t, node);
+ graph_transforms::SetNodeAttr("axis", axis, node);
+ graph_transforms::SetNodeAttr("num", num, node);
+ return node;
+}
+
+NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
+ const std::vector<DataType>& t_arguments,
+ const std::vector<DataType>& output_types,
+ const std::vector<TensorShape>& output_shapes,
+ const string& function_name, FunctionDef* fn) {
+ NameAttrList func;
+ func.set_name(function_name);
+ NodeDef* node = function_utils::AddNode(name, "MapDefun", inputs, {}, fn);
+ graph_transforms::SetNodeAttr("Targuments", t_arguments, node);
+ graph_transforms::SetNodeAttr("Tcaptured", DataTypeVector(), node);
+ graph_transforms::SetNodeAttr("output_types", output_types, node);
+ graph_transforms::SetNodeAttr("output_shapes", output_shapes, node);
+ graph_transforms::SetNodeAttr("f", func, node);
+ return node;
+}
+
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
+// TODO(rachelim): Use FunctionDefHelper::Create instead
+FunctionDef CreateFunction(
+ StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
+ const std::vector<std::pair<string, DataType>>& outputs,
+ const std::map<string, string>& rets) {
+ FunctionDef func;
+ auto* signature = func.mutable_signature();
+ signature->set_name(string(name));
+ for (const auto& x : inputs) {
+ auto* arg_def = signature->add_input_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : outputs) {
+ auto* arg_def = signature->add_output_arg();
+ arg_def->set_name(x.first);
+ arg_def->set_type(x.second);
+ }
+ for (const auto& x : rets) {
+ (*func.mutable_ret())[x.first] = x.second;
+ }
+
+ return func;
+}
+
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | | | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "arg0"}, {"ret1", "arg1"}});
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"ret0", "ret1"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ Status s = VectorizeMapDefun(outer, *map_defun, &lib, &vectorized);
+ LOG(ERROR) << s;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
+}
+
+// Before:
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// | +-----------+ Arg0 +---+ Arg1 +----+ |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ +---v--+ | |
+// | | +---| XOp1 | | Cast | | |
+// | | +---+--+ +---+--+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+// where XOp1 is not convertible.
+//
+// After:
+//
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}},
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
+ NodeDef* x_op1 =
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
+ CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
+
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
+ {{"mapdefun", DT_INT32}, {"mapdefun_0", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x", "y"}, {DT_INT32, DT_INT32}, {DT_INT32, DT_INT32},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(GetRetval(*vectorized, 0),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +---------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +----------+ | |
+// | | | | | |
+// | | MapDefun +---v--+ +---v--+ | |
+// | +-----------+ Ret0 +---+ Ret1 +----+ |
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +-------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +----------+ |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
+ // Tests that behavior is correct when an output is used more than once.
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}, {"ret1", DT_INT64}},
+ {{"ret0", "Cast:y:0"}, {"ret1", "Cast:y:0"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction(
+ "outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}, {"mapdefun_0", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}, {"mapdefun_0", "MapDefun:output:1"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64, DT_INT64},
+ {{}, {}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0), "x");
+ EXPECT_EQ(GetRetval(*vectorized, 0),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(GetRetval(*vectorized, 1),
+ strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"arg0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
+ EXPECT_EQ(unpack_node.input(0), "x");
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+ EXPECT_EQ(GetRetval(*vectorized, 0),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(GetRetval(*vectorized, 1),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(GetRetval(*vectorized, 2),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+}
+
+// Before:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +------------------+ Arg0 +------------------+ |
+// | | +---+--+ | |
+// | | | | |
+// | | +---+--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v---+ num=3 | |
+// | | |Unstack| axis=0 | |
+// | | ++--+--++ | |
+// | | | | | | |
+// | | +----+ | +-------+ | |
+// | | | | | | |
+// | | MapDefun +---v--+ +-v----+ +--v---+ | |
+// | +----------+ Ret0 +--+ Ret1 +--+ Ret2 +------+ |
+// | +---+--+ +--+---+ +--+---+ |
+// | | | | |
+// | +---v--+ +--v---+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+//
+// After:
+//
+// +------+
+// +----------------------+ Arg0 +----------------------+
+// | +---+--+ |
+// | | |
+// | +---+--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v---+ num=3 |
+// | |Unstack| axis=1 |
+// | ++--+--++ |
+// | | | | |
+// | +----+ | +-------+ |
+// | | | | |
+// | | | | |
+// | +---v--+ +-v----+ +--v---+ |
+// +--------------+ Ret0 +--+ Ret1 +--+ Ret2 +----------+
+// +------+ +------+ +------+
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
+ FunctionDef inner = CreateFunction(
+ "inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT32}, {"ret1", DT_INT32}, {"ret2", DT_INT32}},
+ {{"ret0", "MyUnstack:output:0"},
+ {"ret1", "MyUnstack:output:1"},
+ {"ret2", "MyUnstack:output:2"}});
+ NodeDef* cast_op =
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_op);
+ NodeDef* unstack_op =
+ AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
+ CHECK_NOTNULL(unstack_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT32},
+ {"mapdefun_0", DT_INT32},
+ {"mapdefun_1", DT_INT32}},
+ {{"mapdefun", "MapDefun:output:0"},
+ {"mapdefun_0", "MapDefun:output:1"},
+ {"mapdefun_1", "MapDefun:output:2"}});
+
+ NodeDef* map_defun = AddMapDefunNode(
+ "MapDefun", {"x"}, {DT_INT32}, {DT_INT32, DT_INT32, DT_INT32},
+ {{1}, {1}, {1}}, inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0), "x");
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
+ EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
+ EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
+ EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
+ EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
+
+ EXPECT_EQ(GetRetval(*vectorized, 0),
+ strings::StrCat(unpack_node.name(), ":output:0"));
+ EXPECT_EQ(GetRetval(*vectorized, 1),
+ strings::StrCat(unpack_node.name(), ":output:1"));
+ EXPECT_EQ(GetRetval(*vectorized, 2),
+ strings::StrCat(unpack_node.name(), ":output:2"));
+ EXPECT_EQ(vectorized->node_def_size(), 2);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +---+--+ | |
+// | | +---------+ | |
+// | | +---v--+ | | |
+// | | |Print | | | |
+// | | +---+--+ | | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// No change because we don't deal with control inputs for now.
+//
+TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
+ FunctionDef inner =
+ CreateFunction("inner_function", {{"arg0", DT_INT32}},
+ {{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
+ CHECK_NOTNULL(print_op);
+ NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
+ false, &inner);
+ CHECK_NOTNULL(cast_op);
+
+ FunctionDef outer = CreateFunction("outer_function", {{"x", DT_INT32}},
+ {{"mapdefun", DT_INT64}},
+ {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"x"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+ CHECK_NOTNULL(map_defun);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ // They should be unchanged
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeConst) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int32"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Const:output:0"}});
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int32"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT32}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_TRUE(function_utils::ContainsFunctionNodeWithOp("Const", *vectorized));
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | | |
+// | | +------+ | |
+// | | |Const | | |
+// | | +---+--+ | |
+// | | | | |
+// | | +---v--+ | |
+// | | | Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | +------+ |
+// | |Const | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | | Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | |Stack*| |
+// | +---+--+ |
+// | | |
+// | | |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedOutput) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0"}, DT_INT32, DT_INT64, false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Const", *vectorized));
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node.name());
+}
+
+// Before:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +-----------+ Arg0 +-----+ |
+// | | +------+ | |
+// | | | |
+// | | +------+ +------+ | |
+// | | |Const | |Const | | |
+// | | +---+--+ +---+--+ | |
+// | | : +---v--+ | |
+// | | ::::::> Cast | | |
+// | | +---+--+ | |
+// | | | | |
+// | | MapDefun +---v--+ | |
+// | +-----------+ Ret0 +-----+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+//
+//
+// After:
+//
+//
+// +------+
+// +---------------+ Arg0 +---------+
+// | +------+ |
+// | |
+// | |
+// | +------+ |
+// | +------+ |Const | |
+// | |Const | +---+--+ |
+// | +---+--+ | |
+// | : +---v--+ |
+// | ::::::> Cast | |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// | +Stack*+ |
+// | +---+--+ |
+// | | |
+// | +---v--+ |
+// +---------------+ Ret0 +---------+
+// +------+
+// *Not actually a Stack node, but does the equivalent.
+//
+TEST(VectorizeMapDefunTest, VectorizeUnstackedControl) {
+ FunctionDef inner = FunctionDefHelper::Create(
+ "inner_function", {"arg0: int32"}, {"ret0: int64"}, {/* attrs */},
+ {/* nodes */ FunctionDefHelper::Const("Const", 2),
+ FunctionDefHelper::Const("ConstDep", 3)},
+ {{"ret0", "Cast:y:0"}});
+ AddCastNode("Cast", {"Const:output:0", "^ConstDep"}, DT_INT32, DT_INT64,
+ false, &inner);
+
+ FunctionDef outer = FunctionDefHelper::Create(
+ "outer_function", {"outer_arg0: int32"}, {"mapdefun: int64"},
+ {/* attrs */}, {/* nodes */}, {{"mapdefun", "MapDefun:output:0"}});
+
+ NodeDef* map_defun =
+ AddMapDefunNode("MapDefun", {"outer_arg0"}, {DT_INT32}, {DT_INT64}, {{}},
+ inner.signature().name(), &outer);
+
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto find_const = [vectorized](int val) -> const NodeDef* {
+ for (const auto& n : vectorized->node_def()) {
+ if (n.attr().at("value").tensor().int_val(0) == val) {
+ return &n;
+ }
+ }
+ return nullptr;
+ };
+
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ auto const_node = find_const(2);
+ auto const_dep_node = find_const(3);
+ auto cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
+ EXPECT_EQ(cast_node.input(0).substr(0, cast_node.input(0).find(':')),
+ const_node->name());
+ EXPECT_EQ(cast_node.input(1), strings::StrCat("^", const_dep_node->name()));
+}
+
+// TODO(rachelim): More test cases when we get around to implementing them:
+// [] A badly defined converter, e.g. doesn't produce nodes that have the
+// same number of outputs/inputs as the nodes to be converted
+// [] Converter where the 'converted' form has multiple nodes.
+// [] Case with dependent nodes, e.g. ops with const inputs that are
+// broadcasted.
+// [] Python-side tests to actually run the functions to make sure
+// they work.
+
+} // namespace
+} // namespace vectorization_utils
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper.cc b/tensorflow/core/grappler/optimizers/debug_stripper.cc
index 9701a038d0..800160e649 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper.cc
@@ -38,7 +38,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// be optimized away by dependency optimizer.
for (string& inp : *node.mutable_input()) {
if (!IsControlInput(inp)) {
- inp = AsControlDependency(inp);
+ inp = AsControlDependency(NodeName(inp));
}
}
} else if (IsCheckNumerics(node) || IsPrint(node)) {
@@ -54,7 +54,7 @@ Status DebugStripper::Optimize(Cluster* cluster, const GrapplerItem& item,
// input.
for (size_t i = 1; i < node.input_size(); ++i) {
if (!IsControlInput(node.input(i))) {
- *node.mutable_input(i) = AsControlDependency(node.input(i));
+ *node.mutable_input(i) = AsControlDependency(NodeName(node.input(i)));
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
index 96ceee791f..affd2d51c2 100644
--- a/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
+++ b/tensorflow/core/grappler/optimizers/debug_stripper_test.cc
@@ -43,6 +43,35 @@ TEST_F(DebugStripperTest, OutputEqualToInput) {
CompareGraphs(item.graph, output);
}
+TEST_F(DebugStripperTest, StripAssertOnTwoOutputs) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT,
+ ops::Placeholder::Shape({6}));
+ auto split =
+ ops::Split(s.WithOpName("split"), /*axis=*/0, input, /*num_split=*/2);
+ Output x = split[0];
+ Output y = split[1];
+ Output ge = ops::GreaterEqual(s.WithOpName("GreaterEqual"), x, y);
+ auto assert = ops::Assert(s.WithOpName("Assert"), ge, {x, y});
+ Output add = ops::Add(
+ s.WithOpName("add").WithControlDependencies({assert.operation}), x, y);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ DebugStripper optimizer;
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ for (const NodeDef& node : output.node()) {
+ for (const string& input : node.input()) {
+ if (IsControlInput(input)) {
+ EXPECT_EQ(input.find(':'), -1);
+ }
+ }
+ }
+}
+
TEST_F(DebugStripperTest, StripAssertFromGraph) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.cc b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
index 00ad7494f4..79d9ea1608 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.cc
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/evaluation_utils.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/denormal.h"
diff --git a/tensorflow/core/grappler/optimizers/evaluation_utils.h b/tensorflow/core/grappler/optimizers/evaluation_utils.h
index 8414b5b8ca..c9dfb6dc0b 100644
--- a/tensorflow/core/grappler/optimizers/evaluation_utils.h
+++ b/tensorflow/core/grappler/optimizers/evaluation_utils.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace Eigen {
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
new file mode 100644
index 0000000000..2c36c9b7b3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.cc
@@ -0,0 +1,111 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+Status ExperimentalImplementationSelector::LoadFunctions(
+ const GraphDef& graph) {
+ lib_info_.reset(new FunctionLibraryApiInfo);
+ TF_RETURN_IF_ERROR(lib_info_->Init(graph.library()));
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::MaybeOptimizeFunctionCall(
+ NodeDef* node_def) const {
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, or
+ // 2. Via the @defun functional interface, where the real function name
+ // appear as the attribute with type func.
+ std::vector<string> function_attribute_names;
+ for (const auto& attr : node_def->attr()) {
+ if (attr.second.has_func() &&
+ lib_info_->GetApiInfo(attr.second.func().name()) != nullptr) {
+ function_attribute_names.emplace_back(attr.first);
+ }
+ }
+
+ if (function_attribute_names.empty() &&
+ lib_info_->GetApiInfo(node_def->op()) == nullptr) {
+ // A regular op, or a function which has no interface.
+ return Status::OK();
+ }
+
+ string task, device;
+ if (!DeviceNameUtils::SplitDeviceName(node_def->device(), &task, &device)) {
+ return errors::Internal("Could not split device name:", node_def->device());
+ }
+ VLOG(2) << "Op " << node_def->name() << " runs on " << node_def->device()
+ << " = (" << task << ", " << device << ")";
+ DeviceNameUtils::ParsedName parsed_name;
+ DeviceNameUtils::ParseLocalName(device, &parsed_name);
+
+ for (const auto& attr_name : function_attribute_names) {
+ string function_name = node_def->attr().at(attr_name).func().name();
+ string best_function_name;
+ lib_info_->GetBestImplementation(function_name, parsed_name.type,
+ &best_function_name);
+ if (function_name != best_function_name) {
+ node_def->mutable_attr()
+ ->find(attr_name)
+ ->second.mutable_func()
+ ->set_name(best_function_name);
+ }
+ }
+ if (lib_info_->GetApiInfo(node_def->op()) != nullptr) {
+ string best_function_name;
+ lib_info_->GetBestImplementation(node_def->op(), parsed_name.type,
+ &best_function_name);
+ if (node_def->op() != best_function_name) {
+ node_def->set_op(best_function_name);
+ }
+ }
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::SelectImplementation(
+ GraphDef* graph) const {
+ for (int k = 0; k < graph->node_size(); ++k)
+ TF_RETURN_IF_ERROR(MaybeOptimizeFunctionCall(graph->mutable_node(k)));
+
+ return Status::OK();
+}
+
+Status ExperimentalImplementationSelector::Optimize(Cluster* cluster,
+ const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+ TF_RETURN_IF_ERROR(LoadFunctions(*optimized_graph));
+ return SelectImplementation(optimized_graph);
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
new file mode 100644
index 0000000000..82f7473a14
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector.h
@@ -0,0 +1,115 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
+
+#include <string>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace grappler {
+
+// -- EXPERIMENTAL --
+// This transformation replaces function calls by the appropriate function
+// definition based on properties of the runtime system. For instance,
+// we may choose one implementation over another if we have a GPU with
+// enough memory available.
+//
+// It is a way for the programmer to specify alternative implementations
+// of the same functionality in the graph, and let TensorFlow pick the
+// most appropriate one at runtime.
+//
+// For instance, the python code might specify:
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one',
+// experimental_api_preferred_device='GPU')
+// def plus_one_gpu(x): return x + 1.0
+//
+// @Defun(tf.float32,
+// experimental_api_implements='plus_one')
+// def plus_one_reference_implementation(x): return x + 1.0
+// input = tf.constant(2.0, dtype=tf.float32)
+//
+// z = plus_one_reference_implementation(input)
+// z = plus_one_gpu(input)
+// print(sess.run(z))
+//
+// At runtime, we will trim either `plus_one_gpu` or
+// `plus_one_reference_implementation` based on the availability of the GPU.
+//
+// Available annotations:
+// - experimental_api_implements(string): all functions mapping to the same
+// string can be interchanged. For now, all functions must have the same
+// signature and overloads are not allowed. Defuns within defuns are
+// allowed.
+// - experimental_api_preferred_device(string): sets which device is preferred.
+class ExperimentalImplementationSelector : public CustomGraphOptimizer {
+ public:
+ ExperimentalImplementationSelector() = default;
+ ~ExperimentalImplementationSelector() override = default;
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+ string name() const override {
+ return "experimental_implementation_selector";
+ }
+
+ // This call is not thread-safe.
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ // Does not take any feedback.
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ Status LoadFunctions(const GraphDef& graph);
+ Status MaybeOptimizeFunctionCall(NodeDef* node_def) const;
+
+ // Finds all call sites for functions, then replace with the appropriate
+ // implementation.
+ // There are two ways of calling functions:
+ // 1. By specifying an op name as a function name, and
+ // 2. Via the functional interface, where the function name appears as an
+ // Attr.
+ //
+ // There may be multiple call sites for a given function. The function body
+ // may call into another function, so a function might have to be duplicated.
+ // For simplicity, we do not change function bodies. Also, we do not change
+ // gradients.
+ Status SelectImplementation(GraphDef* graph) const;
+
+ std::unique_ptr<FunctionLibraryApiInfo> lib_info_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExperimentalImplementationSelector);
+};
+
+} // namespace grappler
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_EXPERIMENTAL_IMPLEMENTATION_SELECTOR_H_
diff --git a/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
new file mode 100644
index 0000000000..3f1ebefac6
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/experimental_implementation_selector_test.cc
@@ -0,0 +1,138 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
+
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+constexpr char CpuDevice[] = "/device:CPU:0";
+constexpr char GpuDevice[] = "/device:GPU:0";
+
+class ExperimentalImplementationSelectorTest : public GrapplerTest {};
+
+TEST_F(ExperimentalImplementationSelectorTest, NoUpdate) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {CpuDevice});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ std::unique_ptr<CustomGraphOptimizer> optimizer(
+ new ExperimentalImplementationSelector);
+ ASSERT_NE(nullptr, optimizer);
+ TF_ASSERT_OK(optimizer->Init());
+
+ GraphDef output;
+ const Status status = optimizer->Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ // This is a trivial graph so there is nothing to update.
+ EXPECT_EQ(item.graph.node_size(), output.node_size());
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementation) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("times_two");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XAddX();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("times_two");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, GpuDevice),
+ NDef("y1", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("z1", "Identity", {"y1"}, {{"T", DT_FLOAT}}, GpuDevice),
+ NDef("y2", "XTimesTwo", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z2", "Identity", {"y2"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(output.node_size(), 5);
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "y1") {
+ // Make sure the implementation has been swapped to use the GPU version.
+ EXPECT_EQ("XAddX", node.op());
+ } else if (node.name() == "y2") {
+ // Make sure the implementation is not changed.
+ EXPECT_EQ("XTimesTwo", node.op());
+ }
+ }
+}
+
+TEST_F(ExperimentalImplementationSelectorTest, SwapImplementationEval) {
+ using test::function::NDef;
+ auto cpu_def = test::function::XTimesTwo();
+ auto* func_attr = cpu_def.mutable_attr();
+ (*func_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func_attr)["experimental_api_preferred_device"].set_s("CPU");
+
+ auto gpu_def = test::function::XTimesFour();
+ auto* func2_attr = gpu_def.mutable_attr();
+ (*func2_attr)["experimental_api_implements"].set_s("random_boost");
+ (*func2_attr)["experimental_api_preferred_device"].set_s("GPU");
+
+ ExperimentalImplementationSelector optimizer;
+ GraphDef output;
+ GrapplerItem item;
+ item.graph = test::function::GDef(
+ {NDef("x", "Placeholder", {}, {{"dtype", DT_FLOAT}}, CpuDevice),
+ NDef("y", "XTimesFour", {"x"}, {{"T", DT_FLOAT}}, CpuDevice),
+ NDef("z", "Identity", {"y"}, {{"T", DT_FLOAT}}, CpuDevice)},
+ // FunctionLib
+ {cpu_def, gpu_def});
+
+ const Tensor input = test::AsScalar<float>(1.0f);
+ item.fetch = {"z"};
+ item.feed.emplace_back("x", input);
+
+ const auto four_times_boosted_tensor = EvaluateFetchNodes(item);
+ test::ExpectTensorEqual<float>(four_times_boosted_tensor[0],
+ test::AsScalar<float>(4.0f));
+
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+ GrapplerItem optimized(item, std::move(output));
+ const auto twice_boosted_tensor = EvaluateFetchNodes(optimized);
+ test::ExpectTensorEqual<float>(twice_boosted_tensor[0],
+ test::AsScalar<float>(2.0f));
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.cc b/tensorflow/core/grappler/optimizers/function_api_info.cc
new file mode 100644
index 0000000000..798e0f6fd5
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.cc
@@ -0,0 +1,167 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+FunctionApiInfo::FunctionApiInfo() {}
+FunctionApiInfo::~FunctionApiInfo() {}
+
+Status FunctionApiInfo::Init(const FunctionDef& function_def) {
+ for (const auto& attr : function_def.attr()) {
+ if (attr.first == "experimental_api_preferred_device") {
+ preferred_device_ = attr.second.s();
+ }
+ if (attr.first == "experimental_api_implements") {
+ interface_name_ = attr.second.s();
+ }
+ }
+ if (interface_name_.empty() && !preferred_device_.empty()) {
+ return errors::InvalidArgument(
+ "Function '", function_def.signature().name(),
+ "' has a preferred device, but does not implement an interface");
+ }
+ return Status::OK();
+}
+
+const string& FunctionApiInfo::preferred_device() const {
+ return preferred_device_;
+}
+
+const string& FunctionApiInfo::interface_name() const {
+ return interface_name_;
+}
+
+FunctionLibraryApiInfo::FunctionLibraryApiInfo() {}
+FunctionLibraryApiInfo::~FunctionLibraryApiInfo() {}
+
+namespace {
+bool IsSameSignature(const FunctionDef& f1, const FunctionDef& f2) {
+ if (f1.ret().size() != f2.ret().size()) return false;
+ const auto& sig1 = f1.signature();
+ const auto& sig2 = f2.signature();
+ // Functions have positional semantics, so we don't check for names.
+ if (sig1.input_arg_size() != sig2.input_arg_size()) return false;
+ for (int k = 0; k < sig1.input_arg_size(); ++k) {
+ const OpDef::ArgDef& arg1 = sig1.input_arg(k);
+ const OpDef::ArgDef& arg2 = sig2.input_arg(k);
+ if (arg1.type() != arg2.type()) return false;
+ if (arg1.type_attr() != arg2.type_attr()) return false;
+ if (arg1.number_attr() != arg2.number_attr()) return false;
+ if (arg1.type_list_attr() != arg2.type_list_attr()) return false;
+ if (arg1.is_ref() != arg2.is_ref()) return false;
+ }
+ return true;
+}
+
+Status ValidateSignature(const string& interface_name,
+ const std::vector<const FunctionDef*>& equiv_funcs) {
+ if (equiv_funcs.size() < 2) return Status::OK();
+ for (size_t k = 1; k < equiv_funcs.size(); ++k) {
+ if (!IsSameSignature(*equiv_funcs[0], *equiv_funcs[k]))
+ return errors::InvalidArgument(
+ "Functions '", equiv_funcs[0]->signature().name(), "' and '",
+ equiv_funcs[k]->signature().name(), "' both implement '",
+ interface_name, "' but their signatures do not match.");
+ }
+ return Status::OK();
+}
+
+Status ValidateSignatures(
+ const std::unordered_map<string, std::vector<const FunctionDef*>>&
+ intf_to_func) {
+ for (const auto& item : intf_to_func)
+ TF_RETURN_IF_ERROR(ValidateSignature(item.first, item.second));
+ return Status::OK();
+}
+} // namespace
+
+Status FunctionLibraryApiInfo::Init(
+ const FunctionDefLibrary& function_library) {
+ std::unordered_map<string, std::vector<const FunctionDef*>> intf_to_func;
+ for (const auto& function : function_library.function()) {
+ std::unique_ptr<FunctionApiInfo> func_info(new FunctionApiInfo);
+ TF_RETURN_IF_ERROR(func_info->Init(function));
+ // Ignore the function if it does not implement any interface.
+ if (func_info->interface_name().empty()) continue;
+
+ const string& function_name = function.signature().name();
+ const string& interface_name = func_info->interface_name();
+ func_to_intf_[function_name] = interface_name;
+ intf_to_funcs_[interface_name].emplace_back(function_name);
+ intf_to_func[interface_name].emplace_back(&function);
+ func_info_[function_name] = std::move(func_info);
+ }
+ TF_RETURN_IF_ERROR(ValidateSignatures(intf_to_func));
+ return Status::OK();
+}
+
+void FunctionLibraryApiInfo::GetEquivalentImplementations(
+ const string& function_name, std::vector<string>* other_names) const {
+ const auto intf_it = func_to_intf_.find(function_name);
+ // The function does not implement any interface.
+ if (intf_it == func_to_intf_.end()) return;
+ CHECK(!intf_it->second.empty()) << "Function " << function_name
+ << "should at least implement 1 interface.";
+ const auto it = intf_to_funcs_.find(intf_it->second);
+ CHECK(it != intf_to_funcs_.end())
+ << "Function " << function_name << " maps to " << intf_it->second
+ << " but no reverse mapping was found";
+ CHECK_GE(it->second.size(), 1) << "Class " << it->first << " is empty";
+ other_names->reserve(it->second.size() - 1);
+ for (const auto& other_name : it->second) {
+ if (other_name == function_name) continue;
+ other_names->emplace_back(other_name);
+ }
+}
+
+void FunctionLibraryApiInfo::GetBestImplementation(
+ const string& function_name, const string& device,
+ string* best_func_name) const {
+ CHECK(best_func_name != nullptr);
+ const auto func_it = func_to_intf_.find(function_name);
+ if (func_it == func_to_intf_.end()) return;
+
+ const auto it = intf_to_funcs_.find(func_it->second);
+ // No function found for the given interface.
+ if (it == intf_to_funcs_.end()) return;
+ for (const auto& func_name : it->second) {
+ const auto func_api_info = func_info_.find(func_name)->second.get();
+ if (func_api_info->preferred_device() == device) {
+ best_func_name->assign(func_name);
+ return;
+ }
+ }
+ // Didn't find a function with the match device name, choose the first one
+ // among all the available functions.
+ best_func_name->assign(it->second.front());
+}
+
+const FunctionApiInfo* FunctionLibraryApiInfo::GetApiInfo(
+ const string& function_name) const {
+ const auto it = func_info_.find(function_name);
+ if (it == func_info_.end()) return nullptr;
+ return it->second.get();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_api_info.h b/tensorflow/core/grappler/optimizers/function_api_info.h
new file mode 100644
index 0000000000..412687c58c
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info.h
@@ -0,0 +1,80 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
+
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace grappler {
+class FunctionApiInfo {
+ public:
+ FunctionApiInfo();
+ virtual ~FunctionApiInfo();
+
+ Status Init(const FunctionDef& function_def);
+
+ const string& interface_name() const;
+ const string& preferred_device() const;
+
+ private:
+ string interface_name_;
+ string preferred_device_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionApiInfo);
+};
+
+// A collection of information for function and the interface it implements.
+// A interface is a well defined math operation, eg I1 = 2 * x + y. Multiple
+// functions could implement the same interface with different behavior based on
+// different hardware condition and limits,
+// eg F1 = math_ops.add(math_ops.add(x, x), y), or
+// F2 = math_ops.add(math_ops.matmul(x, 2), y).
+class FunctionLibraryApiInfo {
+ public:
+ FunctionLibraryApiInfo();
+ virtual ~FunctionLibraryApiInfo();
+ // Populate the internal field for the functions within the function_library.
+ Status Init(const FunctionDefLibrary& function_library);
+
+ void GetEquivalentImplementations(const string& function_name,
+ std::vector<string>* other_names) const;
+
+ void GetBestImplementation(const string& function_name, const string& device,
+ string* best_func_name) const;
+
+ const FunctionApiInfo* GetApiInfo(const string& function_name) const;
+
+ private:
+ // Map between function name to function details.
+ std::unordered_map<string, std::unique_ptr<FunctionApiInfo>> func_info_;
+ // Map between function name to interface name.
+ std::unordered_map<string, string> func_to_intf_;
+ // Map between interface name to function names.
+ std::unordered_map<string, std::vector<string>> intf_to_funcs_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(FunctionLibraryApiInfo);
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_FUNCTION_API_INFO_H_
diff --git a/tensorflow/core/grappler/optimizers/function_api_info_test.cc b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
new file mode 100644
index 0000000000..582890d3e3
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/function_api_info_test.cc
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/function_api_info.h"
+
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/framework/op_def.pb.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+void SetArg(const string& name, const string& type_name,
+ OpDef::ArgDef* arg_def) {
+ arg_def->set_name(name);
+ arg_def->set_type_attr(type_name);
+}
+
+typedef std::pair<string, string> ArgSpec; // name, type.
+
+void SetArgs(const std::vector<ArgSpec>& args_spec, OpDef* sig) {
+ for (const auto& arg_spec : args_spec)
+ SetArg(arg_spec.first, arg_spec.second, sig->add_input_arg());
+ SetArg("output", "float32", sig->add_output_arg());
+}
+
+void PopulateFunction(const string& name, const string& api_interface_name,
+ const string& preferred_device,
+ const std::vector<ArgSpec>& input_args,
+ FunctionDef* func_def) {
+ OpDef* sig = func_def->mutable_signature();
+ sig->set_name(name);
+
+ SetArgs(input_args, sig);
+
+ if (!api_interface_name.empty() || !preferred_device.empty()) {
+ auto* func_attr = func_def->mutable_attr();
+ if (!api_interface_name.empty())
+ (*func_attr)["experimental_api_implements"].set_s(api_interface_name);
+ if (!preferred_device.empty())
+ (*func_attr)["experimental_api_preferred_device"].set_s(preferred_device);
+ }
+}
+
+void PopulateSampleLibrary(const bool mismatch_args,
+ FunctionDefLibrary* func_lib) {
+ const std::vector<ArgSpec> func_args{{"in1", "float32"}, {"in2", "int32"}};
+ const std::vector<ArgSpec> func_wrong_args{{"in1", "int32"},
+ {"in2", "int32"}};
+ PopulateFunction("DoStuffCpu", "DoStuff", "CPU", func_args,
+ func_lib->add_function());
+ PopulateFunction("DoStuffGpu", "DoStuff", "GPU",
+ mismatch_args ? func_wrong_args : func_args,
+ func_lib->add_function());
+ PopulateFunction("DoThings", "DoThings", "", func_args,
+ func_lib->add_function());
+ PopulateFunction("OneOff", "", "", func_args, func_lib->add_function());
+ PopulateFunction("AnotherOneOff", "", "", func_args,
+ func_lib->add_function());
+}
+
+bool CheckEquivImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name,
+ const std::vector<string>& expected_other) {
+ std::vector<string> other_impl;
+ lib_api_info.GetEquivalentImplementations(func_name, &other_impl);
+ const std::unordered_set<string> actual(other_impl.begin(), other_impl.end());
+ const std::unordered_set<string> expected(expected_other.begin(),
+ expected_other.end());
+ return actual == expected;
+}
+
+bool CheckGetBestImpl(const FunctionLibraryApiInfo& lib_api_info,
+ const string& function_name, const string& device,
+ const string& expected_function_name) {
+ string best_function_name;
+ lib_api_info.GetBestImplementation(function_name, device,
+ &best_function_name);
+
+ return best_function_name == expected_function_name;
+}
+
+string GetInterfaceName(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->interface_name();
+}
+
+string GetPreferredDevice(const FunctionLibraryApiInfo& lib_api_info,
+ const string& func_name) {
+ auto* info = lib_api_info.GetApiInfo(func_name);
+ CHECK_NOTNULL(info);
+ return info->preferred_device();
+}
+
+TEST(FunctionApiInfoTest, ParseTags) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ false, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ TF_ASSERT_OK(lib_api_info.Init(func_lib));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffCpu", {"DoStuffGpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoStuffGpu", {"DoStuffCpu"}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "Undefined", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "OneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "AnotherOneOff", {}));
+ EXPECT_TRUE(CheckEquivImpl(lib_api_info, "DoThings", {}));
+
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("DoStuff", GetInterfaceName(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("DoThings", GetInterfaceName(lib_api_info, "DoThings"));
+
+ EXPECT_EQ("CPU", GetPreferredDevice(lib_api_info, "DoStuffCpu"));
+ EXPECT_EQ("GPU", GetPreferredDevice(lib_api_info, "DoStuffGpu"));
+ EXPECT_EQ("", GetPreferredDevice(lib_api_info, "DoThings"));
+
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffCpu", "GPU", "DoStuffGpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "CPU", "DoStuffCpu"));
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "GPU", "DoStuffGpu"));
+
+ EXPECT_TRUE(CheckGetBestImpl(lib_api_info, "DoThings", "GPU", "DoThings"));
+ // TPU impl is not available, choose the first one available which is the CPU.
+ EXPECT_TRUE(
+ CheckGetBestImpl(lib_api_info, "DoStuffGpu", "TPU", "DoStuffCpu"));
+}
+
+TEST(FunctionApiInfoTest, MismatchedArguments) {
+ FunctionDefLibrary func_lib;
+ PopulateSampleLibrary(/* mismatch_args */ true, &func_lib);
+ FunctionLibraryApiInfo lib_api_info;
+ const Status ret = lib_api_info.Init(func_lib);
+ EXPECT_FALSE(ret.ok());
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 645e4c2087..56364f0095 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -453,6 +453,7 @@ Status InitializeFunctionSpecializationSignature(
}
Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
+ const int graph_def_version,
FunctionOptimizerContext* ctx,
GraphDef* optimized_graph) {
VLOG(2) << "Specialize function instantiation: "
@@ -492,7 +493,8 @@ Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
// Make a GrapplerFunctionItem and convert it back to FunctionDef after
// pushing all constant inputs into the function body.
GrapplerFunctionItem item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, func_attr, flib,
+ graph_def_version, &item));
// Push const inputs into the function body, and keep track of their control
// dependencies.
@@ -576,15 +578,15 @@ NodeDef InlinedFunctionOutputsNode(const NodeDef& func_node,
Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
const FunctionOptimizerContext& ctx,
- GraphDef* optimized_graph) {
+ const int graph_def_version, GraphDef* optimized_graph) {
VLOG(2) << "Inline function instantiation: " << SummarizeNodeDef(func_node);
const std::unordered_map<string, AttrValue> func_attr(
func_node.attr().begin(), func_node.attr().end());
GrapplerFunctionItem item;
- Status item_status =
- MakeGrapplerFunctionItem(func, func_attr, ctx.function_library(), &item);
+ Status item_status = MakeGrapplerFunctionItem(
+ func, func_attr, ctx.function_library(), graph_def_version, &item);
if (!item_status.ok()) {
return errors::InvalidArgument("Failed to inline function ", func_node.op(),
@@ -645,7 +647,8 @@ Status InlineFunction(const NodeDef& func_node, const FunctionDef& func,
if (func_body_node_func != nullptr) {
// Recursively inline function calls.
TF_RETURN_IF_ERROR(InlineFunction(func_body_node, *func_body_node_func,
- ctx, optimized_graph));
+ ctx, graph_def_version,
+ optimized_graph));
} else {
// Annotate the node with the function attributes.
for (const auto& attr : func.attr()) {
@@ -824,7 +827,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (inline_func && ctx.IsInlinedFunction(func_name)) {
// Inline function body into the optimized graph}
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- InlineFunction(node, *func, ctx, optimized_graph));
+ InlineFunction(node, *func, ctx, item.graph.versions().producer(),
+ optimized_graph));
continue;
}
@@ -837,7 +841,8 @@ Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// TODO(ezhulenev): Specialize function call if input has a known shape.
// Specialize function body for its instantiation attributes and inputs.
TF_SKIP_ERROR_IF_GRAPH_UNMODIFIED(
- SpecializeFunction(node, *func, &ctx, optimized_graph));
+ SpecializeFunction(node, *func, item.graph.versions().producer(),
+ &ctx, optimized_graph));
continue;
}
}
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 1be5f8dcc2..c775a26914 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/clusters/virtual_cluster.h"
#include "tensorflow/core/grappler/costs/graph_memory.h"
@@ -1070,11 +1071,13 @@ static bool IdentifySwappingCandidates(
// ensure that swapping the tensor back in won't recreate the memory
// bottleneck. Last but not least, we want the tensor to have as few
// remaining uses as possible.
+ //
+ // Note that we must perform the arithmetic inexactly as "double", since
+ // the values do not fit into any integral type.
mem_info.fitness =
- MathUtil::IPow((earliest_use - peak_time).count(), 2);
- mem_info.fitness /= MathUtil::IPow(mem_info.uses_left.size(), 2);
- mem_info.fitness +=
- MathUtil::IPow((allocation_time - peak_time).count(), 2);
+ MathUtil::IPow<double>((earliest_use - peak_time).count(), 2) /
+ MathUtil::IPow<double>(mem_info.uses_left.size(), 2) +
+ MathUtil::IPow<double>((allocation_time - peak_time).count(), 2);
mem_info.fitness = -mem_info.fitness;
mem_state.push_back(mem_info);
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e778b7879d..c3d70a1fdf 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -23,11 +23,13 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/optimizers/debug_stripper.h"
#include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
+#include "tensorflow/core/grappler/optimizers/experimental_implementation_selector.h"
#include "tensorflow/core/grappler/optimizers/function_optimizer.h"
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
#include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
#include "tensorflow/core/grappler/optimizers/remapper.h"
#include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
@@ -35,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -72,6 +75,16 @@ bool IsRunOnceOptimizer(const string& name) {
name == "loop_optimizer";
}
+// Check if the graphdef contains nodes that indicate TPU execution.
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (auto node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+
} // namespace
#define MK_OPT(NAME, VALUE) \
@@ -94,6 +107,7 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("scoped_allocator",
new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
cfg_.scoped_allocator_opts()));
+ MK_OPT("small_op", new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
return std::unique_ptr<GraphOptimizer>();
}
@@ -102,6 +116,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (cfg_.disable_meta_optimizer()) {
+ return Status::OK();
+ }
if (!cfg_.disable_model_pruning()) {
optimizers->push_back(MakeUnique<ModelPruner>());
}
@@ -122,6 +139,9 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
+ if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
+ optimizers->push_back(MakeUnique<PinToHostOptimizer>());
+ }
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
optimizers->push_back(
MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
@@ -156,11 +176,12 @@ Status MetaOptimizer::InitializeOptimizers(
optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
}
- return Status::OK();
+ return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
}
Status MetaOptimizer::InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ std::set<string> initialized_custom_optimizers;
for (const string& optimizer_name : cfg_.optimizers()) {
auto optimizer = MakeNewOptimizer(optimizer_name);
if (optimizer) {
@@ -174,21 +195,54 @@ Status MetaOptimizer::InitializeOptimizersByName(
if (custom_optimizer) {
VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
- TF_RETURN_IF_ERROR(custom_optimizer->Init());
+ TF_RETURN_IF_ERROR(custom_optimizer->Init(
+ GetCustomGraphOptimizerConfig(optimizer_name)));
optimizers->push_back(std::move(custom_optimizer));
+ initialized_custom_optimizers.insert(optimizer_name);
} else {
VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
}
}
+ return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
+ optimizers);
+}
+
+Status MetaOptimizer::InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
for (const auto& optimizer_config : cfg_.custom_optimizers()) {
- auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
- optimizer_config.name());
+ if (pre_initialized_optimizers.find(optimizer_config.name()) !=
+ pre_initialized_optimizers.end()) {
+ continue;
+ }
+ // Initialize the ExperimentalImplementationSelector here instead of
+ // CustomizeOptimizer registry, due the static link issue in TensorRT for
+ // double registry.
+ // TODO(laigd): Remove this hack and change it back to use the registry once
+ // the duplicate static import issue is fixed.
+ std::unique_ptr<CustomGraphOptimizer> custom_optimizer;
+ if (optimizer_config.name() == "ExperimentalImplementationSelector") {
+ custom_optimizer.reset(new ExperimentalImplementationSelector());
+ } else {
+ custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
+ optimizer_config.name());
+ }
if (custom_optimizer) {
VLOG(2) << "Registered custom configurable graph optimizer: "
<< optimizer_config.name();
TF_RETURN_IF_ERROR(custom_optimizer->Init(&optimizer_config));
optimizers->push_back(std::move(custom_optimizer));
} else {
+ // If there are no custom optimizers with given name, try to initalize a
+ // default optimizer. This way, custom configurable optimizers can be
+ // mixed with default optimizers in any order.
+ auto optimizer = MakeNewOptimizer(optimizer_config.name());
+ if (optimizer) {
+ VLOG(2) << "Registered default graph optimizer: "
+ << optimizer_config.name();
+ optimizers->push_back(std::move(optimizer));
+ continue;
+ }
VLOG(2) << "Can't register an optimizer by name: "
<< optimizer_config.name();
}
@@ -196,6 +250,16 @@ Status MetaOptimizer::InitializeOptimizersByName(
return Status::OK();
}
+const RewriterConfig::CustomGraphOptimizer*
+MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
+ for (const auto& config : cfg_.custom_optimizers()) {
+ if (config.name() == name) {
+ return &config;
+ }
+ }
+ return nullptr;
+}
+
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
@@ -208,7 +272,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
}
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
- if (cfg_.optimizers().empty() && cfg_.custom_optimizers().empty()) {
+ if (cfg_.optimizers().empty()) {
TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
} else {
TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
@@ -326,15 +390,39 @@ Status MetaOptimizer::RunOptimizer(
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
+ VLOG(1) << "Starting optimization for grappler item: " << item.id;
optimization_results_.clear();
// 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
+ VLOG(1) << "Optimized main graph.";
+
+ // Skip optimizing functions if this is a TPU graph. Currently, Grappler
+ // passes do not handle TPU functions correctly in a variety of ways (Note
+ // that due to the pre-placement TPU graph rewriting passes, the TPU-related
+ // ops are encapsulated away into functions). For example, TPU graphs contain
+ // TPUReplicateMetadata node that carries relevant TPU metadata and Grappler
+ // passes could prune that away. Grappler passes could also cause issues
+ // around shape inference. Since the desired and existing behavior is to not
+ // optimize TPU functions with Grappler, this check preserves that.
+ if (IsTPUGraphDef(*optimized_graph)) {
+ VLOG(2) << "Skipping optimizing funcs for TPU graphs";
+ return Status::OK();
+ }
// 2. Optimize function library
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
+ // Find functions for which we might need to compute a gradient at runtime.
+ gtl::FlatSet<string> differentiable_functions;
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (IsSymbolicGradient(node)) {
+ const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
+ if (f_attr) differentiable_functions.insert(f_attr->func().name());
+ }
+ }
+
// Optimize each function only once.
std::unordered_set<string> optimized_funcs;
bool optimize_function_library = true;
@@ -350,6 +438,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Skip parametrized functions (function type or body is defined only at
// function call time by caller node attributes).
+ // They should be specialized to their instantiation type parameters by
+ // the function optimizer, before we can optimize function body.
if (IsParametrized(func)) continue;
VLOG(3) << "Optimize function: function=" << func_name;
@@ -361,7 +451,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Make a GrapplerItem from a FunctionDef.
GrapplerFunctionItem func_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item));
+ TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
+ func, flib, item.graph.versions().producer(), &func_item));
+
+ // If we need to compute the gradient of optimized function at runtime, we
+ // can't perform non-differentiable rewrites.
+ if (differentiable_functions.find(func_name) !=
+ differentiable_functions.end()) {
+ func_item.allowed_optimizations.non_differentiable_rewrites = false;
+ }
// Optimize function body graph.
GraphDef optimized_func_graph;
@@ -392,7 +490,7 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
}
- VLOG(3) << "Optimized " << optimized_funcs.size()
+ VLOG(1) << "Optimized " << optimized_funcs.size()
<< " functions: " << str_util::Join(optimized_funcs, ", ");
return Status::OK();
@@ -413,6 +511,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
+ if (cfg.disable_meta_optimizer()) {
+ return false;
+ }
return !cfg.disable_model_pruning() ||
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
@@ -426,6 +527,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
+ cfg.pin_to_host_optimization() == RewriterConfig::ON ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 151a54cbdf..99a0a33ffa 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -52,6 +52,13 @@ class MetaOptimizer : public GraphOptimizer {
// Initialize active optimizers from RewriterConfig optimizer names.
Status InitializeOptimizersByName(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Initialize active optimizers from RewriterConfig.custom_optimizers.
+ Status InitializeCustomGraphOptimizers(
+ const std::set<string>& pre_initialized_optimizers,
+ std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const;
+ // Returns the config for a custom graph optimizer. Null if none was found.
+ const RewriterConfig::CustomGraphOptimizer* GetCustomGraphOptimizerConfig(
+ const string& name) const;
// Run optimization pass over a single GrapplerItem. Meta optimizer might run
// multiple such passes: 1) for the main graph 2) for the function library
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 9a03c7dfef..3f3f43382f 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -64,6 +65,66 @@ bool TestOptimizer::optimized_;
REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
+class TestGraphOptimizer : public TestOptimizer {
+ public:
+ string name() const override { return "test_graph_optimizer"; }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestGraphOptimizer);
+
+class TestOptimizerWithParams : public TestOptimizer {
+ public:
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ CHECK(config != nullptr);
+ return Status::OK();
+ }
+};
+
+REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
+
+// Record various properties of the GrapplerItems passed for optimization.
+class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
+ public:
+ static void SetAllowedOptimizations(
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations) {
+ allowed_optimizations_ = allowed_optimizations;
+ }
+ static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; }
+
+ GrapplerItemPropertiesAccumulator() {}
+ string name() const override {
+ return "grappler_item_properties_accumulator";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ if (allowed_optimizations_) {
+ allowed_optimizations_->insert({item.id, item.allowed_optimizations});
+ }
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations_;
+};
+
+gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ GrapplerItemPropertiesAccumulator::allowed_optimizations_;
+
+REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -83,6 +144,46 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerWithParams) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizerWithParams");
+ auto* custom_config = rewriter_config.add_custom_optimizers();
+ custom_config->set_name("TestOptimizerWithParams");
+ (*custom_config->mutable_parameter_map())["foo"] = AttrValue();
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+}
+
+TEST_F(MetaOptimizerTest, RunsCustomOptimizerAndCustomGraphOptimizer) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ TestOptimizer::SetOptimized(false);
+ TestGraphOptimizer::SetOptimized(false);
+ RewriterConfig rewriter_config;
+ rewriter_config.add_optimizers("TestOptimizer");
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestOptimizer::IsOptimized());
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
@@ -98,6 +199,24 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TF_EXPECT_OK(status);
}
+TEST_F(MetaOptimizerTest, RunToggleOptimizersAndCustomGraphOptimizerTwice) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ RewriterConfig rewriter_config;
+ auto customGraphOptimizer = rewriter_config.add_custom_optimizers();
+ customGraphOptimizer->set_name("TestGraphOptimizer");
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+ GraphDef output;
+ const Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ EXPECT_TRUE(TestGraphOptimizer::IsOptimized());
+}
+
TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
using test::function::NDef;
@@ -259,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
}
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
+ using test::function::NDef;
+ using FDH = FunctionDefHelper;
+
+ // We will record what type of optimizations meta optimizer allows for each
+ // GrapplerItem (main graph and graphs for each function).
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>
+ allowed_optimizations;
+ GrapplerItemPropertiesAccumulator::SetAllowedOptimizations(
+ &allowed_optimizations);
+
+ // Just record properties of optimized Grappler items.
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+
+ // Define simple function library with two identical mul functions.
+ FunctionDef mul_func_1 = FunctionDefHelper::Create(
+ "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ FunctionDef mul_func_2 = FunctionDefHelper::Create(
+ "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ // Tensorflow graph:
+ //
+ // x0 = tf.Placeholder(tf.float);
+ // x1 = tf.Placeholder(tf.float);
+ // dy = tf.Placeholder(tf.float);
+ //
+ // mul_1 = MyMul1(x0, x1);
+ // mul_2 = MyMul2(x0, x1);
+ // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
+ GrapplerItem item;
+ item.id = "main";
+ item.graph = test::function::GDef(
+ {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ // Calls into function library
+ NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
+ NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
+ // Symbolic gradient of a MyMul2
+ NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
+ {{"f", FDH::FunctionRef("MyMul2", {})},
+ {"Tin", DataTypeSlice{DT_FLOAT}},
+ {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
+ kDevice)},
+ // FunctionLib
+ {mul_func_1, mul_func_2});
+ item.fetch = {"mul_1", "mul_2", "dx"};
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Our custom optimizer must be called for the main graph and for the two
+ // functions.
+ ASSERT_EQ(allowed_optimizations.size(), 3);
+
+ auto allowed_optimizations_main =
+ gtl::FindOrNull(allowed_optimizations, "main");
+ ASSERT_NE(allowed_optimizations_main, nullptr);
+ EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_1 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul1");
+ ASSERT_NE(allowed_optimizations_my_mul_1, nullptr);
+ EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_2 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul2");
+ ASSERT_NE(allowed_optimizations_my_mul_2, nullptr);
+ EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
new file mode 100644
index 0000000000..8ed4271fa4
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -0,0 +1,363 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/tensor_shape.pb.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/topological_sort.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+
+// TODO(williamchan): Change this constant to be something smarter, maybe
+// dynamically determined.
+constexpr int64 kTensorMaxSize = 64;
+
+// All the nodes that should be blacklisted and not swapped.
+bool IsBlacklisted(const NodeDef& node) {
+ return
+ // Collective ops should not be swapped.
+ IsCollective(node) ||
+ // ControlFlow ops should not be swapped.
+ IsControlFlow(node) ||
+ // NoOp ops should not be swapped (due to group dependencies).
+ IsNoOp(node);
+}
+
+// Check if Tensor is integer and small size.
+bool IsTensorIntegerAndSmall(const OpInfo::TensorProperties& prop) {
+ // Check type to be int32 or int64.
+ if (prop.dtype() != DataType::DT_INT32 &&
+ prop.dtype() != DataType::DT_INT64) {
+ return false;
+ }
+
+ // Check size known and small.
+ const int64 size = NumCoefficients(prop.shape());
+ if (size < 0 || size > kTensorMaxSize) {
+ return false;
+ }
+
+ return true;
+}
+
+// Find KernelDef for `node`, greedily return first found from `devices`.
+Status TryFindKernelDef(const std::vector<DeviceType>& devices,
+ const NodeDef& node, const KernelDef** kdef) {
+ for (const DeviceType& device : devices) {
+ const KernelDef* kernel = nullptr;
+ Status s = FindKernelDef(device, node, &kernel, nullptr);
+ if (s.ok()) {
+ if (kdef) {
+ *kdef = kernel;
+ }
+ return Status::OK();
+ }
+ }
+
+ return errors::NotFound("Could not find KernelDef for op: ", node.op());
+}
+
+// Checks if a node's output port is host friendly.
+// Roughly this means checking if the output port is on Host memory.
+Status IsNodeOutputPortHostFriendly(const GraphView& graph,
+ GraphProperties* properties,
+ const NodeDef& node, int port_id,
+ bool* is_candidate) {
+ *is_candidate = false;
+
+ // Make sure we are not a blacklisted op.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
+
+ // Check to make sure we have the right properties (i.e., statically shaped).
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ const auto& output_properties = properties->GetOutputProperties(node.name());
+ if (port_id >= output_properties.size()) {
+ LOG(WARNING) << "port_id=" << port_id
+ << " but output_properties.size()=" << output_properties.size()
+ << "\n"
+ << node.DebugString();
+ return Status::OK();
+ }
+ if (!IsTensorIntegerAndSmall(output_properties[port_id])) {
+ return Status::OK();
+ }
+
+ // These nodes may be optimized away downstream (even if pinned to Host), we
+ // should (recusively) check their source.
+ if (IsIdentity(node)) {
+ for (const auto& fanin : graph.GetFanins(node, false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check if op's device is on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Check if op's output port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ return Status::OK();
+ }
+
+ // Map the port_id to output_arg_id.
+ const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
+ if (output_arg_id < 0) {
+ LOG(WARNING) << "Invalid port: " << port_id << "!\n"
+ << node.DebugString() << "\n"
+ << op->DebugString();
+ return Status::OK();
+ }
+
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
+ &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ return Status::OK();
+ }
+
+ // Check if the output_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->output_arg(output_arg_id).name() == host_memory_arg) {
+ *is_candidate = true;
+ break;
+ }
+ }
+
+ return Status::OK();
+}
+
+// Checks if a node's input port is Host friendly.
+// Roughly this means checking if the input port is on Host memory.
+bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
+ // If node is on Host, assume its inputs are Host friendly.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ return true;
+ }
+
+ // Check if op's input port is pinned to HostMemory.
+ const OpDef* op = nullptr;
+ Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
+ if (!s.ok()) {
+ LOG(WARNING) << "Could not find OpDef for : " << node.op();
+ return false;
+ }
+ const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
+
+ // Find the kernel.
+ const KernelDef* kernel = nullptr;
+ s = internal::TryFindKernelDef(
+ {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
+ if (!s.ok()) {
+ LOG(INFO) << "Could not find KernelDef for: " << node.op();
+ return false;
+ }
+
+ // Check if the input_arg is pinned to Host.
+ for (const string& host_memory_arg : kernel->host_memory_arg()) {
+ if (op->input_arg(input_arg_id).name() == host_memory_arg) {
+ return true;
+ }
+ }
+
+ return false;
+}
+
+// Checks if a node is a candidate to pin to Host.
+// The rough algorithm is as follows:
+// 1] Check if node is blacklisted.
+// 2] Check if node can run on Host.
+// 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
+// ints, and pinned to Host).
+Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
+ const NodeDef& node, bool* is_candidate) {
+ *is_candidate = false;
+
+ // Check if node already on CPU.
+ if (str_util::StrContains(node.device(), DEVICE_CPU)) {
+ *is_candidate = true;
+ return Status::OK();
+ }
+
+ // Skip these node types.
+ if (IsBlacklisted(node)) {
+ return Status::OK();
+ }
+
+ // Check the node can be run on CPU.
+ Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
+ if (!s.ok()) {
+ return Status::OK();
+ }
+
+ // Check all inputs are Host friendly.
+ for (const GraphView::OutputPort& fanin :
+ graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
+ bool fanin_candidate = false;
+ TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
+ graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
+ if (!fanin_candidate) {
+ return Status::OK();
+ }
+ }
+
+ // Check all outputs are Host friendly.
+ if (!properties->has_properties()) {
+ // This is an expensive call, call it lazily.
+ TF_RETURN_IF_ERROR(properties->InferStatically(
+ /*assume_valid_feeds=*/false));
+ }
+ for (const auto& prop : properties->GetOutputProperties(node.name())) {
+ if (!IsTensorIntegerAndSmall(prop)) {
+ return Status::OK();
+ }
+ }
+
+ *is_candidate = true;
+ return Status::OK();
+}
+
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device) {
+ // Force this node onto the CPU.
+ if (device.empty() && has_device_cpu) {
+ return "/device:CPU:0";
+ } else if (str_util::StrContains(device, DEVICE_GPU)) {
+ // Sometimes the cluster can have:
+ // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
+ // and we need to handle them properly.
+ for (const auto& device_match :
+ {std::pair<string, string>("GPU", "CPU:0"),
+ std::pair<string, string>("/device", "/device:CPU:0")}) {
+ const string device_host =
+ strings::StrCat(device.substr(0, device.rfind(device_match.first)),
+ device_match.second);
+ if (devices.find(device_host) != devices.end()) {
+ return device_host;
+ }
+ }
+ }
+
+ // We couldn't find an appropriate Host device, return original device.
+ return device;
+}
+
+bool IsTPUGraphDef(const GraphDef& def) {
+ for (const auto& node : def.node()) {
+ if (node.op() == "TPUCompile" || node.op() == "TPUExecute" ||
+ node.op() == "TPUPartitionedCall") {
+ return true;
+ }
+ }
+ return false;
+}
+} // end namespace internal
+
+Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) {
+ *optimized_graph = item.graph;
+
+ // Skip all TPU graphs.
+ if (internal::IsTPUGraphDef(*optimized_graph)) {
+ return Status::OK();
+ }
+
+ GraphProperties properties(item);
+ GraphView graph(optimized_graph);
+
+ gtl::FlatSet<string> devices;
+ if (cluster) {
+ const std::vector<string> device_names = cluster->GetDeviceNames();
+ devices.insert(device_names.begin(), device_names.end());
+ } else {
+ devices = {"/device:CPU:0"};
+ }
+
+ const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
+
+ // Topologically sort the graph, so that we traverse the nodes in order. This
+ // will help us discover producer->consumer chains of Host ops.
+ TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
+
+ // All the Const nodes, and their original devices in topological order.
+ std::vector<std::pair<NodeDef*, string>> const_nodes;
+
+ for (auto& node : *optimized_graph->mutable_node()) {
+ bool is_candidate = false;
+ TF_RETURN_IF_ERROR(
+ internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
+ if (!is_candidate) {
+ continue;
+ }
+
+ if (IsConstant(node)) {
+ const_nodes.emplace_back(&node, node.device());
+ }
+ // Try and swap the device to Host.
+ node.set_device(
+ internal::TryFindHostDevice(devices, has_device_cpu, node.device()));
+ }
+
+ // Traverse all `const_nodes`, and map them back to GPU greedily.
+ for (auto& it : const_nodes) {
+ NodeDef* node = it.first;
+ const string& device = it.second;
+
+ // Check all the consumers of this node, if any of them are not on CPU, swap
+ // this node back onto the original device.
+ for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
+ // The consumer is not Host friendly, swap it back to the original device.
+ if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
+ fanout.port_id)) {
+ node->set_device(device);
+ break;
+ }
+ }
+ }
+ return Status::OK();
+}
+
+} // end namespace grappler
+} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
new file mode 100644
index 0000000000..d557a03463
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
+
+#include <unordered_set>
+#include "tensorflow/core/grappler/costs/graph_properties.h"
+#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
+#include "tensorflow/core/protobuf/rewriter_config.pb.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace internal {
+// Try and find an appropriate Host device in `devices` given `device`.
+string TryFindHostDevice(const gtl::FlatSet<string>& devices,
+ bool has_device_cpu, const string& device);
+} // end namespace internal
+
+// Optimize TensorFlow ops that should be swapped into the CPU to avoid
+// excessive cpu<->gpu memcpy/sync.
+//
+// TODO(williamchan): The current heuristic will swap any small integer Const to
+// CPU. This may cause a problem cpu->cpu->gpu wherein the original behaviour of
+// gpu->gpu->gpu may have been better/faster. We should probably fix this.
+class PinToHostOptimizer : public GraphOptimizer {
+ public:
+ PinToHostOptimizer() : opt_level_(RewriterConfig::DEFAULT) {}
+ explicit PinToHostOptimizer(RewriterConfig::Toggle opt_level)
+ : opt_level_(opt_level) {}
+
+ ~PinToHostOptimizer() override {}
+
+ string name() const override { return "pin_to_host_optimizer"; };
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override;
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ RewriterConfig::Toggle opt_level_;
+};
+
+} // end namespace grappler
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_PIN_TO_HOST_OPTIMIZER_H_
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
new file mode 100644
index 0000000000..7c64529441
--- /dev/null
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer_test.cc
@@ -0,0 +1,236 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/grappler/grappler_item.h"
+#include "tensorflow/core/grappler/utils/grappler_test.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace grappler {
+namespace {
+
+class PinToHostOptimizerTest : public GrapplerTest {};
+
+TEST_F(PinToHostOptimizerTest, TryFindHostDevice) {
+ gtl::FlatSet<string> devices = {};
+ EXPECT_EQ("ABC", internal::TryFindHostDevice(devices, false, "ABC"));
+
+ devices = {"/device:CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, ""), "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:0"),
+ "/device:CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, true, "/device:XLA_GPU:*"),
+ "/device:CPU:0");
+
+ devices = {"/device:XLA_CPU:0", "/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_CPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_CPU:0");
+
+ devices = {"/device:XLA_GPU:0"};
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, ""), "");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:0"),
+ "/device:XLA_GPU:0");
+ EXPECT_EQ(internal::TryFindHostDevice(devices, false, "/device:XLA_GPU:*"),
+ "/device:XLA_GPU:*");
+}
+
+TEST_F(PinToHostOptimizerTest, OptimizeSmallOpsToHost) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, TopologicalSort) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1024, 1024});
+ Output c = ops::Shape(s.WithOpName("c"), a);
+ Output d = ops::Const(s.WithOpName("d"), 0, {1});
+ Output e = ops::ReduceProd(s.WithOpName("e"), c, d);
+
+ GrapplerItem item;
+ item.fetch = {"a", "c", "d", "e"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ // Reverse the graph, and hence rely on the optimizer to sort it.
+ std::reverse(item.graph.mutable_node()->begin(),
+ item.graph.mutable_node()->end());
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_TRUE(node.device().empty());
+ } else if (node.name() == "d" || node.name() == "e") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 4);
+}
+
+TEST_F(PinToHostOptimizerTest, NoSwap) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `b` should be too big to swap, consequently `c` should not be swapped.
+ // PinToHostOptimizer should then detect that `a` should not be swapped.
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 1});
+ Output b = ops::Const(s.WithOpName("b"), 1, {1, 1024 * 1024});
+ Output c = ops::MatMul(s.WithOpName("c"), a, b);
+
+ GrapplerItem item;
+ item.fetch = {"a", "b", "c"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_TRUE(node.device().empty());
+ ++found;
+ }
+ EXPECT_EQ(found, 3);
+}
+
+TEST_F(PinToHostOptimizerTest, Identity) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ // `a,c` is on GPU, `e` is on CPU, consequently `e` should not be swapped.
+ // `b` should be placed onto Host since `c` pins the input to Host memory.
+ Output a =
+ ops::Const(s.WithOpName("a").WithDevice("/device:GPU:0"), 1, {64, 64});
+ Output b = ops::Const(s.WithOpName("b"), {0, 1}, {2});
+ Output c =
+ ops::ReduceProd(s.WithOpName("c").WithDevice("/device:GPU:0"), a, b);
+ Output d = ops::Identity(s.WithDevice("/device:CPU:0").WithOpName("d"), c);
+ Output e = ops::Multiply(s.WithOpName("e"), d, d);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "a" || node.name() == "c") {
+ EXPECT_EQ(node.device(), "/device:GPU:0");
+ } else if (node.name() == "b") {
+ // If CUDA, then there is a GPU kernel registration that is pinned to Host
+ // memory. Consequently, `b` will be mapped to Host correct if there is
+ // a GPU kernel registered.
+#if GOOGLE_CUDA
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+#else
+ EXPECT_TRUE(node.device().empty());
+#endif
+ } else if (node.name() == "d") {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ } else if (node.name() == "e") {
+ EXPECT_TRUE(node.device().empty());
+ }
+ ++found;
+ }
+ EXPECT_EQ(found, 5);
+}
+
+TEST_F(PinToHostOptimizerTest, PortIdToArgId) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Const(s.WithOpName("a"), 1, {1, 2, 3});
+ ops::ShapeN b(s.WithOpName("b"), {a, a, a});
+
+ GrapplerItem item;
+ item.fetch = {"a", "b"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ auto tensors_expected = EvaluateNodes(item.graph, item.fetch);
+
+ GraphDef output;
+ PinToHostOptimizer optimizer(RewriterConfig::ON);
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ auto tensors = EvaluateNodes(item.graph, item.fetch);
+ EXPECT_EQ(tensors_expected.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); ++i) {
+ test::ExpectTensorEqual<int32>(tensors[i], tensors_expected[i]);
+ }
+
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ EXPECT_EQ(node.device(), "/device:CPU:0");
+ ++found;
+ }
+ EXPECT_EQ(found, 2);
+}
+
+} // namespace
+} // namespace grappler
+} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc
index 03e36a7b9c..9ada8b7ff9 100644
--- a/tensorflow/core/grappler/optimizers/remapper.cc
+++ b/tensorflow/core/grappler/optimizers/remapper.cc
@@ -168,11 +168,12 @@ void AddBatchNormNodes(GraphDef* optimized_graph, const NodeDef& fused_node) {
Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
GraphDef* optimized_graph) {
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(const_cast<GraphDef*>(&item.graph));
// During inference, most of the inputs to FusedBatchNorm are constant, and we
// can therefore replace the op with a much cheaper set of primitives.
+ optimized_graph->mutable_node()->Reserve(item.graph.node_size());
for (const NodeDef& node : item.graph.node()) {
if (node.op() == "FusedBatchNorm" || node.op() == "FusedBatchNormV2") {
bool optimizable = (node.attr().count("T") == 0 ||
@@ -181,6 +182,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
!node.attr().at("is_training").b());
if (optimizable) {
int const_inputs = 0;
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& props = properties.GetInputProperties(node.name());
for (const auto& prop : props) {
if (prop.has_value()) {
@@ -218,7 +224,7 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
void Remapper::Feedback(Cluster* /*cluster*/, const GrapplerItem& /*item*/,
const GraphDef& /*optimized_graph*/,
double /*result*/) {
- // Nothing to do for ArithmeticOptimizer.
+ // Nothing to do for RemapperOptimizer.
}
} // namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
index 275568e464..0d4aaf6462 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc
@@ -203,7 +203,7 @@ void ScopedAllocatorOptimizer::ExtendNodeAttr(StringPiece name,
NodeDef* node_def) {
if (HasNodeAttr(*node_def, name)) {
VLOG(2) << "extending";
- AttrValue* existing = &(*node_def->mutable_attr())[name.ToString()];
+ AttrValue* existing = &(*node_def->mutable_attr())[string(name)];
for (int32 i : values) {
existing->mutable_list()->add_i(i);
}
diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
index 89847f83d4..b033cff8e6 100644
--- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/testlib.h"
diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
index 26c54df56b..6ccb1cd783 100644
--- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc
@@ -15,14 +15,14 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/grappler/graph_view.h"
#include "tensorflow/core/grappler/grappler_item.h"
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
-
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -33,7 +33,7 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
*optimized_graph = item.graph;
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ bool inferred_properties = false;
GraphView graph(optimized_graph);
// The product of all the dimensions in a tensor shape can be expressed more
@@ -55,6 +55,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
const GraphView::OutputPort reduce_indices =
graph.GetRegularFanin(GraphView::InputPort(fanout.node, 1));
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop =
properties.GetOutputProperties(reduce_indices.node->name());
if (prop.size() < reduce_indices.port_id) {
@@ -92,6 +97,11 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
if (!IsSize(*input1.node) || !IsSize(*input2.node)) {
continue;
}
+ if (!inferred_properties) {
+ // Infer properties lazily in case they are not needed.
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
+ inferred_properties = true;
+ }
const auto& prop1 = properties.GetInputProperties(input1.node->name());
const auto& prop2 = properties.GetInputProperties(input2.node->name());
if (prop1.size() != 1 || prop2.size() != 1) {
diff --git a/tensorflow/core/grappler/utils.cc b/tensorflow/core/grappler/utils.cc
index 153785d3b4..5867d01324 100644
--- a/tensorflow/core/grappler/utils.cc
+++ b/tensorflow/core/grappler/utils.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
+#include <iterator>
#include <memory>
#include <queue>
#include <vector>
@@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/scanner.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -154,17 +156,6 @@ bool IsControlInput(const string& name) {
return !name.empty() && name[0] == '^';
}
-string NodeName(const string& name) {
- int position;
- return ParseNodeName(name, &position);
-}
-
-int NodePosition(const string& name) {
- int position;
- ParseNodeNameAsStringPiece(name, &position);
- return position;
-}
-
string AddPrefixToNodeName(const string& name, const string& prefix,
const string& delimiter) {
if (!name.empty()) {
diff --git a/tensorflow/core/grappler/utils.h b/tensorflow/core/grappler/utils.h
index a9c34b6d08..95126d470c 100644
--- a/tensorflow/core/grappler/utils.h
+++ b/tensorflow/core/grappler/utils.h
@@ -29,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/strings/scanner.h"
namespace tensorflow {
namespace grappler {
@@ -102,44 +101,103 @@ bool IsControlInput(const string& name);
// True iff 'name1' and 'name2' refer to the same input.
bool IsSameInput(const string& name1, const string& name2);
+// Returns the trailing position number (or zero if no number is present) if
+// NodeName(input_name) is equal to node_name. Returns -1 for control inputs.
+// Returns -2 if NodeName(input_name) is not equal to node_name.
+// Note: This function is used very heavily, and this hand-optimized
+// version is 3-4x faster than the version using Scanner, which it replaced.
+// This is worth the reduction in readability.
+inline int NodePositionIfSameNode(const string& input_name,
+ const string& node_name) {
+ if (input_name.empty()) return -2;
+ const bool is_ctrl = input_name[0] == '^';
+ auto input_it = is_ctrl ? input_name.begin() + 1 : input_name.begin();
+ auto node_it = node_name.begin();
+ if (node_name.empty() ||
+ std::distance(input_it, input_name.end()) < node_name.size()) {
+ return -2;
+ }
+ while (node_it != node_name.end()) {
+ if (*input_it++ != *node_it++) {
+ return -2;
+ }
+ }
+ if (input_it == input_name.end()) {
+ return is_ctrl ? -1 : 0;
+ } else if (*input_it++ == ':') {
+ StringPiece remaining(&(*input_it),
+ std::distance(input_it, input_name.end()));
+ int position;
+ if (!strings::safe_strto32(remaining, &position)) {
+ return -2;
+ }
+ return is_ctrl ? -1 : position;
+ } else {
+ return -2;
+ }
+}
+
// Return the node name corresponding to 'name' if name is valid, or the empty
// string otherwise.
-string NodeName(const string& name);
+inline StringPiece NodeNameAsStringPiece(const string& name) {
+ static const string empty;
+ if (name.empty()) return StringPiece(empty);
+ const auto begin_it = name[0] == '^' ? name.begin() + 1 : name.begin();
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ if (end_it != name.end() && *end_it != ':') {
+ return StringPiece(empty);
+ }
+ return StringPiece(&(*begin_it), std::distance(begin_it, end_it));
+}
-// Get the trailing position number ":{digits}" (if any) of a node name.
-int NodePosition(const string& name);
+// Return the node name corresponding to 'name' if name is valid, or the empty
+// string otherwise.
+inline string NodeName(const string& name) {
+ return string(NodeNameAsStringPiece(name));
+}
+// Returns the node name and position in a single call.
inline StringPiece ParseNodeNameAsStringPiece(const string& name,
int* position) {
- // Strip the prefix '^' (if any), and strip the trailing ":{digits} (if any)
- // to get a node name.
- strings::Scanner scan(name);
- scan.ZeroOrOneLiteral("^")
- .RestartCapture()
- .One(strings::Scanner::LETTER_DIGIT_DOT_UNDERSCORE)
- .Any(strings::Scanner::LETTER_DIGIT_DASH_DOT_SLASH_UNDERSCORE);
- StringPiece capture;
- StringPiece remaining;
- if (scan.Peek(':') != ':' || !scan.GetResult(&remaining, &capture)) {
+ static const string empty;
+ if (name.empty()) {
*position = 0;
- static const string empty;
return StringPiece(empty);
- } else {
- if (name[0] == '^') {
- *position = -1;
- } else if (remaining.empty()) {
- *position = 0;
- } else {
- // Skip the first ':' character.
- CHECK(strings::safe_strto32(remaining.substr(1), position));
+ }
+ const bool is_ctrl = name[0] == '^';
+ const auto begin_it = is_ctrl ? name.begin() + 1 : name.begin();
+ *position = is_ctrl ? -1 : 0;
+ auto end_it = begin_it;
+ while (end_it != name.end() && *end_it != ':') {
+ ++end_it;
+ }
+ const StringPiece node_name(&(*begin_it), std::distance(begin_it, end_it));
+ if (end_it != name.end()) {
+ if (*end_it != ':') {
+ return StringPiece(empty);
+ } else if (!is_ctrl) {
+ ++end_it;
+ StringPiece remaining(&(*end_it), std::distance(end_it, name.end()));
+ if (!strings::safe_strto32(remaining, position)) {
+ return StringPiece(empty);
+ }
}
- return capture;
}
+ return node_name;
}
// Returns the node name and position in a single call.
inline string ParseNodeName(const string& name, int* position) {
- return std::string(ParseNodeNameAsStringPiece(name, position));
+ return string(ParseNodeNameAsStringPiece(name, position));
+}
+
+inline int NodePosition(const string& name) {
+ int position;
+ ParseNodeNameAsStringPiece(name, &position);
+ return position;
}
// Add a prefix to a node name with a custom delimiter.
diff --git a/tensorflow/core/grappler/utils/BUILD b/tensorflow/core/grappler/utils/BUILD
index e540cc0476..bdbb8836e1 100644
--- a/tensorflow/core/grappler/utils/BUILD
+++ b/tensorflow/core/grappler/utils/BUILD
@@ -1,6 +1,10 @@
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+load(
+ "//tensorflow/core:platform/default/build_config.bzl",
+ "tf_protos_grappler",
+)
cc_library(
name = "scc",
@@ -210,3 +214,28 @@ tf_cc_test(
"//tensorflow/core:testlib",
],
)
+
+cc_library(
+ name = "symbolic_shapes",
+ srcs = ["symbolic_shapes.cc"],
+ hdrs = ["symbolic_shapes.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ ] + tf_protos_grappler(),
+)
+
+tf_cc_test(
+ name = "symbolic_shapes_test",
+ srcs = ["symbolic_shapes_test.cc"],
+ deps = [
+ ":symbolic_shapes",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index 462b752316..6861fb423c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -40,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
- connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ connectivity->RegisterFunctionBodyOutputs(node.name(),
+ std::move(outputs_range_map));
return Status::OK();
}
@@ -74,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
- const InputArgExpansion& input_arg_expansion) {
- const auto& input_name = input_arg_expansion.input_name;
+ InputArgExpansion input_arg_expansion) {
+ string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
- input_arg_expansions_.emplace(input_name, input_arg_expansion);
+
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
- input_arg_placeholders_.emplace(
- placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ input_arg_placeholders_.insert(
+ {placeholder, InputArgPlaceholder{input_name, /*position=*/i}});
}
+ input_arg_expansions_.insert(
+ {std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
- const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_[node_name] = outputs;
+ const string& node_name, tensorflow::NameRangeMap&& outputs) {
+ function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
@@ -173,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const auto& output_range = output->second;
if (position == -1) {
+ graph_def_inputs->reserve(graph_def_inputs->size() +
+ output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
- i == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", i));
+ graph_def_inputs->push_back(
+ i == 0 ? node_name : strings::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@@ -186,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
- pos == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", pos));
+ graph_def_inputs->push_back(
+ pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
}
return Status::OK();
@@ -210,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
}
function_body_node->clear_input();
- for (const string& expanded_input : expanded_inputs)
- function_body_node->add_input(expanded_input);
+ for (string& expanded_input : expanded_inputs)
+ function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
@@ -303,25 +307,26 @@ Status GrapplerFunctionItemInstantiation::GetArgType(
}
GrapplerFunctionItem::GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, bool is_stateful,
- GraphDef&& function_body)
- : description_(description),
- func_attr_(func_attr),
- input_arg_expansions_(input_arg_expansions),
- output_arg_expansions_(output_arg_expansions),
+ string func_name, string description, AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, const int graph_def_version,
+ const bool is_stateful, GraphDef&& function_body)
+ : description_(std::move(description)),
+ func_attr_(std::move(func_attr)),
+ input_arg_expansions_(std::move(input_arg_expansions)),
+ output_arg_expansions_(std::move(output_arg_expansions)),
is_stateful_(is_stateful) {
- id = func_name;
- keep_ops = keep_nodes;
- // Swap the graph body.
- graph.Swap(&function_body);
+ // Move assign GrapplerItem members.
+ keep_ops = std::move(keep_nodes);
+ id = std::move(func_name);
+ graph = std::move(function_body);
+
+ graph.mutable_versions()->set_producer(graph_def_version);
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
- feed.emplace_back(placeholder, Tensor());
+ feed.push_back({placeholder, Tensor()});
input_arg_placeholders_.insert(placeholder);
}
}
@@ -458,7 +463,7 @@ Status InstantiationBodyParameters(
auto it = func_instantiation_attr.find(placeholder);
if (it != func_instantiation_attr.end()) {
- body_parameters->emplace(placeholder, it->second);
+ body_parameters->insert({placeholder, it->second});
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
@@ -472,6 +477,7 @@ Status InstantiationBodyParameters(
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item) {
const OpDef& signature = func.signature();
@@ -495,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
- std::vector<InputArgExpansion> inputs;
- std::vector<OutputArgExpansion> outputs;
- std::vector<string> keep_nodes;
-
// Function body shares the library with the graph that instantiated it.
GraphDef function_body;
*function_body.mutable_library() = flib.ToProto();
@@ -515,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
+ std::vector<InputArgExpansion> inputs;
+ inputs.reserve(signature.input_arg_size());
+
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
@@ -539,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
/*is_ref*/ input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
- inputs.push_back(input_expansion);
+ inputs.push_back(std::move(input_expansion));
}
+ std::vector<string> keep_nodes;
// Add all function nodes to the function body
for (const NodeDef& func_def_node : func.node_def()) {
NodeDef* new_node = function_body.add_node();
@@ -569,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
+ std::vector<OutputArgExpansion> outputs;
+ outputs.reserve(signature.output_arg_size());
// Add function outputs
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
@@ -586,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
- /*output_tensors=*/output_tensors};
- outputs.push_back(output);
+ /*output_tensors=*/std::move(output_tensors)};
+ outputs.push_back(std::move(output));
}
bool is_stateful = signature.is_stateful();
@@ -595,14 +603,17 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
*item = GrapplerFunctionItem(
/*func_name=*/signature.name(), /*description=*/signature.description(),
/*func_attr=*/AttrValueMap(func.attr().begin(), func.attr().end()),
- inputs, outputs, keep_nodes, is_stateful, std::move(function_body));
+ std::move(inputs), std::move(outputs), std::move(keep_nodes),
+ graph_def_version, is_stateful, std::move(function_body));
return Status::OK();
}
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item) {
- return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, item);
+ return MakeGrapplerFunctionItem(func, AttrValueMap(), flib, graph_def_version,
+ item);
}
// Register GrapplerFunctionItem input arg expansion and function body outputs
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 9f607dc2ee..ef944ced09 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -70,9 +71,9 @@ struct OutputArgExpansion {
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
- void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion);
+ void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
- const tensorflow::NameRangeMap& outputs);
+ tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
@@ -136,13 +137,12 @@ class GrapplerFunctionItemInstantiation {
class GrapplerFunctionItem : public GrapplerItem {
public:
GrapplerFunctionItem() = default;
- GrapplerFunctionItem(
- const string& func_name, const string& description,
- const AttrValueMap& func_attr,
- const std::vector<InputArgExpansion>& input_arg_expansions,
- const std::vector<OutputArgExpansion>& output_arg_expansions,
- const std::vector<string>& keep_nodes, bool is_stateful,
- GraphDef&& function_body);
+ GrapplerFunctionItem(string func_name, string description,
+ AttrValueMap func_attr,
+ std::vector<InputArgExpansion> input_arg_expansions,
+ std::vector<OutputArgExpansion> output_arg_expansions,
+ std::vector<string> keep_nodes, int graph_def_version,
+ bool is_stateful, GraphDef&& function_body);
const string& description() const;
@@ -222,6 +222,7 @@ Status ReplaceInputWithConst(const NodeDef& input_const, int input_position,
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const AttrValueMap& func_instantiation_attr,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item);
// Make a GrapplerFunction item from the function definition. Function must be
@@ -231,6 +232,7 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// without specializing it to it's instantiation attributes (at least types)?
Status MakeGrapplerFunctionItem(const FunctionDef& func,
const FunctionLibraryDefinition& flib,
+ const int graph_def_version,
GrapplerFunctionItem* item);
// Make a FunctionDef from the GrapplerFunctionItem. Use function library
diff --git a/tensorflow/core/grappler/utils/functions_test.cc b/tensorflow/core/grappler/utils/functions_test.cc
index b2d059e0ac..b51f2781b8 100644
--- a/tensorflow/core/grappler/utils/functions_test.cc
+++ b/tensorflow/core/grappler/utils/functions_test.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/version.h"
namespace tensorflow {
namespace grappler {
@@ -239,7 +240,8 @@ TEST_F(FunctionsTest, FromSimpleFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("XTimesTwo", item.id);
EXPECT_EQ(4, item.function_body().node_size());
@@ -314,7 +316,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithMultiOutputNodes) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("SubGrad", item.id);
EXPECT_EQ(12, item.function_body().node_size());
@@ -395,7 +398,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithNestedFuncs) {
func_attr["T"].set_type(DT_FLOAT);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
int count = 0;
for (const NodeDef &node : item.function_body().node()) {
@@ -456,7 +460,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithOutputMappings) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(1, item.output_size());
EXPECT_EQ("Exp", item.output(0).output_tensors[0]);
@@ -499,7 +504,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithInputForwarding) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ("ForwardInputs", item.id);
EXPECT_EQ(5, item.function_body().node_size());
@@ -545,7 +551,8 @@ TEST_F(FunctionsTest, FromFunctionDefWithoutInput) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(0, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -584,7 +591,8 @@ TEST_F(FunctionsTest, MakeFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
FunctionDef specialized;
TF_EXPECT_OK(MakeFunctionDef(item, flib, &specialized));
@@ -622,7 +630,8 @@ TEST_F(FunctionsTest, ReplaceInputWithConst) {
FunctionLibraryDefinition flib(OpRegistry::Global(), FunctionDefLibrary());
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
EXPECT_EQ(2, item.input_size());
EXPECT_EQ(1, item.output_size());
@@ -713,7 +722,8 @@ TEST_F(FunctionsTest, SwapFunctionBodyAndMakeFunctionDef) {
FunctionLibraryDefinition flib(OpRegistry::Global(), lib_def);
GrapplerFunctionItem item;
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
// Replace function body with identity function
item.SwapFunctionBody(std::move(id_func_body));
@@ -754,7 +764,8 @@ TEST_F(FunctionsTest, FunctionDefGrapplerFunctionItemRoundTrip) {
GrapplerFunctionItem item;
std::unordered_map<string, AttrValue> func_attr;
func_attr["T"].set_type(DT_INT32);
- TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib, &item));
+ TF_EXPECT_OK(MakeGrapplerFunctionItem(func, func_attr, flib,
+ TF_GRAPH_DEF_VERSION, &item));
FunctionDef func2;
TF_EXPECT_OK(MakeFunctionDef(item, flib, &func2));
diff --git a/tensorflow/core/grappler/utils/grappler_test.cc b/tensorflow/core/grappler/utils/grappler_test.cc
index 910b0acaef..6266733f3e 100644
--- a/tensorflow/core/grappler/utils/grappler_test.cc
+++ b/tensorflow/core/grappler/utils/grappler_test.cc
@@ -30,13 +30,16 @@ GrapplerTest::GrapplerTest() {
// optimizations interfering in the comparison.
RewriterConfig* cfg =
options_.config.mutable_graph_options()->mutable_rewrite_options();
- cfg->set_constant_folding(RewriterConfig::OFF);
+ // TODO(rmlarsen): Add utility to generate config w/ all optimizers turned
+ // off.
cfg->set_arithmetic_optimization(RewriterConfig::OFF);
+ cfg->set_constant_folding(RewriterConfig::OFF);
+ cfg->set_debug_stripper(RewriterConfig::OFF);
cfg->set_dependency_optimization(RewriterConfig::OFF);
- cfg->set_loop_optimization(RewriterConfig::OFF);
cfg->set_function_optimization(RewriterConfig::OFF);
cfg->set_layout_optimizer(RewriterConfig::OFF);
- cfg->set_debug_stripper(RewriterConfig::OFF);
+ cfg->set_loop_optimization(RewriterConfig::OFF);
+ cfg->set_pin_to_host_optimization(RewriterConfig::OFF);
}
std::vector<Tensor> GrapplerTest::EvaluateNodes(
diff --git a/tensorflow/core/grappler/utils/scc.h b/tensorflow/core/grappler/utils/scc.h
index 4fb7aab647..ceb9f5dbf2 100644
--- a/tensorflow/core/grappler/utils/scc.h
+++ b/tensorflow/core/grappler/utils/scc.h
@@ -24,15 +24,16 @@ limitations under the License.
namespace tensorflow {
namespace grappler {
-// Compute modified strongly connected components:
+// Computes modified strongly connected components:
// All nodes that are not part of a loop are assigned the special -1 id
// All nodes that are part of at least one loop are assigned a positive
// component id: if 2 nodes v and w are reachable from one another (i.e. if they
// belong to the same scc), they'll be assigned the same id, otherwise they'll
-// be assigned distinct ids. Returns the number of distinct ids.
+// be assigned distinct ids. *num_components is set to the number of distinct
+// ids.
void StronglyConnectedComponents(
const GraphDef& graph, std::unordered_map<const NodeDef*, int>* components,
- int* num_ids);
+ int* num_components);
// Returns the number of individual loops present in the graph, and populate the
// 'loops' argument with the collection of loops (denoted by their loop ids) a
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc b/tensorflow/core/grappler/utils/symbolic_shapes.cc
index 155843a744..1666de4b80 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes.h b/tensorflow/core/grappler/utils/symbolic_shapes.h
index ace7bd1fe7..0a7d8ac82b 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes.h
+++ b/tensorflow/core/grappler/utils/symbolic_shapes.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
-#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
+#define TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
@@ -74,4 +74,4 @@ int64 ComputeSizeRatio(const TensorShapeProto& numerator,
} // namespace grappler
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_SYMBOLIC_SHAPES_H_
+#endif // TENSORFLOW_CORE_GRAPPLER_UTILS_SYMBOLIC_SHAPES_H_
diff --git a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
index 7ce995d1c5..6ac644cdb1 100644
--- a/tensorflow/core/grappler/optimizers/symbolic_shapes_test.cc
+++ b/tensorflow/core/grappler/utils/symbolic_shapes_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
+#include "tensorflow/core/grappler/utils/symbolic_shapes.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/core/grappler/utils_test.cc b/tensorflow/core/grappler/utils_test.cc
index c6e035834c..9b6c1f690b 100644
--- a/tensorflow/core/grappler/utils_test.cc
+++ b/tensorflow/core/grappler/utils_test.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/notification.h"
#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace grappler {
@@ -147,6 +148,21 @@ TEST_F(UtilsTest, NodePosition) {
EXPECT_EQ(0, NodePosition(""));
}
+TEST_F(UtilsTest, NodePositionIfSameNode) {
+ EXPECT_EQ(-2, NodePositionIfSameNode(":123", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode(":", ""));
+ EXPECT_EQ(-2, NodePositionIfSameNode("", ""));
+ EXPECT_EQ(123, NodePositionIfSameNode("abc:123", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc", "abc"));
+ EXPECT_EQ(-1, NodePositionIfSameNode("^abc:123", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc", "abc/xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc/xyz", "abc"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("abc:123", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc", "xyz"));
+ EXPECT_EQ(-2, NodePositionIfSameNode("^abc:123", "xyz"));
+}
+
TEST_F(UtilsTest, AddNodeNamePrefix) {
EXPECT_EQ("OPTIMIZED/abc", AddPrefixToNodeName("abc", "OPTIMIZED"));
EXPECT_EQ("^OPTIMIZED/abc", AddPrefixToNodeName("^abc", "OPTIMIZED"));
@@ -209,7 +225,6 @@ TEST_F(UtilsTest, GetTailOfChain) {
auto noop = ops::NoOp(s.WithControlDependencies(neg0).WithOpName("noop"));
GraphDef graph;
TF_CHECK_OK(s.ToGraphDef(&graph));
- LOG(INFO) << graph.DebugString();
ASSERT_EQ("c0", graph.node(0).name());
ASSERT_EQ("c1", graph.node(1).name());
@@ -336,9 +351,45 @@ TEST_F(UtilsTest, NumNonControlOutputs) {
}
TEST_F(UtilsTest, DeleteNodes) {
- // TODO(rmlarsen): write forgtten test.
+ // TODO(rmlarsen): write forgotten test.
}
+#define BM_NodePositionIfSameNode(I, N, NAME) \
+ static void BM_NodePositionIfSameNode_##NAME(int iters) { \
+ string input = I; \
+ string node = N; \
+ for (int i = 0; i < iters; ++i) { \
+ const int pos = NodePositionIfSameNode(input, node); \
+ CHECK_GT(pos, -3); \
+ } \
+ } \
+ BENCHMARK(BM_NodePositionIfSameNode_##NAME)
+
+BM_NodePositionIfSameNode("foo/bar/baz:7", "foo/bar/baz", Match_7);
+BM_NodePositionIfSameNode("foo/bar/baz", "foo/bar/baz", Match_0);
+BM_NodePositionIfSameNode("^foo/bar/baz", "foo/bar/baz", Match_Ctrl);
+BM_NodePositionIfSameNode("blah", "foo/bar/baz", NoMatch_0);
+BM_NodePositionIfSameNode("foo/bar/baz/gnu", "foo/bar/baz", NoMatch_end);
+
+#define BM_ParseNodeNameAsStringPiece(I, NAME) \
+ static void BM_ParseNodeNameAsStringPiece_##NAME(int iters) { \
+ string input = I; \
+ for (int i = 0; i < iters; ++i) { \
+ int position; \
+ const StringPiece name = ParseNodeNameAsStringPiece(input, &position); \
+ CHECK_GE(position, -1); \
+ CHECK(!name.empty()); \
+ } \
+ } \
+ BENCHMARK(BM_ParseNodeNameAsStringPiece_##NAME)
+
+BM_ParseNodeNameAsStringPiece("foo", foo);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz", foo_bar_baz);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz", foo_bar_baz_ctrl);
+BM_ParseNodeNameAsStringPiece("foo:123", foo123);
+BM_ParseNodeNameAsStringPiece("foo/bar/baz:123", foo_bar_baz_123);
+BM_ParseNodeNameAsStringPiece("^foo/bar/baz:123", foo_bar_baz_123_ctrl);
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index a30916d8b9..3a920f26f3 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -30,6 +30,7 @@ load(
"//tensorflow:tensorflow.bzl",
"if_android",
"tf_cc_test",
+ "tf_cc_test_mkl",
"tf_cc_tests",
"tf_cc_binary",
"tf_copts",
@@ -50,8 +51,14 @@ load(
"tf_kernel_tests_linkstatic",
)
load(
+ "//tensorflow/core:platform/default/build_config_root.bzl",
+ "tf_cuda_tests_tags",
+)
+load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
+ "if_mkl_ml",
+ "mkl_deps",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
@@ -210,6 +217,19 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "extract_volume_patches_op",
+ prefix = "extract_volume_patches_op",
+ deps = [
+ ":bounds_check",
+ ":eigen_helpers",
+ ":ops_util",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
cc_library(
name = "conv_3d",
hdrs = ["conv_3d.h"],
@@ -493,16 +513,6 @@ cc_library(
],
)
-cc_library(
- name = "warn_about_ints",
- srcs = ["warn_about_ints.cc"],
- hdrs = ["warn_about_ints.h"],
- deps = [
- "//tensorflow/core:framework",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
# Private support libraries ---------------------------------------------------
cc_header_only_library(
@@ -625,9 +635,11 @@ cc_library(
":diag_op",
":edit_distance_op",
":extract_image_patches_op",
+ ":extract_volume_patches_op",
":gather_nd_op",
":gather_op",
":guarantee_const_op",
+ ":host_constant_op",
":identity_n_op",
":identity_op",
":inplace_ops",
@@ -643,6 +655,7 @@ cc_library(
":reshape_op",
":reverse_op",
":reverse_sequence_op",
+ ":searchsorted_op",
":shape_ops",
":slice_op",
":snapshot_op",
@@ -650,14 +663,7 @@ cc_library(
":split_v_op",
":strided_slice_op",
":tile_ops",
- ] + if_mkl(
- [
- ":mkl_transpose_op",
- ],
- [
- ":transpose_op",
- ],
- ) + [
+ ":transpose_op",
":unique_op",
":unpack_op",
":unravel_index_op",
@@ -702,6 +708,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "host_constant_op",
+ prefix = "host_constant_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "diag_op",
prefix = "diag_op",
deps = ARRAY_DEPS,
@@ -877,6 +889,12 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "searchsorted_op",
+ prefix = "searchsorted_op",
+ deps = ARRAY_DEPS,
+)
+
+tf_kernel_library(
name = "inplace_ops",
prefix = "inplace_ops",
deps = ARRAY_DEPS,
@@ -894,27 +912,13 @@ tf_kernel_library(
deps = ARRAY_DEPS,
)
-if_mkl(
- [tf_mkl_kernel_library(
- name = "mkl_transpose_op",
- srcs = [
- "mkl_transpose_op.cc",
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS + if_mkl([
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ]),
- )],
- [tf_kernel_library(
- name = "transpose_op",
- srcs = [
- "transpose_op.cc",
- ],
- hdrs = ["transpose_op.h"],
- deps = ARRAY_DEPS,
- )],
+tf_kernel_library(
+ name = "transpose_op",
+ srcs = [
+ "transpose_op.cc",
+ ],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + if_mkl([":mkl_transpose_op"]),
)
tf_kernel_library(
@@ -1127,7 +1131,7 @@ tf_cuda_cc_test(
name = "depthwise_conv_ops_test",
size = "small",
srcs = ["depthwise_conv_ops_test.cc"],
- tags = ["requires-gpu-sm35"],
+ tags = tf_cuda_tests_tags(),
deps = [
":conv_ops",
":image",
@@ -1193,8 +1197,10 @@ tf_cc_test(
tf_cc_test(
name = "example_parsing_ops_test",
- size = "large",
+ size = "medium",
srcs = ["example_parsing_ops_test.cc"],
+ shard_count = 4,
+ tags = ["optonly"],
deps = [
":example_parsing_ops",
":ops_testutil",
@@ -1284,6 +1290,7 @@ tf_cuda_cc_test(
srcs = ["gather_op_test.cc"],
deps = [
":gather_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -1302,6 +1309,7 @@ tf_cuda_cc_test(
srcs = ["gather_nd_op_test.cc"],
deps = [
":gather_nd_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -2022,8 +2030,8 @@ tf_kernel_library(
":variable_ops",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:resource_variable_ops_op_lib",
- "//third_party/eigen3",
],
)
@@ -2298,6 +2306,31 @@ tf_cc_tests(
],
)
+cc_library(
+ name = "eigen_benchmark",
+ testonly = 1,
+ hdrs = [
+ "eigen_benchmark.h",
+ ":eigen_helpers",
+ ],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
+ name = "eigen_benchmark_cpu_test",
+ srcs = ["eigen_benchmark_cpu_test.cc"],
+ deps = [
+ ":eigen_benchmark",
+ ":eigen_helpers",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//third_party/eigen3",
+ ],
+)
+
tf_cc_tests(
name = "basic_ops_benchmark_test",
size = "small",
@@ -2558,6 +2591,7 @@ tf_kernel_library(
# allow multiple definitions when linking this.
linkopts = select({
"//tensorflow:darwin": [],
+ "//tensorflow:windows": [],
"//conditions:default": ["-Wl,-z,muldefs"],
}),
visibility = [":friends"],
@@ -2696,6 +2730,7 @@ cc_library(
)
LOGGING_DEPS = [
+ "@com_google_absl//absl/strings",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -2753,6 +2788,7 @@ tf_cc_tests(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/strings",
],
)
@@ -2867,7 +2903,7 @@ tf_kernel_library(
tf_kernel_library(
name = "batch_matmul_op",
- srcs = [] + if_mkl([
+ srcs = if_mkl_ml([
"mkl_batch_matmul_op.cc",
]),
# <prefix>*impl.h are excluded by default from the CPU build, add explicitly.
@@ -2876,7 +2912,7 @@ tf_kernel_library(
# to avoid long compiling time. See https://github.com/tensorflow/tensorflow/issues/10521
copts = if_override_eigen_strong_inline(["/DEIGEN_STRONG_INLINE=inline"]),
prefix = "batch_matmul_op",
- deps = MATH_DEPS + if_mkl([
+ deps = MATH_DEPS + if_mkl_ml([
"//third_party/mkl:intel_binary_blob",
]),
)
@@ -2959,10 +2995,7 @@ tf_kernel_library(
"@libxsmm_archive//:xsmm_avx",
],
"//conditions:default": [],
- }) + if_mkl([
- "//third_party/mkl:intel_binary_blob",
- "@mkl_dnn",
- ]) + if_cuda([
+ }) + mkl_deps() + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin",
]),
)
@@ -3172,6 +3205,7 @@ tf_cuda_cc_test(
"//conditions:default": [],
}),
deps = [
+ ":host_constant_op",
":ops_testutil",
":ops_util",
":reduction_ops",
@@ -3307,6 +3341,7 @@ tf_cuda_cc_test(
srcs = ["diag_op_test.cc"],
deps = [
":diag_op",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -3527,13 +3562,13 @@ tf_kernel_library(
tf_kernel_library(
name = "softplus_op",
prefix = "softplus_op",
- deps = NN_DEPS + [":warn_about_ints"],
+ deps = NN_DEPS,
)
tf_kernel_library(
name = "softsign_op",
prefix = "softsign_op",
- deps = NN_DEPS + [":warn_about_ints"],
+ deps = NN_DEPS,
)
tf_kernel_library(
@@ -3634,6 +3669,7 @@ tf_cuda_cc_test(
name = "nn_ops_test",
srcs = ["nn_ops_test.cc"],
deps = [
+ ":host_constant_op",
":nn",
":ops_testutil",
":ops_util",
@@ -3767,7 +3803,7 @@ tf_kernel_library(
"spacetobatch_functor.h",
"spacetobatch_functor_gpu.cu.cc",
],
- visibility = ["//visibility:private"],
+ visibility = [":friends"],
deps = [
":bounds_check",
"//tensorflow/core:framework",
@@ -3781,6 +3817,7 @@ tf_cuda_cc_test(
srcs = ["spacetobatch_benchmark_test.cc"],
deps = [
":batch_space_ops",
+ ":host_constant_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -3920,6 +3957,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["random_op_test.cc"],
deps = [
+ ":host_constant_op",
":random_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -4013,11 +4051,6 @@ cc_library(
)
SPARSE_DEPS = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:sparse_ops_op_lib",
@@ -4050,7 +4083,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_cross_op",
prefix = "sparse_cross_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4062,13 +4097,19 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_dense_binary_op_shared",
prefix = "sparse_dense_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_sparse_binary_op_shared",
prefix = "sparse_sparse_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4100,7 +4141,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_softmax",
prefix = "sparse_softmax",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4112,25 +4155,37 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_tensor_dense_add_op",
prefix = "sparse_tensor_dense_add_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":scatter_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_tensor_dense_matmul_op",
prefix = "sparse_tensor_dense_matmul_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ ":fill_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_to_dense_op",
prefix = "sparse_to_dense_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_xent_op",
prefix = "sparse_xent_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4174,6 +4229,7 @@ tf_cuda_cc_tests(
"sparse_xent_op_test.cc",
],
deps = [
+ ":host_constant_op",
":ops_testutil",
":ops_util",
":sparse",
@@ -4194,6 +4250,7 @@ cc_library(
"hinge-loss.h",
"logistic-loss.h",
"loss.h",
+ "poisson-loss.h",
"smooth-hinge-loss.h",
"squared-loss.h",
],
@@ -4386,16 +4443,32 @@ cc_library(
":reduce_join_op",
":regex_full_match_op",
":regex_replace_op",
+ ":string_format_op",
":string_join_op",
+ ":string_length_op",
":string_split_op",
":string_strip_op",
":string_to_hash_bucket_op",
":substr_op",
+ ":unicode_script_op",
+ ],
+)
+
+cc_library(
+ name = "string_util",
+ srcs = ["string_util.cc"],
+ hdrs = ["string_util.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "@icu//:common",
],
)
STRING_DEPS = [
":bounds_check",
+ ":string_util",
"//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -4416,12 +4489,42 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "string_format_op",
+ prefix = "string_format_op",
+ deps = STRING_DEPS + ["@com_google_absl//absl/strings"],
+)
+
+tf_cc_test(
+ name = "string_format_op_test",
+ size = "small",
+ srcs = ["string_format_op_test.cc"],
+ deps = [
+ ":string_format_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "string_join_op",
prefix = "string_join_op",
deps = STRING_DEPS,
)
tf_kernel_library(
+ name = "string_length_op",
+ prefix = "string_length_op",
+ deps = STRING_DEPS,
+)
+
+tf_kernel_library(
name = "regex_full_match_op",
prefix = "regex_full_match_op",
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
@@ -4433,12 +4536,48 @@ tf_kernel_library(
deps = STRING_DEPS + ["@com_googlesource_code_re2//:re2"],
)
+tf_cc_test(
+ name = "regex_replace_op_test",
+ size = "small",
+ srcs = ["regex_replace_op_test.cc"],
+ deps = [
+ ":regex_replace_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "string_split_op",
prefix = "string_split_op",
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "string_split_op_test",
+ size = "small",
+ srcs = ["string_split_op_test.cc"],
+ deps = [
+ ":string_split_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "string_strip_op",
prefix = "string_strip_op",
@@ -4451,6 +4590,25 @@ tf_kernel_library(
deps = STRING_DEPS,
)
+tf_cc_test(
+ name = "substr_op_test",
+ size = "small",
+ srcs = ["substr_op_test.cc"],
+ deps = [
+ ":substr_op",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:ops_testutil",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
tf_kernel_library(
name = "as_string_op",
prefix = "as_string_op",
@@ -4512,6 +4670,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["multinomial_op_test.cc"],
deps = [
+ ":host_constant_op",
":multinomial_op",
":ops_util",
"//tensorflow/core:core_cpu",
@@ -4539,6 +4698,7 @@ tf_cuda_cc_test(
size = "small",
srcs = ["parameterized_truncated_normal_op_test.cc"],
deps = [
+ ":host_constant_op",
":ops_util",
":parameterized_truncated_normal_op",
"//tensorflow/core:core_cpu",
@@ -5039,6 +5199,7 @@ filegroup(
"spacetobatch_functor.h",
"spacetodepth_op.h",
"spectrogram.h",
+ "string_util.h",
"tensor_array.h",
"tile_functor.h",
"tile_ops_cpu_impl.h",
@@ -5048,7 +5209,6 @@ filegroup(
"training_ops.h",
"transpose_functor.h",
"transpose_op.h",
- "warn_about_ints.h",
"where_op.h",
"xent_op.h",
],
@@ -5119,6 +5279,8 @@ filegroup(
"cwise_op_squared_difference.cc",
"cwise_op_sub.cc",
"cwise_op_tanh.cc",
+ "cwise_op_xlogy.cc",
+ "cwise_op_xdivy.cc",
"data_format_ops.cc",
"decode_wav_op.cc",
"deep_conv2d.cc",
@@ -5130,6 +5292,7 @@ filegroup(
"fifo_queue.cc",
"fifo_queue_op.cc",
"fused_batch_norm_op.cc",
+ "listdiff_op.cc",
"population_count_op.cc",
"population_count_op.h",
"winograd_transform.h",
@@ -5207,6 +5370,7 @@ filegroup(
"spectrogram_op.cc",
"stack_ops.cc",
"string_join_op.cc",
+ "string_util.cc",
"summary_op.cc",
"tensor_array.cc",
"tensor_array_ops.cc",
@@ -5225,7 +5389,6 @@ filegroup(
"transpose_functor_cpu.cc",
"transpose_op.cc",
"unique_op.cc",
- "warn_about_ints.cc",
"where_op.cc",
"xent_op.cc",
":android_extended_ops_headers",
@@ -5333,6 +5496,7 @@ filegroup(
"batch_kernels.*",
"regex_full_match_op.cc",
"regex_replace_op.cc",
+ "unicode_script_op.cc",
# Ops that are inherently incompatible with Android (e.g. tied to x86 platform).
"mkl_*",
"xsmm_*",
@@ -6152,8 +6316,27 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
+)
+
+tf_cc_test_mkl(
+ name = "mkl_conv_ops_test",
+ size = "small",
+ srcs = ["mkl_conv_ops_test.cc"],
+ deps = [
+ ":ops_testutil",
+ ":ops_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
)
tf_mkl_kernel_library(
@@ -6167,8 +6350,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6183,8 +6365,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6203,8 +6384,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6219,8 +6399,7 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6235,56 +6414,49 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib",
"//third_party/eigen3",
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ ] + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_fused_batch_norm_op",
srcs = ["mkl_fused_batch_norm_op.cc"],
- deps = NN_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = NN_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_aggregate_ops",
prefix = "mkl_aggregate_ops",
- deps = MATH_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = MATH_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_concat_op",
prefix = "mkl_concat_op",
- deps = ARRAY_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = ARRAY_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_reshape_op",
prefix = "mkl_reshape_op",
- deps = ARRAY_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = ARRAY_DEPS + mkl_deps(),
+)
+
+tf_mkl_kernel_library(
+ name = "mkl_slice_op",
+ prefix = "mkl_slice_op",
+ deps = ARRAY_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_identity_op",
prefix = "mkl_identity_op",
- deps = ARRAY_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = ARRAY_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
name = "mkl_lrn_op",
prefix = "mkl_lrn_op",
- deps = NN_DEPS + [
- "//third_party/mkl:intel_binary_blob",
- ] + if_mkl(["@mkl_dnn"]),
+ deps = NN_DEPS + mkl_deps(),
)
tf_mkl_kernel_library(
@@ -6295,10 +6467,16 @@ tf_mkl_kernel_library(
"cwise_ops_gradients.h",
],
prefix = "mkl_cwise_ops_common",
- deps = NN_DEPS + [
- "cwise_op",
- "//third_party/mkl:intel_binary_blob",
+ deps = NN_DEPS + mkl_deps() + [":cwise_op"],
+)
+
+tf_mkl_kernel_library(
+ name = "mkl_transpose_op",
+ srcs = [
+ "mkl_transpose_op.cc",
],
+ hdrs = ["transpose_op.h"],
+ deps = ARRAY_DEPS + mkl_deps(),
)
# NOTE(lespeholt): This rule is deprecated, please use:
@@ -6413,6 +6591,16 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "unicode_script_op",
+ srcs = ["unicode_script_op.cc"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:string_ops_op_lib",
+ "@icu//:common",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/core/kernels/adjust_contrast_op.h b/tensorflow/core/kernels/adjust_contrast_op.h
index 7689c04214..f4a53c2ef9 100644
--- a/tensorflow/core/kernels/adjust_contrast_op.h
+++ b/tensorflow/core/kernels/adjust_contrast_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
-#define TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_CONTRAST_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_CONTRAST_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -153,4 +153,4 @@ struct AdjustContrastv2 {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ADJUST_CONTRAST_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_CONTRAST_OP_H_
diff --git a/tensorflow/core/kernels/adjust_hue_op.h b/tensorflow/core/kernels/adjust_hue_op.h
index 03d52a9e77..983a4072bf 100644
--- a/tensorflow/core/kernels/adjust_hue_op.h
+++ b/tensorflow/core/kernels/adjust_hue_op.h
@@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
-#define _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -37,4 +37,4 @@ struct AdjustHueGPU {
} // namespace tensorflow
#endif // GOOGLE_CUDA
-#endif // _TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_HUE_OP_H_
diff --git a/tensorflow/core/kernels/adjust_saturation_op.h b/tensorflow/core/kernels/adjust_saturation_op.h
index 05c45c07c3..fd28ba536f 100644
--- a/tensorflow/core/kernels/adjust_saturation_op.h
+++ b/tensorflow/core/kernels/adjust_saturation_op.h
@@ -11,8 +11,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H
-#define _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H
+#ifndef TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -37,4 +37,4 @@ struct AdjustSaturationGPU {
} // namespace tensorflow
#endif // GOOGLE_CUDA
-#endif // _TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H
+#endif // TENSORFLOW_CORE_KERNELS_ADJUST_SATURATION_OP_H_
diff --git a/tensorflow/core/kernels/aggregate_ops.h b/tensorflow/core/kernels/aggregate_ops.h
index 9ea49fc34b..e074d0c2d9 100644
--- a/tensorflow/core/kernels/aggregate_ops.h
+++ b/tensorflow/core/kernels/aggregate_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
-#define TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
// Functor definitions for Aggregate ops, must be compilable by nvcc.
@@ -223,4 +223,4 @@ struct Add9EigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_H_
diff --git a/tensorflow/core/kernels/aggregate_ops_cpu.h b/tensorflow/core/kernels/aggregate_ops_cpu.h
index aa1cead928..3e87917b64 100644
--- a/tensorflow/core/kernels/aggregate_ops_cpu.h
+++ b/tensorflow/core/kernels/aggregate_ops_cpu.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
-#define TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_
+#define TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -250,4 +250,4 @@ struct Add9Functor<SYCLDevice, T> {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_AGGREGATE_OPS_CPU_H_
+#endif // TENSORFLOW_CORE_KERNELS_AGGREGATE_OPS_CPU_H_
diff --git a/tensorflow/core/kernels/argmax_op.h b/tensorflow/core/kernels/argmax_op.h
index b8bc41e089..224aa4654d 100644
--- a/tensorflow/core/kernels/argmax_op.h
+++ b/tensorflow/core/kernels/argmax_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_ARGMAX_OP_H_
-#define TENSORFLOW_KERNELS_ARGMAX_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_
// Generator definition for ArgMaxOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -65,4 +65,4 @@ struct ArgMin {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ARGMAX_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ARGMAX_OP_H_
diff --git a/tensorflow/core/kernels/assign_op.h b/tensorflow/core/kernels/assign_op.h
index a450b1d1ee..74f926bdc8 100644
--- a/tensorflow/core/kernels/assign_op.h
+++ b/tensorflow/core/kernels/assign_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_ASSIGN_OP_H_
-#define TENSORFLOW_KERNELS_ASSIGN_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_
#define EIGEN_USE_THREADS
@@ -143,4 +143,4 @@ class AssignOp : public OpKernel {
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ASSIGN_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ASSIGN_OP_H_
diff --git a/tensorflow/core/kernels/avgpooling_op.h b/tensorflow/core/kernels/avgpooling_op.h
index f5e81dbc09..1e49a66af9 100644
--- a/tensorflow/core/kernels/avgpooling_op.h
+++ b/tensorflow/core/kernels/avgpooling_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
-#define TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_
+#define TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_
// Functor definition for AvgPoolingOp, must be compilable by nvcc.
#include "tensorflow/core/framework/tensor_types.h"
@@ -76,4 +76,4 @@ bool RunAvePoolBackwardNHWC(const T* const top_diff, const int num,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_AVGPOOLING_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_AVGPOOLING_OP_H_
diff --git a/tensorflow/core/kernels/batch_matmul_op_complex.cc b/tensorflow/core/kernels/batch_matmul_op_complex.cc
index 54c45bfe63..f48bd0c318 100644
--- a/tensorflow/core/kernels/batch_matmul_op_complex.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_complex.cc
@@ -17,14 +17,18 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own complex64/128 kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_GPU);
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 475bda848d..766713a338 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -15,6 +15,9 @@ limitations under the License.
// See docs in ../ops/math_ops.cc.
+#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
+
#define EIGEN_USE_THREADS
#include <vector>
@@ -613,3 +616,5 @@ class BatchMatMul : public OpKernel {
BatchMatMul<SYCLDevice, TYPE>)
#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/batch_matmul_op_real.cc b/tensorflow/core/kernels/batch_matmul_op_real.cc
index 584b507c70..25ae795d8e 100644
--- a/tensorflow/core/kernels/batch_matmul_op_real.cc
+++ b/tensorflow/core/kernels/batch_matmul_op_real.cc
@@ -21,10 +21,15 @@ limitations under the License.
namespace tensorflow {
-#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY)
+// MKL_ML registers its own float and double kernels in mkl_batch_matmul_op.cc
+// if defined(INTEL_MKL) && !defined(INTEL_MKL_DNN_ONLY) && defined(ENABLE_MKL).
+// Anything else (the complement) should register the TF ones.
+// (MKL-DNN doesn't implement these kernels either.)
+#if !defined(INTEL_MKL) || defined(INTEL_MKL_DNN_ONLY) || !defined(ENABLE_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
-#endif
+#endif // !INTEL_MKL || INTEL_MKL_DNN_ONLY || !ENABLE_MKL
+
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);
diff --git a/tensorflow/core/kernels/batch_norm_op.h b/tensorflow/core/kernels/batch_norm_op.h
index 48e73c8757..76b156f8fd 100644
--- a/tensorflow/core/kernels/batch_norm_op.h
+++ b/tensorflow/core/kernels/batch_norm_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
-#define TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
// Functor definition for BatchNormOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -153,4 +153,4 @@ struct BatchNormGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BATCH_NORM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BATCH_NORM_OP_H_
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 792eb74e31..0d53240330 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -1,7 +1,7 @@
# Description: Utilities.
package(
- default_visibility = ["//tensorflow:internal"],
+ default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
@@ -12,7 +12,6 @@ cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
@@ -21,7 +20,6 @@ cc_library(
cc_library(
name = "periodic_function",
- visibility = ["//visibility:public"],
deps = [
":periodic_function_dynamic",
"//tensorflow/core:lib",
@@ -190,7 +188,6 @@ cc_library(
testonly = 1,
srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
diff --git a/tensorflow/core/kernels/betainc_op.h b/tensorflow/core/kernels/betainc_op.h
index c4aa9543ab..b941b27ad3 100644
--- a/tensorflow/core/kernels/betainc_op.h
+++ b/tensorflow/core/kernels/betainc_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BETAINC_OP_H_
-#define TENSORFLOW_KERNELS_BETAINC_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_
// Functor definition for BetaincOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -48,4 +48,4 @@ struct Betainc {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BETAINC_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BETAINC_OP_H_
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 7b28c8e91f..e15ea82e7d 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -134,8 +134,8 @@ class BiasOp : public BinaryOp<T> {
if (data_format_ == FORMAT_NCHW) {
int32 batch, height, width, channel;
GetBiasValueDims(input, data_format_, &batch, &height, &width, &channel);
- Eigen::DSizes<int32, 4> four_dims(1, channel, 1, 1);
- Eigen::DSizes<int32, 4> broad_cast_dims(batch, 1, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(1, channel, 1, 1);
+ Eigen::DSizes<Eigen::Index, 4> broad_cast_dims(batch, 1, height, width);
const Device& d = context->eigen_device<Device>();
output->tensor<T, 4>().device(d) =
input.tensor<T, 4>() +
@@ -247,14 +247,14 @@ class BiasGradOp : public OpKernel {
OP_REQUIRES(context, output_backprop.dims() == 4,
errors::InvalidArgument(
"NCHW format supports only 4D input/output tensor."));
- Eigen::DSizes<int, 4> four_dims(batch, channel, height, width);
+ Eigen::DSizes<Eigen::Index, 4> four_dims(batch, channel, height, width);
#ifdef EIGEN_HAS_INDEX_LIST
using idx0 = Eigen::type2index<0>;
using idx2 = Eigen::type2index<2>;
using idx3 = Eigen::type2index<3>;
Eigen::IndexList<idx0, idx2, idx3> reduction_axes;
#else
- Eigen::array<int, 3> reduction_axes = {0, 2, 3};
+ Eigen::array<Eigen::Index, 3> reduction_axes = {0, 2, 3};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
@@ -263,11 +263,12 @@ class BiasGradOp : public OpKernel {
.sum(reduction_axes)
.template cast<T>(); // End of code by intel_tf.
} else {
- Eigen::DSizes<int, 2> two_dims(batch * height * width, channel);
+ Eigen::DSizes<Eigen::Index, 2> two_dims(batch * height * width,
+ channel);
#ifdef EIGEN_HAS_INDEX_LIST
Eigen::IndexList<Eigen::type2index<0> > reduction_axis;
#else
- Eigen::array<int, 1> reduction_axis = {0};
+ Eigen::array<Eigen::Index, 1> reduction_axis = {0};
#endif
output->template flat<T>().device(context->eigen_device<Device>()) =
output_backprop.flat<T>()
diff --git a/tensorflow/core/kernels/bias_op.h b/tensorflow/core/kernels/bias_op.h
index 065934c709..77f683455d 100644
--- a/tensorflow/core/kernels/bias_op.h
+++ b/tensorflow/core/kernels/bias_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BIAS_OP_H_
-#define TENSORFLOW_KERNELS_BIAS_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BIAS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BIAS_OP_H_
// Functor definition for BiasOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -52,4 +52,4 @@ struct Bias {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BIAS_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BIAS_OP_H_
diff --git a/tensorflow/core/kernels/bincount_op.h b/tensorflow/core/kernels/bincount_op.h
index cd3d560cd1..54cfb79de7 100644
--- a/tensorflow/core/kernels/bincount_op.h
+++ b/tensorflow/core/kernels/bincount_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_BINCOUNT_OP_H_
-#define TENSORFLOW_BINCOUNT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -38,4 +38,4 @@ struct BincountFunctor {
} // end namespace tensorflow
-#endif // TENSORFLOW_BINCOUNT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BINCOUNT_OP_H_
diff --git a/tensorflow/core/kernels/bincount_op_gpu.cu.cc b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
index 6074b3e1f6..7d09e9b820 100644
--- a/tensorflow/core/kernels/bincount_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bincount_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/boosted_trees/BUILD b/tensorflow/core/kernels/boosted_trees/BUILD
index 4910021c63..4e8bfa02fc 100644
--- a/tensorflow/core/kernels/boosted_trees/BUILD
+++ b/tensorflow/core/kernels/boosted_trees/BUILD
@@ -15,7 +15,9 @@ load(
tf_proto_library(
name = "boosted_trees_proto",
- srcs = ["boosted_trees.proto"],
+ srcs = [
+ "boosted_trees.proto",
+ ],
cc_api_version = 2,
visibility = ["//visibility:public"],
)
@@ -87,9 +89,21 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "quantile_ops",
+ srcs = ["quantile_ops.cc"],
+ deps = [
+ "//tensorflow/core:boosted_trees_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core/kernels/boosted_trees/quantiles:weighted_quantiles",
+ ],
+)
+
+tf_kernel_library(
name = "boosted_trees_ops",
deps = [
":prediction_ops",
+ ":quantile_ops",
":resource_ops",
":stats_ops",
":training_ops",
diff --git a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
index c9664f0c1c..1ab72af059 100644
--- a/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
+++ b/tensorflow/core/kernels/boosted_trees/boosted_trees.proto
@@ -11,6 +11,7 @@ message Node {
oneof node {
Leaf leaf = 1;
BucketizedSplit bucketized_split = 2;
+ CategoricalSplit categorical_split = 3;
}
NodeMetadata metadata = 777;
}
@@ -57,6 +58,18 @@ message BucketizedSplit {
int32 right_id = 4;
}
+message CategoricalSplit {
+ // Categorical feature column and split describing the rule feature value ==
+ // value.
+ int32 feature_id = 1;
+ int32 value = 2;
+
+ // Node children indexing into a contiguous
+ // vector of nodes starting from the root.
+ int32 left_id = 3;
+ int32 right_id = 4;
+}
+
// Tree describes a list of connected nodes.
// Node 0 must be the root and can carry any payload including a leaf
// in the case of representing the bias.
diff --git a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
index b2efa06941..4ae26fb95b 100644
--- a/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
+++ b/tensorflow/core/kernels/boosted_trees/prediction_ops.cc
@@ -334,30 +334,34 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
// Initial bias prediction. E.g., prediction based off training mean.
- example_debug_info.add_logits_path(resource->GetTreeWeight(0) *
- resource->node_value(0, 0));
+ float tree_logit =
+ resource->GetTreeWeight(0) * resource->node_value(0, 0);
+ example_debug_info.add_logits_path(tree_logit);
int32 node_id = 0;
int32 tree_id = 0;
int32 feature_id;
- float tree_logit;
float past_trees_logit = 0; // Sum of leaf logits from prior trees.
- // Populate proto.
+ // Go through each tree and populate proto.
while (tree_id <= last_tree) {
- // Feature id used to split.
- feature_id = resource->feature_id(tree_id, node_id);
- example_debug_info.add_feature_ids(feature_id);
- // Get logit after split.
- node_id = resource->next_node(tree_id, node_id, i,
- batch_bucketized_features);
- tree_logit = resource->GetTreeWeight(tree_id) *
- resource->node_value(tree_id, node_id);
- // Output logit incorporates sum of leaf logits from prior trees.
- example_debug_info.add_logits_path(tree_logit + past_trees_logit);
- if (resource->is_leaf(tree_id, node_id)) {
- // Move onto other trees.
- past_trees_logit += tree_logit;
+ if (resource->is_leaf(tree_id, node_id)) { // Move onto other trees.
+ // Accumulate tree_logits only if the leaf is non-root, but do so
+ // for bias tree.
+ if (tree_id == 0 || node_id > 0) {
+ past_trees_logit += tree_logit;
+ }
++tree_id;
node_id = 0;
+ } else { // Add to proto.
+ // Feature id used to split.
+ feature_id = resource->feature_id(tree_id, node_id);
+ example_debug_info.add_feature_ids(feature_id);
+ // Get logit after split.
+ node_id = resource->next_node(tree_id, node_id, i,
+ batch_bucketized_features);
+ tree_logit = resource->GetTreeWeight(tree_id) *
+ resource->node_value(tree_id, node_id);
+ // Output logit incorporates sum of leaf logits from prior trees.
+ example_debug_info.add_logits_path(tree_logit + past_trees_logit);
}
}
// Set output as serialized proto containing debug info.
diff --git a/tensorflow/core/kernels/boosted_trees/quantile_ops.cc b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
new file mode 100644
index 0000000000..d1840941c1
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantile_ops.cc
@@ -0,0 +1,453 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include <algorithm>
+#include <iterator>
+#include <string>
+#include <vector>
+
+#include "tensorflow/core/framework/device_base.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+
+const char* const kExampleWeightsName = "example_weights";
+const char* const kMaxElementsName = "max_elements";
+const char* const kGenerateQuantiles = "generate_quantiles";
+const char* const kNumBucketsName = "num_buckets";
+const char* const kEpsilonName = "epsilon";
+const char* const kBucketBoundariesName = "bucket_boundaries";
+const char* const kBucketsName = "buckets";
+const char* const kSummariesName = "summaries";
+const char* const kNumStreamsName = "num_streams";
+const char* const kNumFeaturesName = "num_features";
+const char* const kFloatFeaturesName = "float_values";
+const char* const kResourceHandleName = "quantile_stream_resource_handle";
+
+using QuantileStreamResource = BoostedTreesQuantileStreamResource;
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+using QuantileSummary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using QuantileSummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateBoundaries(const QuantileStream& stream,
+ const int64 num_boundaries) {
+ std::vector<float> boundaries = stream.GenerateBoundaries(num_boundaries);
+
+ // Uniquify elements as we may get dupes.
+ auto end_it = std::unique(boundaries.begin(), boundaries.end());
+ boundaries.resize(std::distance(boundaries.begin(), end_it));
+ return boundaries;
+}
+
+// Generates quantiles on a finalized QuantileStream.
+std::vector<float> GenerateQuantiles(const QuantileStream& stream,
+ const int64 num_quantiles) {
+ // Do not de-dup boundaries. Exactly num_quantiles+1 boundary values
+ // will be returned.
+ std::vector<float> boundaries = stream.GenerateQuantiles(num_quantiles - 1);
+ CHECK_EQ(boundaries.size(), num_quantiles);
+ return boundaries;
+}
+
+std::vector<float> GetBuckets(const int32 feature,
+ const OpInputList& buckets_list) {
+ const auto& buckets = buckets_list[feature].flat<float>();
+ std::vector<float> buckets_vector(buckets.data(),
+ buckets.data() + buckets.size());
+ return buckets_vector;
+}
+
+REGISTER_RESOURCE_HANDLE_KERNEL(BoostedTreesQuantileStreamResource);
+
+REGISTER_KERNEL_BUILDER(
+ Name("IsBoostedTreesQuantileStreamResourceInitialized").Device(DEVICE_CPU),
+ IsResourceInitialized<BoostedTreesQuantileStreamResource>);
+
+class BoostedTreesCreateQuantileStreamResourceOp : public OpKernel {
+ public:
+ explicit BoostedTreesCreateQuantileStreamResourceOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kMaxElementsName, &max_elements_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Only create one, if one does not exist already. Report status for all
+ // other exceptions. If one already exists, it unrefs the new one.
+ // An epsilon value of zero could cause perfoamance issues and is therefore,
+ // disallowed.
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+ OP_REQUIRES(
+ context, epsilon > 0,
+ errors::InvalidArgument("An epsilon value of zero is not allowed."));
+
+ const Tensor* num_streams_t;
+ OP_REQUIRES_OK(context, context->input(kNumStreamsName, &num_streams_t));
+ int64 num_streams = num_streams_t->scalar<int64>()();
+
+ auto result =
+ new QuantileStreamResource(epsilon, max_elements_, num_streams);
+ auto status = CreateResource(context, HandleFromInput(context, 0), result);
+ if (!status.ok() && status.code() != tensorflow::error::ALREADY_EXISTS) {
+ OP_REQUIRES(context, false, status);
+ }
+ }
+
+ private:
+ // An upper bound on the number of entries that the summaries might have
+ // for a feature.
+ int64 max_elements_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesCreateQuantileStreamResource").Device(DEVICE_CPU),
+ BoostedTreesCreateQuantileStreamResourceOp);
+
+class BoostedTreesMakeQuantileSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesMakeQuantileSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+
+ // Parse example weights and get batch size.
+ const Tensor* example_weights_t;
+ OP_REQUIRES_OK(context,
+ context->input(kExampleWeightsName, &example_weights_t));
+ auto example_weights = example_weights_t->flat<float>();
+ const int64 batch_size = example_weights.size();
+ const Tensor* epsilon_t;
+ OP_REQUIRES_OK(context, context->input(kEpsilonName, &epsilon_t));
+ float epsilon = epsilon_t->scalar<float>()();
+
+ OpOutputList summaries_output_list;
+ OP_REQUIRES_OK(
+ context, context->output_list(kSummariesName, &summaries_output_list));
+
+ auto do_quantile_summary_gen = [&](const int64 begin, const int64 end) {
+ // Iterating features.
+ for (int64 index = begin; index < end; index++) {
+ const auto feature_values = float_features_list[index].flat<float>();
+ QuantileStream stream(epsilon, batch_size + 1);
+ // Run quantile summary generation.
+ for (int64 j = 0; j < batch_size; j++) {
+ stream.PushEntry(feature_values(j), example_weights(j));
+ }
+ stream.Finalize();
+ const auto summary_entry_list = stream.GetFinalSummary().GetEntryList();
+ Tensor* output_t;
+ OP_REQUIRES_OK(
+ context,
+ summaries_output_list.allocate(
+ index,
+ TensorShape({static_cast<int64>(summary_entry_list.size()), 4}),
+ &output_t));
+ auto output = output_t->matrix<float>();
+ for (auto row = 0; row < summary_entry_list.size(); row++) {
+ const auto& entry = summary_entry_list[row];
+ output(row, 0) = entry.value;
+ output(row, 1) = entry.weight;
+ output(row, 2) = entry.min_rank;
+ output(row, 3) = entry.max_rank;
+ }
+ }
+ };
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * batch_size;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_summary_gen);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesMakeQuantileSummaries").Device(DEVICE_CPU),
+ BoostedTreesMakeQuantileSummariesOp);
+
+class BoostedTreesQuantileStreamResourceAddSummariesOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceAddSummariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ OpInputList summaries_list;
+ OP_REQUIRES_OK(context,
+ context->input_list(kSummariesName, &summaries_list));
+ int32 num_streams = stream_resource->num_streams();
+ CHECK_EQ(static_cast<int>(num_streams), summaries_list.size());
+
+ auto do_quantile_add_summary = [&](const int64 begin, const int64 end) {
+ // Iterating all features.
+ for (int64 feature_idx = begin; feature_idx < end; ++feature_idx) {
+ const Tensor& summaries = summaries_list[feature_idx];
+ const auto summary_values = summaries.matrix<float>();
+ const auto& tensor_shape = summaries.shape();
+ const int64 entries_size = tensor_shape.dim_size(0);
+ CHECK_EQ(tensor_shape.dim_size(1), 4);
+ std::vector<QuantileSummaryEntry> summary_entries;
+ summary_entries.reserve(entries_size);
+ for (int64 i = 0; i < entries_size; i++) {
+ float value = summary_values(i, 0);
+ float weight = summary_values(i, 1);
+ float min_rank = summary_values(i, 2);
+ float max_rank = summary_values(i, 3);
+ QuantileSummaryEntry entry(value, weight, min_rank, max_rank);
+ summary_entries.push_back(entry);
+ }
+ stream_resource->stream(feature_idx)->PushSummary(summary_entries);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_add_summary);
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceAddSummaries").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceAddSummariesOp);
+
+class BoostedTreesQuantileStreamResourceFlushOp : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceFlushOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context,
+ context->GetAttr(kGenerateQuantiles, &generate_quantiles_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const Tensor* num_buckets_t;
+ OP_REQUIRES_OK(context, context->input(kNumBucketsName, &num_buckets_t));
+ const int64 num_buckets = num_buckets_t->scalar<int64>()();
+ const int64 num_streams = stream_resource->num_streams();
+
+ auto do_quantile_flush = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; ++stream_idx) {
+ QuantileStream* stream = stream_resource->stream(stream_idx);
+ stream->Finalize();
+ stream_resource->set_boundaries(
+ generate_quantiles_ ? GenerateQuantiles(*stream, num_buckets)
+ : GenerateBoundaries(*stream, num_buckets),
+ stream_idx);
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_flush);
+
+ stream_resource->set_buckets_ready(true);
+ }
+
+ private:
+ bool generate_quantiles_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceFlush").Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceFlushOp);
+
+class BoostedTreesQuantileStreamResourceGetBucketBoundariesOp
+ : public OpKernel {
+ public:
+ explicit BoostedTreesQuantileStreamResourceGetBucketBoundariesOp(
+ OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(context,
+ HandleFromInput(context, kResourceHandleName, &handle));
+ QuantileStreamResource* stream_resource;
+ // Create a reference to the underlying resource using the handle.
+ OP_REQUIRES_OK(context, LookupResource(context, handle, &stream_resource));
+ // Remove the reference at the end of this scope.
+ mutex_lock l(*stream_resource->mutex());
+ core::ScopedUnref unref_me(stream_resource);
+
+ const int64 num_streams = stream_resource->num_streams();
+ CHECK_EQ(num_features_, num_streams);
+ OpOutputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+
+ auto do_quantile_get_buckets = [&](const int64 begin, const int64 end) {
+ // Iterating over all streams.
+ for (int64 stream_idx = begin; stream_idx < end; stream_idx++) {
+ const auto& boundaries = stream_resource->boundaries(stream_idx);
+ Tensor* bucket_boundaries_t = nullptr;
+ OP_REQUIRES_OK(context,
+ bucket_boundaries_list.allocate(
+ stream_idx, {static_cast<int64>(boundaries.size())},
+ &bucket_boundaries_t));
+ auto* quantiles_flat = bucket_boundaries_t->flat<float>().data();
+ memcpy(quantiles_flat, boundaries.data(),
+ sizeof(float) * boundaries.size());
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_streams;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_streams,
+ kCostPerUnit, do_quantile_get_buckets);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Device(DEVICE_CPU),
+ BoostedTreesQuantileStreamResourceGetBucketBoundariesOp);
+
+// Given the calculated quantiles thresholds and input data, this operation
+// converts the input features into the buckets (categorical values), depending
+// on which quantile they fall into.
+class BoostedTreesBucketizeOp : public OpKernel {
+ public:
+ explicit BoostedTreesBucketizeOp(OpKernelConstruction* const context)
+ : OpKernel(context) {
+ OP_REQUIRES_OK(context, context->GetAttr(kNumFeaturesName, &num_features_));
+ }
+
+ void Compute(OpKernelContext* const context) override {
+ // Read float features list;
+ OpInputList float_features_list;
+ OP_REQUIRES_OK(
+ context, context->input_list(kFloatFeaturesName, &float_features_list));
+ OpInputList bucket_boundaries_list;
+ OP_REQUIRES_OK(context, context->input_list(kBucketBoundariesName,
+ &bucket_boundaries_list));
+ OP_REQUIRES(context,
+ tensorflow::TensorShapeUtils::IsVector(
+ bucket_boundaries_list[0].shape()),
+ errors::InvalidArgument(
+ strings::Printf("Buckets should be flat vectors.")));
+ OpOutputList buckets_list;
+ OP_REQUIRES_OK(context, context->output_list(kBucketsName, &buckets_list));
+
+ auto do_quantile_get_quantiles = [&](const int64 begin, const int64 end) {
+ // Iterating over all resources
+ for (int64 feature_idx = begin; feature_idx < end; feature_idx++) {
+ const Tensor& values_tensor = float_features_list[feature_idx];
+ const int64 num_values = values_tensor.dim_size(0);
+
+ Tensor* output_t = nullptr;
+ OP_REQUIRES_OK(
+ context, buckets_list.allocate(
+ feature_idx, TensorShape({num_values, 1}), &output_t));
+ auto output = output_t->matrix<int32>();
+
+ const std::vector<float>& bucket_boundaries_vector =
+ GetBuckets(feature_idx, bucket_boundaries_list);
+ CHECK(!bucket_boundaries_vector.empty())
+ << "Got empty buckets for feature " << feature_idx;
+ auto flat_values = values_tensor.flat<float>();
+ for (int64 instance = 0; instance < num_values; instance++) {
+ const float value = flat_values(instance);
+ auto bucket_iter =
+ std::lower_bound(bucket_boundaries_vector.begin(),
+ bucket_boundaries_vector.end(), value);
+ if (bucket_iter == bucket_boundaries_vector.end()) {
+ --bucket_iter;
+ }
+ const int32 bucket = static_cast<int32>(
+ bucket_iter - bucket_boundaries_vector.begin());
+ // Bucket id.
+ output(instance, 0) = bucket;
+ }
+ }
+ };
+
+ // TODO(tanzheny): comment on the magic number.
+ const int64 kCostPerUnit = 500 * num_features_;
+ const DeviceBase::CpuWorkerThreads& worker_threads =
+ *context->device()->tensorflow_cpu_worker_threads();
+ Shard(worker_threads.num_threads, worker_threads.workers, num_features_,
+ kCostPerUnit, do_quantile_get_quantiles);
+ }
+
+ private:
+ int64 num_features_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BoostedTreesBucketize").Device(DEVICE_CPU),
+ BoostedTreesBucketizeOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/BUILD b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
new file mode 100644
index 0000000000..12d9473776
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/BUILD
@@ -0,0 +1,65 @@
+# Description:
+# This directory contains common quantile utilities used in boosted_trees.
+package(
+ default_visibility = ["//tensorflow:internal"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "tf_cc_test")
+
+# Quantiles
+
+cc_library(
+ name = "weighted_quantiles",
+ srcs = [],
+ hdrs = [
+ "quantile_stream_resource.h",
+ "weighted_quantiles_buffer.h",
+ "weighted_quantiles_stream.h",
+ "weighted_quantiles_summary.h",
+ ],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_buffer_test",
+ size = "small",
+ srcs = ["weighted_quantiles_buffer_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_summary_test",
+ size = "small",
+ srcs = ["weighted_quantiles_summary_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
+
+tf_cc_test(
+ name = "weighted_quantiles_stream_test",
+ size = "small",
+ srcs = ["weighted_quantiles_stream_test.cc"],
+ deps = [
+ ":weighted_quantiles",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
new file mode 100644
index 0000000000..1c31724272
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/quantile_stream_resource.h
@@ -0,0 +1,96 @@
+// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
+
+#include <vector>
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+using QuantileStream =
+ boosted_trees::quantiles::WeightedQuantilesStream<float, float>;
+
+// Quantile Stream Resource for a list of streams sharing the same number of
+// quantiles, maximum elements, and epsilon.
+class BoostedTreesQuantileStreamResource : public ResourceBase {
+ public:
+ BoostedTreesQuantileStreamResource(const float epsilon,
+ const int64 max_elements,
+ const int64 num_streams)
+ : are_buckets_ready_(false),
+ epsilon_(epsilon),
+ num_streams_(num_streams),
+ max_elements_(max_elements) {
+ streams_.reserve(num_streams_);
+ boundaries_.reserve(num_streams_);
+ for (int64 idx = 0; idx < num_streams; ++idx) {
+ streams_.push_back(QuantileStream(epsilon, max_elements));
+ boundaries_.push_back(std::vector<float>());
+ }
+ }
+
+ string DebugString() override { return "QuantileStreamResource"; }
+
+ tensorflow::mutex* mutex() { return &mu_; }
+
+ QuantileStream* stream(const int64 index) { return &streams_[index]; }
+
+ const std::vector<float>& boundaries(const int64 index) {
+ return boundaries_[index];
+ }
+
+ void set_boundaries(const std::vector<float>& boundaries, const int64 index) {
+ boundaries_[index] = boundaries;
+ }
+
+ float epsilon() const { return epsilon_; }
+ int64 num_streams() const { return num_streams_; }
+
+ bool are_buckets_ready() const { return are_buckets_ready_; }
+ void set_buckets_ready(const bool are_buckets_ready) {
+ are_buckets_ready_ = are_buckets_ready;
+ }
+
+ private:
+ ~BoostedTreesQuantileStreamResource() override {}
+
+ // Mutex for the whole resource.
+ tensorflow::mutex mu_;
+
+ // Quantile streams.
+ std::vector<QuantileStream> streams_;
+
+ // Stores the boundaries. Same size as streams_.
+ std::vector<std::vector<float>> boundaries_;
+
+ // Whether boundaries are created. Initially boundaries are empty until
+ // set_boundaries are called.
+ bool are_buckets_ready_;
+
+ const float epsilon_;
+ const int64 num_streams_;
+ // An upper-bound for the number of elements.
+ int64 max_elements_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(BoostedTreesQuantileStreamResource);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_QUANTILE_STREAM_RESOURCE_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h
new file mode 100644
index 0000000000..07aa9831c4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h
@@ -0,0 +1,132 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
+
+#include <algorithm>
+#include <unordered_map>
+#include <vector>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Buffering container ideally suited for scenarios where we need
+// to sort and dedupe/compact fixed chunks of a stream of weighted elements.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesBuffer {
+ public:
+ struct BufferEntry {
+ BufferEntry(ValueType v, WeightType w)
+ : value(std::move(v)), weight(std::move(w)) {}
+ BufferEntry() : value(), weight(0) {}
+
+ bool operator<(const BufferEntry& other) const {
+ return kCompFn(value, other.value);
+ }
+ bool operator==(const BufferEntry& other) const {
+ return value == other.value && weight == other.weight;
+ }
+ friend std::ostream& operator<<(std::ostream& strm,
+ const BufferEntry& entry) {
+ return strm << "{" << entry.value << ", " << entry.weight << "}";
+ }
+ ValueType value;
+ WeightType weight;
+ };
+
+ explicit WeightedQuantilesBuffer(int64 block_size, int64 max_elements)
+ : max_size_(std::min(block_size << 1, max_elements)) {
+ QCHECK(max_size_ > 0) << "Invalid buffer specification: (" << block_size
+ << ", " << max_elements << ")";
+ vec_.reserve(max_size_);
+ }
+
+ // Disallow copying as it's semantically non-sensical in the Squawd algorithm
+ // but enable move semantics.
+ WeightedQuantilesBuffer(const WeightedQuantilesBuffer& other) = delete;
+ WeightedQuantilesBuffer& operator=(const WeightedQuantilesBuffer&) = delete;
+ WeightedQuantilesBuffer(WeightedQuantilesBuffer&& other) = default;
+ WeightedQuantilesBuffer& operator=(WeightedQuantilesBuffer&& other) = default;
+
+ // Push entry to buffer and maintain a compact representation within
+ // pre-defined size limit.
+ void PushEntry(ValueType value, WeightType weight) {
+ // Callers are expected to act on a full compacted buffer after the
+ // PushEntry call returns.
+ QCHECK(!IsFull()) << "Buffer already full: " << max_size_;
+
+ // Ignore zero and negative weight entries.
+ if (weight <= 0) {
+ return;
+ }
+
+ // Push back the entry to the buffer.
+ vec_.push_back(BufferEntry(std::move(value), std::move(weight)));
+ }
+
+ // Returns a sorted vector view of the base buffer and clears the buffer.
+ // Callers should minimize how often this is called, ideally only right after
+ // the buffer becomes full.
+ std::vector<BufferEntry> GenerateEntryList() {
+ std::vector<BufferEntry> ret;
+ if (vec_.size() == 0) {
+ return ret;
+ }
+ ret.swap(vec_);
+ vec_.reserve(max_size_);
+ std::sort(ret.begin(), ret.end());
+ size_t num_entries = 0;
+ for (size_t i = 1; i < ret.size(); ++i) {
+ if (ret[i].value != ret[i - 1].value) {
+ BufferEntry tmp = ret[i];
+ ++num_entries;
+ ret[num_entries] = tmp;
+ } else {
+ ret[num_entries].weight += ret[i].weight;
+ }
+ }
+ ret.resize(num_entries + 1);
+ return ret;
+ }
+
+ int64 Size() const { return vec_.size(); }
+ bool IsFull() const { return vec_.size() >= max_size_; }
+ void Clear() { vec_.clear(); }
+
+ private:
+ using BufferVector = typename std::vector<BufferEntry>;
+
+ // Comparison function.
+ static constexpr decltype(CompareFn()) kCompFn = CompareFn();
+
+ // Base buffer.
+ size_t max_size_;
+ BufferVector vec_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+constexpr decltype(CompareFn())
+ WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>::kCompFn;
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_BUFFER_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc
new file mode 100644
index 0000000000..75f05d64f3
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer_test.cc
@@ -0,0 +1,99 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using Buffer =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>;
+using BufferEntry =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double,
+ double>::BufferEntry;
+
+class WeightedQuantilesBufferTest : public ::testing::Test {};
+
+TEST_F(WeightedQuantilesBufferTest, Invalid) {
+ EXPECT_DEATH(
+ ({
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
+ buffer(2, 0);
+ }),
+ "Invalid buffer specification");
+ EXPECT_DEATH(
+ ({
+ boosted_trees::quantiles::WeightedQuantilesBuffer<double, double>
+ buffer(0, 2);
+ }),
+ "Invalid buffer specification");
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryNotFull) {
+ Buffer buffer(20, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(3, 0); // This entry will be ignored.
+
+ EXPECT_FALSE(buffer.IsFull());
+ EXPECT_EQ(buffer.Size(), 3);
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryFull) {
+ // buffer capacity is 4.
+ Buffer buffer(2, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(2, 1);
+
+ std::vector<BufferEntry> expected;
+ expected.emplace_back(-1, 7);
+ expected.emplace_back(2, 4);
+ expected.emplace_back(5, 9);
+
+ // At this point, we have pushed 4 entries and we expect the buffer to be
+ // full.
+ EXPECT_TRUE(buffer.IsFull());
+ EXPECT_EQ(buffer.GenerateEntryList(), expected);
+ EXPECT_FALSE(buffer.IsFull());
+}
+
+TEST_F(WeightedQuantilesBufferTest, PushEntryFullDeath) {
+ // buffer capacity is 4.
+ Buffer buffer(2, 100);
+ buffer.PushEntry(5, 9);
+ buffer.PushEntry(2, 3);
+ buffer.PushEntry(-1, 7);
+ buffer.PushEntry(2, 1);
+
+ std::vector<BufferEntry> expected;
+ expected.emplace_back(-1, 7);
+ expected.emplace_back(2, 4);
+ expected.emplace_back(5, 9);
+
+ // At this point, we have pushed 4 entries and we expect the buffer to be
+ // full.
+ EXPECT_TRUE(buffer.IsFull());
+ // Can't push any more entries before clearing.
+ EXPECT_DEATH(({ buffer.PushEntry(6, 6); }), "Buffer already full");
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h
new file mode 100644
index 0000000000..525e2a6a64
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h
@@ -0,0 +1,330 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
+
+#include <cmath>
+#include <memory>
+#include <vector>
+
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Class to compute approximate quantiles with error bound guarantees for
+// weighted data sets.
+// This implementation is an adaptation of techniques from the following papers:
+// * (2001) Space-efficient online computation of quantile summaries.
+// * (2004) Power-conserving computation of order-statistics over
+// sensor networks.
+// * (2007) A fast algorithm for approximate quantiles in high speed
+// data streams.
+// * (2016) XGBoost: A Scalable Tree Boosting System.
+//
+// The key ideas at play are the following:
+// - Maintain an in-memory multi-level quantile summary in a way to guarantee
+// a maximum approximation error of eps * W per bucket where W is the total
+// weight across all points in the input dataset.
+// - Two base operations are defined: MERGE and COMPRESS. MERGE combines two
+// summaries guaranteeing a epsNew = max(eps1, eps2). COMPRESS compresses
+// a summary to b + 1 elements guaranteeing epsNew = epsOld + 1/b.
+// - b * sizeof(summary entry) must ideally be small enough to fit in an
+// average CPU L2 cache.
+// - To distribute this algorithm with maintaining error bounds, we need
+// the worker-computed summaries to have no more than eps / h error
+// where h is the height of the distributed computation graph which
+// is 2 for an MR with no combiner.
+//
+// We mainly want to max out IO bw by ensuring we're not compute-bound and
+// using a reasonable amount of RAM.
+//
+// Complexity:
+// Compute: O(n * log(1/eps * log(eps * n))).
+// Memory: O(1/eps * log^2(eps * n)) <- for one worker streaming through the
+// entire dataset.
+// An epsilon value of zero would make the algorithm extremely inefficent and
+// therefore, is disallowed.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesStream {
+ public:
+ using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
+ using BufferEntry = typename Buffer::BufferEntry;
+ using Summary = WeightedQuantilesSummary<ValueType, WeightType, CompareFn>;
+ using SummaryEntry = typename Summary::SummaryEntry;
+
+ explicit WeightedQuantilesStream(double eps, int64 max_elements)
+ : eps_(eps), buffer_(1LL, 2LL), finalized_(false) {
+ // See the class documentation. An epsilon value of zero could cause
+ // perfoamance issues.
+ QCHECK(eps > 0) << "An epsilon value of zero is not allowed.";
+ std::tie(max_levels_, block_size_) = GetQuantileSpecs(eps, max_elements);
+ buffer_ = Buffer(block_size_, max_elements);
+ summary_levels_.reserve(max_levels_);
+ }
+
+ // Disallow copy and assign but enable move semantics for the stream.
+ WeightedQuantilesStream(const WeightedQuantilesStream& other) = delete;
+ WeightedQuantilesStream& operator=(const WeightedQuantilesStream&) = delete;
+ WeightedQuantilesStream(WeightedQuantilesStream&& other) = default;
+ WeightedQuantilesStream& operator=(WeightedQuantilesStream&& other) = default;
+
+ // Pushes one entry while maintaining approximation error invariants.
+ void PushEntry(const ValueType& value, const WeightType& weight) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Push element to base buffer.
+ buffer_.PushEntry(value, weight);
+
+ // When compacted buffer is full we need to compress
+ // and push weighted quantile summary up the level chain.
+ if (buffer_.IsFull()) {
+ PushBuffer(buffer_);
+ }
+ }
+
+ // Pushes full buffer while maintaining approximation error invariants.
+ void PushBuffer(Buffer& buffer) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Create local compressed summary and propagate.
+ local_summary_.BuildFromBufferEntries(buffer.GenerateEntryList());
+ local_summary_.Compress(block_size_, eps_);
+ PropagateLocalSummary();
+ }
+
+ // Pushes full summary while maintaining approximation error invariants.
+ void PushSummary(const std::vector<SummaryEntry>& summary) {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // Create local compressed summary and propagate.
+ local_summary_.BuildFromSummaryEntries(summary);
+ local_summary_.Compress(block_size_, eps_);
+ PropagateLocalSummary();
+ }
+
+ // Flushes approximator and finalizes state.
+ void Finalize() {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() may only be called once.";
+
+ // Flush any remaining buffer elements.
+ PushBuffer(buffer_);
+
+ // Create final merged summary.
+ local_summary_.Clear();
+ for (auto& summary : summary_levels_) {
+ local_summary_.Merge(summary);
+ summary.Clear();
+ }
+ summary_levels_.clear();
+ summary_levels_.shrink_to_fit();
+ finalized_ = true;
+ }
+
+ // Generates requested number of quantiles after finalizing stream.
+ // The returned quantiles can be queried using std::lower_bound to get
+ // the bucket for a given value.
+ std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before generating quantiles.";
+ return local_summary_.GenerateQuantiles(num_quantiles);
+ }
+
+ // Generates requested number of boundaries after finalizing stream.
+ // The returned boundaries can be queried using std::lower_bound to get
+ // the bucket for a given value.
+ // The boundaries, while still guaranteeing approximation bounds, don't
+ // necessarily represent the actual quantiles of the distribution.
+ // Boundaries are preferable over quantiles when the caller is less
+ // interested in the actual quantiles distribution and more interested in
+ // getting a representative sample of boundary values.
+ std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before generating boundaries.";
+ return local_summary_.GenerateBoundaries(num_boundaries);
+ }
+
+ // Calculates approximation error for the specified level.
+ // If the passed level is negative, the approximation error for the entire
+ // summary is returned. Note that after Finalize is called, only the overall
+ // error is available.
+ WeightType ApproximationError(int64 level = -1) const {
+ if (finalized_) {
+ QCHECK(level <= 0) << "Only overall error is available after Finalize()";
+ return local_summary_.ApproximationError();
+ }
+
+ if (summary_levels_.empty()) {
+ // No error even if base buffer isn't empty.
+ return 0;
+ }
+
+ // If level is negative, we get the approximation error
+ // for the top-most level which is the max approximation error
+ // in all summaries by construction.
+ if (level < 0) {
+ level = summary_levels_.size() - 1;
+ }
+ QCHECK(level < summary_levels_.size()) << "Invalid level.";
+ return summary_levels_[level].ApproximationError();
+ }
+
+ size_t MaxDepth() const { return summary_levels_.size(); }
+
+ // Generates requested number of quantiles after finalizing stream.
+ const Summary& GetFinalSummary() const {
+ // Validate state.
+ QCHECK(finalized_)
+ << "Finalize() must be called before requesting final summary.";
+ return local_summary_;
+ }
+
+ // Helper method which, given the desired approximation error
+ // and an upper bound on the number of elements, computes the optimal
+ // number of levels and block size and returns them in the tuple.
+ static std::tuple<int64, int64> GetQuantileSpecs(double eps,
+ int64 max_elements);
+
+ // Serializes the internal state of the stream.
+ std::vector<Summary> SerializeInternalSummaries() const {
+ // The buffer should be empty for serialize to work.
+ QCHECK_EQ(buffer_.Size(), 0);
+ std::vector<Summary> result;
+ result.reserve(summary_levels_.size() + 1);
+ for (const Summary& summary : summary_levels_) {
+ result.push_back(summary);
+ }
+ result.push_back(local_summary_);
+ return result;
+ }
+
+ // Resets the state of the stream with a serialized state.
+ void DeserializeInternalSummaries(const std::vector<Summary>& summaries) {
+ // Clear the state before deserializing.
+ buffer_.Clear();
+ summary_levels_.clear();
+ local_summary_.Clear();
+ QCHECK_GT(max_levels_, summaries.size() - 1);
+ for (int i = 0; i < summaries.size() - 1; ++i) {
+ summary_levels_.push_back(summaries[i]);
+ }
+ local_summary_ = summaries[summaries.size() - 1];
+ }
+
+ private:
+ // Propagates local summary through summary levels while maintaining
+ // approximation error invariants.
+ void PropagateLocalSummary() {
+ // Validate state.
+ QCHECK(!finalized_) << "Finalize() already called.";
+
+ // No-op if there's nothing to add.
+ if (local_summary_.Size() <= 0) {
+ return;
+ }
+
+ // Propagate summary through levels.
+ size_t level = 0;
+ for (bool settled = false; !settled; ++level) {
+ // Ensure we have enough depth.
+ if (summary_levels_.size() <= level) {
+ summary_levels_.emplace_back();
+ }
+
+ // Merge summaries.
+ Summary& current_summary = summary_levels_[level];
+ local_summary_.Merge(current_summary);
+
+ // Check if we need to compress and propagate summary higher.
+ if (current_summary.Size() == 0 ||
+ local_summary_.Size() <= block_size_ + 1) {
+ current_summary = std::move(local_summary_);
+ settled = true;
+ } else {
+ // Compress, empty current level and propagate.
+ local_summary_.Compress(block_size_, eps_);
+ current_summary.Clear();
+ }
+ }
+ }
+
+ // Desired approximation precision.
+ double eps_;
+ // Maximum number of levels.
+ int64 max_levels_;
+ // Max block size per level.
+ int64 block_size_;
+ // Base buffer.
+ Buffer buffer_;
+ // Local summary used to minimize memory allocation and cache misses.
+ // After the stream is finalized, this summary holds the final quantile
+ // estimates.
+ Summary local_summary_;
+ // Summary levels;
+ std::vector<Summary> summary_levels_;
+ // Flag indicating whether the stream is finalized.
+ bool finalized_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+inline std::tuple<int64, int64>
+WeightedQuantilesStream<ValueType, WeightType, CompareFn>::GetQuantileSpecs(
+ double eps, int64 max_elements) {
+ int64 max_level = 1LL;
+ int64 block_size = 2LL;
+ QCHECK(eps >= 0 && eps < 1);
+ QCHECK_GT(max_elements, 0);
+
+ if (eps <= std::numeric_limits<double>::epsilon()) {
+ // Exact quantile computation at the expense of RAM.
+ max_level = 1;
+ block_size = std::max(max_elements, int64{2});
+ } else {
+ // The bottom-most level will become full at most
+ // (max_elements / block_size) times, the level above will become full
+ // (max_elements / 2 * block_size) times and generally level l becomes
+ // full (max_elements / 2^l * block_size) times until the last
+ // level max_level becomes full at most once meaning when the inequality
+ // (2^max_level * block_size >= max_elements) is satisfied.
+ // In what follows, we jointly solve for max_level and block_size by
+ // gradually increasing the level until the inequality above is satisfied.
+ // We could alternatively set max_level = ceil(log2(eps * max_elements));
+ // and block_size = ceil(max_level / eps) + 1 but that tends to give more
+ // pessimistic bounds and wastes RAM needlessly.
+ for (max_level = 1, block_size = 2;
+ (1LL << max_level) * block_size < max_elements; ++max_level) {
+ // Update upper bound on block size at current level, we always
+ // increase the estimate by 2 to hold the min/max elements seen so far.
+ block_size = static_cast<size_t>(ceil(max_level / eps)) + 1;
+ }
+ }
+ return std::make_tuple(max_level, std::max(block_size, int64{2}));
+}
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_STREAM_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc
new file mode 100644
index 0000000000..6c5b9fd23b
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream_test.cc
@@ -0,0 +1,276 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_stream.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+using Tuple = std::tuple<int64, int64>;
+
+using Summary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<double, double>;
+using SummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<double,
+ double>::SummaryEntry;
+using Stream =
+ boosted_trees::quantiles::WeightedQuantilesStream<double, double>;
+
+TEST(GetQuantileSpecs, InvalidEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(-0.01, 0L); }, "eps >= 0");
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(1.01, 0L); }, "eps < 1");
+}
+
+TEST(GetQuantileSpecs, ZeroEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(0.0, 0L); }, "max_elements > 0");
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.0, 1LL), Tuple(1LL, 2LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.0, 20LL), Tuple(1LL, 20LL));
+}
+
+TEST(GetQuantileSpecs, NonZeroEps) {
+ EXPECT_DEATH({ Stream::GetQuantileSpecs(0.01, 0L); }, "max_elements > 0");
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.1, 320LL), Tuple(4LL, 31LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 25600LL), Tuple(6LL, 501LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 104857600LL), Tuple(17LL, 1601LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.1, 104857600LL), Tuple(20LL, 191LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.01, 1LL << 40), Tuple(29LL, 2801LL));
+ EXPECT_EQ(Stream::GetQuantileSpecs(0.001, 1LL << 40), Tuple(26LL, 25001LL));
+}
+
+class WeightedQuantilesStreamTest : public ::testing::Test {};
+
+// Stream generators.
+void GenerateFixedUniformSummary(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = static_cast<double>(i) / max_elements;
+ stream->PushEntry(x, 1.0);
+ ++(*total_weight);
+ }
+ stream->Finalize();
+}
+
+void GenerateFixedNonUniformSummary(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = static_cast<double>(i) / max_elements;
+ stream->PushEntry(x, x);
+ (*total_weight) += x;
+ }
+ stream->Finalize();
+}
+
+void GenerateRandUniformFixedWeightsSummary(int32 worker_id, int64 max_elements,
+ double *total_weight,
+ Stream *stream) {
+ // Simulate uniform distribution stream.
+ random::PhiloxRandom philox(13 + worker_id);
+ random::SimplePhilox rand(&philox);
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = rand.RandDouble();
+ stream->PushEntry(x, 1);
+ ++(*total_weight);
+ }
+ stream->Finalize();
+}
+
+void GenerateRandUniformRandWeightsSummary(int32 worker_id, int64 max_elements,
+ double *total_weight,
+ Stream *stream) {
+ // Simulate uniform distribution stream.
+ random::PhiloxRandom philox(13 + worker_id);
+ random::SimplePhilox rand(&philox);
+ for (int64 i = 0; i < max_elements; ++i) {
+ const double x = rand.RandDouble();
+ const double w = rand.RandDouble();
+ stream->PushEntry(x, w);
+ (*total_weight) += w;
+ }
+ stream->Finalize();
+}
+
+// Single worker tests.
+void TestSingleWorkerStreams(
+ double eps, int64 max_elements,
+ const std::function<void(int32, int64, double *, Stream *)>
+ &worker_summary_generator,
+ std::initializer_list<double> expected_quantiles,
+ double quantiles_matcher_epsilon) {
+ // Generate single stream.
+ double total_weight = 0;
+ Stream stream(eps, max_elements);
+ worker_summary_generator(0, max_elements, &total_weight, &stream);
+
+ // Ensure we didn't lose track of any elements and are
+ // within approximation error bound.
+ EXPECT_LE(stream.ApproximationError(), eps);
+ EXPECT_NEAR(stream.GetFinalSummary().TotalWeight(), total_weight, 1e-6);
+
+ // Verify expected quantiles.
+ int i = 0;
+ auto actuals = stream.GenerateQuantiles(expected_quantiles.size() - 1);
+ for (auto expected_quantile : expected_quantiles) {
+ EXPECT_NEAR(actuals[i], expected_quantile, quantiles_matcher_epsilon);
+ ++i;
+ }
+}
+
+// Stream generators.
+void GenerateOneValue(int32 worker_id, int64 max_elements, double *total_weight,
+ Stream *stream) {
+ stream->PushEntry(10, 1);
+ ++(*total_weight);
+ stream->Finalize();
+}
+
+void GenerateOneZeroWeightedValue(int32 worker_id, int64 max_elements,
+ double *total_weight, Stream *stream) {
+ stream->PushEntry(10, 0);
+ stream->Finalize();
+}
+
+TEST(WeightedQuantilesStreamTest, OneValue) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateOneValue,
+ {10.0, 10.0, 10.0, 10.0, 10.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, OneZeroWeightValue) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateOneZeroWeightedValue, {},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedUniform) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateFixedUniformSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedNonUniform) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(eps, max_elements, GenerateFixedNonUniformSummary,
+ {0, std::sqrt(0.1), std::sqrt(0.2), std::sqrt(0.3),
+ std::sqrt(0.4), std::sqrt(0.5), std::sqrt(0.6),
+ std::sqrt(0.7), std::sqrt(0.8), std::sqrt(0.9), 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformFixedWeights) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(
+ eps, max_elements, GenerateRandUniformFixedWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformRandWeights) {
+ const double eps = 0.01;
+ const int64 max_elements = 1 << 16;
+ TestSingleWorkerStreams(
+ eps, max_elements, GenerateRandUniformRandWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+// Distributed tests.
+void TestDistributedStreams(
+ int32 num_workers, double eps, int64 max_elements,
+ const std::function<void(int32, int64, double *, Stream *)>
+ &worker_summary_generator,
+ std::initializer_list<double> expected_quantiles,
+ double quantiles_matcher_epsilon) {
+ // Simulate streams on each worker running independently
+ double total_weight = 0;
+ std::vector<std::vector<SummaryEntry>> worker_summaries;
+ for (int32 i = 0; i < num_workers; ++i) {
+ Stream stream(eps / 2, max_elements);
+ worker_summary_generator(i, max_elements / num_workers, &total_weight,
+ &stream);
+ worker_summaries.push_back(stream.GetFinalSummary().GetEntryList());
+ }
+
+ // In the accumulation phase, we aggregate the summaries from each worker
+ // and build an overall summary while maintaining error bounds by ensuring we
+ // don't increase the error by more than eps / 2.
+ Stream reducer_stream(eps, max_elements);
+ for (const auto &summary : worker_summaries) {
+ reducer_stream.PushSummary(summary);
+ }
+ reducer_stream.Finalize();
+
+ // Ensure we didn't lose track of any elements and are
+ // within approximation error bound.
+ EXPECT_LE(reducer_stream.ApproximationError(), eps);
+ EXPECT_NEAR(reducer_stream.GetFinalSummary().TotalWeight(), total_weight,
+ total_weight);
+
+ // Verify expected quantiles.
+ int i = 0;
+ auto actuals =
+ reducer_stream.GenerateQuantiles(expected_quantiles.size() - 1);
+ for (auto expected_quantile : expected_quantiles) {
+ EXPECT_NEAR(actuals[i], expected_quantile, quantiles_matcher_epsilon);
+ ++i;
+ }
+}
+
+TEST(WeightedQuantilesStreamTest, FixedUniformDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateFixedUniformSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, FixedNonUniformDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(num_workers, eps, max_elements,
+ GenerateFixedNonUniformSummary,
+ {0, std::sqrt(0.1), std::sqrt(0.2), std::sqrt(0.3),
+ std::sqrt(0.4), std::sqrt(0.5), std::sqrt(0.6),
+ std::sqrt(0.7), std::sqrt(0.8), std::sqrt(0.9), 1.0},
+ 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformFixedWeightsDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateRandUniformFixedWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+TEST(WeightedQuantilesStreamTest, RandUniformRandWeightsDistributed) {
+ const int32 num_workers = 10;
+ const double eps = 0.01;
+ const int64 max_elements = num_workers * (1 << 16);
+ TestDistributedStreams(
+ num_workers, eps, max_elements, GenerateRandUniformRandWeightsSummary,
+ {0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0}, 1e-2);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h
new file mode 100644
index 0000000000..31d7fe25a4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h
@@ -0,0 +1,344 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#ifndef TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+#define TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
+
+#include <cstring>
+#include <vector>
+
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_buffer.h"
+
+namespace tensorflow {
+namespace boosted_trees {
+namespace quantiles {
+
+// Summary holding a sorted block of entries with upper bound guarantees
+// over the approximation error.
+template <typename ValueType, typename WeightType,
+ typename CompareFn = std::less<ValueType>>
+class WeightedQuantilesSummary {
+ public:
+ using Buffer = WeightedQuantilesBuffer<ValueType, WeightType, CompareFn>;
+ using BufferEntry = typename Buffer::BufferEntry;
+
+ struct SummaryEntry {
+ SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min,
+ const WeightType& max) {
+ // Explicitly initialize all of memory (including padding from memory
+ // alignment) to allow the struct to be msan-resistant "plain old data".
+ //
+ // POD = http://en.cppreference.com/w/cpp/concept/PODType
+ memset(this, 0, sizeof(*this));
+
+ value = v;
+ weight = w;
+ min_rank = min;
+ max_rank = max;
+ }
+
+ SummaryEntry() {
+ memset(this, 0, sizeof(*this));
+
+ value = ValueType();
+ weight = 0;
+ min_rank = 0;
+ max_rank = 0;
+ }
+
+ bool operator==(const SummaryEntry& other) const {
+ return value == other.value && weight == other.weight &&
+ min_rank == other.min_rank && max_rank == other.max_rank;
+ }
+ friend std::ostream& operator<<(std::ostream& strm,
+ const SummaryEntry& entry) {
+ return strm << "{" << entry.value << ", " << entry.weight << ", "
+ << entry.min_rank << ", " << entry.max_rank << "}";
+ }
+
+ // Max rank estimate for previous smaller value.
+ WeightType PrevMaxRank() const { return max_rank - weight; }
+
+ // Min rank estimate for next larger value.
+ WeightType NextMinRank() const { return min_rank + weight; }
+
+ ValueType value;
+ WeightType weight;
+ WeightType min_rank;
+ WeightType max_rank;
+ };
+
+ // Re-construct summary from the specified buffer.
+ void BuildFromBufferEntries(const std::vector<BufferEntry>& buffer_entries) {
+ entries_.clear();
+ entries_.reserve(buffer_entries.size());
+ WeightType cumulative_weight = 0;
+ for (const auto& entry : buffer_entries) {
+ WeightType current_weight = entry.weight;
+ entries_.emplace_back(entry.value, entry.weight, cumulative_weight,
+ cumulative_weight + current_weight);
+ cumulative_weight += current_weight;
+ }
+ }
+
+ // Re-construct summary from the specified summary entries.
+ void BuildFromSummaryEntries(
+ const std::vector<SummaryEntry>& summary_entries) {
+ entries_.clear();
+ entries_.reserve(summary_entries.size());
+ entries_.insert(entries_.begin(), summary_entries.begin(),
+ summary_entries.end());
+ }
+
+ // Merges two summaries through an algorithm that's derived from MergeSort
+ // for summary entries while guaranteeing that the max approximation error
+ // of the final merged summary is no greater than the approximation errors
+ // of each individual summary.
+ // For example consider summaries where each entry is of the form
+ // (element, weight, min rank, max rank):
+ // summary entries 1: (1, 3, 0, 3), (4, 2, 3, 5)
+ // summary entries 2: (3, 1, 0, 1), (4, 1, 1, 2)
+ // merged: (1, 3, 0, 3), (3, 1, 3, 4), (4, 3, 4, 7).
+ void Merge(const WeightedQuantilesSummary& other_summary) {
+ // Make sure we have something to merge.
+ const auto& other_entries = other_summary.entries_;
+ if (other_entries.empty()) {
+ return;
+ }
+ if (entries_.empty()) {
+ BuildFromSummaryEntries(other_summary.entries_);
+ return;
+ }
+
+ // Move current entries to make room for a new buffer.
+ std::vector<SummaryEntry> base_entries(std::move(entries_));
+ entries_.clear();
+ entries_.reserve(base_entries.size() + other_entries.size());
+
+ // Merge entries maintaining ranks. The idea is to stack values
+ // in order which we can do in linear time as the two summaries are
+ // already sorted. We keep track of the next lower rank from either
+ // summary and update it as we pop elements from the summaries.
+ // We handle the special case when the next two elements from either
+ // summary are equal, in which case we just merge the two elements
+ // and simultaneously update both ranks.
+ auto it1 = base_entries.cbegin();
+ auto it2 = other_entries.cbegin();
+ WeightType next_min_rank1 = 0;
+ WeightType next_min_rank2 = 0;
+ while (it1 != base_entries.cend() && it2 != other_entries.cend()) {
+ if (kCompFn(it1->value, it2->value)) { // value1 < value2
+ // Take value1 and use the last added value2 to compute
+ // the min rank and the current value2 to compute the max rank.
+ entries_.emplace_back(it1->value, it1->weight,
+ it1->min_rank + next_min_rank2,
+ it1->max_rank + it2->PrevMaxRank());
+ // Update next min rank 1.
+ next_min_rank1 = it1->NextMinRank();
+ ++it1;
+ } else if (kCompFn(it2->value, it1->value)) { // value1 > value2
+ // Take value2 and use the last added value1 to compute
+ // the min rank and the current value1 to compute the max rank.
+ entries_.emplace_back(it2->value, it2->weight,
+ it2->min_rank + next_min_rank1,
+ it2->max_rank + it1->PrevMaxRank());
+ // Update next min rank 2.
+ next_min_rank2 = it2->NextMinRank();
+ ++it2;
+ } else { // value1 == value2
+ // Straight additive merger of the two entries into one.
+ entries_.emplace_back(it1->value, it1->weight + it2->weight,
+ it1->min_rank + it2->min_rank,
+ it1->max_rank + it2->max_rank);
+ // Update next min ranks for both.
+ next_min_rank1 = it1->NextMinRank();
+ next_min_rank2 = it2->NextMinRank();
+ ++it1;
+ ++it2;
+ }
+ }
+
+ // Fill in any residual.
+ while (it1 != base_entries.cend()) {
+ entries_.emplace_back(it1->value, it1->weight,
+ it1->min_rank + next_min_rank2,
+ it1->max_rank + other_entries.back().max_rank);
+ ++it1;
+ }
+ while (it2 != other_entries.cend()) {
+ entries_.emplace_back(it2->value, it2->weight,
+ it2->min_rank + next_min_rank1,
+ it2->max_rank + base_entries.back().max_rank);
+ ++it2;
+ }
+ }
+
+ // Compresses buffer into desired size. The size specification is
+ // considered a hint as we always keep the first and last elements and
+ // maintain strict approximation error bounds.
+ // The approximation error delta is taken as the max of either the requested
+ // min error or 1 / size_hint.
+ // After compression, the approximation error is guaranteed to increase
+ // by no more than that error delta.
+ // This algorithm is linear in the original size of the summary and is
+ // designed to be cache-friendly.
+ void Compress(int64 size_hint, double min_eps = 0) {
+ // No-op if we're already within the size requirement.
+ size_hint = std::max(size_hint, int64{2});
+ if (entries_.size() <= size_hint) {
+ return;
+ }
+
+ // First compute the max error bound delta resulting from this compression.
+ double eps_delta = TotalWeight() * std::max(1.0 / size_hint, min_eps);
+
+ // Compress elements ensuring approximation bounds and elements diversity
+ // are both maintained.
+ int64 add_accumulator = 0, add_step = entries_.size();
+ auto write_it = entries_.begin() + 1, last_it = write_it;
+ for (auto read_it = entries_.begin(); read_it + 1 != entries_.end();) {
+ auto next_it = read_it + 1;
+ while (next_it != entries_.end() && add_accumulator < add_step &&
+ next_it->PrevMaxRank() - read_it->NextMinRank() <= eps_delta) {
+ add_accumulator += size_hint;
+ ++next_it;
+ }
+ if (read_it == next_it - 1) {
+ ++read_it;
+ } else {
+ read_it = next_it - 1;
+ }
+ (*write_it++) = (*read_it);
+ last_it = read_it;
+ add_accumulator -= add_step;
+ }
+ // Write last element and resize.
+ if (last_it + 1 != entries_.end()) {
+ (*write_it++) = entries_.back();
+ }
+ entries_.resize(write_it - entries_.begin());
+ }
+
+ // To construct the boundaries we first run a soft compress over a copy
+ // of the summary and retrieve the values.
+ // The resulting boundaries are guaranteed to both contain at least
+ // num_boundaries unique elements and maintain approximation bounds.
+ std::vector<ValueType> GenerateBoundaries(int64 num_boundaries) const {
+ std::vector<ValueType> output;
+ if (entries_.empty()) {
+ return output;
+ }
+
+ // Generate soft compressed summary.
+ WeightedQuantilesSummary<ValueType, WeightType, CompareFn>
+ compressed_summary;
+ compressed_summary.BuildFromSummaryEntries(entries_);
+ // Set an epsilon for compression that's at most 1.0 / num_boundaries
+ // more than epsilon of original our summary since the compression operation
+ // adds ~1.0/num_boundaries to final approximation error.
+ float compression_eps = ApproximationError() + (1.0 / num_boundaries);
+ compressed_summary.Compress(num_boundaries, compression_eps);
+
+ // Return boundaries.
+ output.reserve(compressed_summary.entries_.size());
+ for (const auto& entry : compressed_summary.entries_) {
+ output.push_back(entry.value);
+ }
+ return output;
+ }
+
+ // To construct the desired n-quantiles we repetitively query n ranks from the
+ // original summary. The following algorithm is an efficient cache-friendly
+ // O(n) implementation of that idea which avoids the cost of the repetitive
+ // full rank queries O(nlogn).
+ std::vector<ValueType> GenerateQuantiles(int64 num_quantiles) const {
+ std::vector<ValueType> output;
+ if (entries_.empty()) {
+ return output;
+ }
+ num_quantiles = std::max(num_quantiles, int64{2});
+ output.reserve(num_quantiles + 1);
+
+ // Make successive rank queries to get boundaries.
+ // We always keep the first (min) and last (max) entries.
+ for (size_t cur_idx = 0, rank = 0; rank <= num_quantiles; ++rank) {
+ // This step boils down to finding the next element sub-range defined by
+ // r = (rmax[i + 1] + rmin[i + 1]) / 2 where the desired rank d < r.
+ WeightType d_2 = 2 * (rank * entries_.back().max_rank / num_quantiles);
+ size_t next_idx = cur_idx + 1;
+ while (next_idx < entries_.size() &&
+ d_2 >= entries_[next_idx].min_rank + entries_[next_idx].max_rank) {
+ ++next_idx;
+ }
+ cur_idx = next_idx - 1;
+
+ // Determine insertion order.
+ if (next_idx == entries_.size() ||
+ d_2 < entries_[cur_idx].NextMinRank() +
+ entries_[next_idx].PrevMaxRank()) {
+ output.push_back(entries_[cur_idx].value);
+ } else {
+ output.push_back(entries_[next_idx].value);
+ }
+ }
+ return output;
+ }
+
+ // Calculates current approximation error which should always be <= eps.
+ double ApproximationError() const {
+ if (entries_.empty()) {
+ return 0;
+ }
+
+ WeightType max_gap = 0;
+ for (auto it = entries_.cbegin() + 1; it < entries_.end(); ++it) {
+ max_gap = std::max(max_gap,
+ std::max(it->max_rank - it->min_rank - it->weight,
+ it->PrevMaxRank() - (it - 1)->NextMinRank()));
+ }
+ return static_cast<double>(max_gap) / TotalWeight();
+ }
+
+ ValueType MinValue() const {
+ return !entries_.empty() ? entries_.front().value
+ : std::numeric_limits<ValueType>::max();
+ }
+ ValueType MaxValue() const {
+ return !entries_.empty() ? entries_.back().value
+ : std::numeric_limits<ValueType>::max();
+ }
+ WeightType TotalWeight() const {
+ return !entries_.empty() ? entries_.back().max_rank : 0;
+ }
+ int64 Size() const { return entries_.size(); }
+ void Clear() { entries_.clear(); }
+ const std::vector<SummaryEntry>& GetEntryList() const { return entries_; }
+
+ private:
+ // Comparison function.
+ static constexpr decltype(CompareFn()) kCompFn = CompareFn();
+
+ // Summary entries.
+ std::vector<SummaryEntry> entries_;
+};
+
+template <typename ValueType, typename WeightType, typename CompareFn>
+constexpr decltype(CompareFn())
+ WeightedQuantilesSummary<ValueType, WeightType, CompareFn>::kCompFn;
+
+} // namespace quantiles
+} // namespace boosted_trees
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_BOOSTED_TREES_QUANTILES_WEIGHTED_QUANTILES_SUMMARY_H_
diff --git a/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc
new file mode 100644
index 0000000000..ccd1215cf4
--- /dev/null
+++ b/tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary_test.cc
@@ -0,0 +1,223 @@
+// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+#include "tensorflow/core/kernels/boosted_trees/quantiles/weighted_quantiles_summary.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+namespace {
+
+using Buffer = boosted_trees::quantiles::WeightedQuantilesBuffer<float, float>;
+using BufferEntry =
+ boosted_trees::quantiles::WeightedQuantilesBuffer<float,
+ float>::BufferEntry;
+using Summary =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float, float>;
+using SummaryEntry =
+ boosted_trees::quantiles::WeightedQuantilesSummary<float,
+ float>::SummaryEntry;
+
+class WeightedQuantilesSummaryTest : public ::testing::Test {
+ protected:
+ void SetUp() override {
+ // Constructs a buffer of 10 weighted unique entries.
+ buffer1_.reset(new Buffer(10, 1000));
+ buffer1_->PushEntry(5, 9);
+ buffer1_->PushEntry(2, 3);
+ buffer1_->PushEntry(-1, 7);
+ buffer1_->PushEntry(-7, 1);
+ buffer1_->PushEntry(3, 2);
+ buffer1_->PushEntry(-2, 3);
+ buffer1_->PushEntry(21, 8);
+ buffer1_->PushEntry(-13, 4);
+ buffer1_->PushEntry(8, 2);
+ buffer1_->PushEntry(-5, 6);
+
+ // Constructs a buffer of 7 weighted unique entries.
+ buffer2_.reset(new Buffer(7, 1000));
+ buffer2_->PushEntry(9, 2);
+ buffer2_->PushEntry(-7, 3);
+ buffer2_->PushEntry(2, 1);
+ buffer2_->PushEntry(4, 13);
+ buffer2_->PushEntry(0, 5);
+ buffer2_->PushEntry(-5, 3);
+ buffer2_->PushEntry(11, 3);
+ }
+
+ void TearDown() override { buffer1_->Clear(); }
+
+ std::unique_ptr<Buffer> buffer1_;
+ std::unique_ptr<Buffer> buffer2_;
+ const double buffer1_min_value_ = -13;
+ const double buffer1_max_value_ = 21;
+ const double buffer1_total_weight_ = 45;
+ const double buffer2_min_value_ = -7;
+ const double buffer2_max_value_ = 11;
+ const double buffer2_total_weight_ = 30;
+};
+
+TEST_F(WeightedQuantilesSummaryTest, BuildFromBuffer) {
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+
+ // We expect no approximation error because no compress operation occurred.
+ EXPECT_EQ(summary.ApproximationError(), 0);
+
+ // Check first and last elements in the summary.
+ const auto& entries = summary.GetEntryList();
+ // First element's rmin should be zero.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(entries.front(), SummaryEntry(-13, 4, 0, 4));
+ // Last element's rmax should be cumulative weight.
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(entries.back(), SummaryEntry(21, 8, 37, 45));
+ // Check total weight.
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressSeparately) {
+ const auto entry_list = buffer1_->GenerateEntryList();
+ for (int new_size = 9; new_size >= 2; --new_size) {
+ Summary summary;
+ summary.BuildFromBufferEntries(entry_list);
+ summary.Compress(new_size);
+
+ // Expect a max approximation error of 1 / n
+ // ie. eps0 + 1/n but eps0 = 0.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
+
+ // Min/Max elements and total weight should not change.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressSequentially) {
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+ for (int new_size = 9; new_size >= 2; new_size -= 2) {
+ double prev_eps = summary.ApproximationError();
+ summary.Compress(new_size);
+
+ // Expect a max approximation error of prev_eps + 1 / n.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), prev_eps + 1.0 / new_size);
+
+ // Min/Max elements and total weight should not change.
+ EXPECT_EQ(summary.MinValue(), buffer1_min_value_);
+ EXPECT_EQ(summary.MaxValue(), buffer1_max_value_);
+ EXPECT_EQ(summary.TotalWeight(), buffer1_total_weight_);
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressRandomized) {
+ // Check multiple size compressions and ensure approximation bounds
+ // are always respected.
+ int prev_size = 1;
+ int size = 2;
+ float max_value = 1 << 20;
+ while (size < (1 << 16)) {
+ // Create buffer of size from uniform random elements.
+ Buffer buffer(size, size << 4);
+ random::PhiloxRandom philox(13);
+ random::SimplePhilox rand(&philox);
+ for (int i = 0; i < size; ++i) {
+ buffer.PushEntry(rand.RandFloat() * max_value,
+ rand.RandFloat() * max_value);
+ }
+
+ // Create summary and compress.
+ Summary summary;
+ summary.BuildFromBufferEntries(buffer.GenerateEntryList());
+ int new_size = std::max(rand.Uniform(size), 2u);
+ summary.Compress(new_size);
+
+ // Ensure approximation error is acceptable.
+ EXPECT_TRUE(summary.Size() >= new_size && summary.Size() <= new_size + 2);
+ EXPECT_LE(summary.ApproximationError(), 1.0 / new_size);
+
+ // Update size to next fib number.
+ size_t last_size = size;
+ size += prev_size;
+ prev_size = last_size;
+ }
+}
+
+TEST_F(WeightedQuantilesSummaryTest, MergeSymmetry) {
+ // Create two separate summaries and merge.
+ const auto list_1 = buffer1_->GenerateEntryList();
+ const auto list_2 = buffer2_->GenerateEntryList();
+ Summary summary1;
+ summary1.BuildFromBufferEntries(list_1);
+ Summary summary2;
+ summary2.BuildFromBufferEntries(list_2);
+
+ // Merge summary 2 into 1 and verify.
+ summary1.Merge(summary2);
+ EXPECT_EQ(summary1.ApproximationError(), 0.0);
+ EXPECT_EQ(summary1.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary1.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary1.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+ EXPECT_EQ(summary1.Size(), 14); // 14 unique values.
+
+ // Merge summary 1 into 2 and verify same result.
+ summary1.BuildFromBufferEntries(list_1);
+ summary2.Merge(summary1);
+ EXPECT_EQ(summary2.ApproximationError(), 0.0);
+ EXPECT_EQ(summary2.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary2.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary2.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+ EXPECT_EQ(summary2.Size(), 14); // 14 unique values.
+}
+
+TEST_F(WeightedQuantilesSummaryTest, CompressThenMerge) {
+ // Create two separate summaries and merge.
+ Summary summary1;
+ summary1.BuildFromBufferEntries(buffer1_->GenerateEntryList());
+ Summary summary2;
+ summary2.BuildFromBufferEntries(buffer2_->GenerateEntryList());
+
+ // Compress summaries.
+ summary1.Compress(5); // max error is 1/5.
+ const auto eps1 = 1.0 / 5;
+ EXPECT_LE(summary1.ApproximationError(), eps1);
+ summary2.Compress(3); // max error is 1/3.
+ const auto eps2 = 1.0 / 3;
+ EXPECT_LE(summary2.ApproximationError(), eps2);
+
+ // Merge guarantees an approximation error of max(eps1, eps2).
+ // Merge summary 2 into 1 and verify.
+ summary1.Merge(summary2);
+ EXPECT_LE(summary1.ApproximationError(), std::max(eps1, eps2));
+ EXPECT_EQ(summary1.MinValue(),
+ std::min(buffer1_min_value_, buffer2_min_value_));
+ EXPECT_EQ(summary1.MaxValue(),
+ std::max(buffer1_max_value_, buffer2_max_value_));
+ EXPECT_EQ(summary1.TotalWeight(),
+ buffer1_total_weight_ + buffer2_total_weight_);
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/boosted_trees/resources.cc b/tensorflow/core/kernels/boosted_trees/resources.cc
index cc90bb2f45..2798722536 100644
--- a/tensorflow/core/kernels/boosted_trees/resources.cc
+++ b/tensorflow/core/kernels/boosted_trees/resources.cc
@@ -60,14 +60,26 @@ int32 BoostedTreesEnsembleResource::next_node(
DCHECK_LT(tree_id, tree_ensemble_->trees_size());
DCHECK_LT(node_id, tree_ensemble_->trees(tree_id).nodes_size());
const auto& node = tree_ensemble_->trees(tree_id).nodes(node_id);
- DCHECK_EQ(node.node_case(), boosted_trees::Node::kBucketizedSplit);
- const auto& split = node.bucketized_split();
- if (bucketized_features[split.feature_id()](index_in_batch) <=
- split.threshold()) {
- return split.left_id();
- } else {
- return split.right_id();
+
+ switch (node.node_case()) {
+ case boosted_trees::Node::kBucketizedSplit: {
+ const auto& split = node.bucketized_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) <=
+ split.threshold())
+ ? split.left_id()
+ : split.right_id();
+ }
+ case boosted_trees::Node::kCategoricalSplit: {
+ const auto& split = node.categorical_split();
+ return (bucketized_features[split.feature_id()](index_in_batch) ==
+ split.value())
+ ? split.left_id()
+ : split.right_id();
+ }
+ default:
+ DCHECK(false) << "Node type " << node.node_case() << " not supported.";
}
+ return -1;
}
float BoostedTreesEnsembleResource::node_value(const int32 tree_id,
diff --git a/tensorflow/core/kernels/bounds_check.h b/tensorflow/core/kernels/bounds_check.h
index c8c60c5524..18727c0db3 100644
--- a/tensorflow/core/kernels/bounds_check.h
+++ b/tensorflow/core/kernels/bounds_check.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_BOUNDS_CHECK_H_
-#define TENSORFLOW_UTIL_BOUNDS_CHECK_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BOUNDS_CHECK_H_
+#define TENSORFLOW_CORE_KERNELS_BOUNDS_CHECK_H_
#include <type_traits>
@@ -51,4 +51,4 @@ EIGEN_ALWAYS_INLINE EIGEN_DEVICE_FUNC const T SubtleMustCopy(const T &x) {
} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_BOUNDS_CHECK_H_
+#endif // TENSORFLOW_CORE_KERNELS_BOUNDS_CHECK_H_
diff --git a/tensorflow/core/kernels/broadcast_to_op.h b/tensorflow/core/kernels/broadcast_to_op.h
index 73fdd5d28e..a2327a7272 100644
--- a/tensorflow/core/kernels/broadcast_to_op.h
+++ b/tensorflow/core/kernels/broadcast_to_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
-#define TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@@ -239,4 +239,4 @@ struct BroadcastTo {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_BROADCAST_TO_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BROADCAST_TO_OP_H_
diff --git a/tensorflow/core/kernels/bucketize_op.h b/tensorflow/core/kernels/bucketize_op.h
index c8e461beb9..32be475f86 100644
--- a/tensorflow/core/kernels/bucketize_op.h
+++ b/tensorflow/core/kernels/bucketize_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_BUCKETIZE_OP_H_
-#define TENSORFLOW_BUCKETIZE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -38,4 +38,4 @@ struct BucketizeFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_BUCKETIZE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_BUCKETIZE_OP_H_
diff --git a/tensorflow/core/kernels/candidate_sampler_ops.cc b/tensorflow/core/kernels/candidate_sampler_ops.cc
index 654d99301a..663bff3657 100644
--- a/tensorflow/core/kernels/candidate_sampler_ops.cc
+++ b/tensorflow/core/kernels/candidate_sampler_ops.cc
@@ -89,9 +89,9 @@ class BaseCandidateSamplerOp : public OpKernel {
// Pick sampled candidates.
auto local_gen = generator_.ReserveSamples32(samples32);
random::SimplePhilox random(&local_gen);
- sampler_->SampleBatchGetExpectedCount(&random, unique_, &sampled_candidate,
- &sampled_expected_count,
- true_candidate, &true_expected_count);
+ sampler_->SampleBatchGetExpectedCount(&random, unique_, sampled_candidate,
+ sampled_expected_count,
+ true_candidate, true_expected_count);
if (sampler_->NeedsUpdates()) {
sampler_->Update(true_candidate);
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index 0478c93280..3a72567655 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -98,7 +98,13 @@ void CastOpBase::Compute(OpKernelContext* ctx) {
ctx->set_output(0, inp);
} else {
Tensor in;
- in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ if (external_src_dtype_ != src_dtype_) {
+ // If the type is a quantized type we need to do an UnsafeCopyFromInternal
+ // since the src_dtype_ is different from external_src_type_.
+ in.UnsafeCopyFromInternal(inp, src_dtype_, inp.shape());
+ } else {
+ in = inp;
+ }
Tensor* out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, in.shape(), &out));
out->set_dtype(dst_dtype_);
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index 527ab528c9..84c44f6b5e 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CAST_OP_H_
-#define TENSORFLOW_KERNELS_CAST_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CAST_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CAST_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/bfloat16.h"
@@ -323,4 +323,4 @@ struct functor_traits<scalar_cast_op<float, ::tensorflow::bfloat16>> {
} // namespace internal
} // namespace Eigen
-#endif // TENSORFLOW_KERNELS_CAST_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CAST_OP_H_
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index 5de41bac72..82e2913b64 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -132,14 +132,20 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, c->input(0).shape(), &output),
+ done);
+ col_params_.instance.shape = c->input(0).shape();
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
- // Allocate the output tensor, trying to reuse the input.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c,
- c->forward_input_or_allocate_output(
- {0}, 0, c->input(0).shape(), &output),
- done);
-
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
done();
@@ -166,7 +172,7 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = true;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -183,16 +189,24 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // Allocate the output tensor, trying to reuse the input.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(c,
+ c->forward_input_or_allocate_output(
+ {0}, 0, col_params_.instance.shape, &output),
+ done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
OP_REQUIRES_ASYNC(
- c, shape_.IsSameSize(c->input(0).shape()),
+ c, col_params_.instance.shape.IsSameSize(c->input(0).shape()),
errors::Internal("Declared shape of op ", col_params_.name,
" does not match shape of input"),
done);
- // Allocate the output Tensor, trying to reuse the input.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(
- c, c->forward_input_or_allocate_output({0}, 0, shape_, &output), done);
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
@@ -202,8 +216,6 @@ class CollectiveBcastSendOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastSendOpKernel);
};
@@ -222,7 +234,7 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
OP_REQUIRES_OK(
c, c->GetAttr("instance_key", &col_params_.instance.instance_key));
OP_REQUIRES_OK(c, c->GetAttr("T", &col_params_.instance.data_type));
- OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_));
+ OP_REQUIRES_OK(c, c->GetAttr("shape", &col_params_.instance.shape));
col_params_.is_source = false;
col_params_.instance.impl_details.subdiv_offsets = {0};
@@ -239,10 +251,17 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
+ // Allocate output on the first pass through this function. This must be
+ // done immediately, while we're still in the executor thread. Otherwise
+ // the memory is not guaranteed to be unused by any concurrently executing
+ // GPU kernel.
+ if (c->mutable_output(0) == nullptr) {
+ // No input, so must allocate output.
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ c, c->allocate_output(0, col_params_.instance.shape, &output), done);
+ }
if (!CanProceedWithCompute(c, col_exec, done)) return;
- // No input, so must allocate output.
- Tensor* output = nullptr;
- OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape_, &output), done);
auto actual_done = [c, col_exec, done](const Status& s) {
OP_REQUIRES_OK_ASYNC(c, s, done);
@@ -252,8 +271,6 @@ class CollectiveBcastRecvOpKernel : public CollectiveOpKernel {
}
private:
- TensorShape shape_;
-
TF_DISALLOW_COPY_AND_ASSIGN(CollectiveBcastRecvOpKernel);
};
diff --git a/tensorflow/core/kernels/colorspace_op.h b/tensorflow/core/kernels/colorspace_op.h
index 90bfce1419..4de14bc339 100644
--- a/tensorflow/core/kernels/colorspace_op.h
+++ b/tensorflow/core/kernels/colorspace_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
-#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -91,4 +91,4 @@ struct HSVToRGB {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_COLORSPACE_OP_H_
diff --git a/tensorflow/core/kernels/concat_lib.h b/tensorflow/core/kernels/concat_lib.h
index 16784c4770..8b53ecf121 100644
--- a/tensorflow/core/kernels/concat_lib.h
+++ b/tensorflow/core/kernels/concat_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONCAT_LIB_H_
-#define TENSORFLOW_KERNELS_CONCAT_LIB_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_
+#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_
#include <vector>
@@ -66,4 +66,4 @@ void ConcatSYCL(
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONCAT_LIB_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_H_
diff --git a/tensorflow/core/kernels/concat_lib_cpu.h b/tensorflow/core/kernels/concat_lib_cpu.h
index 720b506537..29f3a427fe 100644
--- a/tensorflow/core/kernels/concat_lib_cpu.h
+++ b/tensorflow/core/kernels/concat_lib_cpu.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
+#define TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
+
#define EIGEN_USE_THREADS
#include <vector>
@@ -162,3 +165,5 @@ void ConcatSYCLImpl(
}
#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_CONCAT_LIB_CPU_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator.h b/tensorflow/core/kernels/conditional_accumulator.h
index 414891b142..390db8fe5a 100644
--- a/tensorflow/core/kernels/conditional_accumulator.h
+++ b/tensorflow/core/kernels/conditional_accumulator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
-#define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
+#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
@@ -51,9 +51,11 @@ class ConditionalAccumulator
// dtype: The datatype of the gradients to be accumulated.
// shape: The shape of the accumulated gradients.
// name: A name to use for the ConditionalAccumulator.
+ // reduction_type: The reduction type, i.e., MEAN or SUM
ConditionalAccumulator(const DataType& dtype, const PartialTensorShape& shape,
- const string& name)
- : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name) {}
+ const string& name, const string& reduction_type)
+ : TypedConditionalAccumulatorBase<const Tensor>(dtype, shape, name,
+ reduction_type) {}
~ConditionalAccumulator() override{};
protected:
@@ -133,4 +135,4 @@ class ConditionalAccumulator
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.cc b/tensorflow/core/kernels/conditional_accumulator_base.cc
index 90593c56b8..292cf0cd64 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_base.cc
@@ -14,12 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
ConditionalAccumulatorBase::ConditionalAccumulatorBase(
- const DataType& dtype, const PartialTensorShape& shape, const string& name)
- : dtype_(dtype), shape_(shape), name_(name) {
+ const DataType& dtype, const PartialTensorShape& shape, const string& name,
+ const string& reduction_type)
+ : dtype_(dtype),
+ shape_(shape),
+ name_(name),
+ reduction_type_(reduction_type) {
counter_ = 0;
current_global_step_ = 0;
}
@@ -190,7 +195,9 @@ bool ConditionalAccumulatorBase::TakeGradLockedHelper(OpKernelContext* ctx,
current_global_step_++;
// Average the accumulated gradient
- DivideAccumGradByCounter(ctx);
+ if (reduction_type_ == "MEAN") {
+ DivideAccumGradByCounter(ctx);
+ }
// Set output for accumulated gradient tensor
bool successful_set_output = SetOutput(ctx);
diff --git a/tensorflow/core/kernels/conditional_accumulator_base.h b/tensorflow/core/kernels/conditional_accumulator_base.h
index c7c7c98369..4a5ec6f0fb 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
-#define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
+#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
#include <deque>
@@ -52,7 +52,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
// name: A name to use for the ConditionalAccumulator.
ConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name);
+ const string& name, const string& reduction_type);
typedef AsyncOpKernel::DoneCallback DoneCallback;
@@ -125,6 +125,7 @@ class ConditionalAccumulatorBase : public ResourceBase {
const DataType dtype_;
const PartialTensorShape shape_;
const string name_;
+ const string reduction_type_;
mutex mu_;
int counter_ GUARDED_BY(mu_);
int64 current_global_step_ GUARDED_BY(mu_);
@@ -199,4 +200,4 @@ class TypeConverter<Eigen::half, U> {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_base_op.h b/tensorflow/core/kernels/conditional_accumulator_base_op.h
index 33c2d596c8..ca24d690f8 100644
--- a/tensorflow/core/kernels/conditional_accumulator_base_op.h
+++ b/tensorflow/core/kernels/conditional_accumulator_base_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
-#define TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
#define EIGEN_USE_THREADS
@@ -51,6 +51,8 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
&accumulator_handle_, nullptr));
OP_REQUIRES_OK(context, context->GetAttr("shape", &shape_));
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(context,
+ context->GetAttr("reduction_type", &reduction_type_));
}
void Compute(OpKernelContext* ctx) override {
@@ -81,6 +83,7 @@ class ConditionalAccumulatorBaseOp : public OpKernel {
DataType dtype_;
PartialTensorShape shape_;
ContainerInfo cinfo_;
+ string reduction_type_;
private:
Status SetAccumulatorHandle(OpKernelContext* ctx)
@@ -234,4 +237,4 @@ class ConditionalAccumulatorBaseTakeGradientOp
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONDITIONAL_ACCUMULATOR_BASE_OP_H_
diff --git a/tensorflow/core/kernels/conditional_accumulator_op.cc b/tensorflow/core/kernels/conditional_accumulator_op.cc
index e13bf8a4c6..52ac51a9b6 100644
--- a/tensorflow/core/kernels/conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/conditional_accumulator_op.cc
@@ -34,7 +34,8 @@ class ConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
ConditionalAccumulator<Device, T>* accumulator =
- new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name());
+ new ConditionalAccumulator<Device, T>(dtype_, shape_, cinfo_.name(),
+ reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index a888422d49..426c404f43 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -140,44 +140,6 @@ REGISTER_SYCL_KERNEL(SYCL, bool);
#undef REGISTER_SYCL_KERNEL
#endif
-HostConstantOp::HostConstantOp(OpKernelConstruction* ctx)
- : OpKernel(ctx), tensor_(ctx->output_type(0)) {
- const TensorProto* proto = nullptr;
- AllocatorAttributes alloc_attr;
- alloc_attr.set_on_host(true);
- OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
- OP_REQUIRES_OK(
- ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_));
- OP_REQUIRES(
- ctx, ctx->output_type(0) == tensor_.dtype(),
- errors::InvalidArgument("Type mismatch between value (",
- DataTypeString(tensor_.dtype()), ") and dtype (",
- DataTypeString(ctx->output_type(0)), ")"));
-}
-
-void HostConstantOp::Compute(OpKernelContext* ctx) {
- ctx->set_output(0, tensor_);
-}
-
-#if GOOGLE_CUDA
-// A special GPU kernel for int32.
-// TODO(b/25387198): Also enable int32 in device memory. This kernel
-// registration requires all int32 inputs and outputs to be in host memory.
-REGISTER_KERNEL_BUILDER(Name("Const")
- .Device(DEVICE_GPU)
- .HostMemory("output")
- .TypeConstraint<int32>("dtype"),
- HostConstantOp);
-#endif
-
-#ifdef TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("Const")
- .Device(DEVICE_SYCL)
- .HostMemory("output")
- .TypeConstraint<int32>("dtype"),
- HostConstantOp);
-#endif // TENSORFLOW_USE_SYCL
-
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
#ifdef TENSORFLOW_USE_SYCL
@@ -297,8 +259,9 @@ class ZerosLikeOp : public OpKernel {
errors::InvalidArgument("ZerosLike non-scalar Tensor with "
"dtype=DT_VARIANT is not supported."));
const Variant& v = input.scalar<Variant>()();
- Tensor out(ctx->device()->GetAllocator(AllocatorAttributes()), DT_VARIANT,
- TensorShape({}));
+ // DT_VARIANT tensors must be allocated on CPU since they wrap C++
+ // objects which can not be efficiently represented in GPU memory.
+ Tensor out(cpu_allocator(), DT_VARIANT, TensorShape({}));
Variant* out_v = &(out.scalar<Variant>()());
OP_REQUIRES_OK(ctx, UnaryOpVariant<Device>(
ctx, ZEROS_LIKE_VARIANT_UNARY_OP, v, out_v));
diff --git a/tensorflow/core/kernels/constant_op.h b/tensorflow/core/kernels/constant_op.h
index b98153e347..77ba441863 100644
--- a/tensorflow/core/kernels/constant_op.h
+++ b/tensorflow/core/kernels/constant_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONSTANT_OP_H_
-#define TENSORFLOW_KERNELS_CONSTANT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -36,20 +36,6 @@ class ConstantOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(ConstantOp);
};
-// HostConstantOp differs from ConstantOp in that its output is always
-// in host memory.
-class HostConstantOp : public OpKernel {
- public:
- explicit HostConstantOp(OpKernelConstruction* ctx);
- void Compute(OpKernelContext* ctx) override;
- bool IsExpensive() override { return false; }
- ~HostConstantOp() override {}
-
- private:
- Tensor tensor_;
- TF_DISALLOW_COPY_AND_ASSIGN(HostConstantOp);
-};
-
class PlaceholderOp : public OpKernel {
public:
explicit PlaceholderOp(OpKernelConstruction* ctx);
@@ -61,4 +47,4 @@ class PlaceholderOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONSTANT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/control_flow_ops.h b/tensorflow/core/kernels/control_flow_ops.h
index 8edbcc9077..c607fcf298 100644
--- a/tensorflow/core/kernels/control_flow_ops.h
+++ b/tensorflow/core/kernels/control_flow_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
-#define TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -115,4 +115,4 @@ class LoopCondOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONTROL_FLOW_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONTROL_FLOW_OPS_H_
diff --git a/tensorflow/core/kernels/conv_2d.h b/tensorflow/core/kernels/conv_2d.h
index 6b7544fd4c..639c3062cc 100644
--- a/tensorflow/core/kernels/conv_2d.h
+++ b/tensorflow/core/kernels/conv_2d.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONV_2D_H_
-#define TENSORFLOW_KERNELS_CONV_2D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_2D_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_2D_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -137,17 +137,16 @@ struct MatMulConvFunctor {
}
};
-// Shuffles a filter tensor from:
-// [<spatial_dims>, in, out]
-// to:
-// [out, in, <spatial_dims>]
+// Shuffles a filter tensor from TensorFlow format HWIO to dst_filter_format.
+//
+// Note: Currently OIHW is the only supported destination format. Support for
+// OHWI format will be added in a follow-up change.
template <typename Device, typename T, typename IndexType, int NDIMS>
struct TransformFilter {
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, IndexType>::ConstTensor in,
typename TTypes<T, NDIMS, IndexType>::Tensor out) {
- // We want a 3, 2, 0, 1 shuffle. Merge the spatial dimensions together
- // to speed up the shuffle operation.
+ // Merge the spatial dimensions together to speed up the shuffle operation.
Eigen::DSizes<IndexType, 3> merged_dims;
merged_dims[0] = in.dimension(0); // spatial dimensions
for (int i = 1; i < NDIMS - 2; ++i) {
@@ -156,16 +155,30 @@ struct TransformFilter {
merged_dims[1] = in.dimension(NDIMS - 2); // input filters
merged_dims[2] = in.dimension(NDIMS - 1); // output filters
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported destination filter format: "
+ << ToString(dst_filter_format);
+ // Source filter format is FORMAT_HWIO and spatial dimensions HW are merged
+ // in the beginning.
+ Eigen::DSizes<IndexType, 3> shuffling_perm =
+ Eigen::DSizes<IndexType, 3>(2, 1, 0);
+
Eigen::DSizes<IndexType, NDIMS> expanded_dims;
- expanded_dims[0] = in.dimension(NDIMS - 1); // output filters
- expanded_dims[1] = in.dimension(NDIMS - 2); // input filters
- for (int i = 0; i < NDIMS - 2; ++i) { // spatial dimensions
- expanded_dims[i + 2] = in.dimension(i);
+ int out_index = 0;
+ for (int merged_dim = 0; merged_dim < merged_dims.rank(); ++merged_dim) {
+ if (shuffling_perm[merged_dim] == 0) {
+ for (int spatial_dim = 0; spatial_dim < NDIMS - 2; ++spatial_dim) {
+ expanded_dims[out_index++] = in.dimension(spatial_dim);
+ }
+ } else {
+ constexpr int kLastSpatialDim = NDIMS - 3;
+ expanded_dims[out_index++] =
+ in.dimension(kLastSpatialDim + shuffling_perm[merged_dim]);
+ }
}
- out.device(d) = in.reshape(merged_dims)
- .shuffle(Eigen::DSizes<IndexType, 3>(2, 1, 0))
- .reshape(expanded_dims);
+ out.device(d) =
+ in.reshape(merged_dims).shuffle(shuffling_perm).reshape(expanded_dims);
}
};
@@ -282,7 +295,9 @@ struct SwapDimension0And2InTensor3 {
const gtl::ArraySlice<int64>& input_dims, T* out);
};
-// Reverses the effect of TransformFilter above.
+// Transforms back filter from OIHW to HWOI format to reverse effect of
+// TransformFilter above.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename Device, typename T, int NDIMS>
struct ReverseTransformFilter {
void operator()(const Device& d, typename TTypes<T, NDIMS>::ConstTensor in,
@@ -298,4 +313,4 @@ template <>
class ConvAlgorithmMap<Eigen::ThreadPoolDevice> {};
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONV_2D_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONV_2D_H_
diff --git a/tensorflow/core/kernels/conv_3d.h b/tensorflow/core/kernels/conv_3d.h
index 083dec63cc..b819c6f910 100644
--- a/tensorflow/core/kernels/conv_3d.h
+++ b/tensorflow/core/kernels/conv_3d.h
@@ -15,10 +15,11 @@ limitations under the License.
// Functors for 3d convolution.
-#ifndef TENSORFLOW_KERNELS_CONV_3D_H_
-#define TENSORFLOW_KERNELS_CONV_3D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_3D_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_3D_H_
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
namespace tensorflow {
@@ -28,6 +29,14 @@ namespace functor {
template <typename Device, typename T>
struct CuboidConvolution;
+// Backward input pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardInput;
+
+// Backward filter pass for the cuboid convolution.
+template <typename Device, typename T>
+struct CuboidConvolutionBackwardFilter;
+
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename T>
@@ -42,7 +51,41 @@ struct CuboidConvolution<CPUDevice, T> {
}
};
+template <typename T>
+struct CuboidConvolutionBackwardInput<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor input_backward,
+ typename TTypes<T, 5>::ConstTensor filter,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ input_backward.device(d) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward,
+ input_backward.dimension(3), // input_planes
+ input_backward.dimension(2), // input_rows
+ input_backward.dimension(1), // input_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
+template <typename T>
+struct CuboidConvolutionBackwardFilter<CPUDevice, T> {
+ void operator()(const CPUDevice& d,
+ typename TTypes<T, 5>::Tensor filter_backward,
+ typename TTypes<T, 5>::ConstTensor input,
+ typename TTypes<T, 5>::ConstTensor output_backward,
+ int stride_planes, int stride_rows, int stride_cols) {
+ // Need to swap the order of plane/row/col strides when calling Eigen.
+ filter_backward.device(d) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward,
+ filter_backward.dimension(2), // kernel_planes
+ filter_backward.dimension(1), // kernel_rows
+ filter_backward.dimension(0), // kernel_cols
+ stride_cols, stride_rows, stride_planes);
+ }
+};
+
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONV_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_CONV_3D_H_
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 63b1bcda43..9e86a16b66 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -1018,7 +1018,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d664a11e73..43bb5ea56c 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -901,7 +901,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
&transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_out_backprop;
@@ -1090,7 +1091,8 @@ namespace functor {
extern template struct InflatePadAndShuffle<GPUDevice, T, 4, int>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index 5bf709af08..507720c998 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -41,6 +41,17 @@ limitations under the License.
namespace tensorflow {
+// Compute padding for the given spatial dimension.
+int ConvBackpropDimensions::SpatialPadding(const Padding& padding,
+ int dim) const {
+ return (padding == VALID)
+ ? 0
+ : std::max<int>(
+ 0, static_cast<int>((output_size(dim) - 1) * stride(dim) +
+ (filter_size(dim) - 1) * dilation(dim) +
+ 1 - input_size(dim)));
+}
+
// The V2 version computes windowed output size with arbitrary dilation_rate,
// while the original version only handles the cases where dilation_rates equal
// to 1.
@@ -63,7 +74,7 @@ Status ConvBackpropExtractAndVerifyDimensionV2(
return errors::InvalidArgument(
label, ": Size of out_backprop doesn't match computed: ", "actual = ",
dim->output_size, ", computed = ", out_size,
- "spatial_dim: ", spatial_dim, " input: ", dim->input_size,
+ " spatial_dim: ", spatial_dim, " input: ", dim->input_size,
" filter: ", dim->filter_size, " output: ", dim->output_size,
" stride: ", dim->stride, " dilation: ", dim->dilation);
}
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index 535586d53a..9551959463 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -234,6 +234,16 @@ struct ConvBackpropDimensions {
// Input and output feature depth.
int64 in_depth, out_depth;
+
+ // Convenience access methods for spatial dimensions properties.
+ int64 input_size(int dim) const { return spatial_dims[dim].input_size; }
+ int64 filter_size(int dim) const { return spatial_dims[dim].filter_size; }
+ int64 output_size(int dim) const { return spatial_dims[dim].output_size; }
+ int64 stride(int dim) const { return spatial_dims[dim].stride; }
+ int64 dilation(int dim) const { return spatial_dims[dim].dilation; }
+
+ // Compute padding for the given spatial dimension.
+ int SpatialPadding(const Padding& padding, int dim) const;
};
// Common code between implementations of Conv?DBackpropInput and
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 15f1bf9aba..bab91f5e86 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_2d.h"
+#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/conv_ops_gpu.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -32,111 +33,130 @@ limitations under the License.
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/use_cudnn.h"
+#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
using stream_executor::dnn::DimIndex;
#endif
+namespace {
+
+// TODO(ezhulenev): Split this file into conv_grad_filter_ops_3d.cc and
+// conv_grad_input_ops_3d.cc.
+
+// TODO(ezhulenev): Generalize Col2im and Im2col for 2-d and 3-d kernels.
+
+// "Depth" is already used for the channel dimension, so for the third spatial
+// dimension in this file we use "plane", although in NDHWC layout it's
+// indicated with a "D".
+
+// Returns in 'im_data' (assumed to be zero-initialized) image patch in storage
+// order (planes, height, width, depth), constructed from patches in 'col_data',
+// which is required to be in storage order (out_planes * out_height *
+// out_width, filter_planes, filter_height, filter_width, in_depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Col2im(const T* col_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* im_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ T* im_patch_data =
+ im_data + (p_pad * height * width + h_pad * width + w_pad) * depth;
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ for (int i = 0; i < depth; ++i) {
+ im_patch_data[i] += col_data[i];
+ }
+ }
+ im_patch_data += depth;
+ col_data += depth;
+ }
+ // Jump over remaining number of depth.
+ im_patch_data += depth * (width - filter_w);
+ }
+ // Jump over remaining number of (depth * width).
+ im_patch_data += (depth * width) * (height - filter_h);
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+// Returns in 'col_data', image patches in storage order (planes, height, width,
+// depth) extracted from image at 'input_data', which is required to be in
+// storage order (batch, planes, height, width, depth).
+//
+// Based on 2-dimensional implementation written by Yangqing Jia (jiayq).
+template <typename T>
+void Im2col(const T* input_data, const int depth, const int planes,
+ const int height, const int width, const int filter_p,
+ const int filter_h, const int filter_w, const int pad_pt,
+ const int pad_t, const int pad_l, const int pad_pb, const int pad_b,
+ const int pad_r, const int stride_p, const int stride_h,
+ const int stride_w, T* col_data) {
+ const int planes_col = (planes + pad_pt + pad_pb - filter_p) / stride_p + 1;
+ const int height_col = (height + pad_t + pad_b - filter_h) / stride_h + 1;
+ const int width_col = (width + pad_l + pad_r - filter_w) / stride_w + 1;
+
+ int p_pad = -pad_pt;
+ for (int p = 0; p < planes_col; ++p) {
+ int h_pad = -pad_t;
+ for (int h = 0; h < height_col; ++h) {
+ int w_pad = -pad_l;
+ for (int w = 0; w < width_col; ++w) {
+ for (int ip = p_pad; ip < p_pad + filter_p; ++ip) {
+ for (int ih = h_pad; ih < h_pad + filter_h; ++ih) {
+ for (int iw = w_pad; iw < w_pad + filter_w; ++iw) {
+ if (ip >= 0 && ip < planes && ih >= 0 && ih < height && iw >= 0 &&
+ iw < width) {
+ memcpy(col_data,
+ input_data +
+ (ip * height * width + ih * width + iw) * depth,
+ sizeof(T) * depth);
+ } else {
+ // This should be simply padded with zero.
+ memset(col_data, 0, sizeof(T) * depth);
+ }
+ col_data += depth;
+ }
+ }
+ }
+ w_pad += stride_w;
+ }
+ h_pad += stride_h;
+ }
+ p_pad += stride_p;
+ }
+}
+
+} // namespace
+
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
-// TODO(mjanusz): Get rid of the macro and return shapes directly.
-#define EXTRACT_AND_VERIFY_DIMENSIONS(label) \
- const Tensor& out_backprop = context->input(2); \
- OP_REQUIRES( \
- context, input_shape.dims() == 5, \
- errors::InvalidArgument(label, ": input must be 5-dimensional")); \
- OP_REQUIRES( \
- context, filter_shape.dims() == 5, \
- errors::InvalidArgument(label, ": filter must be 5-dimensional")); \
- OP_REQUIRES( \
- context, out_backprop.dims() == 5, \
- errors::InvalidArgument(label, ": out_backprop must be 5-dimensional")); \
- const int64 batch = input_shape.dim_size(0); \
- OP_REQUIRES( \
- context, batch == out_backprop.dim_size(0), \
- errors::InvalidArgument( \
- label, ": input and out_backprop must have the same batch size")); \
- const std::array<int64, 3> input_size = { \
- {GetTensorDim(input_shape, data_format_, '0'), \
- GetTensorDim(input_shape, data_format_, '1'), \
- GetTensorDim(input_shape, data_format_, '2')}}; \
- const int64 in_depth = GetTensorDim(input_shape, data_format_, 'C'); \
- const std::array<int64, 3> filter_size = {{filter_shape.dim_size(0), \
- filter_shape.dim_size(1), \
- filter_shape.dim_size(2)}}; \
- const int64 output_cols = GetTensorDim(out_backprop, data_format_, '2'); \
- const int64 output_rows = GetTensorDim(out_backprop, data_format_, '1'); \
- const int64 output_planes = GetTensorDim(out_backprop, data_format_, '0'); \
- OP_REQUIRES(context, in_depth == filter_shape.dim_size(3), \
- errors::InvalidArgument( \
- label, ": input and filter must have the same depth")); \
- const int64 out_depth = filter_shape.dim_size(4); \
- OP_REQUIRES( \
- context, out_depth == GetTensorDim(out_backprop, data_format_, 'C'), \
- errors::InvalidArgument( \
- label, ": filter and out_backprop must have the same out_depth")); \
- const std::array<int64, 3> dilations = { \
- {GetTensorDim(dilation_, data_format_, '0'), \
- GetTensorDim(dilation_, data_format_, '1'), \
- GetTensorDim(dilation_, data_format_, '2')}}; \
- const std::array<int64, 3> strides = { \
- {GetTensorDim(stride_, data_format_, '0'), \
- GetTensorDim(stride_, data_format_, '1'), \
- GetTensorDim(stride_, data_format_, '2')}}; \
- std::array<int64, 3> out, padding; \
- OP_REQUIRES_OK( \
- context, Get3dOutputSizeV2(input_size, filter_size, dilations, strides, \
- padding_, &out, &padding)); \
- OP_REQUIRES(context, output_planes == out[0], \
- errors::InvalidArgument( \
- label, \
- ": Number of planes of out_backprop doesn't match " \
- "computed: actual = ", \
- output_planes, ", computed = ", out[0])); \
- OP_REQUIRES( \
- context, output_rows == out[1], \
- errors::InvalidArgument( \
- label, ": Number of rows of out_backprop doesn't match computed: ", \
- "actual = ", output_rows, ", computed = ", out[1])); \
- OP_REQUIRES( \
- context, output_cols == out[2], \
- errors::InvalidArgument( \
- label, ": Number of cols of out_backprop doesn't match computed: ", \
- "actual = ", output_cols, ", computed = ", out[2])); \
- const auto expanded_out_planes = (output_planes - 1) * strides[0] + 1; \
- const auto expanded_out_rows = (output_rows - 1) * strides[1] + 1; \
- const auto expanded_out_cols = (output_cols - 1) * strides[2] + 1; \
- const auto padded_out_planes = input_size[0] + filter_size[0] - 1; \
- const auto padded_out_rows = input_size[1] + filter_size[1] - 1; \
- const auto padded_out_cols = input_size[2] + filter_size[2] - 1; \
- const auto top_pad_planes = filter_size[0] - 1 - padding[0]; \
- const auto top_pad_rows = filter_size[1] - 1 - padding[1]; \
- const auto left_pad_cols = filter_size[2] - 1 - padding[2]; \
- const auto bottom_pad_planes = \
- padded_out_planes - expanded_out_planes - top_pad_planes; \
- const auto bottom_pad_rows = \
- padded_out_rows - expanded_out_rows - top_pad_rows; \
- const auto right_pad_cols = \
- padded_out_cols - expanded_out_cols - left_pad_cols; \
- VLOG(2) << "Conv3d: " << label \
- << ": expanded_out_planes = " << expanded_out_planes \
- << ": expanded_out_rows = " << expanded_out_rows \
- << ", expanded_out_cols = " << expanded_out_cols \
- << ", padded_out_planes = " << padded_out_planes \
- << ", padded_out_rows = " << padded_out_rows \
- << ", padded_out_cols = " << padded_out_cols \
- << ", top_pad_planes = " << top_pad_planes \
- << ", top_pad_rows = " << top_pad_rows \
- << ", left_pad_cols = " << left_pad_cols \
- << ", bottom_pad_planes = " << bottom_pad_planes \
- << ", bottom_pad_rows = " << bottom_pad_rows \
- << ", right_pad_cols = " << right_pad_cols
-
-// Backprop for input.
+// Backprop for input that offloads computation to
+// Eigen::CuboidConvolutionBackwardInput.
template <typename Device, class T>
class Conv3DBackpropInputOp : public OpKernel {
public:
@@ -192,6 +212,116 @@ class Conv3DBackpropInputOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape input_shape;
+ if (takes_shape_) {
+ const Tensor& input_sizes = context->input(0);
+ // MakeShape is able to handle both DT_INT32 and DT_INT64 for input_sizes.
+ OP_REQUIRES_OK(context, MakeShape(input_sizes, &input_shape));
+ } else {
+ input_shape = context->input(0).shape();
+ }
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
+ Tensor* in_backprop;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input_shape, &in_backprop));
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropInputOp);
+};
+
+// Custom backprop for input that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropInputOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropInputOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropInputOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& filter = context->input(1);
+ const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -200,51 +330,239 @@ class Conv3DBackpropInputOp : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context, ConvBackpropComputeDimensions(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
- // Fill out a padded out_backprop.
- TensorShape padded_out_shape({batch, padded_out_planes, padded_out_rows,
- padded_out_cols, out_depth});
- Tensor padded_output;
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // Use L3 cache size as target working set size.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ // Calculate size of matrices involved in MatMul: C = A x B.
+ const int64 size_A = output_image_size * dims.out_depth;
+
+ const int64 size_B = filter_total_size * dims.out_depth;
+
+ const int64 size_C = output_image_size * filter_total_size;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ // Use parallel tensor contractions if there is no batching.
+ //
+ // Compared to Conv2D code, this version is missing work size estimation. In
+ // benchmarks I didn't find a case when it's beneficial to run parallel
+ // contraction compared to sharding and matmuls.
+ const bool use_parallel_contraction = dims.batch_size == 1;
+
+ const size_t shard_size =
+ use_parallel_contraction
+ ? 1
+ : (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropInputOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardInput<Device, T>()(
+ context->eigen_device<Device>(),
+ in_backprop->tensor<T, 5>(), // input_backward
+ filter.tensor<T, 5>(), // filter
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> no_op_shuffle{0, 1, 2, 3, 4};
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, no_op_shuffle, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // Fill a new "reverted" filter. We need to transpose the in_depth and
- // out_depth for the filter and reverse the planes, rows and cols.
- TensorShape r_filter_shape(
- {filter_size[0], filter_size[1], filter_size[2], out_depth, in_depth});
- Tensor r_filter;
- OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
- r_filter_shape, &r_filter));
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{0, 1, 2, 4, 3};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter.tensor<T, 5>(), filter_order,
- filter_rev_dims, r_filter.tensor<T, 5>());
- const Tensor& r_filter_cref = r_filter;
-
- // Now we can call conv_3d directly.
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), in_backprop->tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), r_filter_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* filter_data = filter.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+
+ auto in_backprop_flat = in_backprop->template flat<T>();
+ T* input_backprop_data = in_backprop_flat.data();
+ in_backprop_flat.device(context->eigen_device<Device>()) =
+ in_backprop_flat.constant(T(0));
+
+ if (use_parallel_contraction) {
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ // Initialize contraction dims (we need to transpose 'B' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 1;
+ contract_dims[0].second = 1;
+
+ for (int image_id = 0; image_id < dims.batch_size; ++image_id) {
+ // Compute gradient into col_buffer.
+ TensorMap C(col_buffer_data, output_image_size, filter_total_size);
+
+ ConstTensorMap A(out_backprop_data + output_offset * image_id,
+ output_image_size, dims.out_depth);
+ ConstTensorMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.device(context->eigen_cpu_device()) = A.contract(B, contract_dims);
+
+ Col2im<T>(col_buffer_data, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_backprop_data);
+
+ input_backprop_data += input_offset;
+ }
+ } else {
+ typedef Eigen::Map<
+ Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>>
+ MatrixMap;
+ typedef Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, Eigen::Dynamic,
+ Eigen::RowMajor>>
+ ConstMatrixMap;
+
+ for (int image_id = 0; image_id < dims.batch_size;
+ image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&dims, &top_pad_planes, &top_pad_rows, &left_pad_cols,
+ &bottom_pad_planes, &bottom_pad_rows, &right_pad_cols,
+ &output_image_size, &filter_total_size,
+ &input_backprop_data, &col_buffer_data,
+ &out_backprop_data, &filter_data, &input_offset,
+ &output_offset, &size_C](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ T* im2col_buf = col_buffer_data + shard_id * size_C;
+ T* input_data = input_backprop_data + shard_id * input_offset;
+ const T* out_data = out_backprop_data + shard_id * output_offset;
+
+ // Compute gradient into 'im2col_buf'.
+ MatrixMap C(im2col_buf, output_image_size, filter_total_size);
+
+ ConstMatrixMap A(out_data, output_image_size, dims.out_depth);
+ ConstMatrixMap B(filter_data, filter_total_size, dims.out_depth);
+
+ C.noalias() = A * B.transpose();
+
+ Col2im<T>(im2col_buf, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ input_data);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ work_unit_size, shard);
+
+ input_backprop_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
+ }
}
private:
@@ -253,21 +571,48 @@ class Conv3DBackpropInputOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropInputOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInput").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>); \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropInputV2").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropInputOp<CPUDevice, T>);
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropInputOp<CPUDevice, T>);
+
TF_CALL_half(REGISTER_CPU_KERNEL);
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
-// Backprop for filter.
+// Backprop for filter that offloads computation to
+// Eigen::CuboidConvolutionBackwardFilter.
template <typename Device, class T>
class Conv3DBackpropFilterOp : public OpKernel {
public:
@@ -323,8 +668,11 @@ class Conv3DBackpropFilterOp : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
- TensorShape filter_shape;
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
@@ -333,13 +681,13 @@ class Conv3DBackpropFilterOp : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
- Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 5> pad_dims{
- {0, 0},
- {top_pad_planes, bottom_pad_planes},
- {top_pad_rows, bottom_pad_rows},
- {left_pad_cols, right_pad_cols},
- {0, 0}};
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, filter_shape, &filter_backprop));
@@ -349,70 +697,292 @@ class Conv3DBackpropFilterOp : public OpKernel {
return;
}
- // For the backprop of the filter, we need to also transpose the
- // out_backprop.
- // The shape of backprop is
- // [batch, out_z, out_y, out_x, out_depth]
- // And we need to change it to
- // [out_depth, out_x, out_y, out_z, batch]
- Eigen::DSizes<Eigen::DenseIndex, 5> out_order{4, 1, 2, 3, 0};
- TensorShape padded_out_shape({out_depth, padded_out_planes, padded_out_rows,
- padded_out_cols, batch});
- Tensor padded_output;
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+ }
+
+ private:
+ std::vector<int32> dilation_;
+ std::vector<int32> stride_;
+ Padding padding_;
+ TensorFormat data_format_;
+ bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DBackpropFilterOp);
+};
+
+// Custom backprop for filter that explicitly does the work sharding and calls
+// Eigen only to multiply matrices.
+template <typename Device, class T>
+class Conv3DCustomBackpropFilterOp : public OpKernel {
+ // Limit the maximum size of allocated temporary buffer to
+ // kMaxTempAllocationOverhead times the size of the input tensors (input,
+ // filter, out_backprop). If the size of the temporary buffer exceeds this
+ // limit, fallback on Eigen implementation.
+ static constexpr int kMaxTempAllocationOverhead = 25;
+
+ public:
+ explicit Conv3DCustomBackpropFilterOp(OpKernelConstruction* context)
+ : OpKernel(context),
+ data_format_(FORMAT_NHWC),
+ takes_shape_(type_string().find("V2") != std::string::npos) {
+ // data_format is only available in V2.
+ if (takes_shape_) {
+ string data_format;
+ OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
+ OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
+ errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES(
+ context, data_format_ == FORMAT_NHWC,
+ errors::InvalidArgument(
+ "Conv3DBackpropFilterOpV2 only supports NDHWC on the CPU."));
+ }
+
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_));
+ OP_REQUIRES(context, dilation_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, 'C') == 1 &&
+ GetTensorDim(dilation_, data_format_, 'N') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilation rates in the batch and depth dimensions."));
+
+ // TODO(yangzihao): Add CPU version of dilated conv 3D.
+ OP_REQUIRES(context,
+ (GetTensorDim(dilation_, data_format_, '0') == 1 &&
+ GetTensorDim(dilation_, data_format_, '1') == 1 &&
+ GetTensorDim(dilation_, data_format_, '2') == 1),
+ errors::InvalidArgument(
+ "Current CPU implementation does not yet support "
+ "dilation rates larger than 1."));
+
+ OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_));
+ OP_REQUIRES(context, stride_.size() == 5,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(stride_, data_format_, 'C') == 1 &&
+ GetTensorDim(stride_, data_format_, 'N') == 1),
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+ const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
+ TensorShape filter_shape;
+ if (takes_shape_) {
+ const Tensor& filter_sizes = context->input(1);
+ OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
+ filter_sizes.vec<int32>(), &filter_shape));
+ } else {
+ filter_shape = context->input(1).shape();
+ }
+
+ ConvBackpropDimensions dims;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- padded_out_shape, &padded_output));
- Eigen::DSizes<Eigen::DenseIndex, 5> eigen_strides{1, strides[0], strides[1],
- strides[2], 1};
- functor::InflatePadAndShuffle<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), out_backprop.tensor<T, 5>(),
- eigen_strides, pad_dims, out_order, padded_output.tensor<T, 5>());
- const Tensor& padded_output_cref = padded_output;
-
- // For the backprop of the filter, we need to transpose the input.
- // The shape of input is
- // [batch, in_z, in_y, in_x, in_depth]
- // And we need to change it to
- // [in_z, in_y, in_x, batch, in_depth]
- Eigen::DSizes<Eigen::DenseIndex, 5> in_order{1, 2, 3, 0, 4};
- TensorShape in_shuffle_shape(
- {input_size[0], input_size[1], input_size[2], batch, in_depth});
- Tensor in_shuffle;
+ ConvBackpropComputeDimensions(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, stride_,
+ padding_, data_format_, &dims));
+
+ Tensor* filter_backprop;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::v(),
- in_shuffle_shape, &in_shuffle));
- // No need for reversing this time.
- Eigen::array<bool, 5> no_reverse{false, false, false, false, false};
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), input.tensor<T, 5>(), in_order,
- no_reverse, in_shuffle.tensor<T, 5>());
- const Tensor& in_shuffle_cref = in_shuffle;
-
- // The output of the conv_3d would be
- // [out_depth, filter_size[2], filter_size[1], filter_size[0], in_depth]
- // and we need to shuffle it back to
- // [filter_size[2], filter_size[1], filter_size[0], in_depth, out_depth];
- // And we need to reverse the filter backprops.
- // So we need to allocate (sigh) yet another piece of memory to hold the
- // output.
- TensorShape filter_shuffle_shape(
- {out_depth, filter_size[0], filter_size[1], filter_size[2], in_depth});
- Tensor filter_shuffle;
- OP_REQUIRES_OK(
- context, context->allocate_temp(DataTypeToEnum<T>::v(),
- filter_shuffle_shape, &filter_shuffle));
- functor::CuboidConvolution<Device, T>()(
- context->eigen_device<Device>(), filter_shuffle.tensor<T, 5>(),
- padded_output_cref.tensor<T, 5>(), in_shuffle_cref.tensor<T, 5>(), 1, 1,
- 1, BrainPadding2EigenPadding(VALID));
-
- // Now copy the filter_backprop back to the destination.
- Eigen::DSizes<Eigen::DenseIndex, 5> filter_order{1, 2, 3, 4, 0};
- Eigen::array<bool, 5> filter_rev_dims{true, true, true, false, false};
- const Tensor& filter_shuffle_cref = filter_shuffle;
- functor::ShuffleAndReverse<Device, T, 5, Eigen::DenseIndex>()(
- context->eigen_device<Device>(), filter_shuffle_cref.tensor<T, 5>(),
- filter_order, filter_rev_dims, filter_backprop->tensor<T, 5>());
+ context->allocate_output(0, filter_shape, &filter_backprop));
+
+ if (input_shape.num_elements() == 0) {
+ filter_backprop->template flat<T>().setZero();
+ return;
+ }
+
+ int64 top_pad_planes, bottom_pad_planes;
+ int64 top_pad_rows, bottom_pad_rows;
+ int64 left_pad_cols, right_pad_cols;
+
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[0].input_size,
+ dims.spatial_dims[0].filter_size,
+ dims.spatial_dims[0].stride, padding_,
+ &dims.spatial_dims[0].output_size,
+ &top_pad_planes, &bottom_pad_planes));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[1].input_size,
+ dims.spatial_dims[1].filter_size,
+ dims.spatial_dims[1].stride, padding_,
+ &dims.spatial_dims[1].output_size,
+ &top_pad_rows, &bottom_pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
+ dims.spatial_dims[2].input_size,
+ dims.spatial_dims[2].filter_size,
+ dims.spatial_dims[2].stride, padding_,
+ &dims.spatial_dims[2].output_size,
+ &left_pad_cols, &right_pad_cols));
+
+ // TODO(ezhulenev): Extract work size and shard estimation to shared
+ // functions in conv_grad_ops, and update 2d convolution backprop.
+
+ // The total dimension size of each kernel.
+ const int64 filter_total_size =
+ dims.spatial_dims[0].filter_size * dims.spatial_dims[1].filter_size *
+ dims.spatial_dims[2].filter_size * dims.in_depth;
+ // The output image size is the spatial size of the output.
+ const int64 output_image_size = dims.spatial_dims[0].output_size *
+ dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size;
+
+ // Shard 'batch' images (volumes) into 'shard_size' groups of images
+ // (volumes) to be fed into the parallel matmul. Calculate 'shard_size' by
+ // dividing the L3 cache size ('target_working_set_size') by the matmul size
+ // of an individual image ('work_unit_size').
+
+ const auto cache_sizes = Eigen::internal::CacheSizes();
+ const ptrdiff_t l3_cache_size = cache_sizes.m_l3;
+
+ // TODO(andydavis)
+ // *) Consider reducing 'target_working_set_size' if L3 is shared by
+ // other concurrently running tensorflow ops.
+ const size_t target_working_set_size = l3_cache_size / sizeof(T);
+
+ const int64 size_A = output_image_size * filter_total_size;
+
+ const int64 size_B = output_image_size * dims.out_depth;
+
+ const int64 size_C = filter_total_size * dims.out_depth;
+
+ const int64 work_unit_size = size_A + size_B + size_C;
+
+ const size_t shard_size =
+ (target_working_set_size + work_unit_size - 1) / work_unit_size;
+
+ // Total number of elements in all the tensors used by this kernel.
+ int64 total_tensor_elements = input_shape.num_elements() +
+ filter_shape.num_elements() +
+ out_backprop_shape.num_elements();
+
+ // Shape of the temporary workspace buffer.
+ TensorShape col_buffer_shape = {static_cast<int64>(shard_size),
+ static_cast<int64>(output_image_size),
+ static_cast<int64>(filter_total_size)};
+ int64 col_buffer_elements = col_buffer_shape.num_elements();
+
+ // If the temporary allocation overhead is too large, fallback on Eigen
+ // implementation which requires much less memory.
+ int64 col_buffer_overhead = col_buffer_elements / total_tensor_elements;
+ if (col_buffer_overhead > kMaxTempAllocationOverhead) {
+ VLOG(2) << "Fallback on Eigen implementation of Conv3DBackpropFilterOp: "
+ "col_buffer_overhead="
+ << col_buffer_overhead;
+
+ functor::CuboidConvolutionBackwardFilter<Device, T>()(
+ context->eigen_device<Device>(),
+ filter_backprop->tensor<T, 5>(), // filter_backward
+ input.tensor<T, 5>(), // input
+ out_backprop.tensor<T, 5>(), // output_backward
+ static_cast<int>(dims.spatial_dims[0].stride), // stride_planes
+ static_cast<int>(dims.spatial_dims[1].stride), // stride_rows
+ static_cast<int>(dims.spatial_dims[2].stride)); // stride_cols
+
+ return;
+ }
+
+ Tensor col_buffer;
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(DataTypeToEnum<T>::value,
+ col_buffer_shape, &col_buffer));
+
+ // The input offset corresponding to a single input image.
+ const int64 input_offset = dims.spatial_dims[0].input_size *
+ dims.spatial_dims[1].input_size *
+ dims.spatial_dims[2].input_size * dims.in_depth;
+ // The output offset corresponding to a single output image.
+ const int64 output_offset =
+ dims.spatial_dims[0].output_size * dims.spatial_dims[1].output_size *
+ dims.spatial_dims[2].output_size * dims.out_depth;
+
+ const T* input_data = input.template flat<T>().data();
+ T* col_buffer_data = col_buffer.template flat<T>().data();
+ const T* out_backprop_data = out_backprop.template flat<T>().data();
+ T* filter_backprop_data = filter_backprop->template flat<T>().data();
+
+ typedef Eigen::TensorMap<Eigen::Tensor<T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ TensorMap;
+ typedef Eigen::TensorMap<Eigen::Tensor<const T, 2, Eigen::RowMajor>,
+ Eigen::Unaligned>
+ ConstTensorMap;
+
+ TensorMap C(filter_backprop_data, filter_total_size, dims.out_depth);
+ C.setZero();
+
+ // Initialize contraction dims (we need to transpose 'A' below).
+ Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_dims;
+ contract_dims[0].first = 0;
+ contract_dims[0].second = 0;
+
+ auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
+
+ for (int image_id = 0; image_id < dims.batch_size; image_id += shard_size) {
+ const int shard_limit =
+ std::min(static_cast<int>(shard_size),
+ static_cast<int>(dims.batch_size) - image_id);
+
+ auto shard = [&input_data, &col_buffer_data, &dims, &top_pad_planes,
+ &top_pad_rows, &left_pad_cols, &bottom_pad_planes,
+ &bottom_pad_rows, &right_pad_cols, &input_offset,
+ &size_A](int64 start, int64 limit) {
+ for (int shard_id = start; shard_id < limit; ++shard_id) {
+ const T* input_data_shard = input_data + shard_id * input_offset;
+ T* col_data_shard = col_buffer_data + shard_id * size_A;
+
+ // When we compute the gradient with respect to the filters, we need
+ // to do im2col to allow gemm-type computation.
+ Im2col<T>(input_data_shard, dims.in_depth,
+ // Input spatial dimensions.
+ dims.spatial_dims[0].input_size, // input planes
+ dims.spatial_dims[1].input_size, // input rows
+ dims.spatial_dims[2].input_size, // input cols
+ // Filter spatial dimensions.
+ dims.spatial_dims[0].filter_size, // filter planes
+ dims.spatial_dims[1].filter_size, // filter rows
+ dims.spatial_dims[2].filter_size, // filter cols
+ // Spatial padding.
+ top_pad_planes, top_pad_rows, left_pad_cols,
+ bottom_pad_planes, bottom_pad_rows, right_pad_cols,
+ // Spatial striding.
+ dims.spatial_dims[0].stride, // stride planes
+ dims.spatial_dims[1].stride, // stride rows
+ dims.spatial_dims[2].stride, // stride cols
+ col_data_shard);
+ }
+ };
+ Shard(worker_threads.num_threads, worker_threads.workers, shard_limit,
+ size_A, shard);
+
+ ConstTensorMap A(col_buffer_data, output_image_size * shard_limit,
+ filter_total_size);
+ ConstTensorMap B(out_backprop_data, output_image_size * shard_limit,
+ dims.out_depth);
+
+ // Gradient with respect to filter.
+ C.device(context->eigen_cpu_device()) += A.contract(B, contract_dims);
+
+ input_data += input_offset * shard_limit;
+ out_backprop_data += output_offset * shard_limit;
+ }
}
private:
@@ -421,21 +991,60 @@ class Conv3DBackpropFilterOp : public OpKernel {
Padding padding_;
TensorFormat data_format_;
bool takes_shape_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(Conv3DCustomBackpropFilterOp);
};
+// Custom backrop input kernel is 30% - 4x faster when compiled with AVX2 than
+// default Eigen implementation (at the cost of ~2x-8x peak memory usage).
+
#define REGISTER_CPU_KERNEL(T) \
REGISTER_KERNEL_BUILDER( \
Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
- Conv3DBackpropFilterOp<CPUDevice, T>); \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .Label("custom") \
+ .TypeConstraint<T>("T"), \
+ Conv3DCustomBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
.Device(DEVICE_CPU) \
+ .Label("eigen_tensor") \
.TypeConstraint<T>("T"), \
Conv3DBackpropFilterOp<CPUDevice, T>);
-TF_CALL_half(REGISTER_CPU_KERNEL);
+
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
#undef REGISTER_CPU_KERNEL
+// WARNING: Eigen::half is not trivially copyable and can't be used in
+// custom backprop filter kernel because of memcpy and memset in Im2col.
+#define REGISTER_CPU_KERNEL(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Conv3DBackpropFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("Conv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T"), \
+ Conv3DBackpropFilterOp<CPUDevice, T>);
+
+TF_CALL_half(REGISTER_CPU_KERNEL);
+#undef REGISTER_CPU_KERNEL
+
// GPU definitions of both ops.
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
@@ -445,7 +1054,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
@@ -523,6 +1133,10 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& filter = context->input(1);
const TensorShape& filter_shape = filter.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape input_shape;
if (takes_shape_) {
const Tensor& input_sizes = context->input(0);
@@ -531,7 +1145,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
} else {
input_shape = context->input(0).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropInput");
+
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropInputOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
+
Tensor* in_backprop;
OP_REQUIRES_OK(context,
context->allocate_output(0, input_shape, &in_backprop));
@@ -539,13 +1160,15 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[0] == 1 && filter_size[1] == 1 && filter_size[2] == 1 &&
- dilation_[0] == 1 && dilation_[1] == 1 && dilation_[2] == 1 &&
- stride_[0] == 1 && stride_[1] == 1 && stride_[2] == 1 &&
+ if (dims.filter_size(0) == 1 && dims.filter_size(1) == 1 &&
+ dims.filter_size(2) == 1 && dims.dilation(0) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(2) == 1 && dims.stride(0) == 1 &&
+ dims.stride(1) == 1 && dims.stride(2) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = batch * input_size[0] * input_size[1] * input_size[2];
- const uint64 k = out_depth;
- const uint64 n = in_depth;
+ const uint64 m = dims.batch_size * dims.input_size(0) *
+ dims.input_size(1) * dims.input_size(2);
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -567,13 +1190,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = batch;
- const uint64 k = out_depth;
- const uint64 n = input_size[0] * input_size[1] * input_size[2] * in_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.batch_size;
+ const uint64 k = dims.out_depth;
+ const uint64 n = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
auto a_ptr = AsDeviceMemory(out_backprop.template flat<T>().data(),
out_backprop.template flat<T>().size());
@@ -597,65 +1221,59 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
- const bool planes_odd = (padding_planes % 2 != 0);
TensorShape compatible_input_shape;
if (rows_odd || cols_odd || planes_odd) {
// cuDNN only supports the same amount of padding on both sides.
compatible_input_shape = {
- batch,
- in_depth,
- input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd,
+ dims.batch_size,
+ dims.in_depth,
+ dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd,
};
} else {
- compatible_input_shape = {batch, in_depth, input_size[0], input_size[1],
- input_size[2]};
+ compatible_input_shape = {dims.batch_size, dims.in_depth,
+ dims.input_size(0), dims.input_size(1),
+ dims.input_size(2)};
}
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -664,20 +1282,23 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
Tensor transformed_filter;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &transformed_filter));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &transformed_filter));
functor::TransformFilter<GPUDevice, T, int, 5>()(
- context->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ context->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
// Shape: batch, filters, z, y, x.
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
- if (out_depth > 1) {
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
+ if (dims.out_depth > 1) {
OP_REQUIRES_OK(context, context->allocate_temp(
DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
@@ -713,14 +1334,14 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = context->input(0).dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
@@ -799,10 +1420,11 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
if (rows_odd || cols_odd || planes_odd) {
Tensor in_backprop_remove_padding;
OP_REQUIRES_OK(context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- {batch, in_depth, input_size[0],
- input_size[1], input_size[2]},
- &in_backprop_remove_padding));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ {dims.batch_size, dims.in_depth, dims.input_size(0),
+ dims.input_size(1), dims.input_size(2)},
+ &in_backprop_remove_padding));
// Remove the padding for odd spatial dimensions.
functor::PadInput<GPUDevice, T, int, 5>()(
@@ -896,6 +1518,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
void Compute(OpKernelContext* context) override {
const Tensor& input = context->input(0);
const TensorShape& input_shape = input.shape();
+
+ const Tensor& out_backprop = context->input(2);
+ const TensorShape& out_backprop_shape = out_backprop.shape();
+
TensorShape filter_shape;
if (takes_shape_) {
const Tensor& filter_sizes = context->input(1);
@@ -905,7 +1531,12 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
filter_shape = context->input(1).shape();
}
- EXTRACT_AND_VERIFY_DIMENSIONS("Conv3DBackpropFilter");
+ ConvBackpropDimensions dims;
+ OP_REQUIRES_OK(context,
+ ConvBackpropComputeDimensionsV2(
+ "Conv3DBackpropFilterOp", /*num_spatial_dims=*/3,
+ input_shape, filter_shape, out_backprop_shape, dilation_,
+ stride_, padding_, data_format_, &dims));
Tensor* filter_backprop;
OP_REQUIRES_OK(context,
@@ -914,13 +1545,15 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- if (filter_size[1] == 1 && filter_size[2] == 1 && filter_size[0] == 1 &&
- dilations[2] == 1 && dilations[1] == 1 && dilations[0] == 1 &&
- strides[2] == 1 && strides[1] == 1 && strides[0] == 1 &&
+ if (dims.filter_size(1) == 1 && dims.filter_size(2) == 1 &&
+ dims.filter_size(0) == 1 && dims.dilation(2) == 1 &&
+ dims.dilation(1) == 1 && dims.dilation(0) == 1 && dims.stride(2) == 1 &&
+ dims.stride(1) == 1 && dims.stride(0) == 1 &&
data_format_ == FORMAT_NHWC) {
- const uint64 m = in_depth;
- const uint64 k = batch * input_size[1] * input_size[2] * input_size[0];
- const uint64 n = out_depth;
+ const uint64 m = dims.in_depth;
+ const uint64 k = dims.batch_size * dims.input_size(1) *
+ dims.input_size(2) * dims.input_size(0);
+ const uint64 n = dims.out_depth;
// The shape of output backprop is
// [batch, out_z, out_y, out_x, out_depth]
@@ -951,13 +1584,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
", n=", n, ", k=", k));
}
return;
- } else if (filter_size[0] == input_size[0] &&
- filter_size[1] == input_size[1] &&
- filter_size[2] == input_size[2] && padding_ == Padding::VALID &&
- data_format_ == FORMAT_NHWC) {
- const uint64 m = input_size[0] * input_size[1] * input_size[2] * in_depth;
- const uint64 k = batch;
- const uint64 n = out_depth;
+ } else if (dims.filter_size(0) == dims.input_size(0) &&
+ dims.filter_size(1) == dims.input_size(1) &&
+ dims.filter_size(2) == dims.input_size(2) &&
+ padding_ == Padding::VALID && data_format_ == FORMAT_NHWC) {
+ const uint64 m = dims.input_size(0) * dims.input_size(1) *
+ dims.input_size(2) * dims.in_depth;
+ const uint64 k = dims.batch_size;
+ const uint64 n = dims.out_depth;
auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
input.template flat<T>().size());
@@ -979,30 +1613,24 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
return;
}
- int padding_rows = 0, padding_cols = 0, padding_planes = 0;
-
- if (padding_ == Padding::SAME) {
- padding_planes = std::max<int>(
- 0, (output_planes - 1) * strides[0] + filter_size[0] - input_size[0]);
- padding_cols = std::max<int>(
- 0, (output_cols - 1) * strides[2] + filter_size[2] - input_size[2]);
- padding_rows = std::max<int>(
- 0, (output_rows - 1) * strides[1] + filter_size[1] - input_size[1]);
- }
- bool rows_odd = (padding_rows % 2 != 0);
- bool cols_odd = (padding_cols % 2 != 0);
- bool planes_odd = (padding_planes % 2 != 0);
+ int padding_planes = dims.SpatialPadding(padding_, 0);
+ int padding_rows = dims.SpatialPadding(padding_, 1);
+ int padding_cols = dims.SpatialPadding(padding_, 2);
+ const bool planes_odd = (padding_planes % 2 != 0);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
Tensor compatible_input;
if (rows_odd || cols_odd || planes_odd) {
- OP_REQUIRES_OK(context, context->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format_, batch,
- {{input_size[0] + planes_odd,
- input_size[1] + rows_odd,
- input_size[2] + cols_odd}},
- in_depth),
- &compatible_input));
+ OP_REQUIRES_OK(context,
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format_, dims.batch_size,
+ {{dims.input_size(0) + planes_odd,
+ dims.input_size(1) + rows_odd,
+ dims.input_size(2) + cols_odd}},
+ dims.in_depth),
+ &compatible_input));
functor::PadInput<GPUDevice, T, int, 5>()(
context->template eigen_device<GPUDevice>(),
To32Bit(input.tensor<T, 5>()), {{0, 0, 0}},
@@ -1016,35 +1644,35 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
se::dnn::BatchDescriptor input_desc(3);
- input_desc.set_count(batch)
+ input_desc.set_count(dims.batch_size)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
.set_spatial_dim(DimIndex::Y,
GetTensorDim(compatible_input, data_format_, '1'))
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
- .set_feature_map_count(in_depth)
+ .set_feature_map_count(dims.in_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::BatchDescriptor output_desc(3);
- output_desc.set_count(batch)
- .set_spatial_dim(DimIndex::X, output_cols)
- .set_spatial_dim(DimIndex::Y, output_rows)
- .set_spatial_dim(DimIndex::Z, output_planes)
- .set_feature_map_count(out_depth)
+ output_desc.set_count(dims.batch_size)
+ .set_spatial_dim(DimIndex::X, dims.output_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.output_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.output_size(0))
+ .set_feature_map_count(dims.out_depth)
.set_layout(se::dnn::DataLayout::kBatchDepthYX);
se::dnn::FilterDescriptor filter_desc(3);
- filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
- .set_spatial_dim(DimIndex::Y, filter_size[1])
- .set_spatial_dim(DimIndex::Z, filter_size[0])
- .set_input_feature_map_count(in_depth)
- .set_output_feature_map_count(out_depth);
+ filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2))
+ .set_spatial_dim(DimIndex::Y, dims.filter_size(1))
+ .set_spatial_dim(DimIndex::Z, dims.filter_size(0))
+ .set_input_feature_map_count(dims.in_depth)
+ .set_output_feature_map_count(dims.out_depth);
se::dnn::ConvolutionDescriptor conv_desc(3);
- conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
- .set_dilation_rate(DimIndex::Y, dilations[1])
- .set_dilation_rate(DimIndex::Z, dilations[0])
- .set_filter_stride(DimIndex::X, strides[2])
- .set_filter_stride(DimIndex::Y, strides[1])
- .set_filter_stride(DimIndex::Z, strides[0])
+ conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2))
+ .set_dilation_rate(DimIndex::Y, dims.dilation(1))
+ .set_dilation_rate(DimIndex::Z, dims.dilation(0))
+ .set_filter_stride(DimIndex::X, dims.stride(2))
+ .set_filter_stride(DimIndex::Y, dims.stride(1))
+ .set_filter_stride(DimIndex::Z, dims.stride(0))
.set_zero_padding(DimIndex::X, padding_cols / 2)
.set_zero_padding(DimIndex::Y, padding_rows / 2)
.set_zero_padding(DimIndex::Z, padding_planes / 2);
@@ -1052,19 +1680,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
Tensor pre_transformed_filter_backprop;
OP_REQUIRES_OK(
context,
- context->allocate_temp(DataTypeToEnum<T>::value,
- TensorShape({out_depth, in_depth, filter_size[0],
- filter_size[1], filter_size[2]}),
- &pre_transformed_filter_backprop));
+ context->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({dims.out_depth, dims.in_depth, dims.filter_size(0),
+ dims.filter_size(1), dims.filter_size(2)}),
+ &pre_transformed_filter_backprop));
Tensor transformed_out_backprop;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, out_depth, output_planes, output_rows,
- output_cols};
+ TensorShape nchw_shape = {dims.batch_size, dims.out_depth,
+ dims.output_size(0), dims.output_size(1),
+ dims.output_size(2)};
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<T>::value, nchw_shape,
&transformed_out_backprop));
- if (out_depth > 1) {
+ if (dims.out_depth > 1) {
functor::NHWCToNCHW<GPUDevice, T, 5>()(
context->eigen_device<GPUDevice>(), out_backprop.tensor<T, 5>(),
transformed_out_backprop.tensor<T, 5>());
@@ -1076,10 +1706,10 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
}
Tensor transformed_input;
if (data_format_ == FORMAT_NHWC) {
- TensorShape nchw_shape = {batch, in_depth, compatible_input.dim_size(1),
- compatible_input.dim_size(2),
- compatible_input.dim_size(3)};
- if (in_depth > 1) {
+ TensorShape nchw_shape = {
+ dims.batch_size, dims.in_depth, compatible_input.dim_size(1),
+ compatible_input.dim_size(2), compatible_input.dim_size(3)};
+ if (dims.in_depth > 1) {
OP_REQUIRES_OK(context,
context->allocate_temp(DataTypeToEnum<T>::value,
nchw_shape, &transformed_input));
@@ -1110,14 +1740,14 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
const int device_id = stream->parent()->device_ordinal();
DataType dtype = input.dtype();
const ConvParameters conv_parameters = {
- batch,
- in_depth,
- {{input_size[0], input_size[1], input_size[2]}},
+ dims.batch_size,
+ dims.in_depth,
+ {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}},
FORMAT_NCHW,
- out_depth,
- {{filter_size[0], filter_size[1], filter_size[2]}},
- {{dilations[0], dilations[1], dilations[2]}},
- {{strides[0], strides[1], strides[2]}},
+ dims.out_depth,
+ {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}},
+ {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}},
+ {{dims.stride(0), dims.stride(1), dims.stride(2)}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
device_id,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index ef692418d6..78856c4a99 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -264,150 +264,198 @@ class LaunchXsmmConvOp<CPUDevice, float> {
};
#endif
+#define TF_REQUIRES(EXP, STATUS) \
+ do { \
+ if (!TF_PREDICT_TRUE(EXP)) return (STATUS); \
+ } while (false)
+
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params) {
+ TF_RETURN_IF_ERROR(context->GetAttr("dilations", &params->dilations));
+ TF_RETURN_IF_ERROR(context->GetAttr("strides", &params->strides));
+ TF_RETURN_IF_ERROR(context->GetAttr("padding", &params->padding));
+ string data_format_string;
+ TF_RETURN_IF_ERROR(context->GetAttr("data_format", &data_format_string));
+ TF_REQUIRES(FormatFromString(data_format_string, &params->data_format),
+ errors::InvalidArgument("Invalid data format"));
+
+ const auto& strides = params->strides;
+ const auto& dilations = params->dilations;
+ const auto& data_format = params->data_format;
+
+ TF_REQUIRES(dilations.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ TF_REQUIRES(strides.size() == 4,
+ errors::InvalidArgument("Sliding window strides field must "
+ "specify 4 dimensions"));
+ const int64 stride_n = GetTensorDim(strides, data_format, 'N');
+ const int64 stride_c = GetTensorDim(strides, data_format, 'C');
+ const int64 stride_h = GetTensorDim(strides, data_format, 'H');
+ const int64 stride_w = GetTensorDim(strides, data_format, 'W');
+ TF_REQUIRES(
+ stride_n == 1 && stride_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "strides in the batch and depth dimensions."));
+ TF_REQUIRES(stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations, data_format, 'N');
+ const int64 dilation_c = GetTensorDim(dilations, data_format, 'C');
+ const int64 dilation_h = GetTensorDim(dilations, data_format, 'H');
+ const int64 dilation_w = GetTensorDim(dilations, data_format, 'W');
+ TF_REQUIRES(
+ dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ TF_REQUIRES(
+ dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ return Status::OK();
+}
+
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions) {
+ // Check that 2D convolution input and filter have exactly 4 dimensions.
+ TF_REQUIRES(input.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input.shape().DebugString()));
+ TF_REQUIRES(filter.dims() == 4,
+ errors::InvalidArgument("filter must be 4-dimensional: ",
+ filter.shape().DebugString()));
+ for (int i = 0; i < 3; i++) {
+ TF_REQUIRES(
+ FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
+ errors::InvalidArgument("filter too large"));
+ }
+
+ // The last dimension for input is in_depth. Check that it is the same as the
+ // filter's in_depth or it is evenly divisible by filter's in_depth.
+ const int64 in_depth_raw = GetTensorDim(input, params.data_format, 'C');
+ const int64 patch_depth_raw = filter.dim_size(2);
+ TF_REQUIRES(FastBoundsCheck(in_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input depth too large"));
+ TF_REQUIRES(FastBoundsCheck(patch_depth_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Patch depth too large"));
+ const int in_depth = static_cast<int>(in_depth_raw);
+ const int patch_depth = static_cast<int>(patch_depth_raw);
+ TF_REQUIRES(in_depth % patch_depth == 0,
+ errors::InvalidArgument(
+ "input depth must be evenly divisible by filter depth: ",
+ in_depth, " vs ", patch_depth));
+
+ // The last dimension for filter is out_depth.
+ const int out_depth = static_cast<int>(filter.dim_size(3));
+
+ // The second dimension for input is rows/height.
+ // The first dimension for filter is rows/height.
+ const int64 input_rows_raw = GetTensorDim(input, params.data_format, 'H');
+ TF_REQUIRES(FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input rows too large"));
+ const int input_rows = static_cast<int>(input_rows_raw);
+ const int filter_rows = static_cast<int>(filter.dim_size(0));
+
+ // The third dimension for input is columns/width.
+ // The second dimension for filter is columns/width.
+ const int64 input_cols_raw = GetTensorDim(input, params.data_format, 'W');
+ TF_REQUIRES(FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("Input cols too large"));
+ const int input_cols = static_cast<int>(input_cols_raw);
+ const int filter_cols = static_cast<int>(filter.dim_size(1));
+
+ // The first dimension for input is batch.
+ const int64 batch_raw = GetTensorDim(input, params.data_format, 'N');
+ TF_REQUIRES(FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("batch is too large"));
+ const int batch = static_cast<int>(batch_raw);
+
+ // Take the stride and dilation from the second and third dimensions only (we
+ // do not support striding or dilation on the batch or depth dimension).
+ const int stride_rows = GetTensorDim(params.strides, params.data_format, 'H');
+ const int stride_cols = GetTensorDim(params.strides, params.data_format, 'W');
+ const int dilation_rows =
+ GetTensorDim(params.dilations, params.data_format, 'H');
+ const int dilation_cols =
+ GetTensorDim(params.dilations, params.data_format, 'W');
+
+ // Compute windowed output sizes for rows and columns.
+ int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows, stride_rows, params.padding,
+ &out_rows, &pad_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols, stride_cols, params.padding,
+ &out_cols, &pad_cols));
+
+ dimensions->batch = batch;
+ dimensions->input_rows = input_rows;
+ dimensions->input_cols = input_cols;
+ dimensions->in_depth = in_depth;
+ dimensions->filter_rows = filter_rows;
+ dimensions->filter_cols = filter_cols;
+ dimensions->patch_depth = patch_depth;
+ dimensions->out_depth = out_depth;
+ dimensions->stride_rows = stride_rows;
+ dimensions->stride_cols = stride_cols;
+ dimensions->dilation_rows = dilation_rows;
+ dimensions->dilation_cols = dilation_cols;
+ dimensions->out_rows = out_rows;
+ dimensions->out_cols = out_cols;
+ dimensions->pad_rows = pad_rows;
+ dimensions->pad_cols = pad_cols;
+
+ return Status::OK();
+}
+
+#undef TF_REQUIRES
+
template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
- OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
- string data_format;
- OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
- OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
- errors::InvalidArgument("Invalid data format"));
+ OP_REQUIRES_OK(context, InitConv2DParameters(context, &params_));
+
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- OP_REQUIRES(context, strides_.size() == 4,
- errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
- const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
- const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
- const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
- const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
- OP_REQUIRES(
- context, stride_n == 1 && stride_c == 1,
- errors::InvalidArgument("Current implementation does not yet support "
- "strides in the batch and depth dimensions."));
- OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
- errors::InvalidArgument(
- "Row and column strides should be larger than 0."));
-
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
- OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
-
const Tensor& input = context->input(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, out_depth]
const Tensor& filter = context->input(1);
- // For 2D convolution, there should be 4 dimensions.
- OP_REQUIRES(context, input.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input.shape().DebugString()));
- OP_REQUIRES(context, filter.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
- filter.shape().DebugString()));
-
- for (int i = 0; i < 3; i++) {
- OP_REQUIRES(
- context,
- FastBoundsCheck(filter.dim_size(i), std::numeric_limits<int>::max()),
- errors::InvalidArgument("filter too large"));
- }
+ Conv2DDimensions dimensions;
+ OP_REQUIRES_OK(context,
+ ComputeConv2DDimension(params_, input, filter, &dimensions));
- // The last dimension for input is in_depth. It must be the same as the
- // filter's in_depth or be evenly divisible by filter's in_depth.
- const int64 in_depth = GetTensorDim(input, data_format_, 'C');
- const int64 patch_depth = filter.dim_size(2);
- OP_REQUIRES(context, in_depth % patch_depth == 0,
- errors::InvalidArgument(
- "input depth must be evenly divisible by filter depth: ",
- in_depth, " vs ", patch_depth));
-
- // The last dimension for filter is out_depth.
- const int out_depth = static_cast<int>(filter.dim_size(3));
-
- // The second dimension for input is rows/height.
- // The first dimension for filter is rows/height.
- const int64 input_rows_raw = GetTensorDim(input, data_format_, 'H');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input rows too large"));
- const int input_rows = static_cast<int>(input_rows_raw);
- const int filter_rows = static_cast<int>(filter.dim_size(0));
-
- // The third dimension for input is columns/width.
- // The second dimension for filter is columns/width.
- const int64 input_cols_raw = GetTensorDim(input, data_format_, 'W');
- OP_REQUIRES(
- context,
- FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("Input cols too large"));
- const int input_cols = static_cast<int>(input_cols_raw);
- const int filter_cols = static_cast<int>(filter.dim_size(1));
-
- // The first dimension for input is batch.
- const int64 batch_raw = GetTensorDim(input, data_format_, 'N');
- OP_REQUIRES(context,
- FastBoundsCheck(batch_raw, std::numeric_limits<int>::max()),
- errors::InvalidArgument("batch is too large"));
- const int batch = static_cast<int>(batch_raw);
-
- // For now we take the stride and dilation from the second and third
- // dimensions only (we do not support striding or dilation on the batch or
- // depth dimension).
- const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
-
- const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
- const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
-
- int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_rows, filter_rows, dilation_rows,
- stride_rows, padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
- input_cols, filter_cols, dilation_cols,
- stride_cols, padding_, &out_cols, &pad_cols));
- TensorShape out_shape =
- ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
+ TensorShape out_shape = ShapeFromFormat(
+ params_.data_format, dimensions.batch, dimensions.out_rows,
+ dimensions.out_cols, dimensions.out_depth);
// Output tensor is of the following dimensions:
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- VLOG(2) << "Conv2D: in_depth = " << in_depth
- << ", patch_depth = " << patch_depth
- << ", input_cols = " << input_cols
- << ", filter_cols = " << filter_cols
- << ", input_rows = " << input_rows
- << ", filter_rows = " << filter_rows
- << ", stride_rows = " << stride_rows
- << ", stride_cols = " << stride_cols
- << ", dilation_rows = " << dilation_rows
- << ", dilation_cols = " << dilation_cols
- << ", out_depth = " << out_depth;
+ VLOG(2) << "Conv2D: in_depth = " << dimensions.in_depth
+ << ", patch_depth = " << dimensions.patch_depth
+ << ", input_cols = " << dimensions.input_cols
+ << ", filter_cols = " << dimensions.filter_cols
+ << ", input_rows = " << dimensions.input_rows
+ << ", filter_rows = " << dimensions.filter_rows
+ << ", stride_rows = " << dimensions.stride_rows
+ << ", stride_cols = " << dimensions.stride_cols
+ << ", dilation_rows = " << dimensions.dilation_rows
+ << ", dilation_cols = " << dimensions.dilation_cols
+ << ", out_depth = " << dimensions.out_depth;
// If there is nothing to compute, return.
if (out_shape.num_elements() == 0) {
@@ -416,36 +464,41 @@ class Conv2DOp : public BinaryOp<T> {
#ifdef TENSORFLOW_USE_LIBXSMM_CONVOLUTIONS
if (LaunchXsmmConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
#endif
if (LaunchDeepConvOp<Device, T>::Run(
- context, input, filter, batch, input_rows, input_cols, in_depth,
- filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
- output, data_format_)) {
+ context, input, filter, dimensions.batch, dimensions.input_rows,
+ dimensions.input_cols, dimensions.in_depth, dimensions.filter_rows,
+ dimensions.filter_cols, dimensions.pad_rows, dimensions.pad_cols,
+ dimensions.out_rows, dimensions.out_cols, dimensions.out_depth,
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, output,
+ params_.data_format)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
- output, data_format_);
+ dimensions.dilation_rows, dimensions.dilation_cols,
+ dimensions.stride_rows, dimensions.stride_cols, params_.padding,
+ output, params_.data_format);
}
private:
- std::vector<int32> dilations_;
- std::vector<int32> strides_;
+ Conv2DParameters params_;
bool use_cudnn_;
- Padding padding_;
- TensorFormat data_format_;
- LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
+ LaunchConv2DOp<Device, T> launcher_;
+
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
};
@@ -680,9 +733,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
TensorShape({filter.dim_size(3), filter.dim_size(2),
filter.dim_size(0), filter.dim_size(1)}),
&transformed_filter));
-
functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 4>()),
To32Bit(transformed_filter.tensor<T, 4>()));
Tensor transformed_output;
@@ -731,9 +784,15 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
if (cudnn_use_autotune &&
!AutoTuneConv::GetInstance()->Find(conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(stream->parent()),
- &algorithms));
+ OP_REQUIRES(
+ ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown("Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -823,7 +882,8 @@ namespace functor {
extern template struct MatMulConvFunctor<GPUDevice, T>; \
template <> \
void TransformFilter<GPUDevice, T, int, 4>::operator()( \
- const GPUDevice& d, typename TTypes<T, 4, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 4, int>::ConstTensor in, \
typename TTypes<T, 4, int>::Tensor out); \
extern template struct TransformFilter<GPUDevice, T, int, 4>; \
template <> \
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index 09a3b78776..7ec878e0b2 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CONV_OPS_H_
-#define TENSORFLOW_KERNELS_CONV_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -66,6 +66,50 @@ struct Im2ColBufferResource : public ResourceBase {
string DebugString() { return "Im2ColBufferResource"; }
};
+// Convolution parameters specified by Op attributes.
+struct Conv2DParameters {
+ std::vector<int32> dilations;
+ std::vector<int32> strides;
+ Padding padding;
+ TensorFormat data_format;
+};
+
+// Convolution dimensions inferred from parameters, input and filter tensors.
+struct Conv2DDimensions {
+ int batch;
+ int input_rows;
+ int input_cols;
+ int in_depth;
+
+ int filter_rows;
+ int filter_cols;
+ int patch_depth;
+ int out_depth;
+
+ int stride_rows;
+ int stride_cols;
+
+ int dilation_rows;
+ int dilation_cols;
+
+ int64 out_rows;
+ int64 out_cols;
+ int64 pad_rows;
+ int64 pad_cols;
+};
+
+// Initializes and validates Conv2D parameters configured by OpKernel
+// attributes.
+Status InitConv2DParameters(const OpKernelConstruction* context,
+ Conv2DParameters* params);
+
+// Computes and validates convolutions dimensions from Conv2D parameters. If
+// parameters are valid, dimensions will be updated with derived convolution
+// dimensions, otherwise error will be returned.
+Status ComputeConv2DDimension(const Conv2DParameters& params,
+ const Tensor& input, const Tensor& filter,
+ Conv2DDimensions* dimensions);
+
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CONV_OPS_H
+#endif // TENSORFLOW_CORE_KERNELS_CONV_OPS_H_
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index a1eed4e68c..83df4dce38 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -386,7 +386,8 @@ struct LaunchConvOp<GPUDevice, T> {
// filter: [x, y, z, in, out]
// t_filter: [out, in, x, y, z]
functor::TransformFilter<GPUDevice, T, int, 5>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 5>()),
+ ctx->eigen_device<GPUDevice>(), FORMAT_OIHW,
+ To32Bit(filter.tensor<T, 5>()),
To32Bit(transformed_filter.tensor<T, 5>()));
Tensor transformed_output;
@@ -434,10 +435,16 @@ struct LaunchConvOp<GPUDevice, T> {
if (cudnn_use_autotune && !AutoTuneConv3d::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
std::vector<AlgorithmDesc> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(
- conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
- stream->parent()),
- &algorithms));
+ OP_REQUIRES(ctx,
+ stream->parent()->GetConvolveAlgorithms(
+ conv_parameters.ShouldIncludeWinogradNonfusedAlgo<T>(
+ stream->parent()),
+ &algorithms),
+ errors::Unknown(
+ "Failed to get convolution algorithm. This is probably "
+ "because cuDNN failed to initialize, so try looking to "
+ "see if a warning log message was printed above."));
+
ProfileResult best_result;
ProfileResult best_result_no_scratch;
for (auto profile_algorithm : algorithms) {
@@ -514,7 +521,8 @@ namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
void TransformFilter<GPUDevice, T, int, 5>::operator()( \
- const GPUDevice& d, typename TTypes<T, 5, int>::ConstTensor in, \
+ const GPUDevice& d, FilterTensorFormat dst_filter_format, \
+ typename TTypes<T, 5, int>::ConstTensor in, \
typename TTypes<T, 5, int>::Tensor out); \
template <> \
void ReverseTransformFilter<GPUDevice, T, 5>::operator()( \
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index afc611f277..21d135decd 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -142,8 +142,12 @@ class ConvParameters {
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo(
se::StreamExecutor* stream_exec) const {
+ auto* dnn_support = stream_exec->AsDnn();
+ if (!dnn_support) {
+ return false;
+ }
// Skip this check for cuDNN 7 and newer.
- auto version = stream_exec->AsDnn()->GetVersion();
+ auto version = dnn_support->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
return true;
}
diff --git a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
index a5fa48f85e..46167db3a2 100644
--- a/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
+++ b/tensorflow/core/kernels/conv_ops_gpu_3.cu.cc
@@ -170,51 +170,33 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Index<IndexCount> FlatToTensorIndex(
return tensor_index;
}
-// A Cuda custom kernel that swaps dimension-0 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension0And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
- Dimension<3> output_dims;
- output_dims[0] = input_dims[2];
- output_dims[1] = input_dims[1];
- output_dims[2] = input_dims[0];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
-
- Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
-
- Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[2];
- input_tensor_index[1] = output_tensor_index[1];
- input_tensor_index[2] = output_tensor_index[0];
-
- int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
-
- output[output_index] =
- maybe_conj<T, conjugate>::run(ldg(input + input_index));
- }
-}
-
-// A Cuda custom kernel that swaps dimension-1 and dimension-2 of a 3D tensor.
-template <typename T, bool conjugate = false>
-__global__ void SwapDimension1And2InTensor3Simple(int nthreads, const T* input,
- Dimension<3> input_dims,
- T* output) {
+// A simple CUDA custom kernel to shuffle dimensions of a 3D tensor according to
+// the given shuffle permutation in template parameters. Shuffle permutation
+// <sp0, sp1, sp2> shuffles dimensions such that input dimension 0 goes to sp0,
+// 1 goes to sp1 and 2 goes to sp2. For example, shuffle permutation <2, 0, 1>
+// will populate output so that input[x][y][z] is equal to (*output)[y][z][x].
+//
+// Requires that nthreads is equal to the total number of elements in the input
+// tensor.
+template <typename T, int sp0, int sp1, int sp2, bool conjugate = false>
+__global__ void ShuffleInTensor3Simple(int nthreads, const T* input,
+ Dimension<3> input_dims, T* output) {
Dimension<3> output_dims;
- output_dims[0] = input_dims[0];
- output_dims[1] = input_dims[2];
- output_dims[2] = input_dims[1];
-
- CUDA_1D_KERNEL_LOOP(index, nthreads) {
- int output_index = index;
+ output_dims[sp0] = input_dims[0];
+ output_dims[sp1] = input_dims[1];
+ output_dims[sp2] = input_dims[2];
+
+ // Iterate over output as opposed to iterating over input for better
+ // performance. Iterating over output will generate sequential writes and
+ // random reads that performs better compared to sequential reads and random
+ // writes.
+ CUDA_1D_KERNEL_LOOP(output_index, nthreads) {
Index<3> output_tensor_index = FlatToTensorIndex(output_index, output_dims);
Index<3> input_tensor_index;
- input_tensor_index[0] = output_tensor_index[0];
- input_tensor_index[1] = output_tensor_index[2];
- input_tensor_index[2] = output_tensor_index[1];
+ input_tensor_index[0] = output_tensor_index[sp0];
+ input_tensor_index[1] = output_tensor_index[sp1];
+ input_tensor_index[2] = output_tensor_index[sp2];
int input_index = TensorIndexToFlat(input_tensor_index, input_dims);
@@ -439,7 +421,7 @@ __global__ void PadInputCustomKernelNCHW(int nthreads, const T* input,
template <typename T, int NDIMS>
struct TransformFilter<GPUDevice, T, int, NDIMS> {
typedef GPUDevice Device;
- void operator()(const Device& d,
+ void operator()(const Device& d, FilterTensorFormat dst_filter_format,
typename TTypes<T, NDIMS, int>::ConstTensor in,
typename TTypes<T, NDIMS, int>::Tensor out) {
Dimension<3> combined_dims;
@@ -450,13 +432,18 @@ struct TransformFilter<GPUDevice, T, int, NDIMS> {
combined_dims[1] = in.dimension(NDIMS - 2); // input filters
combined_dims[2] = in.dimension(NDIMS - 1); // output filters
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+
+ CHECK(dst_filter_format == FORMAT_OIHW)
+ << "Unsupported output layout: " << ToString(dst_filter_format);
+
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
};
-// Converts Cudnn filter format back to TensorFlow filter format.
+// Converts Cudnn filter format OIHW back to TensorFlow filter format HWIO.
+// TODO(hinsu): Support reverse transformation from filter format OHWI as well.
template <typename T, int NDIMS>
struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
typedef GPUDevice Device;
@@ -470,7 +457,7 @@ struct ReverseTransformFilter<GPUDevice, T, NDIMS> {
combined_dims[2] *= in.dimension(i);
}
CudaLaunchConfig config = GetCudaLaunchConfig(out.size(), d);
- SwapDimension0And2InTensor3Simple<T>
+ ShuffleInTensor3Simple<T, 2, 1, 0>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in.data(), combined_dims, out.data());
}
@@ -937,7 +924,7 @@ void RunSwapDimension1And2InTensor3(const GPUDevice& d, const T* input,
} else {
int total_element_count = input_dims[0] * input_dims[1] * input_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_element_count, d);
- SwapDimension1And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 0, 2, 1, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, input, input_dims, output);
}
@@ -969,7 +956,7 @@ struct SwapDimension0And2InTensor3<GPUDevice, T, conjugate> {
static_cast<int>(combined_dims[2])};
size_t total_size = combined_dims[0] * combined_dims[1] * combined_dims[2];
CudaLaunchConfig config = GetCudaLaunchConfig(total_size, d);
- SwapDimension0And2InTensor3Simple<T, conjugate>
+ ShuffleInTensor3Simple<T, 2, 1, 0, conjugate>
<<<config.block_count, config.thread_per_block, 0, d.stream()>>>(
config.virtual_thread_count, in, input_dims, out);
}
diff --git a/tensorflow/core/kernels/cross_op.h b/tensorflow/core/kernels/cross_op.h
index ca6beba52b..45bc46a921 100644
--- a/tensorflow/core/kernels/cross_op.h
+++ b/tensorflow/core/kernels/cross_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_COLORSPACE_OP_H_
-#define TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -51,4 +51,4 @@ struct Cross {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_COLORSPACE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_CROSS_OP_H_
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index b2e8ee23a9..2c30d036df 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================
*/
+#ifndef TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
+#define TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
+
// This header declares the class CudaSolver, which contains wrappers of linear
// algebra solvers in the cuBlas and cuSolverDN libraries for use in TensorFlow
// kernels.
@@ -433,3 +436,5 @@ inline DeviceLapackInfo CudaSolver::GetDeviceLapackInfo(
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_CUDA_SOLVERS_H_
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h
index 280d697fc2..738e928246 100644
--- a/tensorflow/core/kernels/cudnn_pooling_gpu.h
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h
@@ -15,8 +15,8 @@ limitations under the License.
// Helper functions to run 3d pooling on GPU using CuDNN.
-#ifndef TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
-#define TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_
#include <array>
@@ -67,4 +67,4 @@ class DnnPooling3dGradOp {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CUDNN_POOLING_GPU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CUDNN_POOLING_GPU_H_
diff --git a/tensorflow/core/kernels/cwise_op_div.cc b/tensorflow/core/kernels/cwise_op_div.cc
index b12652f7fb..313d976e2c 100644
--- a/tensorflow/core/kernels/cwise_op_div.cc
+++ b/tensorflow/core/kernels/cwise_op_div.cc
@@ -24,6 +24,8 @@ REGISTER5(BinaryOp, CPU, "TruncateDiv", functor::safe_div, uint8, uint16, int16,
int32, int64);
REGISTER6(BinaryOp, CPU, "RealDiv", functor::div, float, Eigen::half, double,
bfloat16, complex64, complex128);
+REGISTER2(BinaryOp, CPU, "DivNoNan", functor::div_no_nan, float, double);
+
#if GOOGLE_CUDA
REGISTER9(BinaryOp, GPU, "Div", functor::div, float, Eigen::half, double, uint8,
uint16, int16, int64, complex64, complex128);
@@ -31,6 +33,7 @@ REGISTER4(BinaryOp, GPU, "TruncateDiv", functor::div, uint8, uint16, int16,
int64);
REGISTER5(BinaryOp, GPU, "RealDiv", functor::div, float, Eigen::half, double,
complex64, complex128);
+REGISTER2(BinaryOp, GPU, "DivNoNan", functor::div_no_nan, float, double);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
index 0b05416274..25ccdcfb00 100644
--- a/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_div.cu.cc
@@ -21,6 +21,7 @@ namespace tensorflow {
namespace functor {
DEFINE_BINARY10(div, Eigen::half, float, double, uint8, uint16, int16, int32,
int64, complex64, complex128);
+DEFINE_BINARY2(div_no_nan, float, double);
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/optional.cc b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
index 8dea073788..e4b21a66c6 100644
--- a/tensorflow/core/lib/gtl/optional.cc
+++ b/tensorflow/core/kernels/cwise_op_gpu_xdivy.cu.cc
@@ -1,4 +1,4 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -13,13 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/lib/gtl/optional.h"
+#if GOOGLE_CUDA
-namespace tensorflow {
-namespace gtl {
-
-nullopt_t::init_t nullopt_t::init;
-extern const nullopt_t nullopt{nullopt_t::init};
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
-} // namespace gtl
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xdivy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
new file mode 100644
index 0000000000..1e1b5a426e
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_gpu_xlogy.cu.cc
@@ -0,0 +1,26 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#include "tensorflow/core/kernels/cwise_ops_gpu_common.cu.h"
+
+namespace tensorflow {
+namespace functor {
+DEFINE_BINARY5(xlogy, Eigen::half, float, double, complex64, complex128);
+} // namespace functor
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/cwise_op_select.cc b/tensorflow/core/kernels/cwise_op_select.cc
index 98df0844ea..d6988a562c 100644
--- a/tensorflow/core/kernels/cwise_op_select.cc
+++ b/tensorflow/core/kernels/cwise_op_select.cc
@@ -33,6 +33,11 @@ typedef Eigen::GpuDevice GPUDevice;
typedef Eigen::SyclDevice SYCLDevice;
#endif // TENSORFLOW_USE_SYCL
+namespace functor {
+template <typename Device, typename T>
+struct SelectScalarHandler;
+} // namespace functor
+
template <typename Device, typename T>
class SelectOp : public OpKernel {
public:
@@ -131,16 +136,8 @@ class SelectOp : public OpKernel {
then->shape().DebugString(), " vs. ",
else_->shape().DebugString()));
- Tensor* output = nullptr;
- OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
- {"t", "e"}, "output", then->shape(), &output));
-
- if (output->NumElements() > 0) {
- functor::SelectScalarFunctor<Device, T> func;
- TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
- func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
- then->flat<T>(), else_->flat<T>());
- }
+ functor::SelectScalarHandler<Device, T> handler;
+ handler(ctx, cond, then, else_);
}
private:
@@ -209,6 +206,40 @@ struct SelectFunctor<SYCLDevice, T> : SelectFunctorBase<SYCLDevice, T> {};
#endif // TENSORFLOW_USE_SYCL
template <typename Device, typename T>
+struct SelectScalarHandler {
+ void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
+ const Tensor* else_) {
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->forward_input_or_allocate_output(
+ {"t", "e"}, "output", then->shape(), &output));
+
+ if (output->NumElements() > 0) {
+ functor::SelectScalarFunctor<Device, T> func;
+ TTypes<bool>::ConstScalar cond_scalar = cond->scalar<bool>();
+ func(ctx->eigen_device<Device>(), output->flat<T>(), cond_scalar,
+ then->flat<T>(), else_->flat<T>());
+ }
+ }
+};
+
+// Specilization for CPU device. Forward input to output depending on the `cond`
+// value.
+// TODO(sjhwang): Consider specializing for GPUDevice as well by using
+// GPUDevice::memcpyDeviceToHost() to fetch bool value.
+template <typename T>
+struct SelectScalarHandler<CPUDevice, T> {
+ void operator()(OpKernelContext* ctx, const Tensor* cond, const Tensor* then,
+ const Tensor* else_) {
+ if (cond->scalar<bool>()()) {
+ OP_REQUIRES_OK(ctx, ctx->set_output("output", *then));
+ } else {
+ OP_REQUIRES_OK(ctx, ctx->set_output("output", *else_));
+ }
+ }
+};
+
+#ifdef TENSORFLOW_USE_SYCL
+template <typename Device, typename T>
struct SelectScalarFunctorBase {
void operator()(const Device& d, typename TTypes<T>::Flat out,
TTypes<bool>::ConstScalar cond,
@@ -218,11 +249,6 @@ struct SelectScalarFunctorBase {
}
};
-// CPU Specializations of Select functors with scalar
-template <typename T>
-struct SelectScalarFunctor<CPUDevice, T>
- : SelectScalarFunctorBase<CPUDevice, T> {};
-#ifdef TENSORFLOW_USE_SYCL
template <typename T>
struct SelectScalarFunctor<SYCLDevice, T>
: SelectScalarFunctorBase<SYCLDevice, T> {};
diff --git a/tensorflow/core/kernels/cwise_op_xdivy.cc b/tensorflow/core/kernels/cwise_op_xdivy.cc
new file mode 100644
index 0000000000..6a6aec5e86
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xdivy.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xdivy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xdivy<TYPE>>);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xdivy", functor::xdivy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_xlogy.cc b/tensorflow/core/kernels/cwise_op_xlogy.cc
new file mode 100644
index 0000000000..e71a9109b2
--- /dev/null
+++ b/tensorflow/core/kernels/cwise_op_xlogy.cc
@@ -0,0 +1,41 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/cwise_ops_common.h"
+
+namespace tensorflow {
+REGISTER5(BinaryOp, CPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(TYPE) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("Xlogy").Device(DEVICE_SYCL).TypeConstraint<TYPE>("T"), \
+ BinaryOp<SYCLDevice, functor::xlogy<TYPE>>);
+REGISTER_SYCL_KERNEL(Eigen::half);
+REGISTER_SYCL_KERNEL(float);
+REGISTER_SYCL_KERNEL(double);
+REGISTER_SYCL_KERNEL(complex64);
+REGISTER_SYCL_KERNEL(complex128);
+#undef REGISTER_SYCL_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#if GOOGLE_CUDA
+REGISTER5(BinaryOp, GPU, "Xlogy", functor::xlogy, float, Eigen::half, double,
+ complex64, complex128);
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_op_zeta.cc b/tensorflow/core/kernels/cwise_op_zeta.cc
index 2c5538534c..dc064eec5f 100644
--- a/tensorflow/core/kernels/cwise_op_zeta.cc
+++ b/tensorflow/core/kernels/cwise_op_zeta.cc
@@ -18,4 +18,9 @@ limitations under the License.
namespace tensorflow {
REGISTER2(BinaryOp, CPU, "Zeta", functor::zeta, float, double);
REGISTER2(BinaryOp, CPU, "Polygamma", functor::polygamma, float, double);
+
+#if GOOGLE_CUDA
+REGISTER2(BinaryOp, GPU, "Zeta", functor::zeta, float, double);
+REGISTER2(BinaryOp, GPU, "Polygamma", functor::polygamma, float, double);
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops.h b/tensorflow/core/kernels/cwise_ops.h
index 1b1a704d42..66ba827a90 100644
--- a/tensorflow/core/kernels/cwise_ops.h
+++ b/tensorflow/core/kernels/cwise_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
#include <cmath>
#include <functional>
@@ -153,6 +153,27 @@ struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
};
};
+template <typename T>
+struct div_no_nan_op {
+ EIGEN_EMPTY_STRUCT_CTOR(div_no_nan_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const T operator()(const T& a,
+ const T& b) const {
+ if (b != 0) {
+ return scalar_quotient_op<T>()(a, b);
+ } else {
+ return 0;
+ }
+ }
+};
+
+template <typename T>
+struct functor_traits<div_no_nan_op<T>> {
+ enum {
+ Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
+ PacketAccess = false,
+ };
+};
+
// scalar_left and scalar_right are template helpers to partially
// apply a binary function.
//
@@ -450,6 +471,45 @@ struct functor_traits<bitwise_xor_op<Scalar>> {
enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
};
+// TODO(srvasude): Add packet versions of this operation.
+template <typename Scalar>
+struct xlogy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xlogy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x * numext::log(y);
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xlogy_op<Scalar>> {
+ enum {
+ Cost = (sizeof(Scalar) == 4 ? 40 : 85) + Eigen::NumTraits<Scalar>::MulCost,
+ PacketAccess = false
+ };
+};
+
+template <typename Scalar>
+// TODO(srvasude): Add packet versions of this operation.
+struct xdivy_op {
+ EIGEN_EMPTY_STRUCT_CTOR(xdivy_op)
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar
+ operator()(const Scalar& x, const Scalar& y) const {
+ if (x == Scalar(0.)) {
+ return Scalar(0.);
+ }
+ return x / y;
+ }
+};
+
+template <typename Scalar>
+struct functor_traits<xdivy_op<Scalar>> {
+ enum { Cost = Eigen::NumTraits<Scalar>::MulCost, PacketAccess = false };
+};
+
} // end namespace internal
} // end namespace Eigen
@@ -721,6 +781,9 @@ struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
};
template <typename T>
+struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
+
+template <typename T>
struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
template <typename T>
@@ -806,6 +869,12 @@ struct squared_difference
Eigen::internal::scalar_difference_op<T>>> {};
template <typename T>
+struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
+
+template <typename T>
+struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
+
+template <typename T>
struct less : base<T, Eigen::internal::less<T>, bool> {};
template <typename T>
@@ -1012,4 +1081,4 @@ struct BatchSelectFunctor {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
diff --git a/tensorflow/core/kernels/cwise_ops_common.cc b/tensorflow/core/kernels/cwise_ops_common.cc
index 980edffceb..8ad3b4d1fc 100644
--- a/tensorflow/core/kernels/cwise_ops_common.cc
+++ b/tensorflow/core/kernels/cwise_ops_common.cc
@@ -20,9 +20,9 @@ namespace tensorflow {
BinaryOpShared::BinaryOpShared(OpKernelConstruction* ctx, DataType out,
DataType in)
: OpKernel(ctx) {
-#ifndef INTEL_MKL
+#if !defined(INTEL_MKL) || !defined(ENABLE_MKL)
OP_REQUIRES_OK(ctx, ctx->MatchSignature({in, in}, {out}));
-#endif
+#endif // !INTEL_MKL || !ENABLE_MKL
}
void BinaryOpShared::SetUnimplementedError(OpKernelContext* ctx) {
diff --git a/tensorflow/core/kernels/cwise_ops_common.h b/tensorflow/core/kernels/cwise_ops_common.h
index e32eccf547..f77d7238af 100644
--- a/tensorflow/core/kernels/cwise_ops_common.h
+++ b/tensorflow/core/kernels/cwise_ops_common.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
// See docs in ../ops/math_ops.cc.
@@ -602,4 +602,4 @@ struct ApproximateEqual<CPUDevice, T> {
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
index 965e42dcce..cfae273bf4 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_common.cu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
#define EIGEN_USE_GPU
@@ -188,4 +188,4 @@ struct ApproximateEqual<GPUDevice, T> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_COMMON_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
index e81b840a50..15e5de0f72 100644
--- a/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
+++ b/tensorflow/core/kernels/cwise_ops_gpu_gradients.cu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
#define EIGEN_USE_GPU
@@ -68,4 +68,4 @@ struct SimpleBinaryFunctor<GPUDevice, Functor> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GPU_GRADIENTS_CU_H_
diff --git a/tensorflow/core/kernels/cwise_ops_gradients.h b/tensorflow/core/kernels/cwise_ops_gradients.h
index 7a6f14babc..53b53cc277 100644
--- a/tensorflow/core/kernels/cwise_ops_gradients.h
+++ b/tensorflow/core/kernels/cwise_ops_gradients.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
-#define TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
+#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -208,4 +208,4 @@ struct igamma_grad_a : base<T, Eigen::internal::scalar_igamma_der_a_op<T>> {};
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_CWISE_OPS_GRADIENTS_H_
+#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_GRADIENTS_H_
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 607a694dba..451f8c1a6c 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -51,6 +51,7 @@ cc_library(
hdrs = ["captured_function.h"],
deps = [
":dataset",
+ ":single_threaded_executor",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -61,6 +62,42 @@ cc_library(
)
cc_library(
+ name = "single_threaded_executor",
+ srcs = ["single_threaded_executor.cc"],
+ hdrs = ["single_threaded_executor.h"],
+ deps = [
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:lib",
+ ],
+ alwayslink = 1,
+)
+
+tf_cc_test(
+ name = "single_threaded_executor_test",
+ srcs = ["single_threaded_executor_test.cc"],
+ deps = [
+ ":single_threaded_executor",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:framework_internal",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ "//tensorflow/core/kernels:array",
+ "//tensorflow/core/kernels:control_flow_ops",
+ "//tensorflow/core/kernels:function_ops",
+ "//tensorflow/core/kernels:math",
+ "//tensorflow/core/kernels:random_ops",
+ "//tensorflow/core/kernels:state",
+ ],
+)
+
+cc_library(
name = "window_dataset",
srcs = ["window_dataset.cc"],
hdrs = ["window_dataset.h"],
@@ -233,6 +270,16 @@ cc_library(
)
tf_kernel_library(
+ name = "parse_example_dataset_op",
+ srcs = ["parse_example_dataset_op.cc"],
+ deps = [
+ ":parallel_map_iterator",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ ],
+)
+
+tf_kernel_library(
name = "parallel_map_dataset_op",
srcs = ["parallel_map_dataset_op.cc"],
deps = [
@@ -411,6 +458,7 @@ tf_kernel_library(
srcs = ["stats_aggregator_dataset_op.cc"],
deps = [
":dataset",
+ "//tensorflow/core:core_cpu_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib_internal",
],
@@ -471,8 +519,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -495,8 +542,7 @@ tf_kernel_library(
":dataset",
"//tensorflow/core:dataset_ops_op_lib",
"//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:lib_internal",
+ "//tensorflow/core:graph",
],
)
@@ -583,6 +629,20 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "multi_device_iterator_ops",
+ srcs = ["multi_device_iterator_ops.cc"],
+ deps = [
+ ":dataset",
+ ":dataset_utils",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ "//tensorflow/core/kernels:ops_util",
+ ],
+)
+
+tf_kernel_library(
name = "optional_ops",
srcs = ["optional_ops.cc"],
hdrs = ["optional_ops.h"],
@@ -630,6 +690,19 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "model_dataset_op",
+ srcs = ["model_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "dataset_ops",
srcs = ["dataset_ops.cc"],
deps = [
@@ -663,11 +736,14 @@ tf_kernel_library(
":map_and_batch_dataset_op",
":map_dataset_op",
":map_defun_op",
+ ":model_dataset_op",
+ ":multi_device_iterator_ops",
":optimize_dataset_op",
":optional_ops",
":padded_batch_dataset_op",
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
+ ":parse_example_dataset_op",
":prefetch_dataset_op",
":random_dataset_op",
":range_dataset_op",
@@ -690,6 +766,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/core/kernels/data/batch_dataset_op.cc b/tensorflow/core/kernels/data/batch_dataset_op.cc
index 5295c9d2a6..d1db1d7bec 100644
--- a/tensorflow/core/kernels/data/batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/batch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -49,11 +49,11 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
drop_remainder_(drop_remainder),
input_(input) {
@@ -117,6 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
@@ -241,5 +242,5 @@ REGISTER_KERNEL_BUILDER(Name("BatchDatasetV2").Device(DEVICE_CPU),
BatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/cache_dataset_ops.cc b/tensorflow/core/kernels/data/cache_dataset_ops.cc
index 3762e403a9..34c6c86538 100644
--- a/tensorflow/core/kernels/data/cache_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/cache_dataset_ops.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_bundle/tensor_bundle.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level description of
@@ -46,11 +46,11 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class FileDataset : public GraphDatasetBase {
+ class FileDataset : public DatasetBase {
public:
explicit FileDataset(OpKernelContext* ctx, const DatasetBase* input,
string filename, Env* env)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
filename_(std::move(filename)),
env_(env),
@@ -69,7 +69,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(
- new FileIterator({this, strings::StrCat(prefix, "::FileIterator")}));
+ new FileIterator({this, strings::StrCat(prefix, "::FileCache")}));
}
const DataTypeVector& output_dtypes() const override {
@@ -539,10 +539,12 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
const string tensor_format_string_;
}; // FileDataset
- class MemoryDataset : public GraphDatasetBase {
+ class MemoryDataset : public DatasetBase {
public:
explicit MemoryDataset(OpKernelContext* ctx, const DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input), cache_(new MemoryCache()) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ cache_(new MemoryCache()) {
input->Ref();
}
@@ -551,7 +553,7 @@ class CacheDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
return std::unique_ptr<IteratorBase>(new MemoryIterator(
- {this, strings::StrCat(prefix, "::MemoryIterator")}, cache_));
+ {this, strings::StrCat(prefix, "::MemoryCache")}, cache_));
}
const DataTypeVector& output_dtypes() const override {
@@ -889,5 +891,5 @@ REGISTER_KERNEL_BUILDER(Name("CacheDataset").Device(DEVICE_CPU),
CacheDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc
index 82da385405..0bb929b3ce 100644
--- a/tensorflow/core/kernels/data/captured_function.cc
+++ b/tensorflow/core/kernels/data/captured_function.cc
@@ -17,33 +17,101 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/notification.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
+namespace data {
+
+namespace {
+
+// Simplistic implementation of the `StepStatsCollectorInterface` that only
+// cares about collecting the CPU time needed to execute a captured function.
+class SimpleStepStatsCollector : public StepStatsCollectorInterface {
+ public:
+ void IncrementProcessingTime(int64 delta) {
+ mutex_lock l(mu_);
+ processing_time_ += delta;
+ }
+
+ NodeExecStatsInterface* CreateNodeExecStats(const Node* node) override {
+ return new SimpleNodeExecStats(this);
+ }
+
+ string ReportAllocsOnResourceExhausted(const string& err) override {
+ return "";
+ }
+
+ int64 processing_time() {
+ tf_shared_lock l(mu_);
+ return processing_time_;
+ }
+
+ private:
+ class SimpleNodeExecStats : public NodeExecStatsInterface {
+ public:
+ explicit SimpleNodeExecStats(SimpleStepStatsCollector* step_stats_collector)
+ : step_stats_collector_(step_stats_collector) {}
+
+ void Done(const string& device) override {
+ step_stats_collector_->IncrementProcessingTime(end_time_ns_ -
+ start_time_ns_);
+ delete this;
+ }
+
+ void RecordExecutorStarted() override {
+ start_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void RecordComputeStarted() override {}
+
+ void RecordComputeEnded() override {}
+
+ void RecordExecutorEnded() override {
+ end_time_ns_ = Env::Default()->NowNanos();
+ }
+
+ void SetMemory(OpKernelContext* ctx) override {}
+
+ void SetOutput(int slot, const Tensor* tensor) override {}
+
+ void SetReferencedTensors(const TensorReferenceVector& tensors) override {}
+
+ void SetScheduled(int64 nanos) override {}
+
+ private:
+ int64 start_time_ns_ = 0;
+ int64 end_time_ns_ = 0;
+ SimpleStepStatsCollector* step_stats_collector_; // Not owned.
+ };
+
+ mutex mu_;
+ int64 processing_time_ GUARDED_BY(mu_) = 0;
+};
+
+} // namespace
/* static */
Status CapturedFunction::Create(
- const NameAttrList& func, std::vector<Tensor> captured_inputs,
+ const NameAttrList& func, OpKernelContext* ctx, const string& argument,
std::unique_ptr<CapturedFunction>* out_function) {
- out_function->reset(new CapturedFunction(func, std::move(captured_inputs)));
- return Status::OK();
+ return CapturedFunction::Create(func, ctx, argument, true, out_function);
}
-/* static */
Status CapturedFunction::Create(
const NameAttrList& func, OpKernelContext* ctx, const string& argument,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function) {
- OpInputList argument_inputs;
- TF_RETURN_IF_ERROR(ctx->input_list(argument, &argument_inputs));
- std::vector<Tensor> arguments_t;
- arguments_t.reserve(argument_inputs.size());
- for (const Tensor& t : argument_inputs) {
- arguments_t.push_back(t);
- }
- return CapturedFunction::Create(func, std::move(arguments_t), out_function);
+ OpInputList inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list(argument, &inputs));
+ std::vector<Tensor> arguments(inputs.begin(), inputs.end());
+ *out_function = WrapUnique(new CapturedFunction(func, std::move(arguments),
+ use_inter_op_parallelism));
+ return Status::OK();
}
CapturedFunction::~CapturedFunction() {
@@ -172,31 +240,17 @@ class BorrowedArgsCallFrame : public CallFrameBase {
} // namespace
-Status CapturedFunction::MaybeInstantiate(
- IteratorContext* ctx, FunctionLibraryRuntime::Handle* out_handle) {
- mutex_lock l(mu_);
+Status CapturedFunction::GetHandle(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle) {
+ tf_shared_lock l(mu_);
if (lib_ == nullptr) {
- // The context's runtime will be used for all subsequent calls.
- lib_ = ctx->lib();
- DCHECK(f_handle_ == kInvalidHandle);
- FunctionLibraryRuntime::InstantiateOptions inst_opts;
- inst_opts.overlay_lib = ctx->function_library().get();
- inst_opts.state_handle = std::to_string(random::New64());
- TF_RETURN_IF_ERROR(lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
- inst_opts, &f_handle_));
- const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
- if (fbody == nullptr) {
- return errors::Internal("Failed to instantiate function body.");
- }
- ret_types_ = fbody->ret_types;
- } else {
- // TODO(mrry): Consider moving this under a shared lock, as it is
- // the common case.
- if (ctx->lib() != lib_) {
- return errors::Internal(
- "Captured function was called with a different "
- "FunctionLibraryRuntime*, which is not permitted.");
- }
+ return errors::Internal("Captured function \"", func_.name(),
+ "\" was called before it was instantiated.");
+ }
+ if (ctx->lib() != lib_) {
+ return errors::Internal("Captured function \"", func_.name(),
+ "\" was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
}
*out_handle = f_handle_;
return Status::OK();
@@ -205,7 +259,7 @@ Status CapturedFunction::MaybeInstantiate(
Status CapturedFunction::Run(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
@@ -242,7 +296,7 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
const std::vector<Tensor>& args,
std::vector<Tensor>* rets) {
FunctionLibraryRuntime::Handle handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &handle));
+ TF_RETURN_IF_ERROR(GetHandle(ctx, &handle));
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
@@ -277,9 +331,33 @@ Status CapturedFunction::RunWithBorrowedArgs(IteratorContext* ctx,
}
Status CapturedFunction::Instantiate(IteratorContext* ctx) {
- FunctionLibraryRuntime::Handle unused_handle;
- TF_RETURN_IF_ERROR(MaybeInstantiate(ctx, &unused_handle));
mutex_lock l(mu_);
+ if (lib_ == nullptr) {
+ // The context's runtime will be used for all subsequent calls.
+ lib_ = ctx->lib();
+ DCHECK(f_handle_ == kInvalidHandle);
+ FunctionLibraryRuntime::InstantiateOptions inst_opts;
+ inst_opts.overlay_lib = ctx->function_library().get();
+ inst_opts.state_handle = std::to_string(random::New64());
+ inst_opts.create_kernels_eagerly = true;
+ if (!use_inter_op_parallelism_) {
+ inst_opts.executor_type = "SINGLE_THREADED_EXECUTOR";
+ }
+ Status s = (lib_->Instantiate(func_.name(), AttrSlice(&func_.attr()),
+ inst_opts, &f_handle_));
+ TF_RETURN_IF_ERROR(s);
+ const FunctionBody* fbody = lib_->GetFunctionBody(f_handle_);
+ if (fbody == nullptr) {
+ return errors::Internal("Failed to instantiate function body.");
+ }
+ ret_types_ = fbody->ret_types;
+ } else {
+ if (ctx->lib() != lib_) {
+ return errors::Internal(
+ "Captured function was called with a different "
+ "FunctionLibraryRuntime*, which is not permitted.");
+ }
+ }
if (captured_runner_ == nullptr) {
captured_runner_ = *ctx->runner();
}
@@ -338,23 +416,24 @@ Status CapturedFunction::RunInstantiated(const std::vector<Tensor>& args,
void CapturedFunction::RunAsync(IteratorContext* ctx,
std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done) {
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix) {
// NOTE(mrry): This method does not transfer ownership of `ctx`, and it may
// be deleted before `done` is called. Take care not to capture `ctx` in any
// code that may execute asynchronously in this function.
FunctionLibraryRuntime::Handle handle;
- Status s = MaybeInstantiate(ctx, &handle);
+ Status s = GetHandle(ctx, &handle);
if (!s.ok()) {
done(s);
return;
}
- auto frame =
+ OwnedArgsCallFrame* frame =
new OwnedArgsCallFrame(std::move(args), &captured_inputs_, ret_types_);
FunctionLibraryRuntime::Options f_opts;
f_opts.step_id = CapturedFunction::generate_step_id();
ResourceMgr* resource_mgr = ctx->lib()->device()->resource_manager();
- auto step_container = new ScopedStepContainer(
+ ScopedStepContainer* step_container = new ScopedStepContainer(
f_opts.step_id, [resource_mgr](const string& name) {
resource_mgr->Cleanup(name).IgnoreError();
});
@@ -369,32 +448,50 @@ void CapturedFunction::RunAsync(IteratorContext* ctx,
// (such as queue kernels) that depend on the non-nullness of
// `OpKernelContext::cancellation_manager()`, but additional effort
// will be required to plumb it through the `IteratorContext`.
- auto c_mgr = new CancellationManager;
+ CancellationManager* c_mgr = new CancellationManager;
f_opts.cancellation_manager = c_mgr;
-
- tf_shared_lock l(mu_);
- ctx->lib()->Run(f_opts, handle, frame,
- std::bind(
- [rets, step_container, c_mgr, frame](
- FunctionLibraryRuntime::DoneCallback done,
- // Begin unbound arguments.
- Status s) {
- delete step_container;
- delete c_mgr;
- if (s.ok()) {
- s = frame->ConsumeRetvals(rets);
- }
- delete frame;
- done(s);
- },
- std::move(done), std::placeholders::_1));
+ std::shared_ptr<SimpleStepStatsCollector> stats_collector;
+ if (ctx->model()) {
+ stats_collector = MakeUnique<SimpleStepStatsCollector>();
+ }
+ f_opts.stats_collector = stats_collector.get();
+
+ auto callback = std::bind(
+ [rets, step_container, c_mgr, frame](
+ const FunctionLibraryRuntime::DoneCallback& done,
+ const std::shared_ptr<model::Model>& model, const string& prefix,
+ const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
+ // Begin unbound arguments.
+ Status s) {
+ delete step_container;
+ delete c_mgr;
+ if (s.ok()) {
+ s = frame->ConsumeRetvals(rets);
+ }
+ delete frame;
+ if (model) {
+ model->AddProcessingTime(prefix, stats_collector->processing_time());
+ model->RecordStart(prefix, false /* stop_output */);
+ }
+ done(s);
+ if (model) {
+ model->RecordStop(prefix, false /* start_output */);
+ }
+ },
+ std::move(done), ctx->model(), prefix, std::move(stats_collector),
+ std::placeholders::_1);
+
+ ctx->lib()->Run(f_opts, handle, frame, std::move(callback));
}
CapturedFunction::CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs)
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism)
: func_(func),
lib_(nullptr),
f_handle_(kInvalidHandle),
- captured_inputs_(std::move(captured_inputs)) {}
+ captured_inputs_(std::move(captured_inputs)),
+ use_inter_op_parallelism_(use_inter_op_parallelism) {}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/captured_function.h b/tensorflow/core/kernels/data/captured_function.h
index e9ad3e381d..a10376bf97 100644
--- a/tensorflow/core/kernels/data/captured_function.h
+++ b/tensorflow/core/kernels/data/captured_function.h
@@ -32,6 +32,8 @@ class Device;
class OpKernelContext;
class ResourceMgr;
+namespace data {
+
// A `CapturedFunction` encapsulates a TensorFlow function and all of
// the runtime support required to execute it.
//
@@ -40,18 +42,19 @@ class ResourceMgr;
// context.
class CapturedFunction {
public:
- // Creates a new instance from a list of named attributes and captured inputs.
- //
- // NOTE(mrry): The `captured_inputs` are passed by value. For
- // efficiency, you are recommended to move this argument into the call.
- static Status Create(const NameAttrList& func,
- std::vector<Tensor> captured_inputs,
+ // Creates a new instance using a list of named attributes, fetching captured
+ // inputs from a context argument.
+ static Status Create(const NameAttrList& func, OpKernelContext* ctx,
+ const string& argument,
std::unique_ptr<CapturedFunction>* out_function);
// Creates a new instance using a list of named attributes, fetching captured
// inputs from a context argument.
+ //
+ // If `use_inter_op_parallelism` is false, the runtime may use an executor
+ // that is optimized for small functions.
static Status Create(const NameAttrList& func, OpKernelContext* ctx,
- const string& argument,
+ const string& argument, bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction>* out_function);
~CapturedFunction();
@@ -93,7 +96,8 @@ class CapturedFunction {
// in order to be able to deallocate them as early as possible.
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
- FunctionLibraryRuntime::DoneCallback done);
+ FunctionLibraryRuntime::DoneCallback done,
+ const string& prefix);
// Returns the named list of function arguments.
const NameAttrList& func() { return func_; }
@@ -114,10 +118,11 @@ class CapturedFunction {
private:
CapturedFunction(const NameAttrList& func,
- std::vector<Tensor> captured_inputs);
+ std::vector<Tensor> captured_inputs,
+ bool use_inter_op_parallelism);
- Status MaybeInstantiate(IteratorContext* ctx,
- FunctionLibraryRuntime::Handle* out_handle);
+ Status GetHandle(IteratorContext* ctx,
+ FunctionLibraryRuntime::Handle* out_handle);
mutex mu_;
const NameAttrList func_;
@@ -126,10 +131,17 @@ class CapturedFunction {
const std::vector<Tensor> captured_inputs_;
DataTypeSlice ret_types_;
std::function<void(std::function<void()>)> captured_runner_ = nullptr;
+ const bool use_inter_op_parallelism_;
TF_DISALLOW_COPY_AND_ASSIGN(CapturedFunction);
};
+} // namespace data
+
+// TODO(b/114112161): Remove these aliases when all users have moved over to the
+// `tensorflow::data` namespace.
+using data::CapturedFunction;
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_CAPTURED_FUNCTION_H_
diff --git a/tensorflow/core/kernels/data/concatenate_dataset_op.cc b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
index 6393005cdc..9607e9444c 100644
--- a/tensorflow/core/kernels/data/concatenate_dataset_op.cc
+++ b/tensorflow/core/kernels/data/concatenate_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -39,11 +39,11 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
const DatasetBase* to_concatenate)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
to_concatenate_(to_concatenate) {
input_->Ref();
@@ -171,16 +171,16 @@ class ConcatenateDatasetOp : public BinaryDatasetOpKernel {
static PartialTensorShape MostSpecificCompatibleShape(
const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
- return output_tensorshape;
+ return PartialTensorShape();
+ PartialTensorShape output_tensorshape({});
auto dims1 = ts1.dim_sizes();
auto dims2 = ts2.dim_sizes();
for (int d = 0; d < ts1.dims(); d++) {
if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
+ output_tensorshape.AddDim(dims1[d]);
else
- output_tensorshape.Concatenate(-1);
+ output_tensorshape.AddDim(-1);
}
return output_tensorshape;
}
@@ -195,5 +195,5 @@ REGISTER_KERNEL_BUILDER(Name("ConcatenateDataset").Device(DEVICE_CPU),
ConcatenateDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_ops.cc b/tensorflow/core/kernels/data/dataset_ops.cc
index c71d027f23..bd1ccd5b5d 100644
--- a/tensorflow/core/kernels/data/dataset_ops.cc
+++ b/tensorflow/core/kernels/data/dataset_ops.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
@@ -48,4 +49,5 @@ class DatasetToGraphOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("DatasetToGraph").Device(DEVICE_CPU),
DatasetToGraphOp);
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.cc b/tensorflow/core/kernels/data/dataset_utils.cc
index d85ef1cbab..e10833f525 100644
--- a/tensorflow/core/kernels/data/dataset_utils.cc
+++ b/tensorflow/core/kernels/data/dataset_utils.cc
@@ -17,8 +17,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
namespace tensorflow {
-
-namespace dataset {
+namespace data {
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
@@ -45,6 +44,42 @@ Status MakeIteratorFromInputElement(
ctx, strings::StrCat(prefix, "[", thread_index, "]"), out_iterator);
}
-} // namespace dataset
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/dataset_utils.h b/tensorflow/core/kernels/data/dataset_utils.h
index 6c4191c2be..6ec1350cd4 100644
--- a/tensorflow/core/kernels/data/dataset_utils.h
+++ b/tensorflow/core/kernels/data/dataset_utils.h
@@ -20,16 +20,24 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
-namespace dataset {
+namespace data {
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, CapturedFunction* captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
-} // namespace dataset
+// Returns Status::OK() if `expected` and `received` types match,
+// errors::InvalidArgument otherwise.
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received);
+
+// Returns Status::OK() if `expected` and `received` shapes are compatible,
+// errors::InvalidArgument otherwise.
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_DATASET_UTILS_H_
diff --git a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
index 9105587cf4..237511a07d 100644
--- a/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/dense_to_sparse_batch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -76,11 +76,11 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
private:
// TODO(mrry): Push the templated code down to the raw copying routine.
template <class T>
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size,
const PartialTensorShape& row_shape, const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
row_shape_(row_shape),
input_(input) {
@@ -301,5 +301,5 @@ REGISTER_KERNEL_BUILDER(Name("DenseToSparseBatchDataset").Device(DEVICE_CPU),
DenseToSparseBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
new file mode 100644
index 0000000000..43406db3ed
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -0,0 +1,139 @@
+# Description:
+# Contains experimental kernels for datasets and iterators.
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
+cc_library(
+ name = "indexed_dataset_headers",
+ hdrs = ["indexed_dataset.h"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "indexed_dataset",
+ srcs = [
+ "identity_indexed_dataset.cc",
+ "indexed_dataset.cc",
+ ],
+ deps = [
+ ":indexed_dataset_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "prefetching_kernels",
+ srcs = ["prefetching_kernels.cc"],
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "directed_interleave_dataset_op",
+ srcs = ["directed_interleave_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "csv_dataset_op",
+ srcs = ["csv_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
+ name = "ignore_errors_dataset_op",
+ srcs = ["ignore_errors_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "lmdb_dataset_op",
+ srcs = ["lmdb_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ "@lmdb",
+ ],
+)
+
+tf_kernel_library(
+ name = "threadpool_dataset_op",
+ srcs = ["threadpool_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "unique_dataset_op",
+ srcs = ["unique_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "assert_next_dataset_op",
+ srcs = ["assert_next_dataset_op.cc"],
+ deps = [
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//third_party/eigen3",
+ ],
+)
+
+tf_kernel_library(
+ name = "dataset_kernels",
+ deps = [
+ ":assert_next_dataset_op",
+ ":csv_dataset_op",
+ ":directed_interleave_dataset_op",
+ ":ignore_errors_dataset_op",
+ ":indexed_dataset",
+ ":lmdb_dataset_op",
+ ":prefetching_kernels",
+ ":threadpool_dataset_op",
+ ":unique_dataset_op",
+ ],
+)
diff --git a/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
new file mode 100644
index 0000000000..3511cca0f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -0,0 +1,156 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <map>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+class AssertNextDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit AssertNextDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ std::vector<string> transformations;
+ OP_REQUIRES_OK(ctx, ParseVectorArgument<string>(ctx, "transformations",
+ &transformations));
+ *output =
+ new Dataset(ctx, input, transformations, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const std::vector<string>& transformations,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ transformations_(transformations),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Assert")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "AssertNextDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* transformations_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(transformations_, &transformations_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, transformations_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ std::vector<string> tokens =
+ str_util::Split(prefix(), ':', str_util::SkipEmpty());
+ if (dataset()->transformations_.size() > tokens.size() - 2) {
+ return errors::InvalidArgument(
+ "Asserted next ", dataset()->transformations_.size(),
+ " transformations but encountered only ", tokens.size() - 2, ".");
+ }
+ int n = tokens.size();
+ for (size_t i = 0; i < dataset()->transformations_.size(); ++i) {
+ if (dataset()->transformations_[i] != tokens[n - 2 - i]) {
+ return errors::InvalidArgument(
+ "Asserted ", dataset()->transformations_[i],
+ " transformation at offset ", i, " but encountered ",
+ tokens[n - 2 - i], " transformation instead.");
+ }
+ }
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* input_;
+ const std::vector<string> transformations_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
new file mode 100644
index 0000000000..7451ca4cb1
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -0,0 +1,860 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/parsing_ops.cc.
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/io/inputstream_interface.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
+#include "tensorflow/core/lib/io/zlib_compression_options.h"
+#include "tensorflow/core/lib/io/zlib_inputstream.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class CSVDatasetOp : public DatasetOpKernel {
+ public:
+ explicit CSVDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ string compression_type;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "compression_type",
+ &compression_type));
+
+ OpInputList record_defaults_list;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("record_defaults", &record_defaults_list));
+ for (int i = 0; i < record_defaults_list.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults_list[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
+ OP_REQUIRES(ctx, record_defaults_list[i].NumElements() < 2,
+ errors::InvalidArgument(
+ "There should only be 1 default per field but field ", i,
+ " has ", record_defaults_list[i].NumElements()));
+ }
+
+ const Tensor* select_cols_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("select_cols", &select_cols_tensor));
+ OP_REQUIRES(ctx, select_cols_tensor->dims() == 1,
+ errors::InvalidArgument("`select_cols` must be a vector."));
+
+ int64 buffer_size;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
+ OP_REQUIRES(ctx, buffer_size > 0,
+ errors::InvalidArgument("buffer_size should be positive"));
+
+ string delim;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "field_delim", &delim));
+ OP_REQUIRES(ctx, delim.size() == 1,
+ errors::InvalidArgument("field_delim should be only 1 char"));
+
+ bool header;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "header", &header));
+
+ bool use_quote_delim;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "use_quote_delim",
+ &use_quote_delim));
+ string na_value;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<string>(ctx, "na_value", &na_value));
+
+ std::vector<Tensor> record_defaults;
+ record_defaults.reserve(record_defaults_list.size());
+ for (const Tensor& t : record_defaults_list) {
+ record_defaults.push_back(t);
+ }
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ io::ZlibCompressionOptions zlib_compression_options =
+ io::ZlibCompressionOptions::DEFAULT();
+ if (compression_type == "ZLIB") {
+ zlib_compression_options = io::ZlibCompressionOptions::DEFAULT();
+ } else if (compression_type == "GZIP") {
+ zlib_compression_options = io::ZlibCompressionOptions::GZIP();
+ } else {
+ OP_REQUIRES(ctx, compression_type.empty(),
+ errors::InvalidArgument(
+ "Unsupported compression_type: ", compression_type, "."));
+ }
+ zlib_compression_options.input_buffer_size = buffer_size;
+
+ std::vector<int64> select_cols;
+ select_cols.reserve(select_cols_tensor->NumElements());
+ for (int i = 0; i < select_cols_tensor->NumElements(); ++i) {
+ select_cols.push_back(select_cols_tensor->flat<int64>()(i));
+ }
+ OP_REQUIRES(
+ ctx, output_types_.size() == select_cols.size() || select_cols.empty(),
+ errors::InvalidArgument("select_cols should match output size"));
+ for (int i = 1; i < select_cols.size(); i++) {
+ OP_REQUIRES(ctx, select_cols[i - 1] < select_cols[i],
+ errors::InvalidArgument(
+ "select_cols should be strictly increasing indices"));
+ }
+ OP_REQUIRES(
+ ctx, select_cols.empty() || select_cols.front() >= 0,
+ errors::InvalidArgument("select_cols should be non-negative indices"));
+
+ *output = new Dataset(ctx, std::move(filenames), header,
+ std::move(compression_type), zlib_compression_options,
+ output_types_, output_shapes_,
+ std::move(record_defaults), std::move(select_cols),
+ use_quote_delim, delim[0], std::move(na_value));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, std::vector<string> filenames, bool header,
+ string compression_type, io::ZlibCompressionOptions options,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ std::vector<Tensor> record_defaults, std::vector<int64> select_cols,
+ bool use_quote_delim, char delim, string na_value)
+ : DatasetBase(DatasetContext(ctx)),
+ filenames_(std::move(filenames)),
+ header_(header),
+ out_type_(output_types),
+ output_shapes_(output_shapes),
+ record_defaults_(std::move(record_defaults)),
+ select_cols_(std::move(select_cols)),
+ use_quote_delim_(use_quote_delim),
+ delim_(delim),
+ na_value_(std::move(na_value)),
+ use_compression_(!compression_type.empty()),
+ compression_type_(std::move(compression_type)),
+ options_(options) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::CSV")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override { return out_type_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override { return "CSVDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ Node* compression_type = nullptr;
+ Node* buffer_size = nullptr;
+ Node* header = nullptr;
+ Node* delim = nullptr;
+ Node* use_quote_delim = nullptr;
+ Node* na_value = nullptr;
+ Node* select_cols = nullptr;
+
+ std::vector<Node*> record_defaults;
+ record_defaults.reserve(record_defaults_.size());
+ for (const Tensor& t : record_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ record_defaults.emplace_back(node);
+ }
+
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddScalar(compression_type_, &compression_type));
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(options_.input_buffer_size, &buffer_size));
+ TF_RETURN_IF_ERROR(b->AddScalar(header_, &header));
+
+ string delim_string(1, delim_);
+ TF_RETURN_IF_ERROR(b->AddScalar(delim_string, &delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(use_quote_delim_, &use_quote_delim));
+ TF_RETURN_IF_ERROR(b->AddScalar(na_value_, &na_value));
+ TF_RETURN_IF_ERROR(b->AddVector(select_cols_, &select_cols));
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {std::make_pair(0, filenames), std::make_pair(1, compression_type),
+ std::make_pair(2, buffer_size), std::make_pair(3, header),
+ std::make_pair(4, delim), std::make_pair(5, use_quote_delim),
+ std::make_pair(6, na_value),
+ std::make_pair(7, select_cols)}, // Single tensor inputs
+ {std::make_pair(8, record_defaults)}, // Tensor list inputs
+ {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ bool select_all = dataset()->select_cols_.empty();
+ do {
+ // We are currently processing a file, so try to read the next record
+ if (input_stream_) {
+ Status s = ReadRecord(ctx, out_tensors, select_all,
+ dataset()->select_cols_);
+ if (s.ok()) {
+ // Validate output
+ if (out_tensors->size() != dataset()->out_type_.size()) {
+ return errors::InvalidArgument(
+ "Expect ", dataset()->out_type_.size(), " fields but have ",
+ out_tensors->size(), " in record");
+ }
+
+ *end_of_sequence = false;
+ return s;
+ }
+ if (!errors::IsOutOfRange(s)) {
+ // Not at the end of file, return OK or non-EOF errors to caller.
+ *end_of_sequence = false;
+ return s;
+ }
+ // We have reached the end of the current file, so maybe
+ // move on to next file.
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ // Iteration ends when there are no more files to process.
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
+ current_file_index_));
+ // `input_stream_` is empty if
+ // 1. GetNext has not been called even once.
+ // 2. All files have been read and the iterator has been exhausted.
+ if (input_stream_ && num_buffer_reads_ > 0) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("pos"), pos_));
+ // If num_buffer_reads_ == 0, the buffer hasn't been filled even once.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_buffer_reads"),
+ num_buffer_reads_));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ ResetStreamsLocked();
+ int64 current_file_index;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("current_file_index"),
+ &current_file_index));
+ current_file_index_ = size_t(current_file_index);
+ // The keys "pos" and "num_buffer_reads" are written only if
+ // the iterator was saved with an open, partially read file.
+ if (reader->Contains(full_name("pos"))) {
+ int64 pos, num_buffer_reads;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("pos"), &pos));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_buffer_reads"),
+ &num_buffer_reads));
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+
+ num_buffer_reads_ = size_t(num_buffer_reads - 1);
+
+ // Restores the most recently held buffer
+ Status s = input_stream_->SkipNBytes(
+ num_buffer_reads_ * dataset()->options_.input_buffer_size);
+ if (!s.ok() && !errors::IsOutOfRange(s)) {
+ // We might get out of range error here if the size of the file
+ // is not an exact multiple of the buffer size, and the last buffer
+ // read is < buffer_size. This is valid and we do not surface the
+ // error.
+ return s;
+ }
+
+ Status s2 = FillBuffer(&buffer_);
+ if (!s2.ok() && !errors::IsOutOfRange(s2)) {
+ return s2;
+ }
+ pos_ = size_t(pos);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // Reads an entire CSV row from the input stream, either from the
+ // existing buffer or by filling the buffer as needed. Converts extracted
+ // fields to output tensors as we go.
+ //
+ // When this function is called, pos_ should be the index of the first
+ // character of the record in buffer_, or past the end of the buffer.
+ // Note: ctx and out_tensors are only used in this function
+ // when fields are included in the record.
+ Status ReadRecord(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool select_all, const std::vector<int64>& selected)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // At the end of the file, this will return errors::OutOfRange
+ TF_RETURN_IF_ERROR(FillBuffer(&buffer_));
+ pos_ = 0;
+ }
+
+ // The first character may be \n if this is the continuation of a
+ // \r\n linebreak between this and the previous record. If so, skip it.
+
+ bool end_of_record = false; // Keep track of when we find \n, \r or EOF
+ size_t num_parsed = 0;
+ size_t num_selected_parsed = 0;
+
+ Status result;
+
+ while (!end_of_record) { // Read till we reach \n, \r or EOF
+ bool include =
+ select_all || (num_selected_parsed < selected.size() &&
+ selected[num_selected_parsed] == num_parsed);
+
+ // Don't fail fast, so that the next call to GetNext may still return
+ // a valid record
+ result.Update(
+ ParseOneField(ctx, out_tensors, &end_of_record, include));
+
+ num_parsed++;
+ if (include) num_selected_parsed++;
+ }
+
+ return result;
+ }
+
+ // Parses one field from position pos_ in the buffer. Fields are
+ // delimited by delim, CRLF, or EOF. Advances pos_ to the first char of
+ // the next field.
+ Status ParseOneField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ // If we get here, this means the previous field's end coincided
+ // with the end of the buffer. We can fill the buffer without abandon.
+ Status s = FillBuffer(&buffer_);
+
+ if (errors::IsOutOfRange(s)) {
+ // Reached EOF, and last field is empty
+ *end_of_record = true;
+ if (include) {
+ return FieldToOutput(ctx, StringPiece(), out_tensors);
+ } else {
+ return Status::OK();
+ }
+ } else if (!s.ok()) {
+ return s; // Surface other errors back to caller
+ }
+
+ pos_ = 0;
+ }
+
+ if (dataset()->use_quote_delim_ && buffer_[pos_] == '"') {
+ return ParseQuotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ return ParseUnquotedField(ctx, out_tensors, end_of_record, include);
+ }
+
+ // For keeping track of relevant parts of a field from a previous buffer
+ struct Piece {
+ size_t start;
+ size_t len;
+ string buffer;
+
+ Piece(string buffer, size_t start, size_t len)
+ : start(start), len(len), buffer(std::move(buffer)) {}
+ };
+
+ // Given that pos_ exceeds the buffer, saves the relevant part of the
+ // current buffer (if necessary), fills the buffer, and resets indices to
+ // 0.
+ Status SaveAndFillBuffer(std::vector<Piece>* earlier_pieces,
+ size_t* start, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ string temp_buffer;
+
+ buffer_.swap(temp_buffer);
+ if (include && pos_ > *start) {
+ earlier_pieces->push_back(
+ Piece(std::move(temp_buffer), *start, pos_ - *start));
+ }
+ pos_ = 0;
+ *start = 0;
+ return FillBuffer(&buffer_);
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseQuotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ pos_++; // Starting quotation mark
+
+ Status parse_result;
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ return errors::InvalidArgument(
+ "Reached end of file without closing quoted field in "
+ "record");
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+ if (ch == '"') {
+ // When we encounter a quote, we look ahead to the next character to
+ // decide what to do
+ pos_++;
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ if (errors::IsOutOfRange(s)) {
+ // This was the last field. We are done
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(), out_tensors, earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s;
+ }
+ }
+
+ char next = buffer_[pos_];
+ pos_++;
+ if (next == dataset()->delim_) {
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ return parse_result;
+
+ } else if (next == '\n' || next == '\r') {
+ *end_of_record = true;
+ parse_result.Update(QuotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - 1 - start),
+ out_tensors, earlier_pieces, include));
+ if (next == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ } else if (next != '"') {
+ // Take note of the error, but keep going to end of field.
+ include = false; // So we don't get funky errors when trying to
+ // unescape the quotes.
+ parse_result.Update(errors::InvalidArgument(
+ "Quote inside a string has to be escaped by another quote"));
+ }
+
+ } else {
+ pos_++;
+ }
+ }
+ }
+
+ // Converts quoted field to an output tensor, removing the starting
+ // and ending quotes from it and unescaping double quotations if
+ // necessary.
+ Status QuotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ if (field.find('\"', 1) == field.size() - 1) {
+ // `field` contains no escaped quotation marks.
+ // Exclude framing quotation marks
+ field.remove_prefix(1);
+ field.remove_suffix(1);
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+ }
+ string field_complete;
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ field_complete.reserve(str_len);
+
+ // This bool flips every time we see a quote, so that we skip the second
+ // quote of every pair of adjacent quotes in the field. We need to track
+ // this across iterations of the for loop because adjacent double quotes
+ // may be in different buffers. Initialize to true because we also skip
+ // the opening quotation mark of the quoted field.
+ bool skip_next_quote = true;
+ for (const Piece& p : earlier_pieces) {
+ AppendUnescapedPiece(StringPiece(&p.buffer[p.start], p.len),
+ &field_complete, &skip_next_quote);
+ }
+ AppendUnescapedPiece(field, &field_complete, &skip_next_quote);
+ StringPiece result = StringPiece(field_complete);
+ result.remove_suffix(1); // Skip final quote
+
+ return FieldToOutput(ctx, result, out_tensors);
+ }
+
+ void AppendUnescapedPiece(StringPiece piece, string* field_complete,
+ bool* skip_next_quote) {
+ size_t from = 0;
+ size_t found = piece.find('\"', from);
+ while (found != string::npos) {
+ if (!*skip_next_quote) {
+ // This is the first quote in a pair of adjacent double quotes
+ field_complete->append(piece.data() + from, found + 1 - from);
+ }
+ *skip_next_quote = !*skip_next_quote;
+ from = found + 1;
+ found = piece.find('\"', from);
+ }
+ // Include the chunk after the last quotation mark in the string
+ if (from < piece.size()) {
+ field_complete->append(piece.data() + from, piece.size() - from);
+ }
+ }
+
+ // Parses unquoted field from position pos_ in the buffer. Continually
+ // reads from buffer until end of field is reached (delim, CRLF, or EOF).
+ // Advances pos_ to keep track of our position in the buffer as we go,
+ // stopping at the first character of the next field.
+ Status ParseUnquotedField(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_record, bool include)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ std::vector<Piece> earlier_pieces;
+ size_t start = pos_;
+ Status parse_result;
+
+ while (true) { // Each iter reads 1 char, filling buffer if necessary
+ if (pos_ >= buffer_.size()) {
+ Status s = SaveAndFillBuffer(&earlier_pieces, &start, include);
+ // Handle errors
+ if (errors::IsOutOfRange(s)) {
+ // Whatever we have is the last field of the last record
+ *end_of_record = true;
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ return parse_result;
+ } else if (!s.ok()) {
+ return s; // Surface all other errors to caller
+ }
+ }
+
+ char ch = buffer_[pos_];
+
+ if (ch == dataset()->delim_) {
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ pos_++;
+ return parse_result;
+ }
+ if (ch == '\n' || ch == '\r') {
+ // need special case to skip over first \n of record if the line
+ // breaks are \r\n
+ parse_result.Update(UnquotedFieldToOutput(
+ ctx, StringPiece(&buffer_[start], pos_ - start), out_tensors,
+ earlier_pieces, include));
+ *end_of_record = true;
+ pos_++;
+ if (ch == '\r') SkipNewLineIfNecessary();
+ return parse_result;
+ }
+ if (dataset()->use_quote_delim_ && ch == '"') {
+ // Take note of the error, but keep going to end of field.
+ parse_result.Update(errors::InvalidArgument(
+ "Unquoted fields cannot have quotes inside"));
+ }
+ // Otherwise, go to next character
+ pos_++;
+ }
+ }
+
+ Status FillBuffer(string* result) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ result->clear();
+ ++num_buffer_reads_;
+ Status s = input_stream_->ReadNBytes(
+ dataset()->options_.input_buffer_size, result);
+
+ if (errors::IsOutOfRange(s) && !result->empty()) {
+ // Ignore OutOfRange error when ReadNBytes read < N bytes.
+ return Status::OK();
+ }
+ return s;
+ }
+
+ // Given a field, converts it to the right output tensor type
+ Status FieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors) {
+ size_t output_idx = out_tensors->size();
+ if (output_idx >= dataset()->out_type_.size()) {
+ // We can get here if we're selecting all columns, but the number of
+ // fields exceeds the number of defaults provided
+ return errors::InvalidArgument("Expect ", dataset()->out_type_.size(),
+ " fields but have more in record");
+ }
+ const DataType& dtype = dataset()->out_type_[output_idx];
+ Tensor component(ctx->allocator({}), dtype, {});
+ if ((field.empty() || field == dataset()->na_value_) &&
+ dataset()->record_defaults_[output_idx].NumElements() != 1) {
+ // If the field is empty or NA value, and default is not given,
+ // report error.
+ return errors::InvalidArgument("Field ", output_idx,
+ " is required but missing in record!");
+ }
+
+ switch (dtype) {
+ // For each case, if the field is empty, we use the default.
+ // Otherwise, we convert it to the right type.
+ case DT_INT32: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int32>()() =
+ dataset()->record_defaults_[output_idx].flat<int32>()(0);
+ } else {
+ int32 value;
+ if (!strings::safe_strto32(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int32: ", field);
+ }
+ component.scalar<int32>()() = value;
+ }
+ break;
+ }
+ case DT_INT64: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<int64>()() =
+ dataset()->record_defaults_[output_idx].flat<int64>()(0);
+ } else {
+ int64 value;
+ if (!strings::safe_strto64(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid int64: ", field);
+ }
+ component.scalar<int64>()() = value;
+ }
+ break;
+ }
+ case DT_FLOAT: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<float>()() =
+ dataset()->record_defaults_[output_idx].flat<float>()(0);
+ } else {
+ float value;
+ if (!strings::safe_strtof(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid float: ", field);
+ }
+ component.scalar<float>()() = value;
+ }
+ break;
+ }
+ case DT_DOUBLE: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<double>()() =
+ dataset()->record_defaults_[output_idx].flat<double>()(0);
+ } else {
+ double value;
+ if (!strings::safe_strtod(field, &value)) {
+ return errors::InvalidArgument(
+ "Field ", output_idx,
+ " in record is not a valid double: ", field);
+ }
+ component.scalar<double>()() = value;
+ }
+ break;
+ }
+ case DT_STRING: {
+ if (field.empty() || field == dataset()->na_value_) {
+ component.scalar<string>()() =
+ dataset()->record_defaults_[output_idx].flat<string>()(0);
+ } else {
+ component.scalar<string>()() = string(field);
+ }
+ break;
+ }
+ default:
+ return errors::InvalidArgument("csv: data type ", dtype,
+ " not supported in field ",
+ output_idx);
+ }
+ out_tensors->push_back(std::move(component));
+ return Status::OK();
+ }
+
+ // Records can be delimited by "\r\n" line breaks. When we encounter a
+ // '\r', we have to check the next character to see if it is part of the
+ // linebreak, and ignore it if so.
+ void SkipNewLineIfNecessary() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (pos_ >= buffer_.size()) {
+ Status s = FillBuffer(&buffer_);
+ pos_ = 0;
+ // If we failed to fill buffer, it doesn't matter because we're done
+ // with the record
+ if (!s.ok()) return;
+ }
+ if (buffer_[pos_] == '\n') {
+ pos_++;
+ }
+ }
+
+ // Given a string field, and its index in the output,
+ // converts it to a Tensor of the right type and adds it to the
+ // out_tensors vector.
+ Status UnquotedFieldToOutput(IteratorContext* ctx, StringPiece field,
+ std::vector<Tensor>* out_tensors,
+ const std::vector<Piece>& earlier_pieces,
+ bool include) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!include) return Status::OK();
+
+ if (earlier_pieces.empty()) {
+ return FieldToOutput(ctx, field, out_tensors);
+ }
+
+ size_t str_len = field.size();
+ for (const Piece& p : earlier_pieces) {
+ str_len += p.len;
+ }
+ string field_complete;
+ field_complete.reserve(str_len);
+
+ for (const Piece& p : earlier_pieces) {
+ field_complete.append(p.buffer, p.start, p.len);
+ }
+
+ field_complete.append(field.data(), field.size());
+ return FieldToOutput(ctx, field_complete, out_tensors);
+ }
+
+ // Sets up reader streams to read from the file at `current_file_index_`.
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+
+ // Actually move on to next file.
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(
+ dataset()->filenames_[current_file_index_], &file_));
+ random_access_input_stream_ =
+ std::make_shared<io::RandomAccessInputStream>(file_.get(), false);
+
+ if (dataset()->use_compression_) {
+ input_stream_ = std::make_shared<io::ZlibInputStream>(
+ random_access_input_stream_.get(),
+ dataset()->options_.input_buffer_size,
+ dataset()->options_.input_buffer_size, dataset()->options_);
+ } else {
+ input_stream_ = random_access_input_stream_;
+ }
+ buffer_.clear();
+ pos_ = 0;
+ num_buffer_reads_ = 0;
+ if (dataset()->header_) {
+ // Read one line, but don't include it. Pass nullptrs as dummy
+ // pointers to objects that shouldn't be invoked anyway
+ // We need to process this as a record here instead of just finding
+ // the first newline because it might contain quoted fields with
+ // newlines in the header as well
+ std::vector<int64> empty;
+ Status s = ReadRecord(nullptr, nullptr, false, empty);
+ if (!s.ok()) {
+ return errors::InvalidArgument("Can't read header of file");
+ }
+ }
+ return Status::OK();
+ }
+
+ // Resets all reader streams.
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ input_stream_.reset();
+ file_.reset();
+ }
+
+ mutex mu_;
+ string buffer_ GUARDED_BY(mu_); // Maintain our own buffer
+ size_t pos_ GUARDED_BY(
+ mu_); // Index into the buffer must be maintained between iters
+ size_t num_buffer_reads_ GUARDED_BY(mu_);
+ std::shared_ptr<io::RandomAccessInputStream> random_access_input_stream_
+ GUARDED_BY(mu_);
+ std::shared_ptr<io::InputStreamInterface> input_stream_ GUARDED_BY(mu_);
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<RandomAccessFile> file_
+ GUARDED_BY(mu_); // must outlive input_stream_
+ }; // class Iterator
+
+ const std::vector<string> filenames_;
+ const bool header_;
+ const DataTypeVector out_type_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<Tensor> record_defaults_;
+ const std::vector<int64> select_cols_;
+ const bool use_quote_delim_;
+ const char delim_;
+ const string na_value_;
+ const bool use_compression_;
+ const string compression_type_;
+ const io::ZlibCompressionOptions options_;
+ }; // class Dataset
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+}; // class CSVDatasetOp
+
+// Register the kernel implementation for CSVDataset.
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
new file mode 100644
index 0000000000..c47a9099c4
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -0,0 +1,281 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class DirectedInterleaveDatasetOp : public DatasetOpKernel {
+ public:
+ explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx)
+ : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ DatasetBase* selector_input;
+ OP_REQUIRES_OK(ctx,
+ GetDatasetFromVariantTensor(ctx->input(0), &selector_input));
+
+ OP_REQUIRES(
+ ctx,
+ selector_input->output_dtypes().size() == 1 &&
+ selector_input->output_dtypes()[0] == DT_INT64 &&
+ selector_input->output_shapes().size() == 1 &&
+ selector_input->output_shapes()[0].IsCompatibleWith(
+ PartialTensorShape({})),
+ errors::InvalidArgument(
+ "The selector input must be a dataset of scalar int64 elements."));
+
+ std::vector<DatasetBase*> data_inputs;
+ for (size_t i = 1; i < ctx->num_inputs(); ++i) {
+ DatasetBase* input;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(i), &input));
+ data_inputs.push_back(input);
+
+ OP_REQUIRES(
+ ctx, data_inputs[0]->output_dtypes() == input->output_dtypes(),
+ errors::InvalidArgument(
+ "All inputs must have the same output_dtypes. First input "
+ "has types ",
+ DataTypeVectorString(data_inputs[0]->output_dtypes()),
+ ", and input ", i - 1, " has types ",
+ DataTypeVectorString(input->output_dtypes())));
+ }
+ *output = new Dataset(ctx, selector_input, std::move(data_inputs));
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* selector_input,
+ std::vector<DatasetBase*> data_inputs)
+ : DatasetBase(DatasetContext(ctx)),
+ selector_input_(selector_input),
+ data_inputs_(std::move(data_inputs)) {
+ selector_input_->Ref();
+
+ output_shapes_ = data_inputs_[0]->output_shapes();
+ data_inputs_[0]->Ref();
+ for (size_t i = 1; i < data_inputs_.size(); ++i) {
+ const DatasetBase* data_input = data_inputs_[i];
+ data_input->Ref();
+ for (size_t j = 0; j < output_shapes_.size(); ++j) {
+ output_shapes_[j] = MostSpecificCompatibleShape(
+ output_shapes_[j], data_input->output_shapes()[j]);
+ }
+ }
+ }
+
+ ~Dataset() override {
+ selector_input_->Unref();
+ for (DatasetBase* data_input : data_inputs_) {
+ data_input->Unref();
+ }
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::DirectedInterleave")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return data_inputs_[0]->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("DirectedInterleaveDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* selector_input_node;
+ TF_RETURN_IF_ERROR(
+ b->AddInputDataset(ctx, selector_input_, &selector_input_node));
+ std::vector<Node*> data_input_nodes(data_inputs_.size());
+ for (size_t i = 0; i < data_inputs_.size(); ++i) {
+ TF_RETURN_IF_ERROR(
+ b->AddInputDataset(ctx, data_inputs_[i], &data_input_nodes[i]));
+ }
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {{0, selector_input_node}},
+ {{1, data_input_nodes}}, {}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ num_active_inputs_(params.dataset->data_inputs_.size()) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(dataset()->selector_input_->MakeIterator(
+ ctx, strings::StrCat(prefix(), ".selector"),
+ &selector_input_impl_));
+ data_input_impls_.resize(dataset()->data_inputs_.size());
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const DatasetBase* data_input = dataset()->data_inputs_[i];
+ TF_RETURN_IF_ERROR(data_input->MakeIterator(
+ ctx, strings::StrCat(prefix(), "[", i, "]"),
+ &data_input_impls_[i]));
+ }
+ return Status::OK();
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (!selector_input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ while (true) {
+ std::vector<Tensor> selector_result;
+ *end_of_sequence = false;
+ TF_RETURN_IF_ERROR(selector_input_impl_->GetNext(
+ ctx, &selector_result, end_of_sequence));
+ if (*end_of_sequence) {
+ selector_input_impl_.reset();
+ for (auto& data_input_impl : data_input_impls_) {
+ data_input_impl.reset();
+ }
+ return Status::OK();
+ }
+
+ int64 selected_input = selector_result[0].scalar<int64>()();
+ if (selected_input < 0 || selected_input > data_input_impls_.size()) {
+ return errors::InvalidArgument(
+ "Selector index out of range: ", selected_input,
+ " >= ", data_input_impls_.size());
+ }
+
+ if (data_input_impls_[selected_input]) {
+ bool end_of_selected_input = false;
+ TF_RETURN_IF_ERROR(data_input_impls_[selected_input]->GetNext(
+ ctx, out_tensors, &end_of_selected_input));
+
+ if (!end_of_selected_input) {
+ return Status::OK();
+ }
+
+ data_input_impls_[selected_input].reset();
+ --num_active_inputs_;
+
+ if (num_active_inputs_ == 0) {
+ selector_input_impl_.reset();
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ }
+
+ LOG(WARNING) << "DirectedInterleave selected an exhausted input: "
+ << selected_input;
+ }
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (selector_input_impl_) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ const auto& data_input_impl = data_input_impls_[i];
+ if (data_input_impl) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl));
+ } else {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
+ ""));
+ }
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("selector_input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
+ } else {
+ selector_input_impl_.reset();
+ }
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ if (!reader->Contains(full_name(
+ strings::StrCat("data_input_impl_empty[", i, "]")))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
+ } else {
+ data_input_impls_[i].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> selector_input_impl_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
+ GUARDED_BY(mu_);
+ int64 num_active_inputs_ GUARDED_BY(mu_);
+ };
+
+ static PartialTensorShape MostSpecificCompatibleShape(
+ const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
+ PartialTensorShape output_tensorshape;
+ if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
+ return output_tensorshape;
+ auto dims1 = ts1.dim_sizes();
+ auto dims2 = ts2.dim_sizes();
+ for (int d = 0; d < ts1.dims(); d++) {
+ if (dims1[d] == dims2[d])
+ output_tensorshape.Concatenate(dims1[d]);
+ else
+ output_tensorshape.Concatenate(-1);
+ }
+ return output_tensorshape;
+ }
+
+ const DatasetBase* const selector_input_;
+ const std::vector<DatasetBase*> data_inputs_;
+ std::vector<PartialTensorShape> output_shapes_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
new file mode 100644
index 0000000000..2141f118ca
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -0,0 +1,156 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
+ public:
+ using IndexedDatasetOpKernel::IndexedDatasetOpKernel;
+
+ void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) override {
+ uint64 size = -1;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<uint64>(ctx, "size", &size));
+ OP_REQUIRES(ctx, size > 0, errors::InvalidArgument("`size` must be > 0"));
+ *output = new Dataset(ctx, size);
+ }
+
+ class Dataset : public IndexedDataset {
+ public:
+ Dataset(OpKernelContext* ctx, uint64 size)
+ : IndexedDataset(DatasetContext(ctx)), size_(size) {}
+
+ Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) override {
+ materialized->reset(new Materialized(this));
+ return Status::OK();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_UINT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::IdentityIndexedDataset")}));
+ }
+
+ string DebugString() const override {
+ return "IdentityIndexedDataset::Dataset";
+ }
+
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** node) const override {
+ return errors::Unimplemented(
+ "identity_indexed_dataset.AsGraphDefInternal");
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ if (cur_ < dataset()->size_) {
+ Tensor result_tensor(ctx->allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = cur_++;
+ out_tensors->emplace_back(std::move(result_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ uint64 cur_ GUARDED_BY(mu_);
+ };
+
+ class Materialized : public MaterializedIndexedDataset {
+ public:
+ explicit Materialized(Dataset* dataset) : dataset_(dataset) {
+ dataset->Ref();
+ }
+
+ ~Materialized() override {
+ // TODO(saeta): Pull this into MaterializedIndexedDataset
+ dataset_->Unref();
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return dataset_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return dataset_->output_shapes();
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const override {
+ LOG(INFO) << "Materialized(" << dataset_->size_ << ")::Get(" << index
+ << ")";
+ if (index >= dataset_->size_) {
+ // Note: use InvalidArgument instead of OutOfRange error because many
+ // things consider OutOfRange to be a "clean termination" error.
+ return errors::InvalidArgument(
+ "Index ", index,
+ " is out of range for this dataset. (Size is: ", dataset_->size_,
+ ".)");
+ }
+ Tensor result_tensor(ctx.allocator({}), DT_UINT64, {});
+ result_tensor.scalar<uint64>()() = index;
+ out_tensors->emplace_back(std::move(result_tensor));
+ return Status::OK();
+ }
+
+ Status Size(uint64* size) const override {
+ *size = dataset_->size_;
+ return Status::OK();
+ }
+
+ private:
+ const Dataset* const dataset_; // Not owned.
+ };
+
+ const uint64 size_;
+ std::shared_ptr<Materialized> materialized_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
new file mode 100644
index 0000000000..b34377c642
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit IgnoreErrorsDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::IgnoreErrors")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "IgnoreErrorsDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ {
+ tf_shared_lock l(mu_);
+ if (!input_impl_) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ while (!s.ok()) {
+ out_tensors->clear();
+ s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ }
+ }
+ if (*end_of_sequence) {
+ mutex_lock l(mu_);
+ input_impl_.reset();
+ }
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_)
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ else
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impls_empty"), ""));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (reader->Contains(full_name("input_impls_empty")))
+ input_impl_.reset();
+ else
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
new file mode 100644
index 0000000000..75ea462f40
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -0,0 +1,375 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
+
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+Status VerifyTypesMatch(const DataTypeVector& expected,
+ const DataTypeVector& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " types but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (expected[i] != received[i]) {
+ return errors::InvalidArgument("Data type mismatch at component ", i,
+ ": expected ", DataTypeString(expected[i]),
+ " but got ", DataTypeString(received[i]),
+ ".");
+ }
+ }
+ return Status::OK();
+}
+
+Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
+ const std::vector<PartialTensorShape>& received) {
+ if (expected.size() != received.size()) {
+ return errors::InvalidArgument(
+ "Number of components does not match: expected ", expected.size(),
+ " shapes but got ", received.size(), ".");
+ }
+ for (size_t i = 0; i < expected.size(); ++i) {
+ if (!expected[i].IsCompatibleWith(received[i])) {
+ return errors::InvalidArgument("Incompatible shapes at component ", i,
+ ": expected ", expected[i].DebugString(),
+ " but got ", received[i].DebugString(),
+ ".");
+ }
+ }
+
+ return Status::OK();
+}
+
+class MaterializedDatasetResource : public ResourceBase {
+ public:
+ MaterializedDatasetResource(
+ const DataTypeVector& output_dtypes,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : output_dtypes_(output_dtypes), output_shapes_(output_shapes) {}
+
+ string DebugString() override {
+ return "Materialized IndexedDataset resource";
+ }
+
+ Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) {
+ std::shared_ptr<MaterializedIndexedDataset> captured(materialized_);
+ if (captured) {
+ return captured->Get(std::move(ctx), index, out_tensors);
+ } else {
+ return errors::FailedPrecondition(
+ "Get() failed because the MaterializedIndexedDataset has not been "
+ "initialized. Ensure that you have run the materialization operation "
+ "for this MaterializedIndexedDataset before retrieving elements.");
+ }
+ }
+
+ // TODO(saeta): Implement Save and Restore
+
+ const DataTypeVector& output_dtypes() const { return output_dtypes_; }
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ Status set_materialized_dataset(
+ const std::shared_ptr<MaterializedIndexedDataset>& dataset) {
+ if (dataset) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, dataset->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, dataset->output_shapes()));
+ }
+ materialized_ = dataset;
+ return Status::OK();
+ }
+
+ private:
+ std::shared_ptr<MaterializedIndexedDataset> materialized_;
+ const DataTypeVector output_dtypes_;
+ const std::vector<PartialTensorShape> output_shapes_;
+};
+
+// A wrapper class for storing an `IndexedDataset` instance in a DT_VARIANT
+// tensor. Objects of the wrapper class own a reference on an instance of an
+// `IndexedTensor` and the wrapper's copy constructor and desctructor take care
+// of managing the reference count.
+//
+// NOTE: This is not a feature-complete implementation of the DT_VARIANT
+// specification. In particular, we cannot currently serialize an arbitrary
+// `IndexedDataset` object, so the `Encode()` and `Decode()` methods are not
+// implemented.
+//
+// NOTE(saeta): When `IndexedDataset`s get merged into core, we can instead just
+// use `tensorflow::DatasetVariantWrapper`.
+class IndexedDatasetVariantWrapper {
+ public:
+ IndexedDatasetVariantWrapper() : dataset_(nullptr) {}
+
+ // Transfers ownership of `dataset` to `*this`.
+ explicit IndexedDatasetVariantWrapper(IndexedDataset* dataset)
+ : dataset_(dataset) {}
+
+ IndexedDatasetVariantWrapper(const IndexedDatasetVariantWrapper& other)
+ : dataset_(other.dataset_) {
+ if (dataset_) dataset_->Ref();
+ }
+
+ ~IndexedDatasetVariantWrapper() {
+ if (dataset_) dataset_->Unref();
+ }
+
+ IndexedDataset* get() const { return dataset_; }
+
+ string TypeName() const { return "tensorflow::IndexedDatasetVariantWrapper"; }
+ string DebugString() const {
+ if (dataset_) {
+ return dataset_->DebugString();
+ } else {
+ return "<Uninitialized IndexedDatasetVariantWrapper>";
+ }
+ }
+
+ void Encode(VariantTensorData* data) const {
+ LOG(ERROR) << "The Encode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ }
+
+ bool Decode(const VariantTensorData& data) {
+ LOG(ERROR) << "The Decode() method is not implemented for "
+ "IndexedDatasetVariantWrapper objects.";
+ return false;
+ }
+
+ private:
+ IndexedDataset* const dataset_; // Owns one reference.
+};
+
+} // namespace
+
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset) {
+ if (!(tensor.dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor.shape()))) {
+ return errors::InvalidArgument(
+ "IndexedDataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ const Variant& variant = tensor.scalar<Variant>()();
+ const IndexedDatasetVariantWrapper* wrapper =
+ variant.get<IndexedDatasetVariantWrapper>();
+ if (wrapper == nullptr) {
+ return errors::InvalidArgument("Tensor must be an IndexedDataset object.");
+ }
+ *out_dataset = wrapper->get();
+ if (*out_dataset == nullptr) {
+ return errors::Internal("Read uninitialized IndexedDataset variant.");
+ }
+ return Status::OK();
+}
+
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor) {
+ if (!(tensor->dtype() == DT_VARIANT ||
+ TensorShapeUtils::IsScalar(tensor->shape()))) {
+ return errors::InvalidArgument(
+ "Dataset tensor must be a scalar of dtype DT_VARIANT.");
+ }
+ tensor->scalar<Variant>()() = IndexedDatasetVariantWrapper(dataset);
+ return Status::OK();
+}
+
+void IndexedDatasetOpKernel::Compute(OpKernelContext* ctx) {
+ IndexedDataset* dataset = nullptr;
+ MakeIndexedDataset(ctx, &dataset);
+
+ if (ctx->status().ok()) {
+ OP_REQUIRES(ctx, dataset != nullptr,
+ errors::Internal("MakeIndexedDataset did not correctly "
+ "construct the IndexedDataset"));
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ OP_REQUIRES_OK(ctx, StoreIndexedDatasetInVariantTensor(dataset, output));
+ }
+}
+
+namespace {
+
+class MaterializedHandleOp : public OpKernel {
+ public:
+ explicit MaterializedHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_dtypes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ ~MaterializedHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ // Note: cargo-culted from $tf/core/framework/resource_op_kernel.h
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MaterializedDatasetResource* resource;
+ OP_REQUIRES_OK(context,
+ mgr->LookupOrCreate<MaterializedDatasetResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this](MaterializedDatasetResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MaterializedDatasetResource(
+ output_dtypes_, output_shapes_);
+ return Status::OK();
+ }));
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MaterializedDatasetResource>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MaterializedDatasetResource* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_dtypes_, resource->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MaterializedDatasetResource* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_dtypes_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+// TODO(saeta): Make async.
+class MaterializeDatasetOp : public OpKernel {
+ public:
+ explicit MaterializeDatasetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ IndexedDataset* dataset;
+ OP_REQUIRES_OK(ctx,
+ GetIndexedDatasetFromVariantTensor(ctx->input(0), &dataset));
+
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &materialized_resource));
+ core::ScopedUnref unref(materialized_resource);
+ std::shared_ptr<MaterializedIndexedDataset> materialized;
+ OP_REQUIRES_OK(ctx, dataset->MaterializeDataset(&materialized));
+ OP_REQUIRES_OK(
+ ctx, materialized_resource->set_materialized_dataset(materialized));
+ }
+};
+
+// TODO(saeta): Make async
+class IndexedDatasetGet : public OpKernel {
+ public:
+ explicit IndexedDatasetGet(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ MaterializedDatasetResource* materialized_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0),
+ &materialized_resource));
+ auto cleanup = gtl::MakeCleanup([materialized_resource] {
+ materialized_resource->Unref(); // Note: can't use core::ScopedUnref.
+ });
+
+ const Tensor* index_t;
+ OP_REQUIRES_OK(ctx, ctx->input("index", &index_t));
+ // TODO(saeta): Support batch reads (indexes should be non-scalar!)
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(index_t->shape()),
+ errors::InvalidArgument("index must be a scalar"));
+ const uint64 index = index_t->scalar<uint64>()();
+
+ std::vector<Tensor> out_tensors;
+ Status s =
+ materialized_resource->Get(IteratorContext(ctx), index, &out_tensors);
+
+ // Note: Unref materialized_resource to avoid destruction races. (Important
+ // in a [future] async op implementation.)
+ cleanup.release()();
+
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else {
+ auto expected_shapes = materialized_resource->output_shapes();
+ auto expected_types = materialized_resource->output_dtypes();
+ for (size_t i = 0; i < out_tensors.size(); ++i) {
+ OP_REQUIRES(
+ ctx, expected_shapes[i].IsCompatibleWith(out_tensors[i].shape()),
+ errors::Internal(
+ "Materialized dataset output at index ", i,
+ " is incompatible with the expected shape. (Expected: ",
+ expected_shapes[i], ", got: ", out_tensors[i].shape(), ")"));
+ OP_REQUIRES(ctx, out_tensors[i].dtype() == expected_types[i],
+ errors::Internal("Materialized dataset output at index ", i,
+ " was not the expected dtype. (Expected: ",
+ expected_types[i],
+ ", got: ", out_tensors[i].dtype(), ")"));
+ ctx->set_output(i, out_tensors[i]);
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ MaterializedHandleOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
new file mode 100644
index 0000000000..27a8360cbc
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -0,0 +1,119 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+namespace data {
+
+// TODO(saeta): Urgh, this is ugly.
+class MaterializedIndexedDataset {
+ public:
+ virtual ~MaterializedIndexedDataset() = default;
+
+ // Retrieve the element at a given index. The output tensors are stored in
+ // out_tensors.
+ //
+ // If `index` is greater than `Size()`, tensorflow::errors::OutOfRangeError is
+ // returned.
+ //
+ // Get is thread-safe.
+ virtual Status Get(IteratorContext&& ctx, uint64 index,
+ std::vector<Tensor>* out_tensors) const = 0;
+
+ // Size determines the number of elements in this IndexedDataset.
+ //
+ // Size is thread-safe.
+ virtual Status Size(uint64* size) const = 0;
+
+ // Returns a vector of DataType values, representing the respective
+ // element types of each tuple component in the outputs of this dataset.
+ virtual const DataTypeVector& output_dtypes() const = 0;
+
+ // Returns a vector of tensor shapes, representing the respective
+ // (and possibly partially defined) shapes of each tuple component
+ // in the outputs of this dataset.
+ virtual const std::vector<PartialTensorShape>& output_shapes() const = 0;
+};
+
+// IndexedDataset represents a dataset that supports random access in addition
+// to iterator-based sequential access.
+//
+// Note: IndexedDatasets are HIGHLY experimental at this time. Expect
+// significant (backwards incompatible) changes!
+class IndexedDataset : public DatasetBase {
+ public:
+ IndexedDataset(DatasetContext&& ctx) : DatasetBase(std::move(ctx)) {}
+
+ // Materialize (if necessary) the dataset, and return a pointer.
+ // TODO(saeta): Add in `IteratorContext* ctx` when materializing.
+ virtual Status MaterializeDataset(
+ std::shared_ptr<MaterializedIndexedDataset>* materialized) = 0;
+};
+
+// IndexedDatasetOpKernel abstracts away interfacing IndexedDatasets with the
+// rest of the TensorFlow runtime.
+//
+// Most IndexedDataset's will be private members of classes inheriting from this
+// class.
+class IndexedDatasetOpKernel : public OpKernel {
+ public:
+ IndexedDatasetOpKernel(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+ void Compute(OpKernelContext* ctx) final;
+
+ protected:
+ // Subclasses should implement this method. It will be called during Compute
+ // execution.
+ virtual void MakeIndexedDataset(OpKernelContext* ctx,
+ IndexedDataset** output) = 0;
+
+ template <typename T>
+ Status ParseScalarArgument(OpKernelContext* ctx,
+ const StringPiece& argument_name, T* output) {
+ const Tensor* argument_t;
+ TF_RETURN_IF_ERROR(ctx->input(argument_name, &argument_t));
+ if (!TensorShapeUtils::IsScalar(argument_t->shape())) {
+ return errors::InvalidArgument(argument_name, " must be a scalar");
+ }
+ *output = argument_t->scalar<T>()();
+ return Status::OK();
+ }
+};
+
+// Validates and extracts an `IndexedDataset` object from `tensor`.
+//
+// `tensor` must have been written by a call to
+// `StoreIndexedDatasetInVariantTensor`
+//
+// The retrieved pointer isa borrowed reference to the dataset, which is owned
+// by the tensor. The consumer must either acquire its own reference to the
+// dataset by calling `(*out_dataset)->Ref()`, or ensure that `tensor` is not
+// destroyed or mutated while the retrieved pointer is in use.
+Status GetIndexedDatasetFromVariantTensor(const Tensor& tensor,
+ IndexedDataset** out_dataset);
+
+// Stores an `IndexedDataset` object in `tensor.`
+//
+// The ownership of `dataset` is transferred to `tensor`.
+Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
+ Tensor* tensor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
new file mode 100644
index 0000000000..8a88d32f0c
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -0,0 +1,218 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <sys/stat.h>
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/platform/file_system.h"
+
+#include "lmdb.h" // NOLINT(build/include)
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class LMDBDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ const Tensor* filenames_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("filenames", &filenames_tensor));
+ OP_REQUIRES(
+ ctx, filenames_tensor->dims() <= 1,
+ errors::InvalidArgument("`filenames` must be a scalar or a vector."));
+
+ std::vector<string> filenames;
+ filenames.reserve(filenames_tensor->NumElements());
+ for (int i = 0; i < filenames_tensor->NumElements(); ++i) {
+ filenames.push_back(filenames_tensor->flat<string>()(i));
+ }
+
+ *output = new Dataset(ctx, filenames);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const std::vector<string>& filenames)
+ : DatasetBase(DatasetContext(ctx)), filenames_(filenames) {}
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::LMDB")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes =
+ new DataTypeVector({DT_STRING, DT_STRING});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}, {}});
+ return *shapes;
+ }
+
+ string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* filenames = nullptr;
+ TF_RETURN_IF_ERROR(b->AddVector(filenames_, &filenames));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {filenames}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ do {
+ if (mdb_cursor_) {
+ Tensor key_tensor(ctx->allocator({}), DT_STRING, {});
+ key_tensor.scalar<string>()() = string(
+ static_cast<const char*>(mdb_key_.mv_data), mdb_key_.mv_size);
+ out_tensors->emplace_back(std::move(key_tensor));
+
+ Tensor value_tensor(ctx->allocator({}), DT_STRING, {});
+ value_tensor.scalar<string>()() =
+ string(static_cast<const char*>(mdb_value_.mv_data),
+ mdb_value_.mv_size);
+ out_tensors->emplace_back(std::move(value_tensor));
+
+ int val;
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_NEXT);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ ++current_file_index_;
+ }
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+ if (current_file_index_ == dataset()->filenames_.size()) {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+
+ TF_RETURN_IF_ERROR(SetupStreamsLocked(ctx->env()));
+ } while (true);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ return errors::Unimplemented(
+ "Checkpointing is currently not supported for LMDBDataset.");
+ }
+
+ private:
+ Status SetupStreamsLocked(Env* env) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (current_file_index_ >= dataset()->filenames_.size()) {
+ return errors::InvalidArgument(
+ "current_file_index_:", current_file_index_,
+ " >= filenames_.size():", dataset()->filenames_.size());
+ }
+ const string& filename = dataset()->filenames_[current_file_index_];
+
+ int val = mdb_env_create(&mdb_env_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ int flags = MDB_RDONLY | MDB_NOTLS | MDB_NOLOCK;
+
+ struct stat source_stat;
+ if (stat(filename.c_str(), &source_stat) == 0 &&
+ (source_stat.st_mode & S_IFREG)) {
+ flags |= MDB_NOSUBDIR;
+ }
+ val = mdb_env_open(mdb_env_, filename.c_str(), flags, 0664);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_txn_begin(mdb_env_, nullptr, MDB_RDONLY, &mdb_txn_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_dbi_open(mdb_txn_, nullptr, 0, &mdb_dbi_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_);
+ if (val != MDB_SUCCESS) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ val = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, MDB_FIRST);
+ if (val != MDB_SUCCESS && val != MDB_NOTFOUND) {
+ return errors::InvalidArgument(mdb_strerror(val));
+ }
+ if (val == MDB_NOTFOUND) {
+ ResetStreamsLocked();
+ }
+ return Status::OK();
+ }
+ void ResetStreamsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (mdb_env_ != nullptr) {
+ if (mdb_cursor_) {
+ mdb_cursor_close(mdb_cursor_);
+ mdb_cursor_ = nullptr;
+ }
+ mdb_dbi_close(mdb_env_, mdb_dbi_);
+ mdb_txn_abort(mdb_txn_);
+ mdb_env_close(mdb_env_);
+ mdb_txn_ = nullptr;
+ mdb_dbi_ = 0;
+ mdb_env_ = nullptr;
+ }
+ }
+ mutex mu_;
+ size_t current_file_index_ GUARDED_BY(mu_) = 0;
+ MDB_env* mdb_env_ GUARDED_BY(mu_) = nullptr;
+ MDB_txn* mdb_txn_ GUARDED_BY(mu_) = nullptr;
+ MDB_dbi mdb_dbi_ GUARDED_BY(mu_) = 0;
+ MDB_cursor* mdb_cursor_ GUARDED_BY(mu_) = nullptr;
+
+ MDB_val mdb_key_ GUARDED_BY(mu_);
+ MDB_val mdb_value_ GUARDED_BY(mu_);
+ };
+
+ const std::vector<string> filenames_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
new file mode 100644
index 0000000000..2c6179d9f5
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -0,0 +1,482 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct BufferElement {
+ // The producer sets `status` if getting the input element fails.
+ Status status;
+ // The buffered data element.
+ std::vector<Tensor> value;
+};
+
+using FunctionBufferCallback = std::function<void(const BufferElement&)>;
+
+class FunctionBufferingResource : public ResourceBase {
+ public:
+ FunctionBufferingResource(FunctionLibraryRuntime* lib,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ const NameAttrList& func, int64 buffer_size,
+ const string& source_device,
+ const string& target_device,
+ const std::vector<Tensor>& func_args,
+ const DataTypeVector& output_types)
+ : lib_(lib),
+ pflr_(std::move(pflr)),
+ func_(func),
+ buffer_size_(buffer_size),
+ source_device_(source_device),
+ target_device_(target_device),
+ func_args_(func_args),
+ output_types_(output_types),
+ handle_(kInvalidHandle),
+ is_buffering_(false),
+ end_of_sequence_(false),
+ cancelled_(false) {}
+
+ ~FunctionBufferingResource() override {
+ Cancel();
+ }
+
+ string DebugString() override {
+ return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_,
+ "; target_device: ", target_device_);
+ }
+
+ // Instantiates the function the first time it's called. After that it caches
+ // the handle.
+ Status Instantiate() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ // Re-use existing handle if it's been set, effectively caching it.
+ if (handle_ != kInvalidHandle) {
+ return Status::OK();
+ }
+ AttrValueMap attr_values = func_.attr();
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ opts.target = target_device_;
+ return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), opts,
+ &handle_);
+ }
+
+ // Returns true if we've got to the end of the sequence and exhausted the
+ // buffer.
+ bool Finished() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ return end_of_sequence_ && buffer_.empty();
+ }
+
+ // Cancels any buffering / prefetching going on.
+ void Cancel() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ while (is_buffering_) {
+ cond_var_.wait(l);
+ }
+ }
+
+ // Cancels all pending operations and then clears out the state.
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ Cancel();
+ mutex_lock l(mu_);
+ buffer_.clear();
+ requests_.clear();
+ is_buffering_ = false;
+ end_of_sequence_ = false;
+ cancelled_ = false;
+ }
+
+ // If the buffer has anything, runs `callback` on the first element in the
+ // buffer, else schedules the `callback` to be called. Requires `args` and
+ // `lib` in case more function calls need to be scheduled.
+ void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) {
+ bool start_buffering = false;
+ bool produced_output = false;
+ BufferElement buffer_element;
+ {
+ mutex_lock l(mu_);
+ if (!is_buffering_ && !end_of_sequence_) {
+ start_buffering = true;
+ }
+ if (!buffer_.empty()) {
+ produced_output = true;
+ std::swap(buffer_element, buffer_.front());
+ buffer_.pop_front();
+ } else {
+ produced_output = false;
+ requests_.push_back(std::move(callback));
+ }
+ }
+ if (produced_output) {
+ callback(buffer_element);
+ }
+ if (start_buffering) {
+ FillBuffer();
+ }
+ }
+
+ private:
+ void FillBuffer() LOCKS_EXCLUDED(mu_) {
+ FunctionLibraryRuntime::Handle handle;
+ std::vector<FunctionBufferCallback> cancellation_callbacks;
+ std::vector<BufferElement> cancellation_buffer_elements;
+ bool cancelled = false;
+ {
+ mutex_lock l(mu_);
+ handle = handle_;
+ if (cancelled_) {
+ cancelled = true;
+ // Run through and fulfill all pending requests, if possible.
+ while (!requests_.empty()) {
+ if (!buffer_.empty()) {
+ cancellation_buffer_elements.push_back(std::move(buffer_.front()));
+ buffer_.pop_front();
+ cancellation_callbacks.push_back(std::move(requests_.front()));
+ requests_.pop_front();
+ } else {
+ LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: "
+ << requests_.size() << " requests";
+ break;
+ }
+ }
+ is_buffering_ = false;
+ } else {
+ is_buffering_ = true;
+ }
+ }
+ if (cancelled) {
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_buffer_elements[i]);
+ }
+ cond_var_.notify_all();
+ return;
+ }
+ FunctionLibraryRuntime::Options opts;
+ // Copied from CapturedFunction::generate_step_id();
+ opts.step_id = -std::abs(static_cast<int64>(random::New64()));
+ opts.source_device = source_device_;
+ AllocatorAttributes arg_alloc_attr;
+ arg_alloc_attr.set_on_host(true);
+ opts.args_alloc_attrs.push_back(arg_alloc_attr);
+ for (const auto& dtype : output_types_) {
+ AllocatorAttributes ret_alloc_attrs;
+ if (DataTypeAlwaysOnHost(dtype)) {
+ ret_alloc_attrs.set_on_host(true);
+ }
+ opts.rets_alloc_attrs.push_back(ret_alloc_attrs);
+ }
+ if (opts.source_device != target_device_) {
+ opts.remote_execution = true;
+ }
+ opts.create_rendezvous = true;
+ auto* rets = new std::vector<Tensor>;
+ lib_->Run(opts, handle, func_args_, rets,
+ [this, rets](const Status& status) {
+ FunctionBufferCallback callback = nullptr;
+ BufferElement buffer_front;
+ bool restart_buffering = false;
+ {
+ mutex_lock l(mu_);
+ BufferElement buffer_element;
+ buffer_element.status = status;
+ if (status.ok()) {
+ buffer_element.value.swap(*rets);
+ } else {
+ end_of_sequence_ = true;
+ is_buffering_ = false;
+ }
+ buffer_.push_back(std::move(buffer_element));
+ if (!requests_.empty()) {
+ buffer_front = std::move(buffer_.front());
+ buffer_.pop_front();
+ callback = std::move(requests_.front());
+ requests_.pop_front();
+ }
+ if (buffer_.size() < buffer_size_ && !end_of_sequence_) {
+ restart_buffering = true;
+ } else {
+ // When the buffer is full, we don't want to call
+ // FillBuffer() unless we're in cancellation phase in which
+ // case FillBuffer() will do the final cleanup post
+ // cancellation.
+ if (cancelled_) {
+ restart_buffering = true;
+ }
+ is_buffering_ = false;
+ }
+ }
+ if (callback != nullptr) {
+ callback(buffer_front);
+ }
+ if (restart_buffering) {
+ FillBuffer();
+ }
+ });
+ }
+
+ mutex mu_;
+ FunctionLibraryRuntime* lib_;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ NameAttrList func_;
+ const int64 buffer_size_;
+ const string source_device_;
+ const string target_device_;
+ const std::vector<Tensor> func_args_;
+ const DataTypeVector output_types_;
+ FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_);
+ std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
+ std::deque<FunctionBufferCallback> requests_ GUARDED_BY(mu_);
+ bool is_buffering_ GUARDED_BY(mu_);
+ bool end_of_sequence_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_);
+ condition_variable cond_var_;
+};
+
+class FunctionBufferResourceHandleOp : public OpKernel {
+ public:
+ explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), flib_def_(nullptr) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ }
+
+ ~FunctionBufferResourceHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<FunctionBufferingResource>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* string_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg));
+ std::vector<Tensor> func_args;
+ func_args.push_back(*string_arg);
+
+ const string& source_device = ctx->device()->name();
+
+ // Obtain and canonicalize target_device.
+ const Tensor* target_arg;
+ OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg));
+ string target_device;
+ OP_REQUIRES_OK(ctx, DeviceNameUtils::CanonicalizeDeviceName(
+ target_arg->scalar<string>()(), source_device,
+ &target_device));
+
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."));
+
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ OP_REQUIRES_OK(ctx, cinfo_.Init(ctx->resource_manager(), def()));
+ FunctionLibraryRuntime* clone_lib;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr;
+ OP_REQUIRES_OK(ctx, lib->Clone(&flib_def_, &pflr, &clone_lib));
+ // Create the resource.
+ FunctionBufferingResource* buffer;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->resource_manager()->LookupOrCreate<FunctionBufferingResource>(
+ cinfo_.container(), cinfo_.name(), &buffer,
+ [clone_lib, &pflr, &source_device, &target_device, func_args,
+ this](FunctionBufferingResource** ptr) {
+ *ptr = new FunctionBufferingResource(
+ clone_lib, std::move(pflr), func_, buffer_size_,
+ source_device, target_device, func_args, output_types_);
+ return Status::OK();
+ }));
+ core::ScopedUnref s(buffer);
+ OP_REQUIRES_OK(ctx, buffer->Instantiate());
+ initialized_ = true;
+ }
+
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<FunctionBufferingResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ NameAttrList func_;
+ int64 buffer_size_;
+ string container_;
+ string name_;
+ DataTypeVector output_types_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_CPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_GPU)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
+ .Device(DEVICE_SYCL)
+ .HostMemory("resource")
+ .HostMemory("string_arg")
+ .HostMemory("target_device"),
+ FunctionBufferResourceHandleOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Prefetches and fills up a buffer by calling a function that provides the
+// elements to buffer.
+class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
+ public:
+ explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx) {}
+
+ ~FunctionBufferingResourceGetNextOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done);
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer),
+ done);
+
+ if (buffer->Finished()) {
+ buffer->Unref();
+ ctx->SetStatus(errors::OutOfRange("end_of_sequence"));
+ done();
+ return;
+ }
+
+ FunctionBufferCallback callback =
+ [ctx, buffer, done](const BufferElement& buffer_element) {
+ Status s = buffer_element.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ buffer->Unref();
+ done();
+ return;
+ }
+ for (size_t i = 0; i < buffer_element.value.size(); ++i) {
+ ctx->set_output(i, buffer_element.value[i]);
+ }
+ buffer->Unref();
+ done();
+ };
+ buffer->MaybeGet(std::move(callback));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceGetNextOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// Resets the FunctionBufferingResource, cancelling all pending requests and
+// clearing out the buffer.
+class FunctionBufferingResourceResetOp : public OpKernel {
+ public:
+ explicit FunctionBufferingResourceResetOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ ~FunctionBufferingResourceResetOp() override {}
+
+ void Compute(OpKernelContext* ctx) override {
+ ResourceHandle handle;
+ OP_REQUIRES_OK(ctx,
+ HandleFromInput(ctx, "function_buffer_resource", &handle));
+ FunctionBufferingResource* buffer = nullptr;
+ OP_REQUIRES_OK(
+ ctx, LookupResource<FunctionBufferingResource>(ctx, handle, &buffer));
+ core::ScopedUnref s(buffer);
+
+ buffer->Reset();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_CPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_GPU)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#if TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
+ .Device(DEVICE_SYCL)
+ .HostMemory("function_buffer_resource"),
+ FunctionBufferingResourceResetOp);
+#endif // TENSORFLOW_USE_SYCL
+
+class IteratorGetDeviceOp : public OpKernel {
+ public:
+ using OpKernel::OpKernel;
+
+ void Compute(OpKernelContext* ctx) override {
+ // NOTE(mrry): We do not currently Validate that the handle
+ // corresponds to a real IteratorResource, because that symbol is
+ // not exposed from the framework library.
+ Tensor* device_name_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &device_name_t));
+ // NOTE(mrry): Since the operation's input is a resource, we must be
+ // colocated with it, and so we can simply return the current device's
+ // name without looking at the input.
+ device_name_t->scalar<string>()() = ctx->device()->name();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
new file mode 100644
index 0000000000..8d561ca0e3
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -0,0 +1,220 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/util/work_sharder.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ThreadPoolResource : public ResourceBase {
+ public:
+ ThreadPoolResource(Env* env, const ThreadOptions& thread_options,
+ const string& name, int num_threads, bool low_latency_hint,
+ int max_intra_op_parallelism)
+ : thread_pool_(env, thread_options, name, num_threads, low_latency_hint),
+ max_intra_op_parallelism_(max_intra_op_parallelism) {}
+
+ // Schedules fn() for execution in the pool of threads.
+ void Schedule(std::function<void()> fn) {
+ if (max_intra_op_parallelism_ < 0) {
+ thread_pool_.Schedule(std::move(fn));
+ } else {
+ thread_pool_.Schedule(std::bind(
+ [this](std::function<void()> bound_fn) {
+ // TODO(mrry): Consider moving this thread-local configuration to
+ // the threads themselves.
+ ScopedPerThreadMaxParallelism scope(max_intra_op_parallelism_);
+ bound_fn();
+ },
+ std::move(fn)));
+ }
+ }
+
+ string DebugString() override { return "ThreadPoolResource"; }
+
+ private:
+ thread::ThreadPool thread_pool_;
+ const int max_intra_op_parallelism_;
+};
+
+// Creates a handle to a ThreadPool resource. Note that we don't use
+// ResourceOpKernel here because the ThreadPoolResource constructor requires
+// access to `OpKernelContext::env()`, which isn't provided by
+// `ResourceOpKernel<T>::CreateResource()`.
+class ThreadPoolHandleOp : public OpKernel {
+ public:
+ explicit ThreadPoolHandleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("display_name", &display_name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_threads", &num_threads_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("max_intra_op_parallelism",
+ &max_intra_op_parallelism_));
+ OP_REQUIRES(
+ ctx, num_threads_ > 0,
+ errors::InvalidArgument("`num_threads` must be greater than zero."));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel. Ideally the resource should be deleted when it is no longer held
+ // by anyone, but it would break backward compatibility.
+ ~ThreadPoolHandleOp() override {
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->Delete<ThreadPoolResource>(cinfo_.container(), cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ if (!initialized_) {
+ ResourceMgr* mgr = ctx->resource_manager();
+ OP_REQUIRES_OK(ctx, cinfo_.Init(mgr, def()));
+ ThreadPoolResource* resource;
+ OP_REQUIRES_OK(ctx, mgr->LookupOrCreate<ThreadPoolResource>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, ctx](ThreadPoolResource** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new ThreadPoolResource(
+ ctx->env(), {}, display_name_,
+ num_threads_, max_intra_op_parallelism_,
+ false /* low_latency_hint */);
+ return Status::OK();
+ }));
+ initialized_ = true;
+ }
+ OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
+ ctx, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<ThreadPoolResource>()));
+ }
+
+ private:
+ mutex mu_;
+ ContainerInfo cinfo_ GUARDED_BY(mu_);
+ bool initialized_ GUARDED_BY(mu_) = false;
+ string display_name_;
+ int num_threads_;
+ int max_intra_op_parallelism_;
+};
+
+class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ThreadPoolDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ ThreadPoolResource* threadpool_resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
+ &threadpool_resource));
+ core::ScopedUnref unref_iterator(threadpool_resource);
+
+ *output = new Dataset(ctx, input, threadpool_resource);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ ThreadPoolResource* threadpool)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ threadpool_(threadpool) {
+ input_->Ref();
+ threadpool_->Ref();
+ }
+
+ ~Dataset() override {
+ input_->Unref();
+ threadpool_->Unref();
+ }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::ThreadPool")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return "ThreadPoolDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ ThreadPoolResource* pool = dataset()->threadpool_;
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = [pool](std::function<void()> c) {
+ pool->Schedule(std::move(c));
+ };
+ params.stats_aggregator = ctx->stats_aggregator();
+ params.lib = ctx->lib();
+ params.function_library = ctx->function_library();
+ params.allocator_getter = ctx->allocator_getter();
+ IteratorContext threadpool_ctx(params);
+ return input_impl_->GetNext(&threadpool_ctx, out_tensors,
+ end_of_sequence);
+ }
+
+ private:
+ std::unique_ptr<IteratorBase> input_impl_;
+ };
+
+ const DatasetBase* const input_;
+ ThreadPoolResource* const threadpool_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
+ ThreadPoolHandleOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
new file mode 100644
index 0000000000..cd612e0eb2
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -0,0 +1,224 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/hash/hash.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class UniqueDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit UniqueDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ OP_REQUIRES(ctx, input->output_dtypes().size() == 1,
+ errors::InvalidArgument("UniqueDataset only supports "
+ "inputs with a single component."));
+
+ DataType input_dtype = input->output_dtypes()[0];
+ OP_REQUIRES(ctx,
+ input_dtype == DT_INT32 || input_dtype == DT_INT64 ||
+ input_dtype == DT_STRING,
+ errors::InvalidArgument(
+ "UniqueDataset only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component."));
+
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Unique")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override {
+ return strings::StrCat("UniqueDatasetOp::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const typename Iterator::Params& params)
+ : DatasetIterator<Dataset>(params) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ bool saw_new_value;
+ do {
+ saw_new_value = false;
+ out_tensors->clear();
+ TF_RETURN_IF_ERROR(
+ input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
+ if (*end_of_sequence) {
+ break;
+ }
+ DCHECK_EQ(1, out_tensors->size());
+ saw_new_value = unique_elements_.insert((*out_tensors)[0]).second;
+ } while (!saw_new_value);
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ if (input_impl_) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ } else {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("input_impl_empty"), ""));
+ }
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("unique_elements_size"), unique_elements_.size()));
+ size_t i = 0;
+ for (const Tensor& t : unique_elements_) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("unique_elements[", i++, "]")), t));
+ }
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ } else {
+ input_impl_.reset();
+ }
+ int64 num_unique_elements;
+ unique_elements_.clear();
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("unique_elements_size"),
+ &num_unique_elements));
+ for (int64 i = 0; i < num_unique_elements; ++i) {
+ Tensor unique_element;
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("unique_elements[", i, "]")),
+ &unique_element));
+ auto insert_result = unique_elements_.insert(unique_element);
+ if (!insert_result.second) {
+ return errors::InvalidArgument(
+ "Checkpoint contained two unique elements with the same "
+ "value.");
+ }
+ }
+ return Status::OK();
+ }
+
+ private:
+ struct TensorHash {
+ size_t operator()(const Tensor& t) const {
+ if (t.dtype() == DT_INT32 || t.dtype() == DT_INT64) {
+ return Hash64(t.tensor_data().data(), t.tensor_data().size());
+ } else {
+ DCHECK_EQ(DT_STRING, t.dtype());
+ auto flat_t = t.flat<string>();
+ uint64 hash = 0;
+ for (int64 i = 0; i < t.NumElements(); ++i) {
+ hash = Hash64Combine(hash, Hash64(flat_t(i)));
+ }
+ return static_cast<size_t>(hash);
+ }
+ }
+ };
+
+ struct TensorKeyEqual {
+ bool operator()(const Tensor& lhs, const Tensor& rhs) const {
+ if (lhs.shape() != rhs.shape() || lhs.dtype() != rhs.dtype()) {
+ return false;
+ }
+ switch (lhs.dtype()) {
+#define HANDLE_TYPE(T) \
+ case T: \
+ do { \
+ auto lhs_flat = lhs.flat<EnumToDataType<T>::Type>(); \
+ auto rhs_flat = rhs.flat<EnumToDataType<T>::Type>(); \
+ for (int64 i = 0; i < lhs.NumElements(); ++i) { \
+ if (lhs_flat(i) != rhs_flat(i)) { \
+ return false; \
+ } \
+ } \
+ return true; \
+ } while (0)
+
+ HANDLE_TYPE(DT_INT32);
+ HANDLE_TYPE(DT_INT64);
+ HANDLE_TYPE(DT_STRING);
+ default:
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
+ }
+ }
+ };
+
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unordered_set<Tensor, TensorHash, TensorKeyEqual> unique_elements_
+ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* const input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
+ UniqueDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
index 4b6d808af0..a7e3a56727 100644
--- a/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_by_component_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -48,12 +48,12 @@ class FilterByLastComponentDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const DataTypeVector& output_types,
std::vector<PartialTensorShape> output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
output_types_(output_types),
output_shapes_(std::move(output_shapes)) {
@@ -166,5 +166,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterByLastComponentDataset").Device(DEVICE_CPU),
FilterByLastComponentDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/filter_dataset_op.cc b/tensorflow/core/kernels/data/filter_dataset_op.cc
index b11d7cf2ef..00884314a9 100644
--- a/tensorflow/core/kernels/data/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/data/filter_dataset_op.cc
@@ -14,14 +14,16 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -37,14 +39,6 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
FunctionLibraryRuntime::Handle pred_handle;
OP_REQUIRES_OK(ctx,
ctx->function_library()->Instantiate(
@@ -61,9 +55,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Node* ret_node = pred_body->ret_nodes[0];
Node* ret_input_node;
OP_REQUIRES_OK(ctx, ret_node->input_node(0, &ret_input_node));
+
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
if (ret_input_node->def().op() == "_Arg") {
int32 index = -1;
@@ -79,12 +74,12 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
private:
const int graph_def_version_;
- class FilterDatasetBase : public GraphDatasetBase {
+ class FilterDatasetBase : public DatasetBase {
public:
FilterDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)) {
@@ -112,7 +107,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -146,10 +141,18 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<FilterDatasetBase> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<FilterDatasetBase>(params) {}
+ : DatasetIterator<FilterDatasetBase>(params),
+ filtered_elements_(0),
+ dropped_elements_(0) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -159,6 +162,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
// `input_impl_` and `f` are thread-safe. However, if multiple
// threads enter this method, outputs may be observed in a
// non-deterministic order.
+ auto stats_aggregator = ctx->stats_aggregator();
bool matched;
do {
{
@@ -181,8 +185,34 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
if (!matched) {
// Clear the output tensor list since it didn't match.
out_tensors->clear();
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ dropped_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::dropped_elements"),
+ static_cast<float>((dropped_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect
+ // aggregated number of dropped elements for all the pipelines,
+ // exploit tagged_context here.
+ stats_aggregator->IncrementCounter(
+ prefix_end_, "dropped_elements", static_cast<float>(1));
+ }
}
} while (!matched);
+ // TODO(shivaniagrawal): add ratio of dropped_elements and
+ // filtered_elements as a histogram.
+ if (stats_aggregator) {
+ mutex_lock l(mu_);
+ filtered_elements_++;
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::filtered_elements"),
+ static_cast<float>((filtered_elements_)));
+ // TODO(shivaniagrawal): multiple pipelines would collect aggregated
+ // number of filtered elements for all the pipelines, exploit
+ // tagged_context here.
+ stats_aggregator->IncrementCounter(prefix_end_, "filtered_elements",
+ static_cast<float>(1));
+ }
*end_of_sequence = false;
return Status::OK();
}
@@ -195,6 +225,10 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
else
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name("input_impls_empty"), ""));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("filtered_elements"),
+ filtered_elements_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("dropped_elements"),
+ dropped_elements_));
return Status::OK();
}
@@ -205,12 +239,19 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
input_impl_.reset();
else
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("filtered_elements"),
+ &filtered_elements_));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("dropped_elements"),
+ &dropped_elements_));
return Status::OK();
}
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ int64 filtered_elements_ GUARDED_BY(mu_);
+ int64 dropped_elements_ GUARDED_BY(mu_);
+ string prefix_end_;
};
const DatasetBase* const input_;
@@ -278,5 +319,5 @@ REGISTER_KERNEL_BUILDER(Name("FilterDataset").Device(DEVICE_CPU),
FilterDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/flat_map_dataset_op.cc b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
index 3419eed6c6..2fada22a21 100644
--- a/tensorflow/core/kernels/data/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/flat_map_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -39,31 +39,22 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
-
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
@@ -94,7 +85,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -129,7 +120,9 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -243,7 +236,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
private:
Status BuildCurrentElementIteratorLocked(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return dataset::MakeIteratorFromInputElement(
+ return MakeIteratorFromInputElement(
ctx, captured_func_inputs_, element_index_++,
dataset()->captured_func_.get(), prefix(),
&current_element_iterator_);
@@ -283,5 +276,5 @@ REGISTER_KERNEL_BUILDER(Name("FlatMapDataset").Device(DEVICE_CPU),
FlatMapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.cc b/tensorflow/core/kernels/data/generator_dataset_op.cc
index c4dd849b8b..b4367d5a11 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/generator_dataset_op.cc
@@ -19,21 +19,23 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/captured_function.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
+class GeneratorDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::unique_ptr<CapturedFunction> init_func,
std::unique_ptr<CapturedFunction> next_func,
std::unique_ptr<CapturedFunction> finalize_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
init_func_(std::move(init_func)),
next_func_(std::move(next_func)),
finalize_func_(std::move(finalize_func)),
@@ -47,12 +49,21 @@ class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
}
const DataTypeVector& output_dtypes() const override { return output_types_; }
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
string DebugString() const override { return "GeneratorDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -71,6 +82,13 @@ class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
}
}
+ Status Initialize(IteratorContext* ctx) override {
+ TF_RETURN_IF_ERROR(dataset()->init_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->next_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
+ return Status::OK();
+ }
+
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
@@ -79,9 +97,6 @@ class GeneratorDatasetOp::Dataset : public GraphDatasetBase {
if (!initialized_) {
TF_RETURN_IF_ERROR(
dataset()->init_func_->RunWithBorrowedArgs(ctx, {}, &state_));
- // Explicitly instantiate the finalize function here so that
- // we can invoke it in the destructor.
- TF_RETURN_IF_ERROR(dataset()->finalize_func_->Instantiate(ctx));
initialized_ = true;
}
@@ -135,54 +150,31 @@ GeneratorDatasetOp::GeneratorDatasetOp(OpKernelConstruction* ctx)
void GeneratorDatasetOp::MakeDataset(OpKernelContext* ctx,
DatasetBase** output) {
- OpInputList init_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("init_func_other_args",
- &init_func_other_args_input));
- std::vector<Tensor> init_func_other_args;
- init_func_other_args.reserve(init_func_other_args_input.size());
- for (const Tensor& t : init_func_other_args_input) {
- init_func_other_args.push_back(t);
- }
std::unique_ptr<CapturedFunction> init_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(init_func_, std::move(init_func_other_args),
- &init_func));
-
- OpInputList next_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("next_func_other_args",
- &next_func_other_args_input));
- std::vector<Tensor> next_func_other_args;
- next_func_other_args.reserve(next_func_other_args_input.size());
- for (const Tensor& t : next_func_other_args_input) {
- next_func_other_args.push_back(t);
- }
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(
+ init_func_, ctx, "init_func_other_args", &init_func));
+
std::unique_ptr<CapturedFunction> next_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(next_func_, std::move(next_func_other_args),
- &next_func));
-
- OpInputList finalize_func_other_args_input;
- OP_REQUIRES_OK(ctx, ctx->input_list("finalize_func_other_args",
- &finalize_func_other_args_input));
- std::vector<Tensor> finalize_func_other_args;
- finalize_func_other_args.reserve(finalize_func_other_args_input.size());
- for (const Tensor& t : finalize_func_other_args_input) {
- finalize_func_other_args.push_back(t);
- }
- std::unique_ptr<CapturedFunction> finalize_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- finalize_func_, std::move(finalize_func_other_args),
- &finalize_func));
+ next_func_, ctx, "next_func_other_args", &next_func));
+
+ std::unique_ptr<CapturedFunction> finalize_func;
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(finalize_func_, ctx,
+ "finalize_func_other_args",
+ &finalize_func));
*output =
new Dataset(ctx, std::move(init_func), std::move(next_func),
std::move(finalize_func), output_types_, output_shapes_);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("GeneratorDataset").Device(DEVICE_CPU),
GeneratorDatasetOp);
REGISTER_KERNEL_BUILDER(
Name("GeneratorDataset").Device(DEVICE_GPU).HostMemory("handle"),
GeneratorDatasetOp);
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/generator_dataset_op.h b/tensorflow/core/kernels/data/generator_dataset_op.h
index 3f84fa9c2e..d23ed97ec3 100644
--- a/tensorflow/core/kernels/data/generator_dataset_op.h
+++ b/tensorflow/core/kernels/data/generator_dataset_op.h
@@ -17,9 +17,9 @@ limitations under the License.
#define TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
#include "tensorflow/core/framework/dataset.h"
-#include "tensorflow/core/kernels/data/captured_function.h"
namespace tensorflow {
+namespace data {
class GeneratorDatasetOp : public DatasetOpKernel {
public:
@@ -37,5 +37,6 @@ class GeneratorDatasetOp : public DatasetOpKernel {
NameAttrList finalize_func_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_GENERATOR_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
index bcf0adacc7..e7244ee208 100644
--- a/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_reducer_dataset_op.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -29,8 +30,7 @@ namespace {
class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByReducerDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("init_func", &init_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
@@ -66,7 +66,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
std::unique_ptr<CapturedFunction> captured_key_func,
@@ -75,7 +75,7 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_finalize_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
captured_key_func_(std::move(captured_key_func)),
captured_init_func_(std::move(captured_init_func)),
@@ -109,11 +109,10 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), init_func().name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func().name()));
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), finalize_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, init_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func().name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, finalize_func().name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -190,7 +189,14 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_init_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_finalize_func_->Instantiate(ctx));
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
@@ -414,7 +420,6 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
@@ -427,4 +432,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByReducerDataset").Device(DEVICE_CPU),
GroupByReducerDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
index 683a50e71c..14aefe5d54 100644
--- a/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/group_by_window_dataset_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -30,8 +31,7 @@ namespace {
class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
public:
explicit GroupByWindowDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("key_func", &key_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("reduce_func", &reduce_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("window_size_func", &window_size_func_));
@@ -41,50 +41,19 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- // Get captured inputs for the key, reduce, and window_size functions.
- OpInputList key_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("key_func_other_arguments",
- &key_func_other_argument_inputs));
- std::vector<Tensor> key_func_other_arguments;
- key_func_other_arguments.reserve(key_func_other_argument_inputs.size());
- for (const Tensor& t : key_func_other_argument_inputs) {
- key_func_other_arguments.push_back(t);
- }
- OpInputList reduce_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("reduce_func_other_arguments",
- &reduce_func_other_argument_inputs));
- std::vector<Tensor> reduce_func_other_arguments;
- reduce_func_other_arguments.reserve(
- reduce_func_other_argument_inputs.size());
- for (const Tensor& t : reduce_func_other_argument_inputs) {
- reduce_func_other_arguments.push_back(t);
- }
- OpInputList window_size_func_other_argument_inputs;
- OP_REQUIRES_OK(ctx,
- ctx->input_list("window_size_func_other_arguments",
- &window_size_func_other_argument_inputs));
- std::vector<Tensor> window_size_func_other_arguments;
- window_size_func_other_arguments.reserve(
- window_size_func_other_argument_inputs.size());
- for (const Tensor& t : window_size_func_other_argument_inputs) {
- window_size_func_other_arguments.push_back(t);
- }
- // TODO(mrry): Refactor CapturedFunction to share the runtime
- // state between multiple functions?
std::unique_ptr<CapturedFunction> captured_key_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- key_func_, std::move(key_func_other_arguments),
- &captured_key_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(key_func_, ctx,
+ "key_func_other_arguments",
+ &captured_key_func));
std::unique_ptr<CapturedFunction> captured_reduce_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(reduce_func_,
- std::move(reduce_func_other_arguments),
- &captured_reduce_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(reduce_func_, ctx,
+ "reduce_func_other_arguments",
+ &captured_reduce_func));
std::unique_ptr<CapturedFunction> captured_window_size_func;
- OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- window_size_func_, std::move(window_size_func_other_arguments),
- &captured_window_size_func));
+ OP_REQUIRES_OK(ctx,
+ CapturedFunction::Create(window_size_func_, ctx,
+ "window_size_func_other_arguments",
+ &captured_window_size_func));
*output = new Dataset(
ctx, input, key_func_, reduce_func_, window_size_func_,
@@ -93,7 +62,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& key_func, const NameAttrList& reduce_func,
@@ -103,7 +72,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_window_size_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
key_func_(key_func),
reduce_func_(reduce_func),
@@ -139,10 +108,9 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), key_func_.name()));
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), reduce_func_.name()));
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), window_size_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, key_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, reduce_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, window_size_func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -205,7 +173,13 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ TF_RETURN_IF_ERROR(dataset()->captured_key_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(dataset()->captured_reduce_func_->Instantiate(ctx));
+ TF_RETURN_IF_ERROR(
+ dataset()->captured_window_size_func_->Instantiate(ctx));
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx,
@@ -532,7 +506,6 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList key_func_;
@@ -544,4 +517,5 @@ REGISTER_KERNEL_BUILDER(Name("GroupByWindowDataset").Device(DEVICE_CPU),
GroupByWindowDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/interleave_dataset_op.cc b/tensorflow/core/kernels/data/interleave_dataset_op.cc
index 8fee29d4d0..0aa802b874 100644
--- a/tensorflow/core/kernels/data/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/interleave_dataset_op.cc
@@ -1,4 +1,3 @@
-
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -22,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -40,14 +39,6 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
const Tensor* cycle_length_t;
OP_REQUIRES_OK(ctx, ctx->input("cycle_length", &cycle_length_t));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(cycle_length_t->shape()),
@@ -67,8 +58,8 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
errors::InvalidArgument("block_length must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, func_, std::move(captured_func), cycle_length,
@@ -76,14 +67,14 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
int64 block_length, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
@@ -117,7 +108,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
@@ -156,7 +147,9 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
void AdvanceToNextInCycle() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
@@ -200,7 +193,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(input_impl_->GetNext(
ctx, &args_list_[cycle_index_], &end_of_input_));
if (!end_of_input_) {
- TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[cycle_index_], cycle_index_,
dataset()->captured_func_.get(), prefix(),
&current_elements_[cycle_index_]));
@@ -287,7 +280,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
&args_list_[idx][i]));
}
- TF_RETURN_IF_ERROR(dataset::MakeIteratorFromInputElement(
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
prefix(), &current_elements_[idx]));
TF_RETURN_IF_ERROR(
@@ -329,5 +322,5 @@ REGISTER_KERNEL_BUILDER(Name("InterleaveDataset").Device(DEVICE_CPU),
InterleaveDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.cc b/tensorflow/core/kernels/data/iterator_ops.cc
index da9d29dd76..7a833668ac 100644
--- a/tensorflow/core/kernels/data/iterator_ops.cc
+++ b/tensorflow/core/kernels/data/iterator_ops.cc
@@ -36,7 +36,7 @@ limitations under the License.
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -44,43 +44,6 @@ namespace {
const char kIteratorVariantTypeName[] = "tensorflow::Iterator";
-Status VerifyTypesMatch(const DataTypeVector& expected,
- const DataTypeVector& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " types but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (expected[i] != received[i]) {
- return errors::InvalidArgument("Data type mismatch at component ", i,
- ": expected ", DataTypeString(expected[i]),
- " but got ", DataTypeString(received[i]),
- ".");
- }
- }
- return Status::OK();
-}
-
-Status VerifyShapesCompatible(const std::vector<PartialTensorShape>& expected,
- const std::vector<PartialTensorShape>& received) {
- if (expected.size() != received.size()) {
- return errors::InvalidArgument(
- "Number of components does not match: expected ", expected.size(),
- " shapes but got ", received.size(), ".");
- }
- for (size_t i = 0; i < expected.size(); ++i) {
- if (!expected[i].IsCompatibleWith(received[i])) {
- return errors::InvalidArgument("Incompatible shapes at component ", i,
- ": expected ", expected[i].DebugString(),
- " but got ", received[i].DebugString(),
- ".");
- }
- }
-
- return Status::OK();
-}
-
} // namespace
class IteratorResource : public ResourceBase {
@@ -104,9 +67,8 @@ class IteratorResource : public ResourceBase {
bool* end_of_sequence) {
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
if (captured_iterator) {
- if (lib_ != nullptr) {
- ctx->set_lib(lib_);
- }
+ CHECK_NOTNULL(lib_);
+ ctx->set_lib(lib_);
return captured_iterator->GetNext(ctx, out_tensors, end_of_sequence);
} else {
return errors::FailedPrecondition(
@@ -130,7 +92,7 @@ class IteratorResource : public ResourceBase {
Status Restore(OpKernelContext* ctx, IteratorStateReader* reader) {
string serialized_graph_def;
- TF_RETURN_IF_ERROR(reader->ReadScalar(GraphDatasetBase::kDatasetGraphKey,
+ TF_RETURN_IF_ERROR(reader->ReadScalar(DatasetBase::kDatasetGraphKey,
&serialized_graph_def));
GraphDef graph_def;
if (!graph_def.ParseFromString(serialized_graph_def)) {
@@ -138,7 +100,7 @@ class IteratorResource : public ResourceBase {
}
string output_node;
TF_RETURN_IF_ERROR(reader->ReadScalar(
- GraphDatasetBase::kDatasetGraphOutputNodeKey, &output_node));
+ DatasetBase::kDatasetGraphOutputNodeKey, &output_node));
DatasetBase* dataset = nullptr;
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
@@ -161,9 +123,11 @@ class IteratorResource : public ResourceBase {
graph_runner.Run(&graph, lib, {}, {output_node}, &outputs));
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(outputs[0], &dataset));
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(lib);
+ TF_RETURN_IF_ERROR(
+ dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
TF_RETURN_IF_ERROR(set_iterator(std::move(iterator)));
std::shared_ptr<IteratorBase> captured_iterator(iterator_);
@@ -198,6 +162,8 @@ class IteratorResource : public ResourceBase {
return lib_def_;
}
+ FunctionLibraryRuntime* function_library_runtime() { return lib_; }
+
// Transfers ownership of iterator to this. This method is thread-safe.
Status set_iterator(std::unique_ptr<IteratorBase> iterator) {
if (iterator) {
@@ -233,6 +199,8 @@ class IteratorResource : public ResourceBase {
const std::vector<PartialTensorShape> output_shapes_;
};
+namespace {
+
// Helper class for reading data from a VariantTensorData object.
class VariantTensorDataReader : public IteratorStateReader {
public:
@@ -258,7 +226,7 @@ class VariantTensorDataReader : public IteratorStateReader {
}
bool Contains(StringPiece key) override {
- return map_.find(key.ToString()) != map_.end();
+ return map_.find(string(key)) != map_.end();
}
private:
@@ -279,18 +247,18 @@ class VariantTensorDataReader : public IteratorStateReader {
template <typename T>
Status ReadScalarInternal(StringPiece key, T* val) {
- if (map_.find(key.ToString()) == map_.end()) {
+ if (map_.find(string(key)) == map_.end()) {
return errors::NotFound(key);
}
- *val = data_->tensors(map_[key.ToString()]).scalar<T>()();
+ *val = data_->tensors(map_[string(key)]).scalar<T>()();
return Status::OK();
}
Status ReadTensorInternal(StringPiece key, Tensor* val) {
- if (map_.find(key.ToString()) == map_.end()) {
+ if (map_.find(string(key)) == map_.end()) {
return errors::NotFound(key);
}
- *val = data_->tensors(map_[key.ToString()]);
+ *val = data_->tensors(map_[string(key)]);
return Status::OK();
}
@@ -339,7 +307,7 @@ class VariantTensorDataWriter : public IteratorStateWriter {
// Write key to the metadata proto. This gets written to `data_`
// when `Flush()` is called. We do this lazily to avoid multiple
// serialization calls.
- metadata_proto_.add_keys(key.ToString());
+ metadata_proto_.add_keys(string(key));
// Update tensors.
*(data_->add_tensors()) = val;
@@ -398,12 +366,12 @@ class IteratorStateVariant {
}
string TypeName() const { return kIteratorVariantTypeName; }
void Encode(VariantTensorData* data) const { *data = *data_; }
- bool Decode(const VariantTensorData& data) {
+ bool Decode(VariantTensorData data) {
if (data.type_name() != TypeName()) {
return false;
}
std::unique_ptr<VariantTensorData> tensor_data(new VariantTensorData);
- *tensor_data = data;
+ std::swap(*tensor_data, data);
std::unique_ptr<VariantTensorDataReader> reader(
new VariantTensorDataReader(tensor_data.get()));
status_ = reader->status();
@@ -440,6 +408,8 @@ class IteratorStateVariant {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(IteratorStateVariant,
kIteratorVariantTypeName);
+} // namespace
+
// Note that IteratorHandleOp holds a reference to the resource it creates. If
// cleaning up resources with DestroyResourceOp is important, consider creating
// resource containers with AnonymousIteratorHandleOp instead.
@@ -611,12 +581,16 @@ void MakeIteratorOp::Compute(OpKernelContext* ctx) {
ctx, LookupResource(ctx, HandleFromInput(ctx, 1), &iterator_resource));
core::ScopedUnref unref(iterator_resource);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
- OP_REQUIRES_OK(ctx, dataset->MakeIterator(&iter_ctx, "Iterator", &iterator));
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(iterator_resource->function_library_runtime());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
OP_REQUIRES_OK(ctx, iterator_resource->set_iterator(std::move(iterator)));
}
+namespace {
+
class ToSingleElementOp : public AsyncOpKernel {
public:
explicit ToSingleElementOp(OpKernelConstruction* ctx)
@@ -633,11 +607,11 @@ class ToSingleElementOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
- dataset->MakeIterator(&iter_ctx, "SingleElementIterator", &iterator),
+ dataset->MakeIterator(IteratorContext(ctx), "SingleElementIterator",
+ &iterator),
done);
// NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
@@ -651,8 +625,8 @@ class ToSingleElementOp : public AsyncOpKernel {
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence = false;
- Status s =
- raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ Status s = raw_iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
if (!s.ok()) {
ctx->SetStatus(s);
return;
@@ -667,8 +641,8 @@ class ToSingleElementOp : public AsyncOpKernel {
}
components.clear();
- Status s2 =
- raw_iterator->GetNext(&iter_ctx, &components, &end_of_sequence);
+ Status s2 = raw_iterator->GetNext(IteratorContext(ctx), &components,
+ &end_of_sequence);
if (!s2.ok()) {
ctx->SetStatus(s2);
return;
@@ -685,6 +659,115 @@ class ToSingleElementOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class ReduceDatasetOp : public AsyncOpKernel {
+ public:
+ explicit ReduceDatasetOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ background_worker_(
+ ctx->env(),
+ strings::StrCat("reduce_thread_", SanitizeThreadSuffix(name()))) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &reduce_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
+ }
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ // The call to `iterator->GetNext()` may block and depend on an
+ // inter-op thread pool thread, so we issue the call from the
+ // owned thread pool.
+ background_worker_.Schedule([this, ctx, done]() {
+ DatasetBase* dataset;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
+ OpInputList inputs;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("initial_state", &inputs),
+ done);
+ std::vector<Tensor> state(inputs.begin(), inputs.end());
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ CapturedFunction::Create(reduce_func_, ctx, "other_arguments",
+ use_inter_op_parallelism_, &captured_func),
+ done);
+
+ IteratorContext iter_ctx(ctx);
+ OP_REQUIRES_OK_ASYNC(ctx, captured_func->Instantiate(&iter_ctx), done);
+
+ std::unique_ptr<IteratorBase> iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, dataset->MakeIterator(&iter_ctx, "ReduceIterator", &iterator),
+ done);
+
+ // NOTE(jsimsa): We must destroy the iterator before calling `done()`, to
+ // avoid destruction races.
+ IteratorBase* raw_iterator = iterator.release();
+ auto cleanup = gtl::MakeCleanup([raw_iterator, done] {
+ delete raw_iterator;
+ done();
+ });
+
+ // Iterate through the input dataset.
+ Status status;
+ while (true) {
+ std::vector<Tensor> next_input_element;
+ bool end_of_input;
+ status = raw_iterator->GetNext(&iter_ctx, &next_input_element,
+ &end_of_input);
+ if (!status.ok() || end_of_input) {
+ break;
+ }
+
+ // Run the reduce function to update the current state.
+ std::vector<Tensor> args;
+ args.reserve(state.size() + next_input_element.size());
+ std::copy(state.begin(), state.end(), std::back_inserter(args));
+ std::copy(next_input_element.begin(), next_input_element.end(),
+ std::back_inserter(args));
+
+ std::vector<Tensor> reduce_func_output;
+ status =
+ captured_func->Run(&iter_ctx, std::move(args), &reduce_func_output);
+ if (!status.ok()) {
+ break;
+ }
+ std::swap(reduce_func_output, state);
+ }
+
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ return;
+ }
+ for (int i = 0; i < state.size(); ++i) {
+ OP_REQUIRES_ASYNC(
+ ctx, state[i].dtype() == output_types_[i],
+ errors::InvalidArgument(
+ "The result does not match the expected type for component ", i,
+ ". Expected: ", DataTypeString(output_types_[i]),
+ ". Actual: ", DataTypeString(state[i].dtype()), "."),
+ done);
+ OP_REQUIRES_ASYNC(
+ ctx, output_shapes_[i].IsCompatibleWith(state[i].shape()),
+ errors::InvalidArgument(
+ "The result does not match the expected shape for component ",
+ i, ". Expected: ", output_shapes_[i].DebugString(),
+ ". Actual: ", state[i].shape().DebugString(), "."),
+ done);
+ ctx->set_output(i, state[i]);
+ }
+ });
+ }
+
+ private:
+ NameAttrList reduce_func_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
+ BackgroundWorker background_worker_;
+};
+
class OneShotIteratorOp : public AsyncOpKernel {
public:
explicit OneShotIteratorOp(OpKernelConstruction* ctx)
@@ -836,9 +919,11 @@ class OneShotIteratorOp : public AsyncOpKernel {
// factory function.
DatasetBase* dataset;
TF_RETURN_IF_ERROR(GetDatasetFromVariantTensor(return_values[0], &dataset));
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iter;
- TF_RETURN_IF_ERROR(dataset->MakeIterator(&iter_ctx, "Iterator", &iter));
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(lib);
+ TF_RETURN_IF_ERROR(
+ dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iter));
TF_RETURN_IF_ERROR((*iterator)->set_iterator(std::move(iter)));
(*iterator)->Ref();
@@ -880,6 +965,8 @@ class OneShotIteratorOp : public AsyncOpKernel {
const int graph_def_version_;
};
+} // namespace
+
void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
IteratorResource* iterator;
OP_REQUIRES_OK_ASYNC(
@@ -922,39 +1009,35 @@ void IteratorGetNextOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
std::move(done)));
}
-class IteratorGetNextSyncOp : public OpKernel {
- public:
- explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
-
- void Compute(OpKernelContext* ctx) override {
- IteratorResource* iterator;
- OP_REQUIRES_OK(ctx,
- LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
- core::ScopedUnref unref_iterator(iterator);
-
- std::vector<Tensor> components;
- bool end_of_sequence = false;
-
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.function_library = iterator->function_library();
- DeviceBase* device = ctx->function_library()->device();
- params.allocator_getter = [device](AllocatorAttributes attrs) {
- return device->GetAllocator(attrs);
- };
- IteratorContext iter_ctx(std::move(params));
+void IteratorGetNextSyncOp::Compute(OpKernelContext* ctx) {
+ IteratorResource* iterator;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator));
+ core::ScopedUnref unref_iterator(iterator);
+
+ std::vector<Tensor> components;
+ bool end_of_sequence = false;
+
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
- OP_REQUIRES_OK(ctx,
- iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
- OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
+ OP_REQUIRES_OK(ctx,
+ iterator->GetNext(&iter_ctx, &components, &end_of_sequence));
+ OP_REQUIRES(ctx, !end_of_sequence, errors::OutOfRange("End of sequence"));
- for (int i = 0; i < components.size(); ++i) {
- // TODO(mrry): Check that the shapes match the shape attrs.
- ctx->set_output(i, components[i]);
- }
+ for (int i = 0; i < components.size(); ++i) {
+ // TODO(mrry): Check that the shapes match the shape attrs.
+ ctx->set_output(i, components[i]);
}
-};
+}
+
+namespace {
class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
public:
@@ -1036,6 +1119,8 @@ class IteratorGetNextAsOptionalOp : public AsyncOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+} // namespace
+
void IteratorToStringHandleOp::Compute(OpKernelContext* ctx) {
const Tensor& resource_handle_t = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
@@ -1107,6 +1192,8 @@ void IteratorFromStringHandleOp::Compute(OpKernelContext* ctx) {
resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
}
+namespace {
+
class SerializeIteratorOp : public OpKernel {
public:
explicit SerializeIteratorOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -1168,6 +1255,8 @@ REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE_GPU),
AnonymousIteratorHandleOp);
REGISTER_KERNEL_BUILDER(Name("DatasetToSingleElement").Device(DEVICE_CPU),
ToSingleElementOp);
+REGISTER_KERNEL_BUILDER(Name("ReduceDataset").Device(DEVICE_CPU),
+ ReduceDatasetOp);
REGISTER_KERNEL_BUILDER(Name("OneShotIterator").Device(DEVICE_CPU),
OneShotIteratorOp);
REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE_CPU),
@@ -1201,4 +1290,7 @@ REGISTER_KERNEL_BUILDER(Name("SerializeIterator").Device(DEVICE_CPU),
REGISTER_KERNEL_BUILDER(Name("DeserializeIterator").Device(DEVICE_CPU),
DeserializeIteratorOp);
+} // namespace
+
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/iterator_ops.h b/tensorflow/core/kernels/data/iterator_ops.h
index e426febcce..8a2b2639a7 100644
--- a/tensorflow/core/kernels/data/iterator_ops.h
+++ b/tensorflow/core/kernels/data/iterator_ops.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
namespace tensorflow {
+namespace data {
class IteratorResource;
@@ -116,6 +117,13 @@ class IteratorGetNextOp : public AsyncOpKernel {
BackgroundWorker background_worker_;
};
+class IteratorGetNextSyncOp : public OpKernel {
+ public:
+ explicit IteratorGetNextSyncOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override;
+};
+
class IteratorToStringHandleOp : public OpKernel {
public:
explicit IteratorToStringHandleOp(OpKernelConstruction* ctx)
@@ -135,6 +143,7 @@ class IteratorFromStringHandleOp : public OpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_ITERATOR_OPS_H_
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 51a7fd23a8..bf08970560 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
+#include <atomic>
#include <utility>
#include "tensorflow/core/common_runtime/function.h"
@@ -26,20 +27,22 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -49,14 +52,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 batch_size;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "batch_size", &batch_size));
OP_REQUIRES(
@@ -77,7 +72,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx,
+ num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
@@ -92,8 +88,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ParseScalarArgument(ctx, "drop_remainder", &drop_remainder));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, batch_size, num_parallel_calls,
drop_remainder, output_types_, output_shapes_, func_,
@@ -101,7 +97,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 batch_size,
int64 num_parallel_calls, bool drop_remainder,
@@ -110,7 +106,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const Eigen::ThreadPoolDevice* device)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
batch_size_(batch_size),
num_parallel_calls_(num_parallel_calls),
@@ -147,7 +143,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), map_fn_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, map_fn_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
Node* batch_size_node;
@@ -190,21 +186,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
+ : DatasetIterator<Dataset>(params),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ mutex_lock l(*mu_);
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
+ }
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -212,25 +223,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
std::shared_ptr<BatchResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
- cond_var_.wait(l);
+ RecordStop(ctx);
+ cond_var_->wait(l);
+ RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
+ cond_var_->notify_all();
}
- cond_var_.notify_all();
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -246,7 +259,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
@@ -287,7 +300,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void Callback(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<BatchResult>& result,
const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
+ int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -323,18 +336,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void CallCompleted(const std::shared_ptr<BatchResult>& result)
- LOCKS_EXCLUDED(mu_) {
- {
- mutex_lock l(mu_);
- num_calls_--;
- result->num_calls--;
- }
- cond_var_.notify_all();
+ LOCKS_EXCLUDED(*mu_) {
+ mutex_lock l(*mu_);
+ num_calls_--;
+ result->num_calls--;
+ cond_var_->notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result,
- int64 offset) LOCKS_EXCLUDED(mu_) {
+ int64 offset) LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
@@ -363,7 +374,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
ctx.get(), std::move(input_element), return_values.get(),
[this, ctx, result, return_values, offset](Status status) {
Callback(ctx, result, return_values, offset, status);
- });
+ },
+ prefix());
},
ctx, std::move(input_element)));
}
@@ -390,7 +402,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -420,11 +432,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
result->output_allocated = true;
}
- int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
- dataset()->batch_size_;
- }
-
Status ProcessResult(IteratorContext* ctx,
const std::shared_ptr<BatchResult>& result,
std::vector<Tensor>* out_tensors,
@@ -471,28 +478,36 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
- new_calls.reserve(dataset()->num_parallel_calls_);
+ RecordStart(ctx.get());
+ auto stop_cleanup =
+ gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
+ int64 max_batch_results =
+ (num_parallel_calls + dataset()->batch_size_ - 1) /
+ dataset()->batch_size_;
+ return num_calls_ >= num_parallel_calls ||
+ (batch_results_.size() > max_batch_results ||
+ (batch_results_.size() == max_batch_results &&
+ call_counter_ % dataset()->batch_size_ == 0));
+ };
while (true) {
{
- mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= dataset()->num_parallel_calls_ ||
- batch_results_.size() > MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ == 0))) {
- cond_var_.wait(l);
+ mutex_lock l(*mu_);
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
+ cond_var_->wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < dataset()->num_parallel_calls_ &&
- (batch_results_.size() < MaxBatchResults() ||
- (batch_results_.size() == MaxBatchResults() &&
- call_counter_ % dataset()->batch_size_ != 0))) {
+ while (!busy()) {
if (call_counter_ % dataset()->batch_size_ == 0) {
batch_results_.emplace_back(
new BatchResult(dataset()->batch_size_));
@@ -511,7 +526,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
- size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
@@ -556,7 +571,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadStatus(IteratorStateReader* reader, const string& prefix,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
@@ -574,7 +589,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -615,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
static_cast<int64>(status.code())));
@@ -629,22 +644,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the callback threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
- // schedule new calls when the number of in-flight calls is less than the
- // user specified level of parallelism and there are slots available in
- // the `batch_results_` buffer.
- condition_variable cond_var_;
+ // schedule new calls when the number of in-flight calls is less than
+ // `num_parallel_calls_->value` and there are slots available in the
+ // `batch_results_` buffer.
+ const std::shared_ptr<condition_variable> cond_var_;
+ // Identifies the maximum number of parallel calls.
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
- int64 call_counter_ GUARDED_BY(mu_) = 0;
+ int64 call_counter_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
- std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
@@ -659,7 +676,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
- const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
@@ -673,5 +689,5 @@ REGISTER_KERNEL_BUILDER(Name("MapAndBatchDatasetV2").Device(DEVICE_CPU),
MapAndBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_dataset_op.cc b/tensorflow/core/kernels/data/map_dataset_op.cc
index ec9e12453b..f112e1dc43 100644
--- a/tensorflow/core/kernels/data/map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_dataset_op.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,41 +28,34 @@ namespace {
class MapDatasetOp : public UnaryDatasetOpKernel {
public:
- explicit MapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ explicit MapDatasetOp(OpKernelConstruction* ctx) : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(captured_func),
output_types_, output_shapes_);
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
std::unique_ptr<CapturedFunction> captured_func,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
captured_func_(std::move(captured_func)),
@@ -92,7 +85,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
@@ -127,7 +120,9 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -181,14 +176,14 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList func_;
+ bool use_inter_op_parallelism_;
};
REGISTER_KERNEL_BUILDER(Name("MapDataset").Device(DEVICE_CPU), MapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/map_defun_op.cc b/tensorflow/core/kernels/data/map_defun_op.cc
index d66716ef66..705b0393de 100644
--- a/tensorflow/core/kernels/data/map_defun_op.cc
+++ b/tensorflow/core/kernels/data/map_defun_op.cc
@@ -18,18 +18,20 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_util.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/util/batch_util.h"
#include "tensorflow/core/util/reffed_status_callback.h"
namespace tensorflow {
+namespace data {
namespace {
void SetRunOptions(OpKernelContext* ctx, FunctionLibraryRuntime::Options* opts,
bool always_collect_stats) {
opts->step_id = ctx->step_id();
opts->rendezvous = ctx->rendezvous();
- opts->cancellation_manager = ctx->cancellation_manager();
if (always_collect_stats) {
opts->stats_collector = ctx->stats_collector();
}
@@ -61,105 +63,186 @@ class MapDefunOp : public AsyncOpKernel {
~MapDefunOp() override {}
void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
- int64 batch_size = ctx->input(0).dim_size(0);
- // Inputs
- auto* args = new std::vector<Tensor>;
- auto* arg_shapes = new std::vector<TensorShape>;
- arg_shapes->reserve(ctx->num_inputs());
- args->reserve(ctx->num_inputs());
-
- for (size_t i = 0; i < ctx->num_inputs(); ++i) {
- args->push_back(ctx->input(i));
- arg_shapes->push_back(ctx->input(i).shape());
- arg_shapes->at(i).RemoveDim(0); // Remove the first batch dimension
- OP_REQUIRES_ASYNC(
- ctx, batch_size == ctx->input(i).dim_size(0),
- errors::InvalidArgument("All inputs must have the same dimension 0."),
- done);
- }
+ ComputeOptions* compute_opts = nullptr;
- // Outputs
- auto* output = new OpOutputList;
- OP_REQUIRES_OK_ASYNC(ctx, ctx->output_list("output", output), done);
+ OP_REQUIRES_OK_ASYNC(ctx, SetupArgs(ctx, &compute_opts), done);
- for (size_t i = 0; i < output_types().size(); ++i) {
- Tensor* out = nullptr;
- TensorShape output_shape = output_shapes_.at(i);
- output_shape.InsertDim(0, batch_size);
- OP_REQUIRES_OK_ASYNC(ctx, output->allocate(i, output_shape, &out), done);
- }
+ Status s = SetupOutputs(ctx, compute_opts);
+ if (!s.ok()) delete compute_opts;
+ OP_REQUIRES_OK_ASYNC(ctx, s, done);
- SetRunOptions(ctx, &opts_, false);
+ FunctionLibraryRuntime::Options opts;
+ SetRunOptions(ctx, &opts, false);
// Run loop
StatusCallback callback = std::bind(
- [](OpKernelContext* ctx, std::vector<Tensor>* args,
- std::vector<TensorShape>* arg_shapes, OpOutputList* output,
+ [](OpKernelContext* ctx, ComputeOptions* compute_opts,
DoneCallback& done, const Status& status) {
- delete args;
- delete arg_shapes;
- delete output;
+ delete compute_opts;
ctx->SetStatus(status);
done();
},
- ctx, args, arg_shapes, output, std::move(done), std::placeholders::_1);
+ ctx, compute_opts, std::move(done), std::placeholders::_1);
auto* refcounted = new ReffedStatusCallback(std::move(callback));
- for (size_t i = 1; i < static_cast<size_t>(batch_size); ++i) {
- // Start from i = 1 because refcounted is initialized with refcount = 1
+ CancellationManager* parent_mgr = ctx->cancellation_manager();
+
+ for (size_t i = 0; i < static_cast<size_t>(compute_opts->batch_size); ++i) {
+ // We use a different cancellation manager each time the function is run
+ // to avoid the race condition between a function run error and other
+ // functions being cancelled as a result.
+ CancellationManager* c_mgr = new CancellationManager;
+ CancellationToken token = parent_mgr->get_cancellation_token();
+ const bool success = parent_mgr->RegisterCallback(
+ token, [c_mgr]() { c_mgr->StartCancel(); });
+
+ opts.cancellation_manager = c_mgr;
+ if (!success) {
+ delete c_mgr;
+ refcounted->UpdateStatus(errors::Cancelled(
+ "MapDefunOp functions cancelled because parent graph cancelled"));
+ break;
+ }
+
+ auto* call_frame = new MapFunctionCallFrame(compute_opts, this, i);
+
refcounted->Ref();
+ ctx->function_library()->Run(opts, func_handle_, call_frame,
+ [call_frame, refcounted, c_mgr, parent_mgr,
+ token](const Status& func_status) {
+ parent_mgr->DeregisterCallback(token);
+ delete c_mgr;
+ delete call_frame;
+ refcounted->UpdateStatus(func_status);
+ refcounted->Unref();
+ });
}
- for (size_t i = 0; i < static_cast<size_t>(batch_size); ++i) {
- auto* call_frame =
- new MapFunctionCallFrame(*args, *arg_shapes, output, this, i);
- ctx->function_library()->Run(
- opts_, func_handle_, call_frame,
- [call_frame, refcounted](const Status& func_status) {
- delete call_frame;
- refcounted->UpdateStatus(func_status);
- refcounted->Unref();
- });
- }
+
+ // Unref 1 because refcounted is initialized with refcount = 1
+ refcounted->Unref();
}
private:
FunctionLibraryRuntime::Handle func_handle_;
- FunctionLibraryRuntime::Options opts_;
- std::vector<TensorShape> output_shapes_;
+ std::vector<PartialTensorShape> output_shapes_;
+
+ struct ComputeOptions {
+ // These vary per MapDefunOp::ComputeAsync call, but must persist until
+ // all calls to the function are complete. This struct also encapsulates
+ // all the components that need to be passed to each MapFunctionCallFrame.
+
+ OpInputList args;
+ const std::vector<TensorShape> arg_shapes;
+ OpInputList captured_inputs;
+ const int64 batch_size;
+
+ // Output of a compute call
+ std::vector<PartialTensorShape> output_shapes GUARDED_BY(mu);
+ OpOutputList output GUARDED_BY(mu);
+ mutex mu;
+
+ // Create a copy of output_shapes because every `Compute` may expect a
+ // different output shape.
+ ComputeOptions(OpInputList args, OpInputList captured_inputs,
+ std::vector<TensorShape> arg_shapes, int64 batch_size,
+ const std::vector<PartialTensorShape>& output_shapes_attr)
+ : args(args),
+ arg_shapes(std::move(arg_shapes)),
+ captured_inputs(captured_inputs),
+ batch_size(batch_size),
+ output_shapes(output_shapes_attr) {}
+ };
+
+ // Get inputs to Compute and check that they are valid.
+ Status SetupArgs(OpKernelContext* ctx, ComputeOptions** compute_opts) {
+ OpInputList arguments;
+ TF_RETURN_IF_ERROR(ctx->input_list("arguments", &arguments));
+ OpInputList captured_inputs;
+ TF_RETURN_IF_ERROR(ctx->input_list("captured_inputs", &captured_inputs));
+
+ int64 batch_size = arguments[0].dims() > 0 ? arguments[0].dim_size(0) : -1;
+
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ if (arguments[i].dims() == 0) {
+ return errors::InvalidArgument(
+ "All inputs must have rank at least 1. Input ", i,
+ " has a rank of 0.");
+ } else if (arguments[i].dim_size(0) != batch_size) {
+ return errors::InvalidArgument(
+ "All inputs must have the same dimension 0. Input ", i,
+ " has leading dimension ", ctx->input(i).dim_size(0),
+ ", while all previous inputs have leading dimension ", batch_size);
+ }
+ }
+
+ std::vector<TensorShape> arg_shapes;
+ arg_shapes.reserve(arguments.size());
+
+ for (size_t i = 0; i < arguments.size(); ++i) {
+ arg_shapes.push_back(arguments[i].shape());
+ arg_shapes.at(i).RemoveDim(0);
+ }
+
+ *compute_opts =
+ new ComputeOptions(arguments, captured_inputs, std::move(arg_shapes),
+ batch_size, output_shapes_);
+ return Status::OK();
+ }
+
+ Status SetupOutputs(OpKernelContext* ctx, ComputeOptions* opts) {
+ mutex_lock l(opts->mu);
+ TF_RETURN_IF_ERROR(ctx->output_list("output", &opts->output));
+
+ for (size_t i = 0; i < output_types().size(); ++i) {
+ if (output_shapes_.at(i).IsFullyDefined()) {
+ Tensor* out = nullptr;
+ TensorShape output_shape;
+ output_shapes_.at(i).AsTensorShape(&output_shape);
+ output_shape.InsertDim(0, opts->batch_size);
+ TF_RETURN_IF_ERROR(opts->output.allocate(i, output_shape, &out));
+ }
+ }
+ return Status::OK();
+ }
class MapFunctionCallFrame : public CallFrameInterface {
public:
- MapFunctionCallFrame(const std::vector<Tensor>& args,
- const std::vector<TensorShape>& arg_shapes,
- OpOutputList* output, OpKernel* kernel, size_t iter)
- : args_(args),
- arg_shapes_(arg_shapes),
- output_(output),
- kernel_(kernel),
- iter_(iter) {}
+ MapFunctionCallFrame(ComputeOptions* compute_opts, OpKernel* kernel,
+ size_t iter)
+ : compute_opts_(compute_opts), kernel_(kernel), iter_(iter) {}
~MapFunctionCallFrame() override {}
- size_t num_args() const override { return args_.size(); }
+ size_t num_args() const override { return compute_opts_->args.size(); }
+
size_t num_retvals() const override {
return static_cast<size_t>(kernel_->num_outputs());
}
Status GetArg(int index, Tensor* val) const override {
- if (index < 0 || index >= args_.size()) {
+ if (index < 0 || index >= compute_opts_->args.size() +
+ compute_opts_->captured_inputs.size()) {
return errors::InvalidArgument(
"Mismatch in number of function inputs.");
}
- bool result = val->CopyFrom(args_.at(index).Slice(iter_, iter_ + 1),
- arg_shapes_.at(index));
+
+ if (index >= compute_opts_->args.size()) {
+ // The function is calling for a captured input
+ *val =
+ compute_opts_->captured_inputs[index - compute_opts_->args.size()];
+ return Status::OK();
+ }
+
+ bool result =
+ val->CopyFrom(compute_opts_->args[index].Slice(iter_, iter_ + 1),
+ compute_opts_->arg_shapes.at(index));
if (!result) {
return errors::Internal("GetArg failed.");
} else if (!val->IsAligned()) {
// Ensure alignment
*val = tensor::DeepCopy(*val);
}
-
return Status::OK();
}
@@ -175,18 +258,39 @@ class MapDefunOp : public AsyncOpKernel {
"output: ",
index);
}
- return batch_util::CopyElementToSlice(val, (*output_)[index], iter_);
+ { // Locking scope
+ mutex_lock l(compute_opts_->mu);
+ if (!compute_opts_->output_shapes.at(index).IsCompatibleWith(
+ val.shape())) {
+ return errors::InvalidArgument(
+ "Mismatch in function retval shape, ", val.shape(),
+ ", and expected output shape, ",
+ compute_opts_->output_shapes.at(index).DebugString(), ".");
+ }
+ if (!compute_opts_->output_shapes.at(index).IsFullyDefined()) {
+ // Given val, we have new information about the output shape at
+ // this index. Store the shape and allocate the output accordingly.
+ compute_opts_->output_shapes.at(index) = val.shape();
+
+ Tensor* out = nullptr;
+ TensorShape actual_shape = val.shape();
+ actual_shape.InsertDim(0, compute_opts_->batch_size);
+ TF_RETURN_IF_ERROR(
+ compute_opts_->output.allocate(index, actual_shape, &out));
+ }
+ return batch_util::CopyElementToSlice(
+ val, (compute_opts_->output)[index], iter_);
+ }
}
private:
- const std::vector<Tensor>& args_;
- const std::vector<TensorShape>& arg_shapes_;
- OpOutputList* output_;
+ ComputeOptions* const compute_opts_; // Not owned
const OpKernel* kernel_;
const size_t iter_;
};
-}; // namespace
+};
REGISTER_KERNEL_BUILDER(Name("MapDefun").Device(DEVICE_CPU), MapDefunOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/model_dataset_op.cc b/tensorflow/core/kernels/data/model_dataset_op.cc
new file mode 100644
index 0000000000..9aa505f4f1
--- /dev/null
+++ b/tensorflow/core/kernels/data/model_dataset_op.cc
@@ -0,0 +1,183 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/data/dataset.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
+
+class ModelDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ModelDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ *output = new Dataset(ctx, input);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input)
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Model")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return input_->output_dtypes();
+ }
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return input_->output_shapes();
+ }
+
+ string DebugString() const override { return "ModelDatasetOp::Dataset"; }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_graph_node}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ model_(std::make_shared<model::Model>()) {}
+
+ ~Iterator() override {
+ // Signal the optimize thread to terminate it. We will then join that
+ // thread when we delete `this->optimize_thread_`.
+ mutex_lock l(mu_);
+ cancelled_ = true;
+ cond_var_.notify_all();
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return dataset()->input_->MakeIterator(&ctx_with_model, prefix(),
+ &input_impl_);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(EnsureOptimizeThreadStarted(ctx));
+ IteratorContext ctx_with_model(CreateParams(ctx));
+ return input_impl_->GetNext(&ctx_with_model, out_tensors,
+ end_of_sequence);
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
+ IteratorContext::Params CreateParams(IteratorContext* ctx) {
+ IteratorContext::Params params = ctx->params();
+ params.model = model_;
+ return params;
+ }
+
+ private:
+ Status EnsureOptimizeThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!optimize_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ optimize_thread_.reset(ctx->env()->StartThread(
+ {}, "optimize_thread",
+ [this, new_ctx]() { OptimizeThread(new_ctx); }));
+ }
+ return Status::OK();
+ }
+
+ void OptimizeThread(const std::shared_ptr<IteratorContext>& ctx) {
+ int64 last_optimization_ms = 0;
+ int64 optimization_period_ms = 10;
+ while (true) {
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ last_optimization_ms + optimization_period_ms >=
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros) {
+ cond_var_.wait_for(
+ l, std::chrono::milliseconds(
+ last_optimization_ms + optimization_period_ms -
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros));
+ }
+ if (cancelled_) return;
+ }
+ model_->Optimize(port::NumSchedulableCPUs());
+ // Exponentially increase the period of running the optimization
+ // until a threshold is reached.
+ if (optimization_period_ms < kOptimizationPeriodThresholdMs) {
+ if (optimization_period_ms << 1 < kOptimizationPeriodThresholdMs) {
+ optimization_period_ms <<= 1;
+ } else {
+ optimization_period_ms = kOptimizationPeriodThresholdMs;
+ }
+ }
+ last_optimization_ms =
+ ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
+ }
+ }
+
+ mutex mu_;
+ condition_variable cond_var_;
+ std::shared_ptr<model::Model> model_;
+ std::unique_ptr<Thread> optimize_thread_ GUARDED_BY(mu_);
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ };
+
+ const DatasetBase* input_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("ModelDataset").Device(DEVICE_CPU),
+ ModelDatasetOp);
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
new file mode 100644
index 0000000000..d909b9e9d3
--- /dev/null
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -0,0 +1,637 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_op_kernel.h"
+#include "tensorflow/core/kernels/data/dataset_utils.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/util/device_name_utils.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+struct HostBufferElement {
+ Status status;
+ bool end_of_sequence;
+ std::vector<Tensor> value;
+};
+
+using MultiDeviceIteratorCallback =
+ std::function<void(const HostBufferElement&)>;
+
+class MultiDeviceIterator : public ResourceBase {
+ public:
+ MultiDeviceIterator(const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes,
+ const std::vector<string>& devices,
+ std::unique_ptr<FunctionLibraryDefinition> flib_def,
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr,
+ FunctionLibraryRuntime* lib)
+ : output_types_(output_types),
+ output_shapes_(output_shapes),
+ devices_(devices),
+ flib_def_(std::move(flib_def)),
+ pflr_(std::move(pflr)),
+ lib_(lib) {
+ DCHECK(lib_ != nullptr);
+ }
+
+ string DebugString() override {
+ return strings::StrCat("MultiDeviceIterator for ", devices_.size(),
+ " devices");
+ }
+
+ Status Init(std::unique_ptr<IteratorBase> iterator, int64 max_buffer_size,
+ int64* incarnation_id) {
+ if (iterator) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, iterator->output_dtypes()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, iterator->output_shapes()));
+ }
+
+ mutex_lock l(mu_);
+ if (multi_device_buffer_) {
+ multi_device_buffer_->Reset();
+ }
+
+ ++incarnation_id_;
+ *incarnation_id = incarnation_id_;
+
+ multi_device_buffer_.reset(
+ new MultiDeviceBuffer(devices_.size(), max_buffer_size, incarnation_id_,
+ std::move(iterator)));
+ return Status::OK();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ if (lib_ != nullptr) {
+ ctx->set_lib(lib_);
+ }
+ tf_shared_lock l(mu_);
+ multi_device_buffer_->GetNextFromShard(ctx, shard_num, incarnation_id,
+ std::move(callback));
+ }
+
+ const DataTypeVector& output_types() const { return output_types_; }
+
+ const std::vector<PartialTensorShape>& output_shapes() const {
+ return output_shapes_;
+ }
+
+ std::shared_ptr<const FunctionLibraryDefinition> function_library() {
+ tf_shared_lock l(mu_);
+ return lib_def_;
+ }
+
+ FunctionLibraryRuntime* const lib() {
+ tf_shared_lock l(mu_);
+ return lib_;
+ }
+
+ private:
+ // A private class that uses a background thread to keep a per device buffer
+ // full.
+ class MultiDeviceBuffer {
+ public:
+ MultiDeviceBuffer(size_t size, int64 max_buffer_size, int64 incarnation_id,
+ std::unique_ptr<IteratorBase> host_iterator)
+ : buffer_(size),
+ size_(size),
+ max_buffer_size_(max_buffer_size),
+ incarnation_id_(incarnation_id),
+ host_iterator_(std::move(host_iterator)) {}
+
+ ~MultiDeviceBuffer() {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_started_) return;
+ }
+ Reset();
+ }
+
+ void Reset() LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
+
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
+ }
+ }
+ RunPendingCallbacks();
+ }
+
+ void GetNextFromShard(IteratorContext* ctx, int shard_num,
+ int64 incarnation_id,
+ MultiDeviceIteratorCallback callback) {
+ HostBufferElement elem;
+ if (incarnation_id_ != incarnation_id) {
+ elem.status = errors::InvalidArgument("Invalid incarnation id");
+ callback(elem);
+ return;
+ }
+
+ bool produced_output = false;
+ {
+ mutex_lock l(mu_);
+ if (cancelled_) {
+ elem.status = errors::Cancelled("Cancelled Multidevice iterator");
+ callback(elem);
+ return;
+ }
+
+ EnsureBackgroundThreadStarted(ctx);
+
+ if (!buffer_[shard_num].data.empty()) {
+ produced_output = true;
+ std::swap(elem, buffer_[shard_num].data.front());
+ buffer_[shard_num].data.pop_front();
+ // Wake up background thread if it is blocked on this element.
+ if (buffer_[shard_num].data.size() == max_buffer_size_ - 1) {
+ buffer_[shard_num].cond_var.notify_all();
+ }
+ } else {
+ if (end_of_iterator_) {
+ produced_output = true;
+ elem.end_of_sequence = true;
+ } else {
+ buffer_[shard_num].callbacks.push_back(std::move(callback));
+ callback = nullptr;
+ }
+ }
+ }
+
+ if (produced_output) {
+ callback(elem);
+ }
+ }
+
+ private:
+ void EnsureBackgroundThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (!background_thread_) {
+ background_thread_.reset(ctx->env()->StartThread(
+ {}, "multi_device_iterator_background_thread",
+ std::bind(&MultiDeviceIterator::MultiDeviceBuffer::BackgroundThread,
+ this, new IteratorContext(*ctx))));
+ }
+ }
+
+ void RunPendingCallbacks() LOCKS_EXCLUDED(mu_) {
+ // Run all remaining callbacks.
+ std::vector<MultiDeviceIteratorCallback> cancellation_callbacks;
+ std::vector<HostBufferElement> cancellation_elements;
+ {
+ mutex_lock l(mu_);
+
+ for (int i = 0; i < size_; ++i) {
+ while (!buffer_[i].callbacks.empty()) {
+ if (buffer_[i].data.empty()) {
+ HostBufferElement elem;
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
+ cancellation_elements.push_back(std::move(elem));
+ } else {
+ cancellation_elements.push_back(
+ std::move(buffer_[i].data.front()));
+ buffer_[i].data.pop_front();
+ }
+ cancellation_callbacks.push_back(
+ std::move(buffer_[i].callbacks.front()));
+ buffer_[i].callbacks.pop_front();
+ }
+ }
+ }
+ for (int i = 0; i < cancellation_callbacks.size(); ++i) {
+ cancellation_callbacks[i](cancellation_elements[i]);
+ }
+ }
+
+ void BackgroundThread(IteratorContext* ctx) {
+ {
+ mutex_lock l(mu_);
+ background_thread_started_ = true;
+ }
+ std::unique_ptr<IteratorContext> cleanup(ctx);
+ int shard_to_fetch = 0;
+ while (true) {
+ HostBufferElement elem;
+ MultiDeviceIteratorCallback callback = nullptr;
+ bool end_of_iterator = false;
+
+ {
+ mutex_lock l(mu_);
+ while (!cancelled_ &&
+ buffer_[shard_to_fetch].data.size() >= max_buffer_size_) {
+ buffer_[shard_to_fetch].cond_var.wait(l);
+ }
+
+ if (cancelled_) {
+ background_thread_finished_ = true;
+ shutdown_cond_var_.notify_all();
+ return;
+ }
+ }
+
+ elem.status =
+ host_iterator_->GetNext(ctx, &elem.value, &elem.end_of_sequence);
+
+ if (elem.status.ok() && elem.end_of_sequence) {
+ end_of_iterator = true;
+ }
+
+ {
+ mutex_lock l(mu_);
+ // Try to find a callback, else just push stuff into buffer.
+ if (!buffer_[shard_to_fetch].callbacks.empty()) {
+ callback = buffer_[shard_to_fetch].callbacks.front();
+ buffer_[shard_to_fetch].callbacks.pop_front();
+ } else {
+ buffer_[shard_to_fetch].data.push_back(std::move(elem));
+ elem = HostBufferElement();
+ }
+ }
+
+ if (callback) {
+ (*ctx->runner())(std::bind(std::move(callback), std::move(elem)));
+ }
+
+ // Finish off the thread if we reach the end of the iterator. Runs
+ // pending callbacks.
+ if (end_of_iterator) {
+ {
+ mutex_lock l(mu_);
+ background_thread_finished_ = true;
+ end_of_iterator_ = true;
+ shutdown_cond_var_.notify_all();
+ }
+ RunPendingCallbacks();
+ return;
+ }
+ shard_to_fetch = (shard_to_fetch + 1) % size_;
+ }
+ }
+
+ struct HostBuffer {
+ condition_variable cond_var;
+ std::deque<HostBufferElement> data;
+ std::deque<MultiDeviceIteratorCallback> callbacks;
+ };
+
+ mutex mu_;
+ std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
+ bool background_thread_finished_ GUARDED_BY(mu_) = false;
+ bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(mu_) = false;
+ condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
+
+ std::vector<HostBuffer> buffer_;
+
+ const size_t size_;
+ const int64 max_buffer_size_;
+ const int64 incarnation_id_;
+ const std::unique_ptr<IteratorBase> host_iterator_;
+ };
+
+ mutex mu_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ const std::vector<string> devices_;
+ const std::unique_ptr<FunctionLibraryDefinition> flib_def_;
+ const std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
+ FunctionLibraryRuntime* const lib_ = nullptr; // not owned.
+ std::shared_ptr<const FunctionLibraryDefinition> lib_def_ GUARDED_BY(mu_);
+
+ int64 incarnation_id_ GUARDED_BY(mu_) = 0;
+ std::unique_ptr<MultiDeviceBuffer> multi_device_buffer_ GUARDED_BY(mu_);
+};
+
+// Just creates a MultiDeviceIterator and returns it.
+class MultiDeviceIteratorHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("devices", &devices_));
+ }
+
+ // The resource is deleted from the resource manager only when it is private
+ // to kernel.
+ ~MultiDeviceIteratorHandleOp() override {
+ if (resource_ != nullptr) {
+ resource_->Unref();
+ if (cinfo_.resource_is_private_to_kernel()) {
+ if (!cinfo_.resource_manager()
+ ->template Delete<MultiDeviceIterator>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
+ }
+ }
+ }
+
+ void Compute(OpKernelContext* context) override LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ if (resource_ == nullptr) {
+ FunctionLibraryRuntime* lib;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def(nullptr);
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(nullptr);
+ OP_REQUIRES_OK(context, context->function_library()->Clone(
+ &flib_def, &pflr, &lib));
+ ResourceMgr* mgr = context->resource_manager();
+ OP_REQUIRES_OK(context, cinfo_.Init(mgr, def()));
+
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(
+ context,
+ mgr->LookupOrCreate<MultiDeviceIterator>(
+ cinfo_.container(), cinfo_.name(), &resource,
+ [this, lib, &flib_def, &pflr](MultiDeviceIterator** ret)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ *ret = new MultiDeviceIterator(
+ output_types_, output_shapes_, devices_,
+ std::move(flib_def), std::move(pflr), lib);
+ return Status::OK();
+ }));
+
+ Status s = VerifyResource(resource);
+ if (TF_PREDICT_FALSE(!s.ok())) {
+ resource->Unref();
+ context->SetStatus(s);
+ return;
+ }
+
+ resource_ = resource;
+ }
+ }
+ OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
+ context, 0, cinfo_.container(), cinfo_.name(),
+ MakeTypeIndex<MultiDeviceIterator>()));
+ }
+
+ private:
+ // During the first Compute(), resource is either created or looked up using
+ // shared_name. In the latter case, the resource found should be verified if
+ // it is compatible with this op's configuration. The verification may fail in
+ // cases such as two graphs asking queues of the same shared name to have
+ // inconsistent capacities.
+ Status VerifyResource(MultiDeviceIterator* resource) {
+ TF_RETURN_IF_ERROR(
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ TF_RETURN_IF_ERROR(
+ VerifyShapesCompatible(output_shapes_, resource->output_shapes()));
+ return Status::OK();
+ }
+
+ mutex mu_;
+ ContainerInfo cinfo_; // Written once under mu_ then constant afterwards.
+ MultiDeviceIterator* resource_ GUARDED_BY(mu_) = nullptr;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ const int graph_def_version_;
+ string name_;
+ string container_;
+ std::vector<string> devices_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIterator").Device(DEVICE_CPU),
+ MultiDeviceIteratorHandleOp);
+
+// Calls init on the MultiDeviceIterator.
+class MultiDeviceIteratorInitOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorInitOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* tensor_max_buffer_size;
+ OP_REQUIRES_OK(ctx, ctx->input("max_buffer_size", &tensor_max_buffer_size));
+ int64 max_buffer_size = tensor_max_buffer_size->scalar<int64>()();
+
+ DatasetBase* dataset;
+ OP_REQUIRES_OK(ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset));
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 1), &resource));
+ core::ScopedUnref unref(resource);
+
+ std::unique_ptr<IteratorBase> iterator;
+ IteratorContext iter_ctx(ctx);
+ iter_ctx.set_lib(resource->lib());
+ OP_REQUIRES_OK(
+ ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
+ int64 incarnation_id;
+ OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
+ &incarnation_id));
+ Tensor tensor_incarnation_id(DT_INT64, TensorShape({}));
+ tensor_incarnation_id.scalar<int64>()() = incarnation_id;
+ OP_REQUIRES_OK(ctx,
+ ctx->set_output("incarnation_id", tensor_incarnation_id));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MultiDeviceIteratorInit").Device(DEVICE_CPU),
+ MultiDeviceIteratorInitOp);
+
+// Calls GetNextFromShard(shard) and returns a vector of Tensors as output.
+// TODO(rohanj): Implement using BackgroundWorker that Derek built?
+class MultiDeviceIteratorGetNextFromShardOp : public AsyncOpKernel {
+ public:
+ explicit MultiDeviceIteratorGetNextFromShardOp(OpKernelConstruction* ctx)
+ : AsyncOpKernel(ctx),
+ thread_pool_(new thread::ThreadPool(
+ ctx->env(), ThreadOptions(),
+ strings::StrCat("multi_device_iterator_get_next_thread_",
+ SanitizeThreadSuffix(name())),
+ 1 /* num_threads */, false /* low_latency_hint */)) {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ const Tensor* tensor_shard_num;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input("shard_num", &tensor_shard_num), done);
+ int32 shard_num = tensor_shard_num->scalar<int32>()();
+
+ const Tensor* tensor_incarnation_id;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ctx->input("incarnation_id", &tensor_incarnation_id), done);
+ int64 incarnation_id = tensor_incarnation_id->scalar<int64>()();
+
+ MultiDeviceIterator* iterator;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &iterator), done);
+ thread_pool_->Schedule(std::bind(
+ [ctx, iterator, shard_num, incarnation_id](DoneCallback done) {
+ IteratorContext::Params params;
+ params.env = ctx->env();
+ params.runner = *(ctx->runner());
+ params.function_library = iterator->function_library();
+ DeviceBase* device = ctx->function_library()->device();
+ params.allocator_getter = [device](AllocatorAttributes attrs) {
+ return device->GetAllocator(attrs);
+ };
+ IteratorContext iter_ctx(std::move(params));
+
+ MultiDeviceIteratorCallback callback = std::bind(
+ [ctx](const HostBufferElement& elem, DoneCallback done) {
+ // iterator->Unref();
+ Status s = elem.status;
+ if (!s.ok()) {
+ ctx->SetStatus(s);
+ } else if (elem.end_of_sequence) {
+ ctx->SetStatus(errors::OutOfRange("End of sequence"));
+ } else {
+ for (int i = 0; i < elem.value.size(); ++i) {
+ ctx->set_output(i, elem.value[i]);
+ }
+ }
+ done();
+ },
+ std::placeholders::_1, std::move(done));
+
+ iterator->GetNextFromShard(&iter_ctx, shard_num, incarnation_id,
+ callback);
+ iterator->Unref();
+ },
+ std::move(done)));
+ }
+
+ private:
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorGetNextFromShard").Device(DEVICE_CPU),
+ MultiDeviceIteratorGetNextFromShardOp);
+
+class MultiDeviceIteratorToStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorToStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& resource_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(resource_handle_t.shape()),
+ errors::InvalidArgument("resource_handle must be a scalar"));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx,
+ LookupResource(ctx, HandleFromInput(ctx, 0), &resource));
+ resource->Unref();
+
+ Tensor* string_handle_t;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &string_handle_t));
+ string_handle_t->scalar<string>()() =
+ resource_handle_t.scalar<ResourceHandle>()().SerializeAsString();
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorToStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorToStringHandleOp);
+
+class MultiDeviceIteratorFromStringHandleOp : public OpKernel {
+ public:
+ explicit MultiDeviceIteratorFromStringHandleOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES(
+ ctx,
+ output_types_.empty() || output_shapes_.empty() ||
+ output_types_.size() == output_shapes_.size(),
+ errors::InvalidArgument("If both 'output_types' and 'output_shapes' "
+ "are set, they must have the same length."));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& string_handle_t = ctx->input(0);
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(string_handle_t.shape()),
+ errors::InvalidArgument("string_handle must be a scalar"));
+
+ ResourceHandle resource_handle;
+ OP_REQUIRES(
+ ctx,
+ resource_handle.ParseFromString(string_handle_t.scalar<string>()()),
+ errors::InvalidArgument(
+ "Could not parse string_handle as a valid ResourceHandle"));
+
+ OP_REQUIRES(
+ ctx, resource_handle.device() == ctx->device()->attributes().name(),
+ errors::InvalidArgument("Attempted create an iterator on device \"",
+ ctx->device()->attributes().name(),
+ "\" from handle defined on device \"",
+ resource_handle.device(), "\""));
+
+ // Validate that the handle corresponds to a real resource, and
+ // that it is an MultiDeviceIterator.
+ MultiDeviceIterator* resource;
+ OP_REQUIRES_OK(ctx, LookupResource(ctx, resource_handle, &resource));
+ core::ScopedUnref unref_iterator(resource);
+ if (!output_types_.empty()) {
+ OP_REQUIRES_OK(ctx,
+ VerifyTypesMatch(output_types_, resource->output_types()));
+ }
+ if (!output_shapes_.empty()) {
+ OP_REQUIRES_OK(ctx, VerifyShapesCompatible(output_shapes_,
+ resource->output_shapes()));
+ }
+
+ Tensor* resource_handle_t;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_output(0, TensorShape({}), &resource_handle_t));
+ resource_handle_t->scalar<ResourceHandle>()() = resource_handle;
+ }
+
+ private:
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("MultiDeviceIteratorFromStringHandle").Device(DEVICE_CPU),
+ MultiDeviceIteratorFromStringHandleOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index 8add049123..1cb7caa738 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
namespace tensorflow {
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -59,13 +60,13 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const std::vector<string>& optimizations,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
optimizations_(optimizations),
output_types_(output_types),
@@ -92,24 +93,35 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
DatasetGraphDefBuilder db(&b);
Node* input_node = nullptr;
SerializationContext::Params params;
+ std::vector<std::pair<string, Tensor>> input_list;
+ params.allow_stateful_functions = true;
params.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
+ params.input_list = &input_list;
SerializationContext serialization_ctx(params);
TF_RETURN_IF_ERROR(
db.AddInputDataset(&serialization_ctx, input_, &input_node));
string output_node = input_node->name();
+
GraphDef graph_def;
TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
VLOG(3) << "Before optimization: " << graph_def.DebugString();
+
TF_RETURN_IF_ERROR(ApplyOptimizations(ctx, &graph_def, &output_node));
VLOG(3) << "After optimization: " << graph_def.DebugString();
- flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(),
- graph_def.library()));
+
+ // Instantiate the optimized input pipeline by running the optimized graph
+ // using the optimized function library.
+ TF_RETURN_IF_ERROR(
+ ctx->function_library()->Clone(&flib_def_, &pflr_, &lib_));
+ TF_RETURN_IF_ERROR(flib_def_->AddLibrary(graph_def.library()));
+
Graph graph(OpRegistry::Global());
TF_RETURN_IF_ERROR(ImportGraphDef({}, graph_def, &graph, nullptr));
std::vector<Tensor> outputs;
GraphRunner graph_runner(ctx->function_library()->device());
- TF_RETURN_IF_ERROR(graph_runner.Run(&graph, ctx->function_library(), {},
- {output_node}, &outputs));
+
+ TF_RETURN_IF_ERROR(
+ graph_runner.Run(&graph, lib_, input_list, {output_node}, &outputs));
TF_RETURN_IF_ERROR(
GetDatasetFromVariantTensor(outputs[0], &optimized_input_));
optimized_input_->Ref();
@@ -142,22 +154,19 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->optimized_input_->MakeIterator(ctx, prefix(),
- &input_impl_);
+ IteratorContext::Params params = ctx->params();
+ params.lib = dataset()->lib_;
+ return dataset()->optimized_input_->MakeIterator(
+ IteratorContext(params), prefix(), &input_impl_);
}
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
- params.lib = ctx->lib();
- params.function_library = dataset()->flib_def_;
- params.allocator_getter = ctx->allocator_getter();
- IteratorContext iter_ctx(params);
- return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ IteratorContext::Params params = ctx->params();
+ params.lib = dataset()->lib_;
+ return input_impl_->GetNext(IteratorContext(params), out_tensors,
+ end_of_sequence);
}
protected:
@@ -236,7 +245,9 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
}
DatasetBase* optimized_input_;
- std::shared_ptr<FunctionLibraryDefinition> flib_def_;
+ FunctionLibraryRuntime* lib_ = nullptr;
+ std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_ = nullptr;
+ std::unique_ptr<FunctionLibraryDefinition> flib_def_ = nullptr;
const DatasetBase* input_;
const std::vector<string> optimizations_;
const DataTypeVector output_types_;
@@ -252,4 +263,5 @@ REGISTER_KERNEL_BUILDER(Name("OptimizeDataset").Device(DEVICE_CPU),
OptimizeDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.cc b/tensorflow/core/kernels/data/optional_ops.cc
index cfac45dbc7..2ab5c83082 100644
--- a/tensorflow/core/kernels/data/optional_ops.cc
+++ b/tensorflow/core/kernels/data/optional_ops.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_op_registry.h"
namespace tensorflow {
+namespace data {
namespace {
const char kOptionalVariantTypeName[] = "tensorflow::data::Optional";
@@ -107,11 +108,8 @@ class OptionalFromValueOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
OpInputList components_input;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &components_input));
- std::vector<Tensor> components;
- components.reserve(components_input.size());
- for (const Tensor& component_t : components_input) {
- components.push_back(component_t);
- }
+ std::vector<Tensor> components(components_input.begin(),
+ components_input.end());
OP_REQUIRES_OK(
ctx, WriteOptionalWithValueToOutput(ctx, 0, std::move(components)));
}
@@ -215,6 +213,14 @@ static Status OptionalDeviceCopy(
std::vector<Tensor> to_values;
to_values.reserve(from_values.size());
for (const Tensor& t : from_values) {
+ if (t.dtype() == DT_VARIANT) {
+ // TODO(b/116349787): Implement support for nested variants.
+ return errors::Unimplemented(
+ "Support for copying nested variants to device has not yet been "
+ "implemented.");
+ }
+ }
+ for (const Tensor& t : from_values) {
if (DMAHelper::CanUseDMA(&t)) {
Tensor tmp(t.dtype());
TF_RETURN_IF_ERROR(copy(t, &tmp));
@@ -230,10 +236,9 @@ static Status OptionalDeviceCopy(
return Status::OK();
}
-#define REGISTER_OPTIONAL_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- OptionalVariant, DIRECTION, kOptionalVariantTypeName, \
- OptionalDeviceCopy)
+#define REGISTER_OPTIONAL_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
+ OptionalVariant, DIRECTION, OptionalDeviceCopy)
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_OPTIONAL_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -267,4 +272,5 @@ Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index) {
return Status::OK();
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/optional_ops.h b/tensorflow/core/kernels/data/optional_ops.h
index 6f25567678..2cbf2933f5 100644
--- a/tensorflow/core/kernels/data/optional_ops.h
+++ b/tensorflow/core/kernels/data/optional_ops.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/variant_tensor_data.h"
namespace tensorflow {
+namespace data {
// Stores a DT_VARIANT value representing an Optional with the given value
// in the `output_index`^th output of the given kernel execution context.
@@ -31,6 +32,7 @@ Status WriteOptionalWithValueToOutput(OpKernelContext* ctx, int output_index,
// in the `output_index`^th output of the given kernel execution context.
Status WriteOptionalNoneToOutput(OpKernelContext* ctx, int output_index);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_OPTIONAL_OPS_H_
diff --git a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
index 755d46dac2..7b01c3b4e0 100644
--- a/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/padded_batch_dataset_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -98,12 +98,12 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 batch_size, bool drop_remainder,
std::vector<PartialTensorShape> padded_shapes,
std::vector<Tensor> padding_values, const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
drop_remainder_(drop_remainder),
padded_shapes_(std::move(padded_shapes)),
@@ -207,6 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
+ AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
@@ -382,5 +383,5 @@ REGISTER_KERNEL_BUILDER(Name("PaddedBatchDatasetV2").Device(DEVICE_CPU),
PaddedBatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index d2b83f9eab..6b6b3d6ab9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <atomic>
#include <deque>
+#include <utility>
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
@@ -21,11 +23,12 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -34,8 +37,7 @@ namespace {
class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
@@ -43,14 +45,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int64 cycle_length = 0;
OP_REQUIRES_OK(ctx,
ParseScalarArgument(ctx, "cycle_length", &cycle_length));
@@ -82,8 +76,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(
- ctx, CapturedFunction::Create(
- interleave_func_, std::move(other_arguments), &captured_func));
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
*output =
new Dataset(ctx, input, interleave_func_, std::move(captured_func),
@@ -92,7 +86,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func,
@@ -100,7 +94,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
int64 block_length, bool sloppy, int64 buffer_output_elements,
int64 prefetch_input_elements, const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
interleave_func_(func),
captured_func_(std::move(captured_func)),
@@ -125,6 +119,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& output_dtypes() const override {
return output_types_;
}
+
const std::vector<PartialTensorShape>& output_shapes() const override {
return output_shapes_;
}
@@ -137,8 +132,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(
- b->AddFunction(ctx->flib_def(), interleave_func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
Node* cycle_length_node;
@@ -251,7 +245,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
// It is implemented so that it matches the deterministic interleave
@@ -279,7 +276,12 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (!current_worker->outputs.empty()) {
// We have an element!
next_index_ = index;
- if (i == 0) {
+ const bool element_acquired_sloppily =
+ dataset()->sloppy_ && i > 1;
+ if (!element_acquired_sloppily) {
+ // If the element was acquired in the regular (non-sloppy)
+ // order, then advance the current block and cycle pointers to
+ // the next element in the regular order.
block_count_++;
if (block_count_ == dataset()->block_length_) {
next_index_ = (index + 1) % interleave_indices_.size();
@@ -343,11 +345,13 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (must_wait_for_input) {
// Wait for elements to become available.
+ RecordStop(ctx);
if (dataset()->sloppy_) {
sloppy_cond_var_.wait(l);
} else {
workers_[interleave_indices_[next_index_]].cond_var.wait(l);
}
+ RecordStart(ctx);
}
}
return errors::Cancelled(
@@ -476,10 +480,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (reader->Contains(full_name("worker_threads_running"))) {
worker_threads_.reserve(dataset()->num_threads());
for (size_t i = 0; i < dataset()->num_threads(); ++i) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
}
}
return Status::OK();
@@ -575,10 +579,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
return Status::OK();
}
workers_[i].SetInputs(s, std::move(args));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
worker_threads_.emplace_back(ctx->env()->StartThread(
{}, "worker_thread",
- std::bind(&Iterator::WorkerThread, this,
- new IteratorContext(*ctx), i)));
+ [this, new_ctx, i]() { WorkerThread(new_ctx, i); }));
if (i < dataset()->cycle_length_) {
interleave_indices_.push_back(i);
} else {
@@ -593,7 +597,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
// Produces elements into the worker's output buffers.
- void WorkerThread(IteratorContext* ctx_ptr, const int64 thread_index) {
+ void WorkerThread(const std::shared_ptr<IteratorContext>& ctx,
+ const int64 thread_index) {
// Notes on checkpointing thread local state, i.e., `WorkerThreadState`:
//
// 1. Any local state that may need to be checkpointed should be kept
@@ -614,10 +619,11 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// std::function arguments are copy-constructable, so we pass raw
// pointers, and then immediately wrap them to ensure correct ownership.
- std::unique_ptr<IteratorContext> ctx(ctx_ptr);
- auto cleanup = gtl::MakeCleanup([this, thread_index] {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, thread_index, ctx] {
mutex_lock l(mu_);
workers_[thread_index].cond_var.notify_all();
+ RecordStop(ctx.get());
});
bool make_new_iterator;
{
@@ -643,9 +649,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// 1. Build a new iterator or use the existing one.
if (make_new_iterator) {
// 1a. Get new input tensors or use the exiting ones.
-
bool read_new_input;
-
{
tf_shared_lock l(ckpt_mu_);
// worker_thread_states_[thread_index].input will be non-empty
@@ -657,7 +661,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
if (read_new_input) {
mutex_lock l(mu_);
while (!cancelled_ && !workers_[thread_index].is_producing) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
// Copy the input tensors so that we do not need to block on `mu_`
@@ -678,7 +684,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
{
tf_shared_lock l(ckpt_mu_);
worker_thread_states_[thread_index].iterator_creation_status =
- dataset::MakeIteratorFromInputElement(
+ MakeIteratorFromInputElement(
ctx.get(), worker_thread_states_[thread_index].input,
thread_index, dataset()->captured_func_.get(), prefix(),
&worker_thread_states_[thread_index].iterator);
@@ -707,7 +713,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
tf_shared_lock ckpt_l(ckpt_mu_);
@@ -756,7 +764,9 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
// Wait for space in the prefetch queue.
while (!cancelled_ && workers_[thread_index].outputs.size() ==
dataset()->buffer_output_elements_) {
+ RecordStop(ctx.get());
workers_[thread_index].cond_var.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) return;
@@ -908,7 +918,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
worker_thread_states_[index].iterator.reset();
} else {
std::unique_ptr<IteratorBase> iterator;
- Status s = dataset::MakeIteratorFromInputElement(
+ Status s = MakeIteratorFromInputElement(
ctx, worker_thread_states_[index].input, index,
dataset()->captured_func_.get(), prefix(), &iterator);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, iterator));
@@ -1052,7 +1062,6 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
NameAttrList interleave_func_;
@@ -1061,6 +1070,605 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
ParallelInterleaveDatasetOp);
-} // namespace
+// The motivation for creating an alternative implementation of parallel
+// interleave is to decouple the degree of parallelism from the cycle length.
+// This makes it possible to change the degree of parallelism (e.g. through
+// auto-tuning) without changing the cycle length (which would change the order
+// in which elements are produced).
+//
+// Furthermore, this class favors modularity over extended functionality. In
+// particular, it refrains from implementing configurable buffering of output
+// elements and prefetching of input iterators, relying on other parts of
+// tf.data to provide this functionality if necessary.
+//
+// The above design choices were made with automated optimizations in mind,
+// isolating the degree of parallelism as the single tunable knob of this
+// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
+class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
+ public:
+ explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &interleave_func_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ }
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 cycle_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "cycle_length", &cycle_length));
+ OP_REQUIRES(ctx, cycle_length > 0,
+ errors::InvalidArgument("`cycle_length` must be > 0"));
+
+ int64 block_length = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument(ctx, "block_length", &block_length));
+ OP_REQUIRES(ctx, block_length > 0,
+ errors::InvalidArgument("`block_length` must be > 0"));
+
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+ OP_REQUIRES(
+ ctx, num_parallel_calls <= cycle_length,
+ errors::InvalidArgument(
+ "num_parallel_calls must less than or equal to cycle_length."));
+
+ std::unique_ptr<CapturedFunction> captured_func;
+ OP_REQUIRES_OK(
+ ctx, CapturedFunction::Create(interleave_func_, ctx, "other_arguments",
+ &captured_func));
+
+ *output = new Dataset(ctx, input, interleave_func_,
+ std::move(captured_func), cycle_length, block_length,
+ num_parallel_calls, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ const NameAttrList& func,
+ std::unique_ptr<CapturedFunction> captured_func, int64 cycle_length,
+ int64 block_length, int64 num_parallel_calls,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ interleave_func_(func),
+ captured_func_(std::move(captured_func)),
+ cycle_length_(cycle_length),
+ block_length_(block_length),
+ num_parallel_calls_(num_parallel_calls),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(new Iterator(
+ {this, strings::StrCat(prefix, "::ParallelInterleaveV2")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+ string DebugString() const override {
+ return "ParallelInterleaveDatasetV2Op::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, interleave_func_.name()));
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
+ Node* cycle_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(cycle_length_, &cycle_length_node));
+ Node* block_length_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
+ Node* num_parallel_calls_node;
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallel_calls_node));
+ DataTypeVector other_arguments_types;
+ other_arguments_types.reserve(captured_func_->captured_inputs().size());
+ std::vector<Node*> other_arguments;
+ other_arguments.reserve(captured_func_->captured_inputs().size());
+ for (const Tensor& t : captured_func_->captured_inputs()) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ other_arguments.emplace_back(node);
+ other_arguments_types.emplace_back(t.dtype());
+ }
+ AttrValue f;
+ b->BuildAttrValue(interleave_func_, &f);
+ AttrValue other_arguments_types_attr;
+ b->BuildAttrValue(other_arguments_types, &other_arguments_types_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this,
+ {{0, input_node},
+ {2, cycle_length_node},
+ {3, block_length_node},
+ {4, num_parallel_calls_node}},
+ {{1, other_arguments}},
+ {{"f", f}, {"Targuments", other_arguments_types_attr}}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
+ args_list_(params.dataset->cycle_length_),
+ current_elements_(params.dataset->cycle_length_),
+ element_in_use_(params.dataset->cycle_length_, false),
+ thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "parallel_interleave",
+ dataset()->cycle_length_ /* num_threads */,
+ false /* low_latency_hint */)) {}
+
+ ~Iterator() override {
+ mutex_lock l(*mu_);
+ // Cancel the runner thread.
+ cancelled_ = true;
+ cond_var_->notify_all();
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_->wait(l);
+ }
+ }
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ dataset()->cycle_length_);
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
+ }
+ AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
+ }
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ std::shared_ptr<InvocationResult> result;
+ do {
+ {
+ mutex_lock l(*mu_);
+ EnsureRunnerThreadStarted(ctx);
+ while (invocation_results_.empty() &&
+ (!end_of_input_ || num_open_ > 0)) {
+ RecordStop(ctx);
+ cond_var_->wait(l);
+ RecordStart(ctx);
+ }
+ if (!invocation_results_.empty()) {
+ std::swap(result, invocation_results_.front());
+ invocation_results_.pop_front();
+ } else {
+ *end_of_sequence = true;
+ return Status::OK();
+ }
+ cond_var_->notify_all();
+ }
+ RecordStop(ctx);
+ result->notification.WaitForNotification();
+ RecordStart(ctx);
+ } while (result->skip);
+
+ if (result->status.ok()) {
+ *out_tensors = std::move(result->return_values);
+ }
+ *end_of_sequence = false;
+ return result->status;
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(*mu_);
+ // Wait for all in-flight calls to complete.
+ while (num_calls_ > 0) {
+ cond_var_->wait(l);
+ }
+ CHECK_EQ(num_calls_, 0);
+ TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name("invocation_results.size"), invocation_results_.size()));
+ for (size_t i = 0; i < invocation_results_.size(); i++) {
+ std::shared_ptr<InvocationResult> result = invocation_results_[i];
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ result->return_values.size()));
+ for (size_t j = 0; j < result->return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(
+ strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result->return_values[j]));
+ }
+ if (result->skip) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")),
+ ""));
+ }
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("cycle_index"), cycle_index_));
+ if (end_of_input_) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("end_of_input"), ""));
+ }
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("num_open"), num_open_));
+ TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(*mu_);
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ int64 invocation_results_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name("invocation_results.size"), &invocation_results_size));
+ for (size_t i = 0; i < invocation_results_size; i++) {
+ std::shared_ptr<InvocationResult> result(new InvocationResult());
+ invocation_results_.push_back(result);
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ size_t num_return_values;
+ {
+ int64 size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("invocation_results[", i, "].size")),
+ &size));
+ num_return_values = static_cast<size_t>(size);
+ if (num_return_values != size) {
+ return errors::InvalidArgument(strings::StrCat(
+ full_name(
+ strings::StrCat("invocation_results[", i, "].size")),
+ ": ", size, " is not a valid value of type size_t."));
+ }
+ }
+ result->return_values.reserve(num_return_values);
+ for (size_t j = 0; j < num_return_values; j++) {
+ result->return_values.emplace_back();
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(full_name(strings::StrCat(
+ "invocation_results[", i, "][", j, "]")),
+ &result->return_values.back()));
+ }
+ result->skip = reader->Contains(
+ full_name(strings::StrCat("invocation_results[", i, "].skip")));
+ result->notification.Notify();
+ }
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("cycle_index"), &cycle_index_));
+ if (reader->Contains(full_name("end_of_input"))) end_of_input_ = true;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(full_name("num_open"), &num_open_));
+ TF_RETURN_IF_ERROR(ReadCurrentElements(ctx, reader));
+ return Status::OK();
+ }
+
+ private:
+ struct InvocationResult {
+ Notification notification; // used for coordination with the consumer
+ Status status; // the invocation status
+ std::vector<Tensor> return_values; // the invocation result values
+ bool skip; // if set the result should be skipped
+ };
+
+ void EnsureRunnerThreadStarted(IteratorContext* ctx)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ if (!runner_thread_) {
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ runner_thread_.reset(ctx->env()->StartThread(
+ {}, "runner_thread",
+ [this, new_ctx]() { RunnerThread(new_ctx); }));
+ }
+ }
+
+ // Fetches up to `results.size()` outputs from the cycle element at
+ // position `cycle_index`.
+ //
+ // If end of input is encountered, the `skip` field of the invocation
+ // result is used to identify results that should be skipped.
+ void FetchOutputs(
+ const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
+ const std::vector<std::shared_ptr<InvocationResult>>& results)
+ LOCKS_EXCLUDED(*mu_) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+ bool end_of_input = false;
+ for (auto& result : results) {
+ if (!end_of_input) {
+ result->status = current_elements_[cycle_index]->GetNext(
+ ctx.get(), &result->return_values, &end_of_input);
+ }
+ if (end_of_input) {
+ result->skip = true;
+ }
+ result->notification.Notify();
+ if (!result->status.ok()) {
+ break;
+ }
+ }
+
+ // Release the ownership of the cycle element iterator, closing the
+ // iterator if end of input was encountered.
+ if (end_of_input) {
+ current_elements_[cycle_index].reset();
+ }
+ mutex_lock l(*mu_);
+ element_in_use_[cycle_index] = false;
+ num_calls_--;
+ if (end_of_input) {
+ args_list_[cycle_index].clear();
+ num_open_--;
+ }
+ cond_var_->notify_all();
+ }
+
+ // Method responsible for 1) creating iterators out of input elements, 2)
+ // determining the order in which elements are fetched from the iterators,
+ // and 3) scheduling the fetching of the elements to a threadpool.
+ //
+ // This method runs in the `runner_thread` background thread.
+ void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ return element_in_use_[cycle_index_] ||
+ num_calls_ >= num_parallel_calls_->value ||
+ invocation_results_.size() >=
+ dataset()->cycle_length_ * dataset()->block_length_;
+ };
+ while (true) {
+ mutex_lock l(*mu_);
+ // Wait until this thread is cancelled, the end of input has been
+ // reached, or the cycle element at the `cycle_index_` position is
+ // not in use and there is space in the `invocation_results_` queue.
+ while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
+ RecordStop(ctx.get());
+ cond_var_->wait(l);
+ RecordStart(ctx.get());
+ }
+
+ if (cancelled_ || (end_of_input_ && num_open_ == 0)) {
+ return;
+ }
+
+ while ((!end_of_input_ || num_open_ > 0) && !busy()) {
+ if (!current_elements_[cycle_index_]) {
+ // Try to create a new iterator from the next input element.
+ Status status = input_impl_->GetNext(
+ ctx.get(), &args_list_[cycle_index_], &end_of_input_);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ if (!end_of_input_) {
+ Status status = MakeIteratorFromInputElement(
+ ctx.get(), args_list_[cycle_index_], cycle_index_,
+ dataset()->captured_func_.get(), prefix(),
+ &current_elements_[cycle_index_]);
+ if (!status.ok()) {
+ invocation_results_.emplace_back(new InvocationResult());
+ std::shared_ptr<InvocationResult>& result =
+ invocation_results_.back();
+ result->status.Update(status);
+ result->notification.Notify();
+ break;
+ }
+ ++num_open_;
+ }
+ }
+ if (current_elements_[cycle_index_]) {
+ // Pre-allocate invocation results for outputs to be fetched
+ // and then fetch the outputs asynchronously.
+ std::vector<std::shared_ptr<InvocationResult>> results;
+ results.reserve(dataset()->block_length_);
+ for (int i = 0; i < dataset()->block_length_; ++i) {
+ invocation_results_.emplace_back(new InvocationResult());
+ results.push_back(invocation_results_.back());
+ }
+ num_calls_++;
+ element_in_use_[cycle_index_] = true;
+ thread_pool_->Schedule(std::bind(&Iterator::FetchOutputs, this,
+ ctx, cycle_index_,
+ std::move(results)));
+ }
+ cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
+ }
+ cond_var_->notify_all();
+ }
+ }
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(
+ strings::StrCat("invocation_results[", index, "].error_message"));
+ }
+
+ Status WriteCurrentElements(IteratorStateWriter* writer)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (current_elements_[idx]) {
+ TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ args_list_[idx].size()));
+ for (int i = 0; i < args_list_[idx].size(); i++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ args_list_[idx][i]));
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status ReadCurrentElements(IteratorContext* ctx,
+ IteratorStateReader* reader)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
+ for (int idx = 0; idx < current_elements_.size(); idx++) {
+ if (reader->Contains(
+ full_name(strings::StrCat("args_size[", idx, "]")))) {
+ int64 args_size;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ full_name(strings::StrCat("args_size[", idx, "]")),
+ &args_size));
+ args_list_[idx].resize(args_size);
+ for (int i = 0; i < args_size; i++) {
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("args_list_[", idx, "][", i, "]")),
+ &args_list_[idx][i]));
+ }
+ TF_RETURN_IF_ERROR(MakeIteratorFromInputElement(
+ ctx, args_list_[idx], idx, dataset()->captured_func_.get(),
+ prefix(), &current_elements_[idx]));
+ TF_RETURN_IF_ERROR(
+ RestoreInput(ctx, reader, current_elements_[idx]));
+ } else {
+ current_elements_[idx].reset();
+ }
+ }
+ return Status::OK();
+ }
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads.
+ const std::shared_ptr<mutex> mu_;
+
+ // Used for coordination between the main thread, the runner thread, and
+ // the worker threads. In particular, the runner thread should only
+ // schedule new calls when the number of in-flight calls is less than the
+ // user specified level of parallelism, there are slots available in the
+ // `invocation_results_` buffer, the current cycle element is not in use,
+ // and there are elements left to be fetched.
+ const std::shared_ptr<condition_variable> cond_var_;
+
+ // Identifies the maximum number of parallel calls.
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
+
+ // Iterator for input elements.
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
+
+ // Identifies current cycle element.
+ int64 cycle_index_ = 0;
+
+ // Arguments for creating an iterator for cycle elements.
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_);
+
+ // Iterators for the current cycle elements. Concurrent access is
+ // protected by `element_in_use_`.
+ std::vector<std::unique_ptr<IteratorBase>> current_elements_;
+
+ // Identifies cycle elements that are in use by worker threads.
+ std::vector<bool> element_in_use_ GUARDED_BY(*mu_);
+
+ // Buffer for storing the invocation results.
+ std::deque<std::shared_ptr<InvocationResult>> invocation_results_
+ GUARDED_BY(*mu_);
+
+ // Identifies whether end of input has been reached.
+ bool end_of_input_ GUARDED_BY(*mu_) = false;
+
+ // Identifies the number of open iterators.
+ int64 num_open_ GUARDED_BY(*mu_) = 0;
+
+ // Identifies the number of outstanding calls.
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
+
+ std::unique_ptr<thread::ThreadPool> thread_pool_;
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+
+ // Identifies whether background activity should be cancelled.
+ bool cancelled_ GUARDED_BY(*mu_) = false;
+ };
+
+ const DatasetBase* const input_;
+ const NameAttrList interleave_func_;
+ const std::unique_ptr<CapturedFunction> captured_func_;
+ const int64 cycle_length_;
+ const int64 block_length_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ NameAttrList interleave_func_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDatasetV2").Device(DEVICE_CPU),
+ ParallelInterleaveDatasetV2Op);
+
+} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
index c56a7ea808..6abe6c8338 100644
--- a/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_map_dataset_op.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -33,53 +33,50 @@ namespace {
class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ParallelMapDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("use_inter_op_parallelism",
+ &use_inter_op_parallelism_));
}
protected:
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
-
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
- OP_REQUIRES(ctx, num_parallel_calls > 0,
+ OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ use_inter_op_parallelism_,
+ &captured_func));
*output = new Dataset(ctx, input, func_, num_parallel_calls, output_types_,
- output_shapes_, std::move(captured_func));
+ output_shapes_, use_inter_op_parallelism_,
+ std::move(captured_func));
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func, int32 num_parallel_calls,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes,
+ bool use_inter_op_parallelism,
std::unique_ptr<CapturedFunction> captured_func)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
num_parallel_calls_(num_parallel_calls),
output_types_(output_types),
output_shapes_(output_shapes),
+ use_inter_op_parallelism_(use_inter_op_parallelism),
captured_func_(std::move(captured_func)) {
input_->Ref();
}
@@ -88,16 +85,30 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
- auto map_func = [this](IteratorContext* ctx,
+ auto init_func = [this](IteratorContext* ctx) {
+ return captured_func_->Instantiate(ctx);
+ };
+
+ const string& new_prefix = strings::StrCat(prefix, "::ParallelMap");
+ ParallelMapIteratorFunction map_func =
+ [this, new_prefix](IteratorContext* ctx,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
- captured_func_->RunAsync(ctx, std::move(input_element), result,
- std::move(done));
- };
+ captured_func_->RunAsync(ctx, std::move(input_element), result,
+ std::move(done), new_prefix);
+ };
+ if (!use_inter_op_parallelism_) {
+ map_func = [map_func](
+ IteratorContext* ctx, std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())(std::bind(map_func, ctx, std::move(input_element),
+ result, std::move(done)));
+ };
+ }
- return NewParallelMapIterator(
- {this, strings::StrCat(prefix, "::ParallelMap")}, input_,
- std::move(map_func), num_parallel_calls_);
+ return NewParallelMapIterator({this, new_prefix}, input_,
+ std::move(init_func), std::move(map_func),
+ num_parallel_calls_);
}
const DataTypeVector& output_dtypes() const override {
@@ -138,7 +149,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
b->AddScalar(num_parallel_calls_, &num_parallel_calls));
// Attr: f
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
AttrValue f;
b->BuildAttrValue(func_, &f);
@@ -163,12 +174,13 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
const int32 num_parallel_calls_;
const DataTypeVector output_types_;
const std::vector<PartialTensorShape> output_shapes_;
+ const bool use_inter_op_parallelism_;
const std::unique_ptr<CapturedFunction> captured_func_;
};
- const int graph_def_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
+ bool use_inter_op_parallelism_;
NameAttrList func_;
};
@@ -176,5 +188,5 @@ REGISTER_KERNEL_BUILDER(Name("ParallelMapDataset").Device(DEVICE_CPU),
ParallelMapDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index 4d32b719a4..13bd4b6036 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -14,87 +14,111 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include <atomic>
#include <deque>
#include <functional>
#include <utility>
#include <vector>
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/cpu_info.h"
+
namespace tensorflow {
+namespace data {
namespace {
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
const typename DatasetBaseIterator::BaseParams& params,
- const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
- int32 num_parallel_calls)
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls)
: DatasetBaseIterator(params),
input_dataset_(input_dataset),
+ init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
- num_parallel_calls_(num_parallel_calls) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ num_parallel_calls, mu_, cond_var_)) {}
~ParallelMapIterator() override {
- // TODO(mrry): Replace this cancellation logic with a
- // CancellationManager. The syntax would be more heavyweight,
- // but it would be possible to thread a cancellation manager
- // through the IteratorContext to upstream,
- // potentially-blocking iterators, when we add these.
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- return input_dataset_->MakeIterator(ctx, prefix(), &input_impl_);
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ // TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
+ // use it here for the maximum.
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
+ } else {
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
+ }
+ TF_RETURN_IF_ERROR(
+ input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
+ if (init_func_) {
+ TF_RETURN_IF_ERROR(init_func_(ctx));
+ }
+ return Status::OK();
}
Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
- cond_var_.wait(l);
+ RecordStop(ctx);
+ cond_var_->wait(l);
+ RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
+ cond_var_->notify_all();
}
- cond_var_.notify_all();
+ RecordStop(ctx);
result->notification.WaitForNotification();
+ RecordStart(ctx);
return ProcessResult(result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- TF_RETURN_IF_ERROR(
- writer->WriteScalar(full_name("invocation_results.size"),
- invocation_results_.size()));
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
+ invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ const auto& result = *(invocation_results_[i]);
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(
- writer->WriteTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
+ result.return_values.size()));
+ for (size_t j = 0; j < result.return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result.return_values[j]));
}
- if (result->end_of_input) {
+ if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
@@ -106,15 +130,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
+ auto& result = *invocation_results_.back();
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
@@ -130,17 +154,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
": ", size, " is not a valid value of type size_t."));
}
}
- result->return_values.reserve(num_return_values);
+ result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ result.return_values.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ &result.return_values.back()));
}
- result->end_of_input = reader->Contains(full_name(
+ result.end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
+ result.notification.Notify();
}
return Status::OK();
}
@@ -154,7 +177,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -164,18 +187,18 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void CallCompleted(const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
num_calls_--;
+ cond_var_->notify_all();
}
result->notification.Notify();
- cond_var_.notify_all();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
@@ -185,9 +208,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
- // Call `func_(input_element)`, store the result in
- // `result->return_values`, and notify `result->notification` to unblock
- // a consumer.
+ // Call `func_(input_element)`, store the result in `result->return_values`,
+ // and notify `result->notification` to unblock a consumer.
auto done = [this, result](Status status) {
result->status.Update(status);
CallCompleted(result);
@@ -197,8 +219,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
- int64 MaxInvocationResults() { return num_parallel_calls_; }
-
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
@@ -218,27 +238,33 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_);
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
+ return num_calls_ >= num_parallel_calls ||
+ invocation_results_.size() >= num_parallel_calls;
+ };
while (true) {
{
- mutex_lock l(mu_);
- while (!cancelled_ &&
- (num_calls_ >= num_parallel_calls_ ||
- invocation_results_.size() >= MaxInvocationResults())) {
- cond_var_.wait(l);
+ mutex_lock l(*mu_);
+ while (!cancelled_ && busy()) {
+ RecordStop(ctx.get());
+ cond_var_->wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
- while (num_calls_ < num_parallel_calls_ &&
- invocation_results_.size() < MaxInvocationResults()) {
- invocation_results_.emplace_back(new InvocationResult());
+ while (!busy()) {
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
+ cond_var_->notify_all();
}
- cond_var_.notify_all();
for (const auto& call : new_calls) {
CallFunction(ctx, call);
}
@@ -247,7 +273,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -258,7 +285,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -285,24 +312,26 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
const DatasetBase* const input_dataset_; // Not owned.
+ const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
- const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
+ // Identifies the maximum number of parallel calls.
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
} // namespace
@@ -311,8 +340,19 @@ std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls) {
- return std::unique_ptr<IteratorBase>(new ParallelMapIterator(
- params, input_dataset, std::move(map_func), num_parallel_calls));
+ return NewParallelMapIterator(params, input_dataset, nullptr,
+ std::move(map_func), num_parallel_calls);
+}
+
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls) {
+ return std::unique_ptr<IteratorBase>(
+ new ParallelMapIterator(params, input_dataset, std::move(init_func),
+ std::move(map_func), num_parallel_calls));
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.h b/tensorflow/core/kernels/data/parallel_map_iterator.h
index 2ce36c3869..dc26c5cf25 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.h
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
namespace tensorflow {
+namespace data {
// A function that transforms elements of one dataset into another
// asynchronously. The arguments are:
@@ -33,12 +34,21 @@ using ParallelMapIteratorFunction =
std::vector<Tensor>*, StatusCallback)>;
// Returns a new iterator that applies `map_func` to the elements of
-// `input_dataset` using the given degree of parallelism.
+// `input_dataset` using the given degree of parallelism. `init_func` (if
+// specified) will be executed when the iterator is initialized (see
+// `IteratorBase::Initialize()`) and enables the user to specify error checking
+// logic that can fail early.
+std::unique_ptr<IteratorBase> NewParallelMapIterator(
+ const DatasetBaseIterator::BaseParams& params,
+ const DatasetBase* input_dataset,
+ std::function<Status(IteratorContext*)> init_func,
+ ParallelMapIteratorFunction map_func, int32 num_parallel_calls);
std::unique_ptr<IteratorBase> NewParallelMapIterator(
const DatasetBaseIterator::BaseParams& params,
const DatasetBase* input_dataset, ParallelMapIteratorFunction map_func,
int32 num_parallel_calls);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PARALLEL_MAP_ITERATOR_H_
diff --git a/tensorflow/core/kernels/data/parse_example_dataset_op.cc b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
new file mode 100644
index 0000000000..1d1a717062
--- /dev/null
+++ b/tensorflow/core/kernels/data/parse_example_dataset_op.cc
@@ -0,0 +1,369 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <deque>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
+#include "tensorflow/core/kernels/data/parallel_map_iterator.h"
+#include "tensorflow/core/util/example_proto_fast_parsing.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
+ public:
+ explicit ParseExampleDatasetOp(OpKernelConstruction* ctx)
+ : UnaryDatasetOpKernel(ctx),
+ graph_def_version_(ctx->graph_def_version()) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_keys", &sparse_keys_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_keys", &dense_keys_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("sparse_types", &sparse_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("Tdense", &dense_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dense_shapes", &dense_shapes_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_shapes", &output_shapes_));
+ for (int i = 0; i < dense_shapes_.size(); ++i) {
+ bool shape_ok = true;
+ if (dense_shapes_[i].dims() == -1) {
+ shape_ok = false;
+ } else {
+ for (int d = 1; d < dense_shapes_[i].dims(); ++d) {
+ if (dense_shapes_[i].dim_size(d) == -1) {
+ shape_ok = false;
+ }
+ }
+ }
+ OP_REQUIRES(ctx, shape_ok,
+ errors::InvalidArgument(
+ "dense_shapes[", i,
+ "] has unknown rank or unknown inner dimensions: ",
+ dense_shapes_[i].DebugString()));
+ TensorShape dense_shape;
+ if (dense_shapes_[i].dims() > 0 && dense_shapes_[i].dim_size(0) == -1) {
+ variable_length_.push_back(true);
+ for (int d = 1; d < dense_shapes_[i].dims(); ++d) {
+ dense_shape.AddDim(dense_shapes_[i].dim_size(d));
+ }
+ } else {
+ variable_length_.push_back(false);
+ dense_shapes_[i].AsTensorShape(&dense_shape);
+ }
+ elements_per_stride_.push_back(dense_shape.num_elements());
+ }
+ }
+
+ protected:
+ void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
+ DatasetBase** output) override {
+ int64 num_parallel_calls;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
+ &num_parallel_calls));
+ OP_REQUIRES(ctx, num_parallel_calls > 0,
+ errors::InvalidArgument(
+ "num_parallel_calls must be greater than zero."));
+
+ OpInputList dense_default_tensors;
+ OP_REQUIRES_OK(ctx,
+ ctx->input_list("dense_defaults", &dense_default_tensors));
+
+ OP_REQUIRES(ctx, dense_default_tensors.size() == dense_keys_.size(),
+ errors::InvalidArgument(
+ "Expected len(dense_defaults) == len(dense_keys) but got: ",
+ dense_default_tensors.size(), " vs. ", dense_keys_.size()));
+
+ std::vector<Tensor> dense_defaults(dense_default_tensors.begin(),
+ dense_default_tensors.end());
+
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ const Tensor& def_value = dense_defaults[d];
+ if (variable_length_[d]) {
+ OP_REQUIRES(ctx, def_value.NumElements() == 1,
+ errors::InvalidArgument(
+ "dense_shape[", d, "] is a variable length shape: ",
+ dense_shapes_[d].DebugString(),
+ ", therefore "
+ "def_value[",
+ d,
+ "] must contain a single element ("
+ "the padding element). But its shape is: ",
+ def_value.shape().DebugString()));
+ } else if (def_value.NumElements() > 0) {
+ OP_REQUIRES(ctx, dense_shapes_[d].IsCompatibleWith(def_value.shape()),
+ errors::InvalidArgument(
+ "def_value[", d,
+ "].shape() == ", def_value.shape().DebugString(),
+ " is not compatible with dense_shapes_[", d,
+ "] == ", dense_shapes_[d].DebugString()));
+ }
+ OP_REQUIRES(ctx, def_value.dtype() == dense_types_[d],
+ errors::InvalidArgument(
+ "dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != dense_types_[", d,
+ "] == ", DataTypeString(dense_types_[d])));
+ }
+
+ example::FastParseExampleConfig config;
+ std::map<string, int> key_to_output_index;
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ config.dense.push_back({dense_keys_[d], dense_types_[d], dense_shapes_[d],
+ dense_default_tensors[d], variable_length_[d],
+ elements_per_stride_[d]});
+ auto result = key_to_output_index.insert({dense_keys_[d], 0});
+ OP_REQUIRES(ctx, result.second,
+ errors::InvalidArgument("Duplicate key not allowed: ",
+ dense_keys_[d]));
+ }
+ for (int d = 0; d < sparse_keys_.size(); ++d) {
+ config.sparse.push_back({sparse_keys_[d], sparse_types_[d]});
+ auto result = key_to_output_index.insert({sparse_keys_[d], 0});
+ OP_REQUIRES(ctx, result.second,
+ errors::InvalidArgument("Duplicate key not allowed: ",
+ sparse_keys_[d]));
+ }
+ int i = 0;
+ for (auto it = key_to_output_index.begin(); it != key_to_output_index.end();
+ it++) {
+ it->second = i++;
+ }
+
+ *output = new Dataset(ctx, input, std::move(dense_defaults),
+ std::move(sparse_keys_), std::move(dense_keys_),
+ std::move(key_to_output_index), std::move(config),
+ num_parallel_calls, sparse_types_, dense_types_,
+ dense_shapes_, output_types_, output_shapes_);
+ }
+
+ private:
+ class Dataset : public DatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, const DatasetBase* input,
+ std::vector<Tensor> dense_defaults, std::vector<string> sparse_keys,
+ std::vector<string> dense_keys,
+ std::map<string, int> key_to_output_index,
+ example::FastParseExampleConfig config, int32 num_parallel_calls,
+ const DataTypeVector& sparse_types,
+ const DataTypeVector& dense_types,
+ const std::vector<PartialTensorShape>& dense_shapes,
+ const DataTypeVector& output_types,
+ const std::vector<PartialTensorShape>& output_shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ dense_defaults_(std::move(dense_defaults)),
+ sparse_keys_(std::move(sparse_keys)),
+ dense_keys_(std::move(dense_keys)),
+ key_to_output_index_(std::move(key_to_output_index)),
+ config_(std::move(config)),
+ num_parallel_calls_(num_parallel_calls),
+ sparse_types_(sparse_types),
+ dense_types_(dense_types),
+ dense_shapes_(dense_shapes),
+ output_types_(output_types),
+ output_shapes_(output_shapes) {
+ input_->Ref();
+ }
+
+ ~Dataset() override { input_->Unref(); }
+
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override {
+ auto map_fn = [this](IteratorContext* ctx,
+ std::vector<Tensor> input_element,
+ std::vector<Tensor>* result, StatusCallback done) {
+ (*ctx->runner())([this, ctx, input_element, result, done]() {
+ thread::ThreadPool* device_threadpool =
+ ctx->lib()->device()->tensorflow_cpu_worker_threads()->workers;
+ std::vector<string> slice_vec;
+ for (Tensor t : input_element) {
+ auto serialized_t = t.flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(),
+ serialized_t.size());
+ for (auto it = slice.begin(); it != slice.end(); it++)
+ slice_vec.push_back(*it);
+ }
+ example::FastParseExampleConfig config = config_;
+ // local copy of config_ for modification.
+ auto stats_aggregator = ctx->stats_aggregator();
+ if (stats_aggregator) {
+ config.collect_feature_stats = true;
+ }
+ example::Result example_result;
+ Status s = FastParseExample(config, slice_vec, {}, device_threadpool,
+ &example_result);
+ if (s.ok()) {
+ (*result).resize(key_to_output_index_.size());
+ for (int d = 0; d < dense_keys_.size(); ++d) {
+ int output_index = key_to_output_index_.at(dense_keys_[d]);
+ CHECK(example_result.dense_values[d].dtype() ==
+ output_dtypes()[output_index])
+ << "Got wrong type for FastParseExample return value " << d
+ << " (expected "
+ << DataTypeString(output_dtypes()[output_index]) << ", got "
+ << DataTypeString(example_result.dense_values[d].dtype())
+ << ").";
+ CHECK(output_shapes()[output_index].IsCompatibleWith(
+ example_result.dense_values[d].shape()))
+ << "Got wrong shape for FastParseExample return value " << d
+ << " (expected "
+ << output_shapes()[output_index].DebugString() << ", got "
+ << example_result.dense_values[d].shape().DebugString()
+ << ").";
+ (*result)[output_index] = example_result.dense_values[d];
+ }
+ for (int d = 0; d < sparse_keys_.size(); ++d) {
+ Tensor serialized_sparse = Tensor(DT_VARIANT, TensorShape({3}));
+ auto serialized_sparse_t = serialized_sparse.vec<Variant>();
+ serialized_sparse_t(0) = example_result.sparse_indices[d];
+ serialized_sparse_t(1) = example_result.sparse_values[d];
+ serialized_sparse_t(2) = example_result.sparse_shapes[d];
+ int output_index = key_to_output_index_.at(sparse_keys_[d]);
+ CHECK(serialized_sparse.dtype() == output_dtypes()[output_index])
+ << "Got wrong type for FastParseExample return value " << d
+ << " (expected "
+ << DataTypeString(output_dtypes()[output_index]) << ", got "
+ << DataTypeString(serialized_sparse.dtype()) << ").";
+ CHECK(output_shapes()[output_index].IsCompatibleWith(
+ serialized_sparse.shape()))
+ << "Got wrong shape for FastParseExample return value " << d
+ << " (expected "
+ << output_shapes()[output_index].DebugString() << ", got "
+ << serialized_sparse.shape().DebugString() << ").";
+ (*result)[output_index] = serialized_sparse;
+ }
+ // TODO(b/111553342): User provided tags instead of fixed tag.
+ if (stats_aggregator) {
+ stats_aggregator->IncrementCounter(
+ "examples_count", "trainer",
+ example_result.feature_stats.size());
+ for (example::PerExampleFeatureStats feature_stats :
+ example_result.feature_stats) {
+ stats_aggregator->AddToHistogram(
+ "features",
+ {static_cast<double>(feature_stats.features_count)});
+ stats_aggregator->IncrementCounter(
+ "features_count", "trainer", feature_stats.features_count);
+ stats_aggregator->IncrementCounter(
+ "feature_values_count", "trainer",
+ feature_stats.feature_values_count);
+ stats_aggregator->AddToHistogram(
+ "feature-values",
+ {static_cast<double>(feature_stats.feature_values_count)});
+ }
+ }
+ }
+ done(s);
+ });
+ };
+
+ return NewParallelMapIterator(
+ {this, strings::StrCat(prefix, "::ParseExample")}, input_,
+ std::move(map_fn), num_parallel_calls_);
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ return output_types_;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ return output_shapes_;
+ }
+
+ string DebugString() const override {
+ return "ParseExampleDatasetOp::Dataset";
+ }
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+
+ Node* num_parallle_calls_node;
+ std::vector<Node*> dense_defaults_nodes;
+ dense_defaults_nodes.reserve(dense_defaults_.size());
+
+ TF_RETURN_IF_ERROR(
+ b->AddScalar(num_parallel_calls_, &num_parallle_calls_node));
+
+ for (const Tensor& dense_default : dense_defaults_) {
+ Node* node;
+ TF_RETURN_IF_ERROR(b->AddTensor(dense_default, &node));
+ dense_defaults_nodes.emplace_back(node);
+ }
+
+ AttrValue sparse_keys_attr;
+ AttrValue dense_keys_attr;
+ AttrValue sparse_types_attr;
+ AttrValue dense_attr;
+ AttrValue dense_shapes_attr;
+
+ b->BuildAttrValue(sparse_keys_, &sparse_keys_attr);
+ b->BuildAttrValue(dense_keys_, &dense_keys_attr);
+ b->BuildAttrValue(sparse_types_, &sparse_types_attr);
+ b->BuildAttrValue(dense_types_, &dense_attr);
+ b->BuildAttrValue(dense_shapes_, &dense_shapes_attr);
+
+ TF_RETURN_IF_ERROR(b->AddDataset(this,
+ {
+ {0, input_graph_node},
+ {1, num_parallle_calls_node},
+ },
+ {{2, dense_defaults_nodes}},
+ {{"sparse_keys", sparse_keys_attr},
+ {"dense_keys", dense_keys_attr},
+ {"sparse_types", sparse_types_attr},
+ {"Tdense", dense_attr},
+ {"dense_shapes", dense_shapes_attr}},
+ output));
+ return Status::OK();
+ }
+
+ private:
+ const DatasetBase* const input_;
+ const std::vector<Tensor> dense_defaults_;
+ const std::vector<string> sparse_keys_;
+ const std::vector<string> dense_keys_;
+ const std::map<string, int> key_to_output_index_;
+ const example::FastParseExampleConfig config_;
+ const int64 num_parallel_calls_;
+ const DataTypeVector sparse_types_;
+ const DataTypeVector dense_types_;
+ const std::vector<PartialTensorShape> dense_shapes_;
+ const DataTypeVector output_types_;
+ const std::vector<PartialTensorShape> output_shapes_;
+ };
+
+ const int graph_def_version_;
+ DataTypeVector output_types_;
+ std::vector<PartialTensorShape> output_shapes_;
+ std::vector<string> sparse_keys_;
+ std::vector<string> dense_keys_;
+ DataTypeVector sparse_types_;
+ DataTypeVector dense_types_;
+ std::vector<PartialTensorShape> dense_shapes_;
+ std::vector<bool> variable_length_;
+ std::vector<std::size_t> elements_per_stride_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseExampleDataset").Device(DEVICE_CPU),
+ ParseExampleDatasetOp);
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.cc b/tensorflow/core/kernels/data/prefetch_autotuner.cc
index b3272f6bcd..da357339c9 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
namespace tensorflow {
+namespace data {
PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
: buffer_limit_(initial_buffer_size) {
@@ -25,6 +26,13 @@ PrefetchAutotuner::PrefetchAutotuner(int64 initial_buffer_size)
}
}
+namespace {
+// Determines what strategy to use for increasing the buffer size limit. For
+// limits less than the threshold, an exponential increase is used, while for
+// limits greater than or equal to the threshold, a linear increase is used.
+size_t kBufferLimitThreshold = 2048;
+} // namespace
+
void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
switch (mode_) {
case Mode::kDisabled:
@@ -36,11 +44,16 @@ void PrefetchAutotuner::RecordConsumption(size_t current_buffer_size) {
return;
case Mode::kDownswing:
if (current_buffer_size == 0) {
- buffer_limit_ *= 2; // Increase the buffer size.
+ if (buffer_limit_ >= kBufferLimitThreshold) {
+ buffer_limit_ += kBufferLimitThreshold;
+ } else {
+ buffer_limit_ *= 2;
+ }
mode_ = Mode::kUpswing;
}
return;
}
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner.h b/tensorflow/core/kernels/data/prefetch_autotuner.h
index fa8a184072..8693205512 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner.h
+++ b/tensorflow/core/kernels/data/prefetch_autotuner.h
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
+namespace data {
// PrefetchAutotuner dynamically adjusts the buffer size of a prefetch iterator.
//
@@ -66,6 +67,7 @@ class PrefetchAutotuner {
Mode mode_ = Mode::kDisabled;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_AUTOTUNER_H_
diff --git a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
index 29a8cc50cd..cfc324fc7e 100644
--- a/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
+++ b/tensorflow/core/kernels/data/prefetch_autotuner_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
+namespace data {
namespace {
TEST(PrefetchAutotuner, Disabled) {
@@ -79,4 +80,5 @@ TEST(PrefetchAutotuner, EnabledSteady) {
}
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.cc b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
index 20148a4378..754ed772db 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.cc
@@ -12,23 +12,29 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include <deque>
-
#include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
+#include <deque>
+
#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
+namespace data {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
-class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
+class PrefetchDatasetOp::Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 buffer_size)
- : GraphDatasetBase(ctx), input_(input), buffer_size_(buffer_size) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ buffer_size_(buffer_size) {
input_->Ref();
}
@@ -68,7 +74,11 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- auto_tuner_(params.dataset->buffer_size_) {}
+ auto_tuner_(params.dataset->buffer_size_) {
+ std::vector<string> components =
+ str_util::Split(params.prefix, "::", str_util::SkipEmpty());
+ prefix_end_ = components.back();
+ }
~Iterator() override {
// Signal the prefetch thread to terminate it. We will then
@@ -93,6 +103,7 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ auto stats_aggregator = ctx->stats_aggregator();
{
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(EnsurePrefetchThreadStarted(ctx));
@@ -101,7 +112,9 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
while (!cancelled_ && buffer_.empty() && !prefetch_thread_finished_ &&
auto_tuner_.buffer_limit() != 0) {
auto_tuner_.RecordEmpty();
+ RecordStop(ctx);
cond_var_.wait(l);
+ RecordStart(ctx);
}
if (cancelled_) {
@@ -110,7 +123,7 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
}
if (!buffer_.empty()) {
- return Consume(out_tensors, end_of_sequence);
+ return Consume(out_tensors, end_of_sequence, stats_aggregator);
}
if (prefetch_thread_finished_) {
@@ -123,6 +136,14 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
mutex_lock parent_l(parent_mu_);
mutex_lock l(mu_);
+ if (stats_aggregator) {
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
return input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
}
@@ -198,14 +219,28 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
std::vector<Tensor> value;
};
- Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence)
+ Status Consume(std::vector<Tensor>* out_tensors, bool* end_of_sequence,
+ const std::shared_ptr<StatsAggregator>& stats_aggregator)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ if (stats_aggregator) {
+ stats_aggregator->AddToHistogram(
+ strings::StrCat(prefix_end_, "::buffer_utilization"),
+ {static_cast<float>(buffer_.size()) /
+ static_cast<float>(auto_tuner_.buffer_limit())});
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_size"),
+ static_cast<float>(buffer_.size()));
+ stats_aggregator->AddScalar(
+ strings::StrCat(prefix_end_, "::buffer_capacity"),
+ static_cast<float>(auto_tuner_.buffer_limit()));
+ }
// A new element is available. Forward the status from computing it, and
// (if we successfully got an element) the output values.
Status s = buffer_.front().status;
if (s.ok()) {
*out_tensors = std::move(buffer_.front().value);
}
+ auto_tuner_.RecordConsumption(buffer_.size());
buffer_.pop_front();
*end_of_sequence = false;
@@ -221,10 +256,10 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
Status EnsurePrefetchThreadStarted(IteratorContext* ctx)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (!prefetch_thread_) {
- prefetch_thread_.reset(
- ctx->env()->StartThread({}, "prefetch_thread",
- std::bind(&Iterator::PrefetchThread, this,
- new IteratorContext(*ctx))));
+ std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
+ prefetch_thread_.reset(ctx->env()->StartThread(
+ {}, "prefetch_thread",
+ [this, new_ctx]() { PrefetchThread(new_ctx); }));
}
return Status::OK();
}
@@ -233,8 +268,9 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
// buffer.
//
// It owns the iterator context passed to it.
- void PrefetchThread(IteratorContext* ctx) {
- std::unique_ptr<IteratorContext> cleanup(ctx);
+ void PrefetchThread(const std::shared_ptr<IteratorContext>& ctx) {
+ RecordStart(ctx.get());
+ auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
while (true) {
std::vector<Tensor> value;
@@ -242,7 +278,9 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
{
mutex_lock l(mu_);
while (!cancelled_ && buffer_.size() >= auto_tuner_.buffer_limit()) {
+ RecordStop(ctx.get());
cond_var_.wait(l);
+ RecordStart(ctx.get());
}
if (cancelled_) {
@@ -259,8 +297,8 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
mutex_lock parent_l(parent_mu_);
bool end_of_sequence;
BufferElement buffer_element;
- buffer_element.status =
- input_impl_->GetNext(ctx, &buffer_element.value, &end_of_sequence);
+ buffer_element.status = input_impl_->GetNext(
+ ctx.get(), &buffer_element.value, &end_of_sequence);
if (buffer_element.status.ok() && end_of_sequence) {
mutex_lock l(mu_);
prefetch_thread_finished_ = true;
@@ -322,6 +360,7 @@ class PrefetchDatasetOp::Dataset : public GraphDatasetBase {
mutex parent_mu_ ACQUIRED_BEFORE(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(parent_mu_);
condition_variable cond_var_;
+ string prefix_end_;
PrefetchAutotuner auto_tuner_ GUARDED_BY(mu_);
std::deque<BufferElement> buffer_ GUARDED_BY(mu_);
std::unique_ptr<Thread> prefetch_thread_ GUARDED_BY(mu_);
@@ -344,6 +383,7 @@ void PrefetchDatasetOp::MakeDataset(OpKernelContext* ctx, DatasetBase* input,
*output = new Dataset(ctx, input, buffer_size);
}
+namespace {
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset").Device(DEVICE_CPU),
PrefetchDatasetOp);
REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
@@ -352,4 +392,7 @@ REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")
.HostMemory("input_dataset")
.HostMemory("handle"),
PrefetchDatasetOp);
+} // namespace
+
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/prefetch_dataset_op.h b/tensorflow/core/kernels/data/prefetch_dataset_op.h
index c40c4b00da..588fb25a06 100644
--- a/tensorflow/core/kernels/data/prefetch_dataset_op.h
+++ b/tensorflow/core/kernels/data/prefetch_dataset_op.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/prefetch_autotuner.h"
namespace tensorflow {
+namespace data {
class PrefetchDatasetOp : public UnaryDatasetOpKernel {
public:
@@ -34,6 +35,7 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
class Dataset;
};
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_PREFETCH_DATASET_OP_H_
diff --git a/tensorflow/core/kernels/data/random_dataset_op.cc b/tensorflow/core/kernels/data/random_dataset_op.cc
index 7e48428b3f..044a791a3f 100644
--- a/tensorflow/core/kernels/data/random_dataset_op.cc
+++ b/tensorflow/core/kernels/data/random_dataset_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random_distributions.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -49,10 +49,10 @@ class RandomDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
- : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
+ : DatasetBase(DatasetContext(ctx)), seed_(seed), seed2_(seed2) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
@@ -151,5 +151,5 @@ REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
RandomDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/range_dataset_op.cc b/tensorflow/core/kernels/data/range_dataset_op.cc
index 50bd3dac4e..89fbaae369 100644
--- a/tensorflow/core/kernels/data/range_dataset_op.cc
+++ b/tensorflow/core/kernels/data/range_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -43,10 +43,13 @@ class RangeDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 start, int64 stop, int64 step)
- : GraphDatasetBase(ctx), start_(start), stop_(stop), step_(step) {}
+ : DatasetBase(DatasetContext(ctx)),
+ start_(start),
+ stop_(stop),
+ step_(step) {}
std::unique_ptr<IteratorBase> MakeIteratorInternal(
const string& prefix) const override {
@@ -139,5 +142,5 @@ REGISTER_KERNEL_BUILDER(Name("RangeDataset").Device(DEVICE_CPU),
RangeDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/reader_dataset_ops.cc b/tensorflow/core/kernels/data/reader_dataset_ops.cc
index 6a71a7af1d..c474cb4773 100644
--- a/tensorflow/core/kernels/data/reader_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/reader_dataset_ops.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/zlib_inputstream.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -78,12 +78,12 @@ class TextLineDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type,
const io::ZlibCompressionOptions& options)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
use_compression_(!compression_type.empty()),
@@ -312,12 +312,12 @@ class FixedLengthRecordDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
int64 header_bytes, int64 record_bytes, int64 footer_bytes,
int64 buffer_size)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
header_bytes_(header_bytes),
record_bytes_(record_bytes),
@@ -531,11 +531,11 @@ class TFRecordDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<string> filenames,
const string& compression_type, int64 buffer_size)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
filenames_(std::move(filenames)),
compression_type_(compression_type),
options_(io::RecordReaderOptions::CreateRecordReaderOptions(
@@ -691,5 +691,5 @@ REGISTER_KERNEL_BUILDER(Name("TFRecordDataset").Device(DEVICE_CPU),
TFRecordDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/repeat_dataset_op.cc b/tensorflow/core/kernels/data/repeat_dataset_op.cc
index 093ea563b4..94e96635ab 100644
--- a/tensorflow/core/kernels/data/repeat_dataset_op.cc
+++ b/tensorflow/core/kernels/data/repeat_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -39,10 +39,10 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
@@ -172,32 +172,39 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
class ForeverIterator : public DatasetIterator<Dataset> {
public:
explicit ForeverIterator(const Params& params)
- : DatasetIterator<Dataset>(params), input_impl_(nullptr) {}
+ : DatasetIterator<Dataset>(params),
+ input_impl_(nullptr),
+ first_call_(true) {}
+
+ Status Initialize(IteratorContext* ctx) override {
+ mutex_lock l(mu_);
+ return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ }
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_); // TODO(mrry): Make locking less conservative.
do {
- bool first_call = false;
if (!input_impl_) {
- first_call = true;
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
}
- TF_RETURN_IF_ERROR(
- input_impl_->GetNext(ctx, out_tensors, end_of_sequence));
- if (!*end_of_sequence) {
+ Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
+ if (first_call_ && *end_of_sequence) {
+ // If the first call to GetNext() fails because the end
+ // of sequence has been reached, we terminate the
+ // iteration immediately. (Otherwise, this iterator
+ // would loop infinitely and never produce a value.)
+ input_impl_.reset();
return Status::OK();
+ }
+ first_call_ = false;
+ if (!*end_of_sequence) {
+ return s;
} else {
input_impl_.reset();
- if (first_call) {
- // If the first call to GetNext() fails because the end
- // of sequence has been reached, we terminate the
- // iteration immediately. (Otherwise, this iterator
- // would loop infinitely and never produce a value.)
- return Status::OK();
- }
+ first_call_ = true;
}
} while (true);
}
@@ -205,7 +212,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
- if (input_impl_)
+ if (!first_call_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
else
TF_RETURN_IF_ERROR(
@@ -218,10 +225,12 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
mutex_lock l(mu_);
if (reader->Contains(full_name("uninitialized"))) {
input_impl_.reset();
+ first_call_ = true;
} else {
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
+ first_call_ = false;
}
return Status::OK();
}
@@ -229,6 +238,7 @@ class RepeatDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ bool first_call_ GUARDED_BY(mu_);
};
const int64 count_;
@@ -240,5 +250,5 @@ REGISTER_KERNEL_BUILDER(Name("RepeatDataset").Device(DEVICE_CPU),
RepeatDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/scan_dataset_op.cc b/tensorflow/core/kernels/data/scan_dataset_op.cc
index 7c59874d96..2a911aa368 100644
--- a/tensorflow/core/kernels/data/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/data/scan_dataset_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -32,8 +32,7 @@ namespace {
class ScanDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ScanDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx),
- graph_def_version_(ctx->graph_def_version()) {
+ : UnaryDatasetOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("Tstate", &state_types_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
@@ -45,23 +44,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
OpInputList initial_state_inputs;
OP_REQUIRES_OK(ctx,
ctx->input_list("initial_state", &initial_state_inputs));
- std::vector<Tensor> initial_state;
- initial_state.reserve(initial_state_inputs.size());
- for (const Tensor& t : initial_state_inputs) {
- initial_state.push_back(t);
- }
-
- OpInputList inputs;
- OP_REQUIRES_OK(ctx, ctx->input_list("other_arguments", &inputs));
- std::vector<Tensor> other_arguments;
- other_arguments.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- other_arguments.push_back(t);
- }
+ std::vector<Tensor> initial_state(initial_state_inputs.begin(),
+ initial_state_inputs.end());
std::unique_ptr<CapturedFunction> captured_func;
- OP_REQUIRES_OK(ctx, CapturedFunction::Create(
- func_, std::move(other_arguments), &captured_func));
+ OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
+ &captured_func));
*output = new Dataset(ctx, input, func_, std::move(initial_state),
std::move(captured_func), state_types_, output_types_,
@@ -69,7 +57,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const DatasetBase* input,
const NameAttrList& func, std::vector<Tensor> initial_state,
@@ -77,7 +65,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const DataTypeVector& state_types,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
func_(func),
initial_state_(std::move(initial_state)),
@@ -109,7 +97,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,
Node** output) const override {
- TF_RETURN_IF_ERROR(b->AddFunction(ctx->flib_def(), func_.name()));
+ TF_RETURN_IF_ERROR(b->AddFunction(ctx, func_.name()));
Node* input_node;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
std::vector<Node*> initial_state_nodes;
@@ -153,7 +141,9 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
state_(params.dataset->initial_state_) {}
Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
+ TF_RETURN_IF_ERROR(
+ dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
+ return dataset()->captured_func_->Instantiate(ctx);
}
Status GetNextInternal(IteratorContext* ctx,
@@ -267,7 +257,6 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
const std::vector<PartialTensorShape> output_shapes_;
};
- const int graph_def_version_;
DataTypeVector state_types_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
@@ -277,5 +266,5 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ScanDataset").Device(DEVICE_CPU), ScanDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/shuffle_dataset_op.cc b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
index 603c3feb79..66466d6a36 100644
--- a/tensorflow/core/kernels/data/shuffle_dataset_op.cc
+++ b/tensorflow/core/kernels/data/shuffle_dataset_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
const int64 kLogIntervalMicros = 10 * 1000000; // 10 seconds.
@@ -40,11 +40,11 @@ class ShuffleDatasetOpBase : public UnaryDatasetOpKernel {
protected:
// Abstract base dataset that implements a shuffling iterator.
- class ShuffleDatasetBase : public GraphDatasetBase {
+ class ShuffleDatasetBase : public DatasetBase {
public:
ShuffleDatasetBase(OpKernelContext* ctx, const DatasetBase* input,
int64 buffer_size, int64 count)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
buffer_size_(buffer_size),
count_(count) {
@@ -620,5 +620,5 @@ REGISTER_KERNEL_BUILDER(Name("ShuffleAndRepeatDataset").Device(DEVICE_CPU),
ShuffleAndRepeatDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.cc b/tensorflow/core/kernels/data/single_threaded_executor.cc
new file mode 100644
index 0000000000..5b084a16f0
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.cc
@@ -0,0 +1,380 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/executor_factory.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+typedef gtl::InlinedVector<TensorValue, 4> TensorValueVec;
+typedef gtl::InlinedVector<DeviceContext*, 4> DeviceContextVec;
+typedef gtl::InlinedVector<AllocatorAttributes, 4> AllocatorAttributeVec;
+
+class SingleThreadedExecutorImpl : public Executor {
+ public:
+ explicit SingleThreadedExecutorImpl(const LocalExecutorParams& params)
+ : params_(params) {}
+
+ ~SingleThreadedExecutorImpl() override {
+ for (const KernelState& kernel_state : kernels_) {
+ params_.delete_kernel(kernel_state.kernel);
+ }
+ }
+
+ Status Initialize(const Graph& graph) {
+ // Topologicially sort `graph` to get a sequence of OpKernels.
+ std::vector<Node*> ordered_nodes;
+ ordered_nodes.reserve(graph.num_nodes());
+ GetReversePostOrder(graph, &ordered_nodes);
+
+ if (ordered_nodes.size() != graph.num_nodes()) {
+ return errors::InvalidArgument("Graph had ", graph.num_nodes(),
+ " but reverse post-order had ",
+ ordered_nodes.size());
+ }
+
+ kernels_.resize(ordered_nodes.size());
+
+ std::unordered_map<Node*, size_t> node_to_index_map;
+
+ // Create the kernel and input-related structures for each node in `graph`.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ node_to_index_map[n] = i;
+
+ for (DataType dt : n->output_types()) {
+ if (IsRefType(dt)) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support reference-typed "
+ "edges.");
+ }
+ }
+
+ if (n->IsControlFlow()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support control flow.");
+ }
+ if (n->IsSend() || n->IsHostSend() || n->IsRecv() || n->IsHostRecv()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support partitioned graphs.");
+ }
+ if (n->IsCollective()) {
+ return errors::Unimplemented(
+ "Single-threaded executor does not support collective ops.");
+ }
+
+ KernelState& kernel_state = kernels_[i];
+ TF_RETURN_IF_ERROR(params_.create_kernel(n->def(), &kernel_state.kernel));
+ kernel_state.num_inputs = n->num_inputs();
+ kernel_state.num_outputs = n->num_outputs();
+
+ if (i == 0) {
+ kernel_state.input_start_index = 0;
+ } else {
+ const KernelState& previous_kernel_state = kernels_[i - 1];
+ kernel_state.input_start_index =
+ previous_kernel_state.input_start_index +
+ previous_kernel_state.num_inputs;
+ }
+ }
+
+ // Build the mapping from each node output to the input slot for the
+ // corresponding destination node.
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ Node* n = ordered_nodes[i];
+ KernelState& kernel_state = kernels_[i];
+ kernel_state.output_locations.resize(kernel_state.num_outputs);
+ for (const Edge* e : n->out_edges()) {
+ if (!e->IsControlEdge()) {
+ kernel_state.output_locations[e->src_output()].push_back(
+ kernels_[node_to_index_map[e->dst()]].input_start_index +
+ e->dst_input());
+ }
+ }
+
+ // Compute allocator attributes for each node output, and corresponding
+ // node input.
+ kernel_state.output_alloc_attrs.resize(kernel_state.num_outputs);
+ AllocatorAttributes* attrs = kernel_state.output_alloc_attrs.data();
+
+ OpKernel* op_kernel = kernel_state.kernel;
+ for (int out = 0; out < n->num_outputs(); out++) {
+ DCHECK_LT(out, op_kernel->output_memory_types().size());
+ bool on_host = op_kernel->output_memory_types()[out] == HOST_MEMORY;
+ if (on_host) {
+ AllocatorAttributes h;
+ h.set_on_host(on_host);
+ attrs[out].Merge(h);
+ }
+ }
+ }
+
+ if (!kernels_.empty()) {
+ const KernelState& last_kernel_state = kernels_.back();
+ total_num_inputs_ =
+ last_kernel_state.input_start_index + last_kernel_state.num_inputs;
+ input_alloc_attrs_.resize(total_num_inputs_);
+ for (size_t i = 0; i < ordered_nodes.size(); ++i) {
+ for (size_t j = 0; j < kernels_[i].output_locations.size(); ++j) {
+ for (size_t output_location : kernels_[i].output_locations[j]) {
+ input_alloc_attrs_[output_location] =
+ kernels_[i].output_alloc_attrs[j];
+ }
+ }
+ }
+ } else {
+ total_num_inputs_ = 0;
+ }
+ return Status::OK();
+ }
+
+ // TODO(mrry): Consider specializing the implementation of Executor::Run()
+ // instead, to avoid unnecessary atomic operations in the callback when
+ // running synchronously.
+ void RunAsync(const Args& args, DoneCallback done) override {
+ // The inputs to each kernel are stored contiguously in `inputs`.
+ //
+ // We use `kernels_[i].input_start_index` and `kernels_[i].num_inputs` to
+ // determine the range of elements in this vector that correspond to
+ // the inputs of `kernels_[i]`.
+ //
+ // This vector has the following layout:
+ //
+ // * Kernel 0, input 0.
+ // * Kernel 0, input 1.
+ // * ...
+ // * Kernel 0, input `kernels_[0].num_inputs - 1`.
+ // * Kernel 1, input 0.
+ // * ...
+ // * Kernel 1, input `kernels_[1].num_inputs - 1`.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input 0.
+ // * ...
+ // * Kernel `kernels_.size() - 1`, input `kernels_.back().num_inputs - 1`.
+ //
+ // Note that kernels with zero inputs do not correspond to any elements in
+ // this vector.
+ //
+ // We use `ManualConstructor<Tensor>` to avoid the overhead of
+ // default-constructing an invalid `Tensor` for each slot at the beginning
+ // of execution:
+ // * Elements are initialized when the outputs of a kernel execution are
+ // propagated to the inputs of kernels that depend on them.
+ // * The elements corresponding to the inputs for kernel `i` are destroyed
+ // after kernel `i` executes.
+ // * In an error case (see below), we use the connectivity information in
+ // `KernelState::output_locations` to determine which locations have been
+ // initialized, and manually destroy them.
+ std::vector<ManualConstructor<Tensor>> inputs(total_num_inputs_);
+
+ // TODO(mrry): Can we avoid copying into these vectors? Consider modifying
+ // OpKernelContext to take the TensorValueVec as a pointer into `inputs`.
+ TensorValueVec node_inputs;
+ DeviceContextVec input_device_contexts;
+ AllocatorAttributeVec input_alloc_attrs;
+
+ // Prepare the parameters that will be the same for all kernels.
+ OpKernelContext::Params params;
+ params.step_id = args.step_id;
+ Device* device = params_.device;
+ params.device = device;
+ params.log_memory = false; // TODO(mrry): Too severe?
+ params.record_tensor_accesses = false; // TODO(mrry): Too severe?
+ params.rendezvous = args.rendezvous;
+ params.session_state = args.session_state;
+ params.tensor_store = args.tensor_store;
+ params.cancellation_manager = args.cancellation_manager;
+ // TODO(mrry): ArgOp is a relatively expensive OpKernel due to the Tensor
+ // allocations that it performs. Consider specializing its handling in the
+ // executor.
+ params.call_frame = args.call_frame;
+ params.function_library = params_.function_library;
+ params.resource_manager = device->resource_manager();
+ params.step_container = args.step_container;
+ params.slice_reader_cache = nullptr; // TODO(mrry): Too severe?
+ params.inputs = &node_inputs;
+ params.input_device_contexts = &input_device_contexts;
+ params.input_alloc_attrs = &input_alloc_attrs;
+
+ Args::Runner runner_copy = args.runner;
+ params.runner = &runner_copy;
+ params.stats_collector = args.stats_collector;
+
+ // NOTE(mrry): We are assuming that the graph is loopless and condless.
+ params.frame_iter = FrameAndIter(0, 0);
+ params.is_input_dead = false;
+
+ // TODO(mrry): Add non-default device context inference.
+ params.op_device_context = nullptr;
+ // TODO(mrry): Consider implementing forwarding.
+ params.forward_from_array = nullptr;
+
+ // Execute the kernels one-at-a-time in topological order.
+ for (size_t i = 0; i < kernels_.size(); ++i) {
+ const KernelState& kernel_state = kernels_[i];
+
+ // Prepare the per-kernel parameters.
+ const size_t input_start_index = kernel_state.input_start_index;
+ const size_t num_inputs = kernel_state.num_inputs;
+ const size_t num_outputs = kernel_state.num_outputs;
+
+ node_inputs.clear();
+ node_inputs.resize(num_inputs);
+ input_alloc_attrs.clear();
+ input_alloc_attrs.resize(num_inputs);
+ for (size_t j = 0; j < num_inputs; ++j) {
+ auto t = inputs[input_start_index + j].get();
+ node_inputs[j].tensor = t;
+ input_alloc_attrs[j] = input_alloc_attrs_[input_start_index + j];
+ }
+ params.op_kernel = kernel_state.kernel;
+ input_device_contexts.clear();
+ input_device_contexts.resize(num_inputs);
+ params.output_attr_array = kernel_state.output_alloc_attrs.data();
+ OpKernelContext ctx(&params, num_outputs);
+
+ // Actually execute the kernel.
+ device->Compute(kernel_state.kernel, &ctx);
+
+ if (!ctx.status().ok()) {
+ // On failure, we must manually free all intermediate tensors. We have
+ // already freed all the inputs for kernels up to (but not including)
+ // the `i`th kernel. We scan through the previously executed kernels and
+ // destroy any tensors that were destined to be the input for a kernel
+ // that has not yet executed.
+ for (size_t j = 0; j < i; ++j) {
+ const KernelState& executed_kernel_state = kernels_[j];
+ for (size_t k = 0; k < executed_kernel_state.num_outputs; ++k) {
+ for (size_t output_location :
+ executed_kernel_state.output_locations[k]) {
+ if (output_location >= input_start_index) {
+ // Only destroy an output location if it is an input to an
+ // operation that has not yet executed.
+ inputs[output_location].Destroy();
+ }
+ }
+ }
+ }
+ done(ctx.status());
+ return;
+ }
+
+ // Free the inputs to the current kernel.
+ for (size_t j = 0; j < num_inputs; ++j) {
+ inputs[input_start_index + j].Destroy();
+ }
+
+ // Forward the outputs of the kernel to the inputs of subsequent kernels.
+ for (size_t j = 0; j < num_outputs; ++j) {
+ TensorValue val = ctx.release_output(j);
+ // TODO(mrry): Consider flattening the `output_locations` vector
+ // to improve the cache-friendliness of this loop.
+ for (size_t output_location : kernel_state.output_locations[j]) {
+ // TODO(mrry): Validate that the types match the expected values or
+ // ensure that the necessary validation has already happened.
+ inputs[output_location].Init(*val.tensor);
+ }
+ delete val.tensor;
+ }
+ }
+ done(Status::OK());
+ }
+
+ private:
+ const LocalExecutorParams params_;
+
+ // All following members are read-only after Initialize().
+
+ // The sum of the number of inputs for each node in the graph. This determines
+ // the length of the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ size_t total_num_inputs_;
+
+ // Represents cached graph structure state for each kernel.
+ struct KernelState {
+ // The kernel object. Not owned.
+ //
+ // This pointer is managed by `params_.create_kernel()` and
+ // `params_.delete_kernel()`.
+ OpKernel* kernel;
+
+ // These fields determine the range of elements in `inputs` that corresponds
+ // to the inputs of `kernel`.
+ size_t input_start_index;
+ size_t num_inputs;
+
+ size_t num_outputs;
+
+ // For the `j`th output of `kernel`, `output_locations[j]` contains the
+ // locations in the flat `inputs` vector to which that output must be
+ // copied. See comment at the beginning of `RunAsync()` for details.
+ std::vector<std::vector<size_t>>
+ output_locations; // Length = `num_outputs`.
+
+ // Memory space information for each output of `kernel`.
+ std::vector<AllocatorAttributes>
+ output_alloc_attrs; // Length = `num_outputs`.
+ };
+ std::vector<KernelState> kernels_;
+
+ // Memory space information for each input. This information is stored in the
+ // same order as the flat `inputs` vector. See comment at the beginning of
+ // `RunAsync()` for details.
+ std::vector<AllocatorAttributes>
+ input_alloc_attrs_; // Length = `total_num_inputs_`.
+};
+
+class SingleThreadedExecutorRegistrar {
+ public:
+ SingleThreadedExecutorRegistrar() {
+ ExecutorFactory::Register("SINGLE_THREADED_EXECUTOR", new Factory());
+ }
+
+ private:
+ class Factory : public ExecutorFactory {
+ Status NewExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ std::unique_ptr<Executor>* out_executor) override {
+ Executor* ret;
+ TF_RETURN_IF_ERROR(
+ NewSingleThreadedExecutor(params, std::move(graph), &ret));
+ out_executor->reset(ret);
+ return Status::OK();
+ }
+ };
+};
+static SingleThreadedExecutorRegistrar registrar;
+
+} // namespace
+
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor) {
+ std::unique_ptr<SingleThreadedExecutorImpl> impl(
+ new SingleThreadedExecutorImpl(params));
+ TF_RETURN_IF_ERROR(impl->Initialize(*graph));
+ *executor = impl.release();
+ return Status::OK();
+}
+
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/single_threaded_executor.h b/tensorflow/core/kernels/data/single_threaded_executor.h
new file mode 100644
index 0000000000..e934352a1d
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor.h
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
+
+#include "tensorflow/core/common_runtime/executor.h"
+
+namespace tensorflow {
+namespace data {
+
+// Creates a new `Executor` for executing `graph` synchronously on the caller
+// thread.
+//
+// NOTE(mrry): The returned executor is optimized to impose low overhead on
+// graphs that perform a small amount of work (e.g. <15us of work per graph on
+// present architectures). It eschews concurrency, because issuing work to
+// multiple threads can dominate the cost of executing small ops synchronously,
+// and because contention in the executor data structures can reduce throughput
+// (in terms of ops executed per unit time).
+//
+// However, the current implementation has the following limitations:
+//
+// 1. Reference-typed tensors are not supported and will not be supported in
+// future.
+// 2. Graphs with control flow (containing "Switch" and "Merge" nodes) are not
+// currently supported. The current plan is to extend support to "functional"
+// control flow after the TensorFlow APIs transition to building graphs in
+// that form (e.g. `tf.cond_v2()`).
+// 3. Partitioned graphs (containing "_Recv" nodes) are not currently supported.
+// The present implementation executes kernels one at a time in topological
+// order, and cannot currently distinguish between disconnected subgraphs
+// that are logically connected by subgraphs on a different device.
+// 4. Memory logging is not currently supported.
+// 5. Allocation forwarding is not currently supported.
+// 6. Non-default device contexts are not currently supported. In effect, this
+// limits the executor to CPU devices.
+// 7. Ops that rely on `OpKernelContext::slice_reader_cache()` being non-null
+// are not currently supported.
+//
+// The single-threaded executor is primarily suitable for executing simple
+// TensorFlow functions, such as one might find in a `tf.data` pipeline.
+Status NewSingleThreadedExecutor(const LocalExecutorParams& params,
+ std::unique_ptr<const Graph> graph,
+ Executor** executor);
+
+} // namespace data
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_DATA_SINGLE_THREADED_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/data/single_threaded_executor_test.cc b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
new file mode 100644
index 0000000000..6244e287bb
--- /dev/null
+++ b/tensorflow/core/kernels/data/single_threaded_executor_test.cc
@@ -0,0 +1,332 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/data/single_threaded_executor.h"
+
+#include <algorithm>
+
+#include "tensorflow/core/common_runtime/device.h"
+#include "tensorflow/core/common_runtime/device_factory.h"
+#include "tensorflow/core/common_runtime/executor.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/common_runtime/process_util.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/rendezvous.h"
+#include "tensorflow/core/framework/versions.pb.h"
+#include "tensorflow/core/graph/algorithm.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/tracing.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace data {
+namespace {
+
+class ExecutorTest : public ::testing::Test {
+ protected:
+ ExecutorTest()
+ : device_(DeviceFactory::NewDevice("CPU", {},
+ "/job:localhost/replica:0/task:0")) {}
+
+ ~ExecutorTest() override {
+ // There should always be exactly one Ref left on the Rendezvous
+ // when the test completes.
+ CHECK(rendez_->Unref());
+ delete exec_;
+ delete device_;
+ }
+
+ // Resets executor_ with a new executor based on a graph 'gdef'.
+ void Create(std::unique_ptr<const Graph> graph) {
+ const int version = graph->versions().producer();
+ LocalExecutorParams params;
+ params.device = device_;
+ params.create_kernel = [this, version](const NodeDef& ndef,
+ OpKernel** kernel) {
+ return CreateNonCachedKernel(device_, nullptr, ndef, version, kernel);
+ };
+ params.delete_kernel = [](OpKernel* kernel) {
+ DeleteNonCachedKernel(kernel);
+ };
+ delete exec_;
+ TF_CHECK_OK(NewSingleThreadedExecutor(params, std::move(graph), &exec_));
+ runner_ = [](std::function<void()> fn) { fn(); };
+ rendez_ = NewLocalRendezvous();
+ }
+
+ Status Run(Rendezvous* rendez) {
+ Executor::Args args;
+ args.rendezvous = rendez;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Status Run(CallFrameInterface* call_frame) {
+ Executor::Args args;
+ args.call_frame = call_frame;
+ args.runner = runner_;
+ return exec_->Run(args);
+ }
+
+ Device* device_ = nullptr;
+ Executor* exec_ = nullptr;
+ Executor::Args::Runner runner_;
+ Rendezvous* rendez_ = nullptr;
+};
+
+// A float val -> Tensor<float>
+Tensor V(const float val) {
+ Tensor tensor(DT_FLOAT, TensorShape({}));
+ tensor.scalar<float>()() = val;
+ return tensor;
+}
+
+// A int32 val -> Tensor<int32>
+Tensor VI(const int32 val) {
+ Tensor tensor(DT_INT32, TensorShape({}));
+ tensor.scalar<int32>()() = val;
+ return tensor;
+}
+
+// A bool val -> Tensor<bool>
+Tensor VB(const bool val) {
+ Tensor tensor(DT_BOOL, TensorShape({}));
+ tensor.scalar<bool>()() = val;
+ return tensor;
+}
+
+// A double val -> Tensor<double>
+Tensor VD(const double val) {
+ Tensor tensor(DT_DOUBLE, TensorShape({}));
+ tensor.scalar<double>()() = val;
+ return tensor;
+}
+
+// Tensor<float> -> a float val.
+float V(const Tensor& tensor) {
+ CHECK_EQ(tensor.dtype(), DT_FLOAT);
+ CHECK(TensorShapeUtils::IsScalar(tensor.shape()));
+ return tensor.scalar<float>()();
+}
+
+Rendezvous::ParsedKey Key(const string& sender, const uint64 incarnation,
+ const string& receiver, const string& name) {
+ Rendezvous::ParsedKey result;
+ TF_CHECK_OK(
+ Rendezvous::ParseKey(Rendezvous::CreateKey(sender, incarnation, receiver,
+ name, FrameAndIter(0, 0)),
+ &result));
+ return result;
+}
+
+TEST_F(ExecutorTest, SimpleAdd) {
+ // c = a + b
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto in0 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto in1 = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ auto tmp = test::graph::Add(g.get(), in0, in1);
+ test::graph::Retval(g.get(), 0, tmp);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT, DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0), V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(2.0, V(retvals[0])); // out = 1.0 + 1.0 = 2.0
+}
+
+TEST_F(ExecutorTest, SelfAdd) {
+ // v0 <- a
+ // v1 = v0 + v0
+ // v2 = v1 + v1
+ // ... ...
+ // v10 = v9 + v9
+ //
+ // b <- v10
+ // All nodes are executed by one thread.
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto v = test::graph::Arg(g.get(), 0, DT_FLOAT);
+ const int N = 10;
+ for (int i = 1; i <= N; ++i) {
+ v = test::graph::Add(g.get(), v, v);
+ }
+ // out <- v10
+ test::graph::Retval(g.get(), 0, v);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ // a = 1.0
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(1024.0, V(retvals[0])); // b=v10=2*v9=4*v8=...=1024*a=1024.0
+}
+
+// Builds a graph which adds N copies of one variable "in". I.e.,
+// a + a + a + ... + a
+// The returned graph is parenthesized ramdonly. I.e.,
+// a + ((a + a) + a)
+// (a + a) + (a + a)
+// ((a + a) + a) + a
+// are all possibly generated.
+void BuildTree(int N, Graph* g) {
+ CHECK_GT(N, 1);
+ // A single input node "in".
+ auto in = test::graph::Arg(g, 0, DT_FLOAT);
+ std::vector<Node*> nodes;
+ int i = 0;
+ // Duplicate "in" N times. Each copies is named as l0, l1, l2, ....
+ for (; i < N; ++i) {
+ nodes.push_back(test::graph::Identity(g, in, 0));
+ }
+ random::PhiloxRandom philox(0, 17);
+ random::SimplePhilox rnd(&philox);
+ while (nodes.size() > 1) {
+ // Randomly pick two from nodes and add them. The resulting node
+ // is named lik n10, n11, .... and is put back into "nodes".
+ int x = rnd.Uniform(nodes.size());
+ auto in0 = nodes[x];
+ nodes[x] = nodes.back();
+ nodes.resize(nodes.size() - 1);
+ x = rnd.Uniform(nodes.size());
+ auto in1 = nodes[x];
+ // node = in0 + in1.
+ nodes[x] = test::graph::Add(g, in0, in1);
+ }
+ // The final output node "out".
+ test::graph::Retval(g, 0, nodes.back());
+ FixupSourceAndSinkEdges(g);
+}
+
+TEST_F(ExecutorTest, RandomTree) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ BuildTree(4096, g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({DT_FLOAT}, {DT_FLOAT});
+ TF_ASSERT_OK(call_frame.SetArgs({V(1.0)}));
+ TF_ASSERT_OK(Run(&call_frame));
+ std::vector<Tensor> retvals;
+ TF_ASSERT_OK(call_frame.ConsumeRetvals(&retvals, false));
+ EXPECT_EQ(4096.0, V(retvals[0]));
+}
+
+TEST_F(ExecutorTest, OpError) {
+ std::unique_ptr<Graph> g(new Graph(OpRegistry::Global()));
+ auto zero = test::graph::Constant(g.get(), V(0.0));
+ auto inf = test::graph::Unary(g.get(), "Reciprocal", zero);
+ auto check = test::graph::CheckNumerics(g.get(), inf, "message");
+ auto two = test::graph::Constant(g.get(), V(2.0));
+ test::graph::Binary(g.get(), "Mul", check, two);
+ FixupSourceAndSinkEdges(g.get());
+ Create(std::move(g));
+ FunctionCallFrame call_frame({}, {});
+ // Fails due to invalid dtype.
+ EXPECT_TRUE(errors::IsInvalidArgument(Run(&call_frame)));
+}
+
+static void BM_executor(int iters, int width, int depth) {
+#ifdef PLATFORM_GOOGLE
+ BenchmarkUseRealTime();
+#endif // PLATFORM_GOOGLE
+ Graph* g = new Graph(OpRegistry::Global());
+ random::PhiloxRandom philox(1729, 17);
+ random::SimplePhilox rand(&philox);
+ uint64 cur = 0;
+ uint32 r = 1 + rand.Rand32() % width;
+ std::vector<Node*> ready_nodes;
+ for (int i = 0; i < r; ++i) {
+ ready_nodes.push_back(test::graph::NoOp(g, {}));
+ ++cur;
+ }
+ for (int i = 0; i < depth; ++i) {
+ std::random_shuffle(ready_nodes.begin(), ready_nodes.end());
+ r = 1 + rand.Rand32() % (ready_nodes.size());
+ std::vector<Node*> control_inputs;
+ for (int j = 0; j < r; ++j) {
+ control_inputs.push_back(ready_nodes.back());
+ ready_nodes.pop_back();
+ }
+ Node* n = test::graph::NoOp(g, control_inputs);
+ ++cur;
+ r = 1 + rand.Rand32() % width;
+ for (int j = 0; j < r; ++j) {
+ ready_nodes.push_back(test::graph::NoOp(g, {n}));
+ ++cur;
+ }
+ }
+ FixupSourceAndSinkEdges(g);
+#ifdef PLATFORM_GOOGLE
+ SetBenchmarkLabel(strings::StrCat("Nodes = ", cur));
+ SetBenchmarkItemsProcessed(cur * static_cast<int64>(iters));
+#endif // PLATFORM_GOOGLE
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .Run(iters);
+}
+
+// Tall skinny graphs
+BENCHMARK(BM_executor)->ArgPair(16, 1024);
+BENCHMARK(BM_executor)->ArgPair(32, 8192);
+
+// Short fat graphs
+BENCHMARK(BM_executor)->ArgPair(1024, 16);
+BENCHMARK(BM_executor)->ArgPair(8192, 32);
+
+// Tall fat graph
+BENCHMARK(BM_executor)->ArgPair(1024, 1024);
+
+// TODO(mrry): This benchmark currently crashes with a use-after free, because
+// test::Benchmark::RunWithArgs() assumes that the executor will take ownership
+// of the given graph, *and* keep its nodes (`x`, `y` and `z`) alive for the
+// duration of the benchmark. Since the single threaded executor does not retain
+// a copy of the graph, this fails.
+//
+// TODO(mrry): Add support for Arg/Retval "function call convention" in
+// `test::Benchmark::RunWithArgs()`.
+#if 0
+#define ALICE "/job:j/replica:0/task:0/cpu:0"
+#define BOB "/job:j/replica:0/task:0/gpu:0"
+
+static void BM_FeedInputFetchOutput(int iters) {
+ Graph* g = new Graph(OpRegistry::Global());
+ // z = x + y: x and y are provided as benchmark inputs. z is the
+ // output of the benchmark. Conceptually, the caller is ALICE, the
+ // benchmark is BOB.
+ Node* x = test::graph::Recv(g, "x", "float", ALICE, 1, BOB);
+ Node* y = test::graph::Recv(g, "y", "float", ALICE, 1, BOB);
+ Node* sum = test::graph::Add(g, x, y);
+ Node* z = test::graph::Send(g, sum, "z", BOB, 1, ALICE);
+ FixupSourceAndSinkEdges(g);
+ Tensor val(DT_FLOAT, TensorShape({}));
+ val.scalar<float>()() = 3.14;
+ SetBenchmarkItemsProcessed(static_cast<int64>(iters));
+ test::Benchmark("cpu", g, nullptr, nullptr, nullptr,
+ "SINGLE_THREADED_EXECUTOR")
+ .RunWithArgs({{x, val}, {y, val}}, {z}, iters);
+}
+BENCHMARK(BM_FeedInputFetchOutput);
+#endif
+
+} // namespace
+} // namespace data
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/skip_dataset_op.cc b/tensorflow/core/kernels/data/skip_dataset_op.cc
index 61db6a0a54..b8c7fb15f4 100644
--- a/tensorflow/core/kernels/data/skip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/skip_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -38,10 +38,10 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
@@ -187,5 +187,5 @@ class SkipDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SkipDataset").Device(DEVICE_CPU), SkipDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/slide_dataset_op.cc b/tensorflow/core/kernels/data/slide_dataset_op.cc
index fd8c5ccd92..1e73cfc753 100644
--- a/tensorflow/core/kernels/data/slide_dataset_op.cc
+++ b/tensorflow/core/kernels/data/slide_dataset_op.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -63,11 +63,11 @@ class SlideDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 window_size, int64 window_shift,
int64 window_stride, const DatasetBase* input)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
window_size_(window_size),
window_shift_(window_shift),
window_stride_(window_stride),
@@ -293,5 +293,5 @@ REGISTER_KERNEL_BUILDER(Name("SlideDataset").Device(DEVICE_CPU),
SlideDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
index 9bb86e76a2..85b1e50695 100644
--- a/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/sparse_tensor_slice_dataset_op.cc
@@ -21,18 +21,18 @@ limitations under the License.
#include "tensorflow/core/util/sparse/sparse_tensor.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
template <typename T>
-class Dataset : public GraphDatasetBase {
+class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx,
const sparse::SparseTensor& sparse_tensor)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
sparse_tensor_(sparse_tensor),
dtypes_({DT_INT64, sparse_tensor.dtype(), DT_INT64}),
shapes_({{-1, sparse_tensor.dims() - 1},
@@ -274,5 +274,5 @@ TF_CALL_DATASET_TYPES(REGISTER_DATASET_KERNEL);
#undef REGISTER_DATASET_KERNEL
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.cc b/tensorflow/core/kernels/data/sql/driver_manager.cc
index ffabda1a8a..783d1e6cb2 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.cc
+++ b/tensorflow/core/kernels/data/sql/driver_manager.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/sql/sqlite_query_connection.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
@@ -30,5 +30,5 @@ std::unique_ptr<QueryConnection> DriverManager::CreateQueryConnection(
}
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/driver_manager.h b/tensorflow/core/kernels/data/sql/driver_manager.h
index a34691b5a2..c5428f396b 100644
--- a/tensorflow/core/kernels/data/sql/driver_manager.h
+++ b/tensorflow/core/kernels/data/sql/driver_manager.h
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/sql/query_connection.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
// A factory class for creating `QueryConnection` instances.
@@ -35,7 +35,7 @@ class DriverManager {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_DRIVER_MANAGER_H_
diff --git a/tensorflow/core/kernels/data/sql/query_connection.h b/tensorflow/core/kernels/data/sql/query_connection.h
index e9ffca202f..2fd229a9bf 100644
--- a/tensorflow/core/kernels/data/sql/query_connection.h
+++ b/tensorflow/core/kernels/data/sql/query_connection.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
namespace tensorflow {
+namespace data {
class IteratorContext;
@@ -63,7 +64,7 @@ class QueryConnection {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
index 7cd07bd8ec..5108e83976 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
SqliteQueryConnection::SqliteQueryConnection() {}
@@ -115,5 +115,5 @@ void SqliteQueryConnection::FillTensorWithResultSetEntry(
}
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
index 81b19530b7..175492c49d 100644
--- a/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
+++ b/tensorflow/core/kernels/data/sql/sqlite_query_connection.h
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
-
+namespace data {
namespace sql {
class SqliteQueryConnection : public QueryConnection {
@@ -50,7 +50,7 @@ class SqliteQueryConnection : public QueryConnection {
};
} // namespace sql
-
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_SQL_SQLITE_QUERY_CONNECTION_H_
diff --git a/tensorflow/core/kernels/data/sql_dataset_ops.cc b/tensorflow/core/kernels/data/sql_dataset_ops.cc
index 9b0190e3fc..6bbe459332 100644
--- a/tensorflow/core/kernels/data/sql_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/sql_dataset_ops.cc
@@ -24,8 +24,9 @@ limitations under the License.
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace tensorflow {
-
+namespace data {
namespace {
+
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following ops.
@@ -75,13 +76,13 @@ class SqlDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, const string& driver_name,
const string& data_source_name, const string& query,
const DataTypeVector& output_types,
const std::vector<PartialTensorShape>& output_shapes)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
driver_name_(driver_name),
data_source_name_(data_source_name),
query_(query),
@@ -211,5 +212,5 @@ class SqlDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("SqlDataset").Device(DEVICE_CPU), SqlDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 8465a1d2c0..c09a73fff1 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -12,15 +12,64 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <memory>
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/stats_aggregator.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
+class StatsAggregatorWithTagAndPrefix : public StatsAggregator {
+ public:
+ StatsAggregatorWithTagAndPrefix(
+ std::shared_ptr<StatsAggregator> stats_aggregator, const string& tag,
+ const string& prefix)
+ : wrapped_(stats_aggregator), tag_(tag), prefix_(prefix) {}
+
+ void AddToHistogram(const string& name,
+ gtl::ArraySlice<double> values) override {
+ if (!tag_.empty()) {
+ wrapped_->AddToHistogram(strings::StrCat(tag_, "_", name), values);
+ } else {
+ wrapped_->AddToHistogram(name, values);
+ }
+ }
+
+ void AddScalar(const string& name, float value) override {
+ if (!tag_.empty()) {
+ wrapped_->AddScalar(strings::StrCat(tag_, "_", name), value);
+ } else {
+ wrapped_->AddScalar(name, value);
+ }
+ }
+
+ void EncodeToProto(Summary* out_summary) override {
+ wrapped_->EncodeToProto(out_summary);
+ }
+
+ void IncrementCounter(const string& name, const string& label,
+ int64 val) override {
+ if (!prefix_.empty()) {
+ wrapped_->IncrementCounter(strings::StrCat(prefix_, "/", name), label,
+ val);
+ } else {
+ wrapped_->IncrementCounter(strings::StrCat("/tensorflow/", name), label,
+ val);
+ }
+ }
+
+ private:
+ std::shared_ptr<StatsAggregator> wrapped_;
+ string tag_;
+ string prefix_;
+ TF_DISALLOW_COPY_AND_ASSIGN(StatsAggregatorWithTagAndPrefix);
+};
+
class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
public:
explicit SetStatsAggregatorDatasetOp(OpKernelConstruction* ctx)
@@ -32,18 +81,28 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 1),
&stats_aggregator_resource));
core::ScopedUnref unref_stats_aggregator(stats_aggregator_resource);
+ string tag;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
+ string prefix;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "counter_prefix", &prefix));
- *output = new Dataset(ctx, input, stats_aggregator_resource);
+ *output = new Dataset(ctx, input, ctx->input(1), stats_aggregator_resource,
+ tag, prefix);
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input,
- StatsAggregatorResource* stats_aggregator_resource)
- : GraphDatasetBase(ctx),
+ const Tensor& resource_handle,
+ StatsAggregatorResource* stats_aggregator_resource,
+ const string& tag, const string& prefix)
+ : DatasetBase(DatasetContext(ctx)),
input_(input),
- stats_aggregator_resource_(stats_aggregator_resource) {
+ resource_handle_(resource_handle),
+ stats_aggregator_resource_(stats_aggregator_resource),
+ tag_(tag),
+ prefix_(prefix) {
input_->Ref();
stats_aggregator_resource_->Ref();
}
@@ -70,6 +129,24 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
return "SetStatsAggregatorDatasetOp::Dataset";
}
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_graph_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
+ Node* resource_handle_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddTensor(resource_handle_, &resource_handle_node));
+ Node* tag_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ Node* prefix_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(prefix_, &prefix_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(
+ this, {input_graph_node, resource_handle_node, tag_node, prefix_node},
+ output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
@@ -89,9 +166,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
IteratorContext::Params params;
params.env = ctx->env();
params.runner = *(ctx->runner());
- params.stats_aggregator_getter = [stats_aggregator_resource]() {
- return stats_aggregator_resource->stats_aggregator();
- };
+ params.stats_aggregator = std::shared_ptr<StatsAggregator>(
+ new StatsAggregatorWithTagAndPrefix(
+ stats_aggregator_resource->stats_aggregator(), dataset()->tag_,
+ dataset()->prefix_));
params.lib = ctx->lib();
params.function_library = ctx->function_library();
params.allocator_getter = ctx->allocator_getter();
@@ -102,16 +180,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
private:
@@ -120,11 +196,15 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
};
const DatasetBase* const input_;
+ const Tensor resource_handle_;
StatsAggregatorResource* stats_aggregator_resource_;
+ string tag_;
+ string prefix_;
};
};
REGISTER_KERNEL_BUILDER(Name("SetStatsAggregatorDataset").Device(DEVICE_CPU),
SetStatsAggregatorDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_aggregator_ops.cc b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
index b133cfab54..2d51467616 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_ops.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_ops.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+namespace data {
namespace {
static mutex* get_counters_map_lock() {
@@ -81,11 +82,12 @@ class StatsAggregatorImpl : public StatsAggregator {
auto counters_map = get_counters_map();
if (counters_map->find(name) == counters_map->end()) {
counters_map->emplace(
- name, monitoring::Counter<1>::New(
- /*streamz name*/ "/tensorflow/" + name,
- /*streamz description*/
- name + " generated or consumed by the component.",
- /*streamz label name*/ "component_descriptor"));
+ name,
+ monitoring::Counter<1>::New(
+ /*streamz name*/ name,
+ /*streamz description*/
+ strings::StrCat(name, " generated or consumed by the component."),
+ /*streamz label name*/ "component_descriptor"));
}
counters_map->at(name)->GetCell(label)->IncrementBy(val);
}
@@ -145,4 +147,5 @@ REGISTER_KERNEL_BUILDER(Name("StatsAggregatorSummary").Device(DEVICE_CPU),
StatsAggregatorSummaryOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/stats_dataset_ops.cc b/tensorflow/core/kernels/data/stats_dataset_ops.cc
index 85fed31773..e9e42f05a1 100644
--- a/tensorflow/core/kernels/data/stats_dataset_ops.cc
+++ b/tensorflow/core/kernels/data/stats_dataset_ops.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
+namespace data {
namespace {
// This op defines a `Dataset` that passes through its input elements and
@@ -49,10 +50,12 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ tag_(std::move(tag)) {
input_->Ref();
}
@@ -149,10 +152,12 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ tag_(std::move(tag)) {
input_->Ref();
}
@@ -238,204 +243,11 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
};
};
-class FeatureStatsDatasetOp : public UnaryDatasetOpKernel {
- public:
- explicit FeatureStatsDatasetOp(OpKernelConstruction* ctx)
- : UnaryDatasetOpKernel(ctx) {}
-
- void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
- DatasetBase** output) override {
- string tag;
- OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
- OP_REQUIRES(ctx, input->output_dtypes()[0] == DT_STRING,
- errors::InvalidArgument("FeatureStatsDataset only supports "
- "input with a single `tf.string` "
- "component."));
- *output = new Dataset(ctx, input, std::move(tag));
- }
-
- private:
- class Dataset : public GraphDatasetBase {
- public:
- explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
- : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
- input_->Ref();
- }
-
- ~Dataset() override { input_->Unref(); }
-
- std::unique_ptr<IteratorBase> MakeIteratorInternal(
- const string& prefix) const override {
- return std::unique_ptr<IteratorBase>(new Iterator(
- {this, strings::StrCat(prefix, "::FeatureStatsDataset")}));
- }
-
- const DataTypeVector& output_dtypes() const override {
- return input_->output_dtypes();
- }
- const std::vector<PartialTensorShape>& output_shapes() const override {
- return input_->output_shapes();
- }
-
- string DebugString() const override {
- return "FeatureStatsDatasetOp::Dataset";
- }
-
- protected:
- Status AsGraphDefInternal(SerializationContext* ctx,
- DatasetGraphDefBuilder* b,
- Node** output) const override {
- Node* input_node;
- TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_node));
- Node* tag_node;
- TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
- TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
- return Status::OK();
- }
-
- private:
- class Iterator : public DatasetIterator<Dataset> {
- public:
- explicit Iterator(const Params& params)
- : DatasetIterator<Dataset>(params) {}
-
- Status Initialize(IteratorContext* ctx) override {
- return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
- }
-
- Status GetNextInternal(IteratorContext* ctx,
- std::vector<Tensor>* out_tensors,
- bool* end_of_sequence) override {
- tf_shared_lock l(mu_);
- Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
- auto stats_aggregator = ctx->stats_aggregator();
- if (stats_aggregator && s.ok() && !*end_of_sequence) {
- for (const Tensor& t : *out_tensors) {
- auto record_t = t.flat<string>();
- Example example;
- // TODO(b/111553342): redundant parsing here, potential solutions
- // to improve performance is to a) have a potential
- // ParseExampleDataset and collect stats from there and b) make
- // changes to parse_example() where it returns stats as well.
- for (int i = 0; i < record_t.size(); ++i) {
- if (example.ParseFromString(record_t(i))) {
- stats_aggregator->IncrementCounter("examples_count", "trainer",
- 1);
- AddStatsFeatures(example, stats_aggregator);
- } else {
- SequenceExample sequence_example;
- if (sequence_example.ParseFromString(record_t(i))) {
- stats_aggregator->IncrementCounter("sequence_examples_count",
- "trainer", 1);
- AddStatsFeatures(sequence_example, stats_aggregator);
- }
- }
- }
- }
- }
- return s;
- }
-
- int AddStatsFeatureValues(const Feature& feature) {
- int feature_values_list_size = 0;
- switch (feature.kind_case()) {
- case Feature::kBytesList: {
- feature_values_list_size = feature.bytes_list().value().size();
- break;
- }
- case Feature::kFloatList: {
- feature_values_list_size = feature.float_list().value().size();
- break;
- }
- case Feature::kInt64List: {
- feature_values_list_size = feature.int64_list().value().size();
- break;
- }
- case Feature::KIND_NOT_SET:
- break;
- }
- return feature_values_list_size;
- }
-
- void AddStatsFeatures(
- const Example& example,
- const std::shared_ptr<StatsAggregator>& stats_aggregator) {
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":features"),
- {static_cast<double>(example.features().feature().size())});
-
- int feature_values_list_size_sum = 0;
- for (const auto& feature : example.features().feature()) {
- stats_aggregator->IncrementCounter("features_count", "trainer", 1);
- feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
- }
- stats_aggregator->IncrementCounter("feature_values_count", "trainer",
- feature_values_list_size_sum);
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":feature-values"),
- {static_cast<double>(feature_values_list_size_sum)});
- }
-
- void AddStatsFeatures(
- const SequenceExample& example,
- const std::shared_ptr<StatsAggregator>& stats_aggregator) {
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":features"),
- {static_cast<double>(
- example.context().feature().size() +
- example.feature_lists().feature_list().size())});
-
- int feature_values_list_size_sum = 0;
- for (const auto& feature : example.context().feature()) {
- stats_aggregator->IncrementCounter("features_count", "trainer", 1);
- feature_values_list_size_sum += AddStatsFeatureValues(feature.second);
- }
-
- for (const auto& feature_list :
- example.feature_lists().feature_list()) {
- stats_aggregator->IncrementCounter("feature_lists_count", "trainer",
- 1);
- for (const auto& feature : feature_list.second.feature()) {
- feature_values_list_size_sum += AddStatsFeatureValues(feature);
- }
- }
- stats_aggregator->IncrementCounter("feature_values_count", "trainer",
- feature_values_list_size_sum);
- stats_aggregator->AddToHistogram(
- strings::StrCat(dataset()->tag_, ":feature-values"),
- {static_cast<double>(feature_values_list_size_sum)});
- }
-
- protected:
- Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
- }
-
- Status RestoreInternal(IteratorContext* ctx,
- IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
- }
-
- private:
- mutex mu_;
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
- };
-
- const DatasetBase* const input_;
- const string tag_;
- };
-};
-
-REGISTER_KERNEL_BUILDER(Name("FeatureStatsDataset").Device(DEVICE_CPU),
- FeatureStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("LatencyStatsDataset").Device(DEVICE_CPU),
LatencyStatsDatasetOp);
REGISTER_KERNEL_BUILDER(Name("BytesProducedStatsDataset").Device(DEVICE_CPU),
BytesProducedStatsDatasetOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/take_dataset_op.cc b/tensorflow/core/kernels/data/take_dataset_op.cc
index d4a3c7a978..e5cdfdd732 100644
--- a/tensorflow/core/kernels/data/take_dataset_op.cc
+++ b/tensorflow/core/kernels/data/take_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -38,10 +38,10 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, int64 count, const DatasetBase* input)
- : GraphDatasetBase(ctx), count_(count), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), count_(count), input_(input) {
input_->Ref();
}
@@ -174,5 +174,5 @@ class TakeDatasetOp : public UnaryDatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("TakeDataset").Device(DEVICE_CPU), TakeDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_dataset_op.cc b/tensorflow/core/kernels/data/tensor_dataset_op.cc
index ac2015c865..ca4ea25b89 100644
--- a/tensorflow/core/kernels/data/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_dataset_op.cc
@@ -14,10 +14,11 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -28,25 +29,19 @@ class TensorDatasetOp : public DatasetOpKernel {
explicit TensorDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
// TODO(mrry): Validate that the shapes of the "components" tensors match
// the "shapes" attr.;
- std::vector<Tensor> components;
- components.reserve(inputs.size());
- for (const Tensor& t : inputs) {
- components.push_back(t);
- }
+ std::vector<Tensor> components(inputs.begin(), inputs.end());
*output = new Dataset(ctx, std::move(components));
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
- : GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
+ : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
shapes_.emplace_back(t.shape().dim_sizes());
@@ -74,7 +69,13 @@ class TensorDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
@@ -135,5 +136,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorDataset").Device(DEVICE_CPU),
TensorDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
index ea472e2b79..2ed636a400 100644
--- a/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_queue_dataset_op.cc
@@ -24,7 +24,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
bool IsGreaterEqualToOrCompatibleWith(const PartialTensorShape& a,
@@ -61,14 +61,14 @@ std::vector<PartialTensorShape> PrependQueueShapeWithBatch(
class EnqueueInQueueDatasetOp;
-class PrependFromQueueAndPaddedBatchDataset : public GraphDatasetBase {
+class PrependFromQueueAndPaddedBatchDataset : public DatasetBase {
public:
PrependFromQueueAndPaddedBatchDataset(
OpKernelContext* ctx, const int64 batch_size, const DatasetBase* input,
const DataTypeVector& dtypes,
const std::vector<PartialTensorShape>& shapes,
std::vector<Tensor> padding_values)
- : GraphDatasetBase(ctx),
+ : DatasetBase(DatasetContext(ctx)),
batch_size_(batch_size),
input_(input),
dtypes_(dtypes),
@@ -648,5 +648,5 @@ REGISTER_KERNEL_BUILDER(Name("EnqueueInQueueDataset").Device(DEVICE_CPU),
EnqueueInQueueDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
index 8f18d38f83..7dc64b0a75 100644
--- a/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/data/tensor_slice_dataset_op.cc
@@ -14,11 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -30,8 +31,6 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
: DatasetOpKernel(ctx) {}
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
- // Create a new TensorDatasetOp::Dataset, insert it in the step
- // container, and return it as the output.
OpInputList inputs;
OP_REQUIRES_OK(ctx, ctx->input_list("components", &inputs));
std::vector<Tensor> components;
@@ -54,10 +53,10 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, std::vector<Tensor> tensors)
- : GraphDatasetBase(ctx), tensors_(std::move(tensors)) {
+ : DatasetBase(DatasetContext(ctx)), tensors_(std::move(tensors)) {
for (const Tensor& t : tensors_) {
dtypes_.push_back(t.dtype());
gtl::InlinedVector<int64, 4> partial_dim_sizes;
@@ -93,7 +92,13 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
- TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ std::vector<std::pair<string, Tensor>>* input_list = ctx->input_list();
+ if (input_list) {
+ TF_RETURN_IF_ERROR(b->AddPlaceholder(t, &node));
+ input_list->emplace_back(node->name(), t);
+ } else {
+ TF_RETURN_IF_ERROR(b->AddTensor(t, &node));
+ }
components.emplace_back(node);
}
AttrValue dtypes;
@@ -163,5 +168,5 @@ REGISTER_KERNEL_BUILDER(Name("TensorSliceDataset").Device(DEVICE_CPU),
TensorSliceDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/unbatch_dataset_op.cc b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
index 02c3c5315a..74908994b4 100644
--- a/tensorflow/core/kernels/data/unbatch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/unbatch_dataset_op.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include "tensorflow/core/util/batch_util.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -35,17 +35,22 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx, DatasetBase* input)
- : GraphDatasetBase(ctx), input_(input) {
+ : DatasetBase(DatasetContext(ctx)), input_(input) {
input_->Ref();
for (const PartialTensorShape& shape : input->output_shapes()) {
- gtl::InlinedVector<int64, 4> partial_dim_sizes;
- for (int i = 1; i < shape.dims(); ++i) {
- partial_dim_sizes.push_back(shape.dim_size(i));
+ if (!shape.unknown_rank()) {
+ gtl::InlinedVector<int64, 4> partial_dim_sizes;
+ for (int i = 1; i < shape.dims(); ++i) {
+ partial_dim_sizes.push_back(shape.dim_size(i));
+ }
+ shapes_.emplace_back(std::move(partial_dim_sizes));
+ } else {
+ // If the input shape is unknown, the output shape will be unknown.
+ shapes_.emplace_back();
}
- shapes_.emplace_back(std::move(partial_dim_sizes));
}
}
@@ -204,5 +209,5 @@ REGISTER_KERNEL_BUILDER(Name("UnbatchDataset").Device(DEVICE_CPU),
UnbatchDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.cc b/tensorflow/core/kernels/data/window_dataset.cc
index 17551bccd9..2ad4711aab 100644
--- a/tensorflow/core/kernels/data/window_dataset.cc
+++ b/tensorflow/core/kernels/data/window_dataset.cc
@@ -13,17 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/window_dataset.h"
+#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
+namespace data {
namespace {
-// TODO(b/110981596): Support checkpointing.
class WindowDataset : public DatasetBase {
public:
WindowDataset(std::vector<std::vector<Tensor>> elements,
DataTypeVector output_types,
std::vector<PartialTensorShape> output_shapes)
- : elements_(std::move(elements)),
+ : DatasetBase(DatasetContext({"Window"})),
+ elements_(std::move(elements)),
output_types_(std::move(output_types)),
output_shapes_(std::move(output_shapes)) {}
@@ -41,6 +43,15 @@ class WindowDataset : public DatasetBase {
string DebugString() const override { return "WindowDataset"; }
+ protected:
+ // TODO(b/110981596): Support checkpointing.
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ return errors::Unimplemented("%s does not support serialization",
+ DebugString());
+ }
+
private:
class Iterator : public DatasetIterator<WindowDataset> {
public:
@@ -97,4 +108,5 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
return Status::OK();
}
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/window_dataset.h b/tensorflow/core/kernels/data/window_dataset.h
index 7bd31a0bc7..84cb3c7860 100644
--- a/tensorflow/core/kernels/data/window_dataset.h
+++ b/tensorflow/core/kernels/data/window_dataset.h
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
+namespace data {
// Creates a dataset representing an eagerly-collected window of elements.
//
@@ -43,6 +44,7 @@ Status NewWindowDataset(std::vector<std::vector<Tensor>> elements,
std::vector<PartialTensorShape> output_shapes,
DatasetBase** out_dataset);
+} // namespace data
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_DATA_WINDOW_DATASET_H_
diff --git a/tensorflow/core/kernels/data/window_dataset_op.cc b/tensorflow/core/kernels/data/window_dataset_op.cc
index f9fd5b5a83..ac44623ce2 100644
--- a/tensorflow/core/kernels/data/window_dataset_op.cc
+++ b/tensorflow/core/kernels/data/window_dataset_op.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/window_dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -33,20 +33,44 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
void MakeDataset(OpKernelContext* ctx, DatasetBase* input,
DatasetBase** output) override {
int64 window_size = 0;
- OP_REQUIRES_OK(
- ctx, ParseScalarArgument<int64>(ctx, "window_size", &window_size));
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "size", &window_size));
OP_REQUIRES(
ctx, window_size > 0,
errors::InvalidArgument("Window size must be greater than zero."));
- *output = new Dataset(ctx, window_size, input);
+ int64 window_shift = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "shift", &window_shift));
+ OP_REQUIRES(
+ ctx, window_shift > 0,
+ errors::InvalidArgument("Window shift must be greater than zero."));
+
+ int64 window_stride = 0;
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int64>(ctx, "stride", &window_stride));
+ OP_REQUIRES(
+ ctx, window_stride > 0,
+ errors::InvalidArgument("Window stride must be greater than zero."));
+
+ bool drop_remainder;
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<bool>(ctx, "drop_remainder", &drop_remainder));
+
+ *output = new Dataset(ctx, input, window_size, window_shift, window_stride,
+ drop_remainder);
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
- Dataset(OpKernelContext* ctx, int64 window_size, const DatasetBase* input)
- : GraphDatasetBase(ctx), window_size_(window_size), input_(input) {
+ Dataset(OpKernelContext* ctx, const DatasetBase* input, int64 window_size,
+ int64 window_shift, int64 window_stride, bool drop_remainder)
+ : DatasetBase(DatasetContext(ctx)),
+ input_(input),
+ window_size_(window_size),
+ window_shift_(window_shift),
+ window_stride_(window_stride),
+ drop_remainder_(drop_remainder) {
input_->Ref();
}
@@ -70,7 +94,8 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
}
string DebugString() const override {
- return strings::StrCat("WindowDatasetOp(", window_size_, ")::Dataset");
+ return strings::StrCat("WindowDatasetOp(", window_size_, window_shift_,
+ window_stride_, drop_remainder_, ")::Dataset");
}
protected:
@@ -79,10 +104,19 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Node** output) const override {
Node* input_graph_node = nullptr;
TF_RETURN_IF_ERROR(b->AddInputDataset(ctx, input_, &input_graph_node));
- Node* window_size = nullptr;
- TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size));
+ Node* window_size_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_size_, &window_size_node));
+ Node* window_shift_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_shift_, &window_shift_node));
+ Node* window_stride_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(window_stride_, &window_stride_node));
+ Node* drop_remainder_node = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(drop_remainder_, &drop_remainder_node));
TF_RETURN_IF_ERROR(
- b->AddDataset(this, {input_graph_node, window_size}, output));
+ b->AddDataset(this,
+ {input_graph_node, window_size_node, window_shift_node,
+ window_stride_node, drop_remainder_node},
+ output));
return Status::OK();
}
@@ -99,37 +133,79 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- // Each row of `window_elements` is a tuple of tensors from the
- // input iterator.
+ const int64 window_size = dataset()->window_size_;
+ const int64 window_shift = dataset()->window_shift_;
+ const int64 window_stride = dataset()->window_stride_;
std::vector<std::vector<Tensor>> window_elements;
+ Status status = Status::OK();
{
mutex_lock l(mu_);
- if (!input_impl_) {
+ if (!input_impl_ && buffer_.empty()) {
*end_of_sequence = true;
return Status::OK();
}
- window_elements.reserve(dataset()->window_size_);
- *end_of_sequence = false;
- for (int i = 0; i < dataset()->window_size_ && !*end_of_sequence;
- ++i) {
- std::vector<Tensor> window_element_tuple;
- TF_RETURN_IF_ERROR(input_impl_->GetNext(ctx, &window_element_tuple,
- end_of_sequence));
- if (!*end_of_sequence) {
- window_elements.emplace_back(std::move(window_element_tuple));
- } else {
- input_impl_.reset();
+
+ // Add elements to the buffer.
+ size_t target_size = TargetBufferSize(window_size, window_stride);
+ if (input_impl_) {
+ *end_of_sequence = false;
+ for (size_t i = buffer_.size();
+ i < target_size && !*end_of_sequence; ++i) {
+ std::vector<Tensor> element;
+ Status status =
+ input_impl_->GetNext(ctx, &element, end_of_sequence);
+ if (!*end_of_sequence) {
+ buffer_.emplace_back(std::move(element), status);
+ } else {
+ input_impl_.reset();
+ }
+ }
+ }
+
+ // If there are not enough elements and `drop_remainder` is set, we do
+ // not wish to return a smaller window.
+ if (buffer_.empty() ||
+ (dataset()->drop_remainder_ && buffer_.size() < target_size)) {
+ DCHECK(*end_of_sequence);
+ return Status::OK();
+ }
+
+ int num_elements = 1 + (buffer_.size() - 1) / window_stride;
+ window_elements.reserve(num_elements);
+ for (size_t i = 0; i < num_elements; ++i) {
+ status.Update(buffer_[window_stride * i].status);
+ if (!status.ok()) {
+ break;
+ }
+ window_elements.emplace_back(buffer_[window_stride * i].result);
+ }
+
+ // Shift the window, discarding elements if necessary.
+ int buffer_size = buffer_.size();
+ if (window_shift >= buffer_size) {
+ for (size_t i = buffer_size; input_impl_ && i < window_shift; ++i) {
+ bool end_of_input;
+ std::vector<Tensor> element;
+ // Ignore non-error status of discarded elements.
+ input_impl_->GetNext(ctx, &element, &end_of_input).IgnoreError();
+ if (end_of_input) {
+ input_impl_.reset();
+ }
}
+ buffer_.clear();
+ } else {
+ buffer_.erase(buffer_.begin(), buffer_.begin() + window_shift);
}
}
- if (window_elements.empty()) {
- DCHECK(*end_of_sequence);
- return Status::OK();
+ if (!status.ok()) {
+ return status;
}
+ // Construct output tensors.
const size_t num_tuple_components = window_elements[0].size();
const int64 num_window_elements = window_elements.size();
+ *end_of_sequence = false;
for (size_t idx = 0; idx < num_tuple_components; ++idx) {
DatasetBase* window_dataset;
std::vector<std::vector<Tensor>> window_component_elements;
@@ -152,7 +228,6 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(StoreDatasetInVariantTensor(window_dataset,
&out_tensors->back()));
}
- *end_of_sequence = false;
return Status::OK();
}
@@ -165,6 +240,20 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
}
+ // Save buffer.
+ TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
+ buffer_.size()));
+ for (int64 i = 0; i < buffer_.size(); i++) {
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, buffer_[i].status));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(strings::StrCat("buffer[", i, "].size"),
+ buffer_[i].result.size()));
+ for (int64 j = 0; j < buffer_[i].result.size(); j++) {
+ TF_RETURN_IF_ERROR(
+ writer->WriteTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
@@ -176,22 +265,92 @@ class WindowDatasetOp : public UnaryDatasetOpKernel {
} else {
input_impl_.reset();
}
+ // Restore buffer.
+ int64 buffer_size;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(strings::StrCat("buffer_size"), &buffer_size));
+ buffer_.resize(buffer_size);
+ for (int64 i = 0; i < buffer_size; i++) {
+ int64 vector_size;
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &buffer_[i].status));
+ TF_RETURN_IF_ERROR(reader->ReadScalar(
+ strings::StrCat("buffer[", i, "].size"), &vector_size));
+ buffer_[i].result.resize(vector_size);
+ for (int64 j = 0; j < vector_size; j++) {
+ TF_RETURN_IF_ERROR(
+ reader->ReadTensor(strings::StrCat("buffer[", i, "][", j, "]"),
+ &buffer_[i].result[j]));
+ }
+ }
return Status::OK();
}
private:
+ struct InvocationResult {
+ InvocationResult() = default;
+ InvocationResult(std::vector<Tensor>&& result, const Status& status)
+ : result(result), status(status) {}
+
+ std::vector<Tensor> result;
+ Status status;
+ };
+
+ Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(
+ CodeKey(index), static_cast<int64>(status.code())));
+ if (!status.ok()) {
+ TF_RETURN_IF_ERROR(writer->WriteScalar(ErrorMessageKey(index),
+ status.error_message()));
+ }
+ return Status::OK();
+ }
+
+ Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ int64 code_int;
+ TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
+ error::Code code = static_cast<error::Code>(code_int);
+
+ if (code != error::Code::OK) {
+ string error_message;
+ TF_RETURN_IF_ERROR(
+ reader->ReadScalar(ErrorMessageKey(index), &error_message));
+ *status = Status(code, error_message);
+ } else {
+ *status = Status::OK();
+ }
+ return Status::OK();
+ }
+
+ string CodeKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].code"));
+ }
+
+ string ErrorMessageKey(size_t index) {
+ return full_name(strings::StrCat("buffer[", index, "].error_message"));
+ }
+
+ size_t TargetBufferSize(int64 window_size, int64 window_stride) {
+ return (window_size - 1) * window_stride + 1;
+ }
+
mutex mu_;
+ std::deque<InvocationResult> buffer_ GUARDED_BY(mu_);
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
- const int64 window_size_;
const DatasetBase* const input_;
+ const int64 window_size_;
+ const int64 window_shift_;
+ const int64 window_stride_;
+ const bool drop_remainder_;
};
};
REGISTER_KERNEL_BUILDER(Name("WindowDataset").Device(DEVICE_CPU),
WindowDatasetOp);
-
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/writer_ops.cc b/tensorflow/core/kernels/data/writer_ops.cc
index 80d9a5b867..3f76695bb1 100644
--- a/tensorflow/core/kernels/data/writer_ops.cc
+++ b/tensorflow/core/kernels/data/writer_ops.cc
@@ -22,7 +22,7 @@ limitations under the License.
#include "tensorflow/core/platform/file_system.h"
namespace tensorflow {
-
+namespace data {
namespace {
class ToTFRecordOp : public AsyncOpKernel {
@@ -70,20 +70,21 @@ class ToTFRecordOp : public AsyncOpKernel {
DatasetBase* dataset;
OP_REQUIRES_OK_ASYNC(
ctx, GetDatasetFromVariantTensor(ctx->input(0), &dataset), done);
- IteratorContext iter_ctx = dataset::MakeIteratorContext(ctx);
std::unique_ptr<IteratorBase> iterator;
OP_REQUIRES_OK_ASYNC(
ctx,
- dataset->MakeIterator(&iter_ctx, "ToTFRecordOpIterator", &iterator),
+ dataset->MakeIterator(IteratorContext(ctx), "ToTFRecordOpIterator",
+ &iterator),
done);
std::vector<Tensor> components;
components.reserve(dataset->output_dtypes().size());
bool end_of_sequence;
do {
- OP_REQUIRES_OK_ASYNC(
- ctx, iterator->GetNext(&iter_ctx, &components, &end_of_sequence),
- done);
+ OP_REQUIRES_OK_ASYNC(ctx,
+ iterator->GetNext(IteratorContext(ctx),
+ &components, &end_of_sequence),
+ done);
if (!end_of_sequence) {
OP_REQUIRES_OK_ASYNC(
@@ -103,4 +104,5 @@ REGISTER_KERNEL_BUILDER(Name("DatasetToTFRecord").Device(DEVICE_CPU),
ToTFRecordOp);
} // namespace
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/zip_dataset_op.cc b/tensorflow/core/kernels/data/zip_dataset_op.cc
index 63e9b99d4b..61a2078f46 100644
--- a/tensorflow/core/kernels/data/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/data/zip_dataset_op.cc
@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset.h"
namespace tensorflow {
-
+namespace data {
namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
@@ -38,11 +38,11 @@ class ZipDatasetOp : public DatasetOpKernel {
}
private:
- class Dataset : public GraphDatasetBase {
+ class Dataset : public DatasetBase {
public:
explicit Dataset(OpKernelContext* ctx,
const std::vector<DatasetBase*>& inputs)
- : GraphDatasetBase(ctx), inputs_(inputs) {
+ : DatasetBase(DatasetContext(ctx)), inputs_(inputs) {
for (const auto& input : inputs_) {
input->Ref();
for (DataType dt : input->output_dtypes()) {
@@ -175,5 +175,5 @@ class ZipDatasetOp : public DatasetOpKernel {
REGISTER_KERNEL_BUILDER(Name("ZipDataset").Device(DEVICE_CPU), ZipDatasetOp);
} // namespace
-
+} // namespace data
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data_format_ops.h b/tensorflow/core/kernels/data_format_ops.h
index 1ca144cb40..bc416fa78b 100644
--- a/tensorflow/core/kernels/data_format_ops.h
+++ b/tensorflow/core/kernels/data_format_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
-#define TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_
// Functor definition for data format dim mapping ops, must be compilable
// by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -83,4 +83,4 @@ struct DataFormatVecPermute {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DATA_FORMAT_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_FORMAT_OPS_H_
diff --git a/tensorflow/core/kernels/debug_ops.h b/tensorflow/core/kernels/debug_ops.h
index 53a23b1306..d705e82b0d 100644
--- a/tensorflow/core/kernels/debug_ops.h
+++ b/tensorflow/core/kernels/debug_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DEBUG_OP_H_
-#define TENSORFLOW_KERNELS_DEBUG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
#if GOOGLE_CUDA
#include "tensorflow/core/common_runtime/gpu/gpu_util.h"
@@ -177,8 +177,10 @@ class BaseDebugOp : public OpKernel {
// Publish a tensor to all debug URLs of the debug op.
// Log an error if the publishing failed.
- void PublishTensor(const Tensor& tensor) {
- if (!debug_urls_.empty()) {
+ Status PublishTensor(const Tensor& tensor) {
+ if (debug_urls_.empty()) {
+ return Status::OK();
+ } else {
Status status = DebugIO::PublishDebugTensor(*debug_watch_key_, tensor,
Env::Default()->NowMicros(),
debug_urls_, gated_grpc_);
@@ -189,6 +191,7 @@ class BaseDebugOp : public OpKernel {
<< str_util::Join(debug_urls_, ", ")
<< ", due to: " << status.error_message();
}
+ return status;
}
}
@@ -213,7 +216,7 @@ class DebugIdentityOp : public BaseDebugOp {
return;
}
- PublishTensor(context->input(0));
+ OP_REQUIRES_OK(context, PublishTensor(context->input(0)));
context->set_output(0, context->input(0));
}
};
@@ -252,7 +255,7 @@ class DebugNanCountOp : public BaseDebugOp {
TensorShape shape({1});
OP_REQUIRES_OK(context, context->allocate_output(0, shape, &output_tensor));
output_tensor->vec<int64>()(0) = nan_count;
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
};
@@ -377,7 +380,7 @@ class DebugNumericSummaryOp : public BaseDebugOp {
bool mute = mute_if_healthy_ && nan_count == 0 && negative_inf_count == 0 &&
positive_inf_count == 0;
if (!mute) {
- PublishTensor(*output_tensor);
+ OP_REQUIRES_OK(context, PublishTensor(*output_tensor));
}
}
@@ -389,4 +392,4 @@ class DebugNumericSummaryOp : public BaseDebugOp {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DEBUG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_DEBUG_OPS_H_
diff --git a/tensorflow/core/kernels/decode_bmp_op.cc b/tensorflow/core/kernels/decode_bmp_op.cc
index b4dcf0a74b..ae451be7e2 100644
--- a/tensorflow/core/kernels/decode_bmp_op.cc
+++ b/tensorflow/core/kernels/decode_bmp_op.cc
@@ -91,8 +91,10 @@ class DecodeBmpOp : public OpKernel {
errors::InvalidArgument(
"Number of channels must be 1, 3 or 4, was ", channels_));
- OP_REQUIRES(context, width > 0 && header_size >= 0,
+ OP_REQUIRES(context, width > 0,
errors::InvalidArgument("Width must be positive"));
+ OP_REQUIRES(context, height != 0,
+ errors::InvalidArgument("Height must be nonzero"));
OP_REQUIRES(context, header_size >= 0,
errors::InvalidArgument("header size must be nonnegative"));
@@ -108,8 +110,7 @@ class DecodeBmpOp : public OpKernel {
const int32 abs_height = abs(height);
// there may be padding bytes when the width is not a multiple of 4 bytes
- // 8 * channels == bits per pixel
- const int row_size = (8 * channels_ * width + 31) / 32 * 4;
+ const int row_size = (channels_ * width + 3) / 4 * 4;
const int64 last_pixel_offset = static_cast<int64>(header_size) +
(abs_height - 1) * row_size +
diff --git a/tensorflow/core/kernels/decode_csv_op.cc b/tensorflow/core/kernels/decode_csv_op.cc
index 3eed847c16..6bfb5bd5bc 100644
--- a/tensorflow/core/kernels/decode_csv_op.cc
+++ b/tensorflow/core/kernels/decode_csv_op.cc
@@ -61,6 +61,9 @@ class DecodeCSVOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input_list("record_defaults", &record_defaults));
for (int i = 0; i < record_defaults.size(); ++i) {
+ OP_REQUIRES(ctx, record_defaults[i].dims() <= 1,
+ errors::InvalidArgument(
+ "Each record default should be at most rank 1"));
OP_REQUIRES(ctx, record_defaults[i].NumElements() < 2,
errors::InvalidArgument(
"There should only be 1 default per field but field ", i,
diff --git a/tensorflow/core/kernels/dense_update_functor.h b/tensorflow/core/kernels/dense_update_functor.h
index 240c13261e..61b5731250 100644
--- a/tensorflow/core/kernels/dense_update_functor.h
+++ b/tensorflow/core/kernels/dense_update_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_
#define EIGEN_USE_THREADS
@@ -105,4 +105,4 @@ Status VariantCopyFn<GPUDevice>(OpKernelContext* context, const Tensor& from,
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_DENSE_UPDATE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_DENSE_UPDATE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 2a25459194..76afd6f18c 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -17,7 +17,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/util_ptx.cuh"
+#include "third_party/cub/util_ptx.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 42fbf95cd3..28940e0849 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -96,8 +96,6 @@ class DequantizeOp : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and
- // QuantizeAndDequantizeV3 to match this SCALED mode again.
const float scale_factor =
std::numeric_limits<T>::min() == 0
? (max_range / std::numeric_limits<T>::max())
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 862a97723f..e7882acc80 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -35,10 +35,10 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "external/cub_archive/cub/device/device_radix_sort.cuh"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/iterator/constant_input_iterator.cuh"
-#include "external/cub_archive/cub/thread/thread_operators.cuh"
+#include "third_party/cub/device/device_radix_sort.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/iterator/constant_input_iterator.cuh"
+#include "third_party/cub/thread/thread_operators.cuh"
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index b01db91720..fb2a4cc8ef 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -247,8 +247,8 @@ class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- T* merged_base = &merged_flat(0, 0);
- const T* data_base = &data_flat(0, 0);
+ T* merged_base = merged_flat.data();
+ const T* data_base = data_flat.data();
for (int i = 0; i < indices_vec.size(); i++) {
int32 index = internal::SubtleMustCopy(indices_vec(i));
OP_REQUIRES(
diff --git a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
index e13e548f86..8edf7d4a2c 100644
--- a/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h
@@ -51,48 +51,55 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
- const TensorReshapingOp<
+ IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
- const TensorReverseOp<const array<bool, 5>, const Kernel> >,
+ 2>,
+ const TensorShufflingOp<
+ const array<
+ typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<const Eigen::array<bool, 5>,
+ const Kernel>>>>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
+ 2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > >,
+ const OutputBackward>>>>,
TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
internal::traits<OutputBackward>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<OutputBackward>::Index>, 2>,
+ IndexPair<typename internal::traits<OutputBackward>::Index>, 1>,
const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
+ 2>,
const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
- const TensorReshapingOp<
+ const OutputBackward>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<OutputBackward>::Index,
- 3>,
- const TensorReverseOp<const array<bool, 5>,
- const Kernel> > > > >::type
+ 2>,
+ const TensorShufflingOp<
+ const array<
+ typename internal::traits<OutputBackward>::Index, 5>,
+ const TensorReverseOp<const Eigen::array<bool, 5>,
+ const Kernel>>>>>>>::type
CuboidConvolutionBackwardInput(
const Kernel& kernel, const OutputBackward& output_backward,
typename internal::traits<OutputBackward>::Index inputPlanes,
typename internal::traits<OutputBackward>::Index inputRows,
typename internal::traits<OutputBackward>::Index inputCols,
- const DenseIndex stridePlanes = 1, const DenseIndex strideRows = 1,
- const DenseIndex strideCols = 1) {
+ const DenseIndex plane_stride = 1, const DenseIndex row_stride = 1,
+ const DenseIndex col_stride = 1) {
typedef typename internal::traits<OutputBackward>::Index TensorIndex;
const TensorRef<const Tensor<typename internal::traits<Kernel>::Scalar,
internal::traits<Kernel>::NumDimensions,
- internal::traits<Kernel>::Layout, TensorIndex> >
+ internal::traits<Kernel>::Layout, TensorIndex>>
kern(kernel);
const TensorRef<
const Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Kernel>::Layout ==
@@ -125,58 +132,45 @@ CuboidConvolutionBackwardInput(
const TensorIndex outputCols =
isColMajor ? out.dimensions()[3] : out.dimensions()[NumDims - 4];
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
- } else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
- }
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ // Computing the forward padding.
+ const TensorIndex forward_pad_top_z = numext::maxi<Index>(
+ 0,
+ ((outputPlanes - 1) * plane_stride + kernelPlanesEff - inputPlanes) / 2);
+ const TensorIndex forward_pad_top = numext::maxi<Index>(
+ 0, ((outputRows - 1) * row_stride + kernelRowsEff - inputRows) / 2);
+ const TensorIndex forward_pad_left = numext::maxi<Index>(
+ 0, ((outputCols - 1) * col_stride + kernelColsEff - inputCols) / 2);
+
+ const TensorIndex padding_top_z = kernelPlanesEff - 1 - forward_pad_top_z;
+ const TensorIndex padding_top = kernelRowsEff - 1 - forward_pad_top;
+ const TensorIndex padding_left = kernelColsEff - 1 - forward_pad_left;
+
+ const TensorIndex padding_bottom_z = inputPlanes -
+ (outputPlanes - 1) * plane_stride - 2 -
+ padding_top_z + kernelPlanesEff;
+ const TensorIndex padding_bottom = inputRows - (outputRows - 1) * row_stride -
+ 2 - padding_top + kernelRowsEff;
+ const TensorIndex padding_right = inputCols - (outputCols - 1) * col_stride -
+ 2 - padding_left + kernelColsEff;
+
+ eigen_assert(padding_top_z >= 0);
eigen_assert(padding_top >= 0);
eigen_assert(padding_left >= 0);
+ eigen_assert(padding_bottom_z >= 0);
eigen_assert(padding_bottom >= 0);
eigen_assert(padding_right >= 0);
- // The kernel has dimensions filters X channels X patch_planes X patch_rows X
- // patch_cols.
+ // The kernel has dimensions :
+ // filters x channels x patch_planes x patch_rows x patch_cols.
// We need to reverse the kernel along the spatial dimensions.
- array<bool, 5> kernel_reverse;
+ Eigen::array<bool, 5> kernel_reverse;
if (isColMajor) {
kernel_reverse[0] = false;
kernel_reverse[1] = false;
@@ -191,15 +185,35 @@ CuboidConvolutionBackwardInput(
kernel_reverse[4] = false;
}
- DSizes<TensorIndex, 3> kernel_dims;
+ // Reorder the dimensions to:
+ // filters x patch_planes x patch_rows x patch_cols x channels
+ array<TensorIndex, 5> kernel_shuffle;
if (isColMajor) {
- kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelRows * kernelCols * kernelPlanes;
+ // From: filters x channels x planes x rows x cols
+ // To: filters x planes x rows x cols x channels
+ kernel_shuffle[0] = 0;
+ kernel_shuffle[1] = 2;
+ kernel_shuffle[2] = 3;
+ kernel_shuffle[3] = 4;
+ kernel_shuffle[4] = 1;
} else {
- kernel_dims[0] = kernelRows * kernelCols * kernelPlanes;
+ // From: cols x rows x planes x channels x filters
+ // To: channels x cols x rows x planes x filters
+ kernel_shuffle[0] = 3;
+ kernel_shuffle[1] = 0;
+ kernel_shuffle[2] = 1;
+ kernel_shuffle[3] = 2;
+ kernel_shuffle[4] = 4;
+ }
+
+ // Collapse the dims
+ DSizes<TensorIndex, 2> kernel_dims;
+ if (isColMajor) {
+ kernel_dims[0] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelChannels;
- kernel_dims[2] = kernelFilters;
+ } else {
+ kernel_dims[1] = kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels;
}
// The output_backward has dimensions out_depth X out_planes X out_rows X
@@ -208,36 +222,32 @@ CuboidConvolutionBackwardInput(
// dimensions:
// out_depth X (patch_planes * patch_rows * patch_cols) X (input_planes *
// input_rows * input_cols * OTHERS)
- DSizes<TensorIndex, 3> pre_contract_dims;
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[0] =
+ kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = inputPlanes * inputRows * inputCols;
for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[2] *= out.dimension(i);
+ pre_contract_dims[1] *= out.dimension(i);
}
} else {
- pre_contract_dims[2] = kernelFilters;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = inputRows * inputCols * inputPlanes;
+ pre_contract_dims[1] =
+ kernelFilters * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = inputPlanes * inputRows * inputCols;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= out.dimension(i);
}
}
- // We will contract along dimensions (0, 2) in kernel and (0, 1) in
- // output_backward, if this is col-major, and
- // dimensions (0, 2) in kernel and (1, 2) in output_backward, if this
- // row-major.
- array<IndexPair<TensorIndex>, 2> contract_dims;
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, kernelPlanes, kernelRows and kernelCols.
+ array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
} else {
// row-major: output.patches.contract(kernel)
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 2);
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
}
// Post contraction, the dimensions of the input_backprop is
@@ -261,40 +271,31 @@ CuboidConvolutionBackwardInput(
}
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
-
return choose(
Cond<internal::traits<OutputBackward>::Layout == ColMajor>(),
kernel.reverse(kernel_reverse)
+ .shuffle(kernel_shuffle)
.reshape(kernel_dims)
+ .eval()
.contract(output_backward
.extract_volume_patches(
kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols, padding_ztop,
- padding_zbottom, padding_top, padding_bottom,
+ plane_stride, row_stride, col_stride, padding_top_z,
+ padding_bottom_z, padding_top, padding_bottom,
padding_left, padding_right)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
output_backward
.extract_volume_patches(kernelPlanes, kernelRows, kernelCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
+ plane_stride, row_stride, col_stride,
+ padding_top_z, padding_bottom_z, padding_top,
padding_bottom, padding_left, padding_right)
.reshape(pre_contract_dims)
- .contract(kernel.reverse(kernel_reverse).reshape(kernel_dims),
+ .contract(kernel.reverse(kernel_reverse)
+ .shuffle(kernel_shuffle)
+ .reshape(kernel_dims)
+ .eval(),
contract_dims)
.reshape(post_contract_dims));
}
@@ -322,48 +323,69 @@ CuboidConvolutionBackwardInput(
*/
template <typename OutputBackward, typename Input>
EIGEN_ALWAYS_INLINE static const typename internal::conditional<
- internal::traits<OutputBackward>::Layout == ColMajor,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
+ internal::traits<Input>::Layout == ColMajor,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input>,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>,
const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
const TensorVolumePatchOp<
Dynamic, Dynamic, Dynamic,
- const OutputBackward> > > > > >,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index, 5>,
- const TensorReverseOp<
- const array<bool, 5>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<OutputBackward>::Index,
- 5>,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>>>>>,
+ const TensorReverseOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Eigen::TensorReshapingOp<
+ const Eigen::DSizes<typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
const TensorContractionOp<
const array<
- IndexPair<typename internal::traits<Input>::Index>, 2>,
- const TensorReshapingOp<
- const DSizes<
- typename internal::traits<OutputBackward>::Index,
- 4>,
- const TensorVolumePatchOp<Dynamic, Dynamic, Dynamic,
- const OutputBackward> >,
+ IndexPair<typename internal::traits<Input>::Index>, 1>,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index,
- 3>,
- const Input> > > > > >::type
+ 2>,
+ const TensorVolumePatchOp<
+ Dynamic, Dynamic, Dynamic,
+ const Eigen::TensorForcedEvalOp<
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const Input>>>>,
+ const Eigen::TensorForcedEvalOp<const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index,
+ 2>,
+ const Eigen::TensorShufflingOp<
+ const Eigen::array<
+ typename internal::traits<Input>::Index,
+ internal::traits<Input>::NumDimensions>,
+ const OutputBackward>>>>>>>>::type
CuboidConvolutionBackwardKernel(
const Input& input, const OutputBackward& output_backward,
typename internal::traits<Input>::Index kernelPlanes,
@@ -374,11 +396,11 @@ CuboidConvolutionBackwardKernel(
typedef typename internal::traits<Input>::Index TensorIndex;
TensorRef<Tensor<typename internal::traits<Input>::Scalar,
internal::traits<Input>::NumDimensions,
- internal::traits<Input>::Layout, TensorIndex> >
+ internal::traits<Input>::Layout, TensorIndex>>
in(input);
TensorRef<Tensor<typename internal::traits<OutputBackward>::Scalar,
internal::traits<OutputBackward>::NumDimensions,
- internal::traits<OutputBackward>::Layout, TensorIndex> >
+ internal::traits<OutputBackward>::Layout, TensorIndex>>
out(output_backward);
EIGEN_STATIC_ASSERT(internal::traits<Input>::Layout ==
@@ -392,6 +414,13 @@ CuboidConvolutionBackwardKernel(
internal::traits<OutputBackward>::NumDimensions,
YOU_MADE_A_PROGRAMMING_MISTAKE);
+ // We do not support higher dimensional backward convolutions, or convolutions
+ // without batch dimension.
+ // TODO(ezhulenev): Relax this constraint, and turn on tests without batch
+ // dimension in eigen_backward_cuboid_convolutions_test.cc.
+ EIGEN_STATIC_ASSERT(internal::traits<Input>::NumDimensions == 5,
+ YOU_MADE_A_PROGRAMMING_MISTAKE);
+
const TensorIndex inputPlanes =
isColMajor ? in.dimension(1) : in.dimension(NumDims - 2);
const TensorIndex inputRows =
@@ -406,213 +435,174 @@ CuboidConvolutionBackwardKernel(
const TensorIndex outputCols =
isColMajor ? out.dimension(3) : out.dimension(NumDims - 4);
+ // Number of filters. This is the same as the output depth.
const TensorIndex kernelFilters =
isColMajor ? out.dimension(0) : out.dimension(NumDims - 1);
+ // Number of channels. This is the same as the input depth.
const TensorIndex kernelChannels =
isColMajor ? in.dimension(0) : in.dimension(NumDims - 1);
- TensorIndex forward_pad_z, forward_pad_y, forward_pad_x;
- const TensorIndex size_z =
- Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
- const TensorIndex size_y =
- Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
- const TensorIndex size_x =
- Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
-
- // Infer padding type.
- if (size_z == outputPlanes && size_y == outputRows && size_x == outputCols) {
- // SAME padding.
- const TensorIndex dz = numext::maxi<TensorIndex>(
- 0, (size_z - 1) * stridePlanes + kernelPlanes - inputPlanes);
- const TensorIndex dy = numext::maxi<TensorIndex>(
- 0, (size_y - 1) * strideRows + kernelRows - inputRows);
- const TensorIndex dx = numext::maxi<TensorIndex>(
- 0, (size_x - 1) * strideCols + kernelCols - inputCols);
-
- forward_pad_z = dz / 2;
- forward_pad_y = dy / 2;
- forward_pad_x = dx / 2;
+ // Number of batches in the input tensor.
+ const TensorIndex batch =
+ isColMajor ? in.dimension(4) : in.dimension(NumDims - 5);
+
+ // TODO(ezhulenev): Add support for inflated strides. Without inflated strides
+ // effective kernel planes/rows/cols are always the same as the kernel itself
+ // (see eigen_spatial_convolutions for details).
+ const TensorIndex kernelPlanesEff = kernelPlanes;
+ const TensorIndex kernelRowsEff = kernelRows;
+ const TensorIndex kernelColsEff = kernelCols;
+
+ // Compute forward padding from input and output_backward dimensions.
+ const TensorIndex padPlanes = numext::maxi<Index>(
+ 0, (outputPlanes - 1) * stridePlanes + kernelPlanesEff - inputPlanes);
+ const TensorIndex padRows = numext::maxi<Index>(
+ 0, (outputRows - 1) * strideRows + kernelRowsEff - inputRows);
+ const TensorIndex padCols = numext::maxi<Index>(
+ 0, (outputCols - 1) * strideCols + kernelColsEff - inputCols);
+
+ const TensorIndex padding_top_z = padPlanes / 2;
+ const TensorIndex padding_top = padRows / 2;
+ const TensorIndex padding_left = padCols / 2;
+
+ // Compute paddings for output_backward before extracting patches.
+ const auto expanded_out_planes = (outputPlanes - 1) * stridePlanes + 1;
+ const auto expanded_out_rows = (outputRows - 1) * strideRows + 1;
+ const auto expanded_out_cols = (outputCols - 1) * strideCols + 1;
+ const auto padded_out_planes = inputPlanes + kernelPlanes - 1;
+ const auto padded_out_rows = inputRows + kernelRows - 1;
+ const auto padded_out_cols = inputCols + kernelCols - 1;
+ const auto top_pad_planes = kernelPlanes - 1 - padding_top_z;
+ const auto top_pad_rows = kernelRows - 1 - padding_top;
+ const auto left_pad_cols = kernelCols - 1 - padding_left;
+ const auto bottom_pad_planes =
+ padded_out_planes - expanded_out_planes - top_pad_planes;
+ const auto bottom_pad_rows =
+ padded_out_rows - expanded_out_rows - top_pad_rows;
+ const auto right_pad_cols =
+ padded_out_cols - expanded_out_cols - left_pad_cols;
+
+ // Reorder output_backward dimensions.
+ array<TensorIndex, 5> output_backward_shuffle;
+ if (isColMajor) {
+ // From: [out_depth, out_planes, out_rows, out_cols, batch]
+ // To: [batch, out_planes, out_rows, out_cols, out_depth]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
} else {
- // VALID padding.
- forward_pad_z = 0;
- forward_pad_y = 0;
- forward_pad_x = 0;
+ // From: [batch, out_cols, out_rows, out_planes, out_depth]
+ // To: [out_depth, out_cols, out_rows, out_planes, batch]
+ output_backward_shuffle = {4, 1, 2, 3, 0};
}
- const TensorIndex padding_ztop = kernelPlanes - 1 - forward_pad_z;
- const TensorIndex padding_top = kernelRows - 1 - forward_pad_y;
- const TensorIndex padding_left = kernelCols - 1 - forward_pad_x;
-
- const TensorIndex padding_zbottom = inputPlanes + kernelPlanes - 1 -
- (outputPlanes - 1) * stridePlanes - 1 -
- padding_ztop;
- const TensorIndex padding_bottom = inputRows + kernelRows - 1 -
- (outputRows - 1) * strideRows - 1 -
- padding_top;
- const TensorIndex padding_right = inputCols + kernelCols - 1 -
- (outputCols - 1) * strideCols - 1 -
- padding_left;
-
- eigen_assert(padding_ztop >= 0);
- eigen_assert(padding_zbottom >= 0);
- eigen_assert(padding_top >= 0);
- eigen_assert(padding_left >= 0);
- eigen_assert(padding_bottom >= 0);
- eigen_assert(padding_right >= 0);
-
- // The output_backward has dimensions out_depth X out_plaens X out_rows X
- // out_cols X OTHERS
- // When we extract the image patches from output_backward (with input as the
- // kernel), it will have dimensions
- // (out_depth) X (input_planes * input_rows * input_cols) X (kernel_planes *
- // kernel_rows * kernel_cols) X OTHERS
- DSizes<TensorIndex, 4> pre_contract_dims;
+ // Reorder input dimensions.
+ array<TensorIndex, 5> input_shuffle;
if (isColMajor) {
- pre_contract_dims[0] = kernelFilters;
- pre_contract_dims[1] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[2] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[3] = 1;
- for (int i = 4; i < NumDims; ++i) {
- pre_contract_dims[3] *= out.dimension(i);
- }
+ // From: [in_depth, in_planes, in_rows, in_cols, batch]
+ // To: [in_depth, batch, in_planes, in_rows, in_cols]
+ input_shuffle = {0, 4, 1, 2, 3};
} else {
- pre_contract_dims[3] = kernelFilters;
- pre_contract_dims[2] = inputRows * inputCols * inputPlanes;
- pre_contract_dims[1] = kernelRows * kernelCols * kernelPlanes;
- pre_contract_dims[0] = 1;
- for (int i = 0; i < NumDims - 4; ++i) {
- pre_contract_dims[0] *= out.dimension(i);
- }
+ // From: [batch, in_cols, in_rows, in_planes, in_depth]
+ // To: [in_cols, in_rows, in_planes, batch, in_depth]
+ input_shuffle = {1, 2, 3, 0, 4};
}
- // The input has dimensions in_depth X (input_planes * input_rows *
- // input_cols) X OTHERS
- DSizes<TensorIndex, 3> input_dims;
+ // Input is playing the role of a "kernel" in this convolution.
+ DSizes<TensorIndex, 2> input_dims;
if (isColMajor) {
input_dims[0] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[2] = 1;
- for (int i = 4; i < NumDims; ++i) {
- input_dims[2] *= in.dimension(i);
- }
- eigen_assert(input_dims[2] == pre_contract_dims[3]);
+ input_dims[1] = batch * inputPlanes * inputRows * inputCols;
} else {
- input_dims[2] = kernelChannels;
- input_dims[1] = inputRows * inputCols * inputPlanes;
- input_dims[0] = 1;
- for (int i = 0; i < NumDims - 4; ++i) {
- input_dims[0] *= in.dimension(i);
- }
- eigen_assert(input_dims[0] == pre_contract_dims[0]);
+ input_dims[1] = kernelChannels;
+ input_dims[0] = inputCols * inputRows * inputPlanes * batch;
}
- // We will contract along dimensions (1, 2) in and (1, 3) in out, if
- // this is col-major.
- // For row-major, it's dimensions (0, 1) in and (0, 2) in out.
- array<IndexPair<TensorIndex>, 2> contract_dims;
+ // Molds the output of the patch extraction result into a 2D tensor:
+ // - the first dimension (dims[0]): the patch values to be multiplied with the
+ // kernels
+ // - the second dimension (dims[1]): everything else
+ DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
- // col-major: in.contract(output.patches)
- contract_dims[0] = IndexPair<TensorIndex>(1, 1);
- contract_dims[1] = IndexPair<TensorIndex>(2, 3);
+ pre_contract_dims[0] = batch * inputPlanes * inputRows * inputCols;
+ pre_contract_dims[1] =
+ kernelPlanes * kernelRows * kernelCols * kernelFilters;
} else {
- // row-major: output.patches.contract(in)
- contract_dims[0] = IndexPair<TensorIndex>(0, 0);
- contract_dims[1] = IndexPair<TensorIndex>(2, 1);
+ pre_contract_dims[1] = inputCols * inputRows * inputPlanes * batch;
+ pre_contract_dims[0] =
+ kernelFilters * kernelCols * kernelRows * kernelPlanes;
}
- // After the contraction, the kernel will have dimension
- // in_depth X out_depth X kernel_patches X kernel_rows X kernel_cols
- // We will need to shuffle the first two dimensions and reverse the spatial
- // dimensions.
- // The end shape is:
- // out_depth X in_shape X kernel_planes X kernel_rows X kernel_cols
+ // We will contract along the collapsed dimension that contains the
+ // batch, inputPlanes, inputRows and inputCols.
+ array<IndexPair<TensorIndex>, 1> contract_dims;
+ contract_dims[0] = IndexPair<TensorIndex>(1, 0);
- // This is the shape of the kernel *before* the shuffling.
- DSizes<TensorIndex, 5> kernel_dims;
+ // Dimensions after contraction.
+ DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
- kernel_dims[0] = kernelChannels;
- kernel_dims[1] = kernelFilters;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelRows;
- kernel_dims[4] = kernelCols;
+ post_contract_dims[0] = kernelChannels;
+ post_contract_dims[1] = kernelPlanes;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelCols;
+ post_contract_dims[4] = kernelFilters;
} else {
- kernel_dims[0] = kernelCols;
- kernel_dims[1] = kernelRows;
- kernel_dims[2] = kernelPlanes;
- kernel_dims[3] = kernelFilters;
- kernel_dims[4] = kernelChannels;
+ post_contract_dims[0] = kernelFilters;
+ post_contract_dims[1] = kernelCols;
+ post_contract_dims[2] = kernelRows;
+ post_contract_dims[3] = kernelPlanes;
+ post_contract_dims[4] = kernelChannels;
}
- // Flip filters and channels.
+ // Reorder output of contraction to valid filter shape.
array<TensorIndex, 5> kernel_shuffle;
if (isColMajor) {
- kernel_shuffle[0] = 1;
- kernel_shuffle[1] = 0;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 3;
- kernel_shuffle[4] = 4;
+ // From: [in_depth, kernel_planes, kernel_rows, kernel_cols, out_depth]
+ // To: [out_depth, in_depth, kernel_planes, kernel_rows, kernel_cols]
+ kernel_shuffle = {4, 0, 1, 2, 3};
} else {
- kernel_shuffle[0] = 0;
- kernel_shuffle[1] = 1;
- kernel_shuffle[2] = 2;
- kernel_shuffle[3] = 4;
- kernel_shuffle[4] = 3;
+ // From: [out_depth, kernel_cols, kernel_rows, kernel_planes, in_depth]
+ // To: [kernel_cols, kernel_rows, kernel_planes, in_depth, out_depth]
+ kernel_shuffle = {1, 2, 3, 4, 0};
}
- // Reverse the spatial dimensions.
- array<bool, 5> kernel_reverse;
+ // Reverse kernel backprop dimensions.
+ array<TensorIndex, 5> kernel_reverse;
if (isColMajor) {
- kernel_reverse[0] = false;
- kernel_reverse[1] = false;
- kernel_reverse[2] = true;
- kernel_reverse[3] = true;
- kernel_reverse[4] = true;
+ kernel_reverse = {false, false, true, true, true};
} else {
- kernel_reverse[0] = true;
- kernel_reverse[1] = true;
- kernel_reverse[2] = true;
- kernel_reverse[3] = false;
- kernel_reverse[4] = false;
+ kernel_reverse = {true, true, true, false, false};
}
- DSizes<TensorIndex, NumDims> strides;
- for (int i = 0; i < NumDims; i++) {
- strides[i] = 1;
- }
- if (isColMajor) {
- strides[1] = stridePlanes;
- strides[2] = strideRows;
- strides[3] = strideCols;
- } else {
- strides[NumDims - 2] = stridePlanes;
- strides[NumDims - 3] = strideRows;
- strides[NumDims - 4] = strideCols;
- }
- return choose(
- Cond<internal::traits<Input>::Layout == ColMajor>(),
- input.reshape(input_dims)
- .contract(output_backward
+ // Create convolution input (aka source of patches) from output backward
+ // tensor by shuffling dimensions.
+ const auto the_input =
+ output_backward.shuffle(output_backward_shuffle).eval();
+
+ // Create convolution kernel (aka filter) from input by shuffling and
+ // reshaping.
+ const auto the_kernel =
+ input.shuffle(input_shuffle).reshape(input_dims).eval();
+
+ return choose(Cond<internal::traits<Input>::Layout == ColMajor>(),
+ the_kernel.contract(
+ the_input
.extract_volume_patches(
inputPlanes, inputRows, inputCols, 1, 1, 1,
stridePlanes, strideRows, strideCols,
-
- padding_ztop, padding_zbottom, padding_top,
- padding_bottom, padding_left, padding_right)
+ top_pad_planes, bottom_pad_planes, top_pad_rows,
+ bottom_pad_rows, left_pad_cols, right_pad_cols)
.reshape(pre_contract_dims),
- contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle),
- output_backward
- .extract_volume_patches(inputPlanes, inputRows, inputCols, 1, 1, 1,
- stridePlanes, strideRows, strideCols,
- padding_ztop, padding_zbottom, padding_top,
- padding_bottom, padding_left, padding_right)
- .reshape(pre_contract_dims)
- .contract(input.reshape(input_dims), contract_dims)
- .reshape(kernel_dims)
- .reverse(kernel_reverse)
- .shuffle(kernel_shuffle));
+ contract_dims),
+ the_input
+ .extract_volume_patches(
+ inputPlanes, inputRows, inputCols, 1, 1, 1,
+ stridePlanes, strideRows, strideCols, top_pad_planes,
+ bottom_pad_planes, top_pad_rows, bottom_pad_rows,
+ left_pad_cols, right_pad_cols)
+ .reshape(pre_contract_dims)
+ .contract(the_kernel, contract_dims))
+ .reshape(post_contract_dims)
+ .shuffle(kernel_shuffle)
+ .reverse(kernel_reverse);
}
} // end namespace Eigen
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
index 099696105b..960920c55b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions.h
@@ -189,14 +189,19 @@ SpatialConvolutionBackwardInput(
}
#endif
- // Reorder the dimensions to filters X patch_rows X patch_cols X channels
+ // Reorder the dimensions to:
+ // filters x patch_rows x patch_cols x channels
array<TensorIndex, 4> kernel_shuffle;
if (isColMajor) {
+ // From: filters x channels x rows x cols
+ // To: filters x rows x cols x channels
kernel_shuffle[0] = 0;
kernel_shuffle[1] = 2;
kernel_shuffle[2] = 3;
kernel_shuffle[3] = 1;
} else {
+ // From: cols x rows x channels x filters
+ // To: channels x cols x rows x filters
kernel_shuffle[0] = 2;
kernel_shuffle[1] = 0;
kernel_shuffle[2] = 1;
@@ -233,8 +238,8 @@ SpatialConvolutionBackwardInput(
}
}
- // We will contract along the fused dimension that contains the kernelFilters,
- // the kernelRows and the kernelCols.
+ // We will contract along the collapsed dimension that contains the
+ // kernelFilters, the kernelRows and the kernelCols.
array<IndexPair<TensorIndex>, 1> contract_dims;
if (isColMajor) {
// col-major: kernel.contract(output.patches)
@@ -327,23 +332,16 @@ EIGEN_ALWAYS_INLINE static const typename internal::conditional<
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic,
- const Input> > > > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> > > >,
TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 4>,
const TensorContractionOp<
const array<IndexPair<typename internal::traits<Input>::Index>, 1>,
- const TensorShufflingOp<
- const array<typename internal::traits<OutputBackward>::Index,
- 2>,
- const TensorReshapingOp<
- const DSizes<typename internal::traits<Input>::Index, 2>,
- const TensorImagePatchOp<Dynamic, Dynamic, const Input> > >,
+ const TensorReshapingOp<
+ const DSizes<typename internal::traits<Input>::Index, 2>,
+ const TensorImagePatchOp<Dynamic, Dynamic, const Input> >,
const TensorReshapingOp<
const DSizes<typename internal::traits<Input>::Index, 2>,
const OutputBackward> > > >::type
@@ -451,12 +449,16 @@ SpatialConvolutionBackwardKernel(
eigen_assert(output_dims[0] == pre_contract_dims[0]);
}
- array<TensorIndex, 2> shuffle_dims;
- shuffle_dims[0] = 1;
- shuffle_dims[1] = 0;
-
+ // We will contract along the collapsed dimension that contains the
+ // outputCols, outputRows and OTHERS.
array<IndexPair<TensorIndex>, 1> contract_dims;
- contract_dims[0] = IndexPair<TensorIndex>(1, 0);
+ if (isColMajor) {
+ // col-major: output_backward.contract(input.patches)
+ contract_dims[0] = IndexPair<TensorIndex>(1, 1);
+ } else {
+ // row-major: input.patches.contract(output_backward)
+ contract_dims[0] = IndexPair<TensorIndex>(0, 0);
+ }
// After the contraction, the kernel will have the desired shape
// out_depth X in_shape X kernel_rows X kernel_cols
@@ -482,8 +484,7 @@ SpatialConvolutionBackwardKernel(
kernelRows, kernelCols, row_stride, col_stride,
row_in_stride, col_in_stride, 1, 1, padding_top,
padding_bottom, padding_left, padding_right, OutScalar(0))
- .reshape(pre_contract_dims)
- .shuffle(shuffle_dims),
+ .reshape(pre_contract_dims),
contract_dims)
.reshape(kernel_dims),
input
@@ -492,11 +493,10 @@ SpatialConvolutionBackwardKernel(
padding_top, padding_bottom, padding_left,
padding_right, OutScalar(0))
.reshape(pre_contract_dims)
- .shuffle(shuffle_dims)
.contract(output_backward.reshape(output_dims), contract_dims)
.reshape(kernel_dims));
}
} // end namespace Eigen
-#endif // EIGEN_CXX11_NEURAL_NETWORKS_BACKWARD_SPATIAL_CONVOLUTIONS_H
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BACKWARD_SPATIAL_CONVOLUTIONS_H_
diff --git a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
index 2229ec9659..673ec1458b 100644
--- a/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
+++ b/tensorflow/core/kernels/eigen_backward_spatial_convolutions_test.cc
@@ -1248,11 +1248,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4> input(input_depth, input_planes, input_rows, input_cols);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5> input(input_depth, input_planes, input_rows, input_cols,
+ /*num_batches*/ 1);
Tensor<float, 5> kernel(output_depth, input_depth, patch_planes, patch_rows,
patch_cols);
- Tensor<float, 4> output_backward(output_depth, output_planes, output_rows,
- output_cols);
+ Tensor<float, 5> output_backward(output_depth, output_planes, output_rows,
+ output_cols, /*num_batches*/ 1);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1282,9 +1285,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(id, i, j, k) *
- output_backward(od, output_i, output_j, output_k);
+ expected += input(id, i, j, k, /*batch*/ 0) *
+ output_backward(od, output_i, output_j,
+ output_k, /*batch*/ 0);
}
}
}
@@ -1311,12 +1314,14 @@ TEST(EigenBackwardSpatialConvolutionsTest,
const int output_cols = input_cols - patch_cols + 1;
const int output_planes = input_planes - patch_planes + 1;
- Tensor<float, 4, RowMajor> input(input_cols, input_rows, input_planes,
- input_depth);
+ // TODO(ezhulenev): Support backward kernel convolution without batch
+ // dimension.
+ Tensor<float, 5, RowMajor> input(/*num_batches*/ 1, input_cols, input_rows,
+ input_planes, input_depth);
Tensor<float, 5, RowMajor> kernel(patch_cols, patch_rows, patch_planes,
input_depth, output_depth);
- Tensor<float, 4, RowMajor> output_backward(output_cols, output_rows,
- output_planes, output_depth);
+ Tensor<float, 5, RowMajor> output_backward(
+ /*num_batches*/ 1, output_cols, output_rows, output_planes, output_depth);
output_backward = output_backward.constant(11.0f) + output_backward.random();
input = input.constant(2.0f) + input.random();
@@ -1346,9 +1351,9 @@ TEST(EigenBackwardSpatialConvolutionsTest,
if (output_i >= 0 && output_i < output_planes &&
output_j >= 0 && output_j < output_rows &&
output_k >= 0 && output_k < output_cols) {
- expected +=
- input(k, j, i, id) *
- output_backward(output_k, output_j, output_i, od);
+ expected += input(/*batch*/ 0, k, j, i, id) *
+ output_backward(/*batch*/ 0, output_k, output_j,
+ output_i, od);
}
}
}
diff --git a/tensorflow/core/kernels/eigen_benchmark.h b/tensorflow/core/kernels/eigen_benchmark.h
new file mode 100644
index 0000000000..87e41b89b3
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark.h
@@ -0,0 +1,304 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+#define TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_backward_cuboid_convolutions.h"
+#include "tensorflow/core/kernels/eigen_backward_spatial_convolutions.h"
+#include "tensorflow/core/kernels/eigen_cuboid_convolution.h"
+#include "tensorflow/core/kernels/eigen_spatial_convolutions.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+using ::tensorflow::TTypes;
+
+template <typename Scalar, typename Device>
+class SpatialConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 4>::ConstTensor;
+ using Filter = TTypes<float, 4>::ConstTensor;
+ using Output = TTypes<float, 4>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 4>;
+
+ SpatialConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void SpatialConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::SpatialConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void SpatialConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using InputBackward = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Filter filter(filter_data, filter_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ input_backward.device(device_) = Eigen::SpatialConvolutionBackwardInput(
+ filter, output_backward, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(filter_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
+ }
+
+ void SpatialConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 4>::ConstTensor;
+ using FilterBackward = TTypes<float, 4>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ filter_dims[3]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, input_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_backward.device(device_) = Eigen::SpatialConvolutionBackwardKernel(
+ input, output_backward, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_backward_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+template <typename Scalar, typename Device>
+class CuboidConvolutionBenchmarksSuite {
+ public:
+ using Input = TTypes<float, 5>::ConstTensor;
+ using Filter = TTypes<float, 5>::ConstTensor;
+ using Output = TTypes<float, 5>::Tensor;
+
+ using Dimensions = Eigen::DSizes<Eigen::Index, 5>;
+
+ CuboidConvolutionBenchmarksSuite(int iters, Device& device)
+ : iters_(iters), device_(device) {}
+
+ Eigen::Index BufferSize(const Dimensions& dims) {
+ return dims.TotalSize() * sizeof(Scalar);
+ }
+
+ void CuboidConvolution(Dimensions input_dims, Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+
+ Input input(input_data, input_dims);
+ Filter filter(filter_data, filter_dims);
+ Output output(output_data, output_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ output.device(device_) = Eigen::CuboidConvolution(input, filter);
+ tensorflow::testing::DoNotOptimize(output);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(filter_data);
+ device_.deallocate(output_data);
+ }
+
+ void CuboidConvolutionBackwardInput(Dimensions input_dims,
+ Dimensions filter_dims) {
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using InputBackward = TTypes<float, 5>::Tensor;
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index input_rows = input_dims[1];
+ Eigen::Index input_cols = input_dims[2];
+ Eigen::Index input_planes = input_dims[3];
+
+ Scalar* filter_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* input_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+
+ device_.memset(filter_data, 123, BufferSize(filter_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Filter filter(filter_data, filter_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ InputBackward input_backward(input_backward_data, input_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ input_backward.device(device_) = Eigen::CuboidConvolutionBackwardInput(
+ filter, output_backward, input_planes, input_rows, input_cols);
+ tensorflow::testing::DoNotOptimize(input_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(filter_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(input_backward_data);
+ }
+
+ void CuboidConvolutionBackwardKernel(Dimensions input_dims,
+ Dimensions filter_dims) {
+ using OutputBackward = TTypes<float, 5>::ConstTensor;
+ using FilterBackward = TTypes<float, 5>::Tensor;
+
+ Dimensions output_dims(input_dims[0], // batch
+ input_dims[1], // input_height
+ input_dims[2], // input_width
+ input_dims[3], // input_planes
+ filter_dims[4]); // filter_count
+
+ // Assuming that the convolution had SAME padding.
+ Eigen::Index filter_rows = filter_dims[0];
+ Eigen::Index filter_cols = filter_dims[1];
+ Eigen::Index filter_planes = filter_dims[2];
+
+ Scalar* input_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(input_dims)));
+ Scalar* output_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(output_dims)));
+ Scalar* filter_backward_data =
+ static_cast<Scalar*>(device_.allocate(BufferSize(filter_dims)));
+
+ device_.memset(input_data, 123, BufferSize(input_dims));
+ device_.memset(output_backward_data, 123, BufferSize(output_dims));
+
+ Input input(input_data, input_dims);
+ OutputBackward output_backward(output_backward_data, output_dims);
+ FilterBackward filter_backward(filter_backward_data, filter_dims);
+
+ ::tensorflow::testing::StartTiming();
+ for (int i = 0; i < iters_; ++i) {
+ filter_backward.device(device_) = Eigen::CuboidConvolutionBackwardKernel(
+ input, output_backward, filter_planes, filter_rows, filter_cols);
+ tensorflow::testing::DoNotOptimize(filter_backward);
+ }
+ ::tensorflow::testing::StopTiming();
+
+ device_.deallocate(input_data);
+ device_.deallocate(output_backward_data);
+ device_.deallocate(filter_backward_data);
+ }
+
+ private:
+ int iters_;
+ Device& device_;
+};
+
+#endif // TENSORFLOW_CORE_KERNELS_EIGEN_BENCHMARK_H_
diff --git a/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
new file mode 100644
index 0000000000..ec949ddc84
--- /dev/null
+++ b/tensorflow/core/kernels/eigen_benchmark_cpu_test.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENTE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONT OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#define EIGEN_USE_CUSTOM_THREAD_POOL
+#define EIGEN_USE_THREADS
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/kernels/eigen_benchmark.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+#define CREATE_THREAD_POOL(threads) \
+ Eigen::ThreadPool tp(threads); \
+ Eigen::ThreadPoolDevice device(&tp, threads)
+
+// -------------------------------------------------------------------------- //
+// Spatial Convolutions //
+// -------------------------------------------------------------------------- //
+
+void SpatialConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolution(input_dims, filter_dims);
+
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_depth * filter_height * filter_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void SpatialConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ SpatialConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(input_batches, input_height,
+ input_width, input_depth);
+ typename Benchmark::Dimensions filter_dims(filter_height, filter_width,
+ input_depth, filter_count);
+
+ benchmark.SpatialConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops =
+ num_computed_elements * (input_batches * input_height * input_width);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_SPATIAL_NAME(prefix, NT, N, H, W, C, FC, FH, FW) \
+ BM_##prefix##_CPU_##NT##T_in_##N##_##H##_##W##_##C##_f_##FC##_##FH##_##FW
+
+#define BM_SpatialConvolution(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolution(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolution, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdInput(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolutionBackwardInput(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK( \
+ BM_SPATIAL_NAME(SpatialConvolutionBwdInput, NT, N, H, W, C, FC, FH, FW))
+
+#define BM_SpatialConvolutionBwdKernel(NT, N, H, W, C, FC, FH, FW, LABEL) \
+ static void BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ SpatialConvolutionBackwardKernel(iters, NT, N, H, W, C, FC, FH, FW); \
+ } \
+ BENCHMARK(BM_SPATIAL_NAME(SpatialConvolutionBwdKernel, NT, N, H, W, C, FC, \
+ FH, FW))
+
+#define BM_SpatialConvolutions(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolution(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolution(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdInput(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdInput(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdInput(16, N, H, W, C, FC, FH, FW, LABEL);
+
+#define BM_SpatialConvolutionsBwdKernel(N, H, W, C, FC, FH, FW, LABEL) \
+ BM_SpatialConvolutionBwdKernel(2, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(4, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(8, N, H, W, C, FC, FH, FW, LABEL); \
+ BM_SpatialConvolutionBwdKernel(16, N, H, W, C, FC, FH, FW, LABEL);
+
+// ImageNet Forward Convolutions -------------------------------------------- //
+
+BM_SpatialConvolutions(32, // batch size
+ 56, 56, 64, // input: height, width, depth
+ 192, 3, 3, // filter: count, height, width
+ "conv2_00");
+
+BM_SpatialConvolutions(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutions(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutions(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutions(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutions(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutions(32, 7, 7, 48, 128, 5, 5, "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutions(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// Benchmarks from https://github.com/soumith/convnet-benchmarks
+BM_SpatialConvolutions(128, 128, 128, 3, 96, 11, 11, "convnet-layer1");
+BM_SpatialConvolutions(128, 64, 64, 64, 128, 9, 9, "convnet-layer2");
+BM_SpatialConvolutions(128, 32, 32, 128, 128, 9, 9, "convnet-layer3");
+BM_SpatialConvolutions(128, 16, 16, 128, 128, 7, 7, "convnet-layer4");
+BM_SpatialConvolutions(128, 13, 13, 384, 384, 3, 3, "convnet-layer5");
+
+// ImageNet BackwardInput Convolutions -------------------------------------- //
+
+BM_SpatialConvolutionsBwdInput(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdInput(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// ImageNet BackwardKernel Convolutions ------------------------------------- //
+
+BM_SpatialConvolutionsBwdKernel(32, 56, 56, 64, 192, 3, 3, "conv2_00");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 96, 128, 3, 3, "conv3a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 16, 32, 5, 5, "conv3a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 128, 192, 3, 3, "conv3_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 28, 28, 32, 96, 5, 5, "conv3_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 96, 204, 3, 3, "conv4a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 16, 48, 5, 5, "conv4a_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 112, 224, 3, 3, "conv4b_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 24, 64, 5, 5,
+ "conv4b_00_5x5 / conv4c_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 128, 256, 3, 3, "conv4c_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 144, 288, 3, 3, "conv4d_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 64, 5, 5, "conv4d_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 160, 320, 3, 3, "conv4_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 14, 14, 32, 128, 5, 5, "conv4_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 160, 320, 3, 3, "conv5a_00_3x3");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 48, 128, 5, 5,
+ "conv5a_00_5x5 / conv5_00_5x5");
+BM_SpatialConvolutionsBwdKernel(32, 7, 7, 192, 384, 3, 3, "conv5_00_3x3");
+
+// -------------------------------------------------------------------------- //
+// Cuboid Convolutions //
+// -------------------------------------------------------------------------- //
+
+void CuboidConvolution(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height, int input_width,
+ int input_planes, int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height, int filter_width,
+ int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolution(input_dims, filter_dims);
+
+ auto num_computed_elements =
+ (input_dims.TotalSize() / input_depth) * filter_count;
+ auto flops = num_computed_elements *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardInput(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardInput(input_dims, filter_dims);
+
+ auto num_computed_elements = input_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_depth * filter_height * filter_width * filter_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+void CuboidConvolutionBackwardKernel(int iters, int num_threads,
+ /* Input dimensions: */
+ int input_batches, int input_height,
+ int input_width, int input_planes,
+ int input_depth,
+ /* Filter (kernel) dimensions: */
+ int filter_count, int filter_height,
+ int filter_width, int filter_planes) {
+ ::tensorflow::testing::StopTiming();
+
+ CREATE_THREAD_POOL(num_threads);
+
+ using Benchmark =
+ CuboidConvolutionBenchmarksSuite<float, Eigen::ThreadPoolDevice>;
+ auto benchmark = Benchmark(iters, device);
+
+ typename Benchmark::Dimensions input_dims(
+ input_batches, input_height, input_width, input_planes, input_depth);
+ typename Benchmark::Dimensions filter_dims(
+ filter_height, filter_width, filter_planes, input_depth, filter_count);
+
+ benchmark.CuboidConvolutionBackwardKernel(input_dims, filter_dims);
+
+ auto num_computed_elements = filter_dims.TotalSize();
+ auto flops = num_computed_elements *
+ (input_batches * input_height * input_width * input_planes);
+ ::tensorflow::testing::ItemsProcessed(flops * iters);
+}
+
+// Macro arguments names: --------------------------------------------------- //
+// NT: num threads
+// N: batch size
+// H: height
+// W: width
+// P: panes
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+// FP: filter panes
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_CUBOID_NAME(p, NT, N, H, W, P, C, FC, FH, FW, FP) \
+ BM_CONCAT(BM_##p##_CPU_##NT##T_in_##N##_##H##_##W##_##P##_##C, \
+ _f_##FC##_##FH##_##FW##_##FP)
+
+#define BM_CuboidConvolution(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, \
+ FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolution(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK( \
+ BM_CUBOID_NAME(CuboidConvolution, NT, N, H, W, P, C, FC, FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdInput(NT, N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolutionBackwardInput(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdInput, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutionBwdKernel(NT, N, H, W, P, C, FC, FH, FW, FP, \
+ LABEL) \
+ static void BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, \
+ FC, FH, FW, FP)(int iters) { \
+ ::tensorflow::testing::SetLabel(LABEL); \
+ CuboidConvolutionBackwardKernel(iters, NT, N, H, W, P, C, FC, FH, FW, FP); \
+ } \
+ BENCHMARK(BM_CUBOID_NAME(CuboidConvolutionBwdKernel, NT, N, H, W, P, C, FC, \
+ FH, FW, FP))
+
+#define BM_CuboidConvolutions(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolution(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolution(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdInput(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdInput(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdInput(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+#define BM_CuboidConvolutionsBwdKernel(N, H, W, P, C, FC, FH, FW, FP, LABEL) \
+ BM_CuboidConvolutionBwdKernel(2, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(4, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(8, N, H, W, P, C, FC, FH, FW, FP, LABEL); \
+ BM_CuboidConvolutionBwdKernel(16, N, H, W, P, C, FC, FH, FW, FP, LABEL);
+
+// Random Cuboid Convolutions ----------------------------------------------- //
+// TODO(ezhulenev): find representative dims for cuboid convolutions (find
+// models using Conv3D ops).
+
+BM_CuboidConvolutions(8, // batch size
+ 25, 25, 25, 4, // input: height, width, panes, depth
+ 16, 5, 5, 5, // filter: count, height, width, panes
+ "conv3d_depth4");
+BM_CuboidConvolutions(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutions(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutions(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
+
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdInput(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdInput(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdInput(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
+
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 4, 16, 5, 5, 5, "conv3d_depth4");
+BM_CuboidConvolutionsBwdKernel(8, 25, 25, 25, 8, 16, 5, 5, 5, "conv3d_depth8");
+BM_CuboidConvolutionsBwdKernel(2, 9, 31, 31, 64, 64, 5, 5, 5, "b2_conv3d_1");
+BM_CuboidConvolutionsBwdKernel(2, 5, 27, 27, 64, 64, 5, 5, 5, "b2_conv3d_2");
diff --git a/tensorflow/core/kernels/eigen_cuboid_convolution.h b/tensorflow/core/kernels/eigen_cuboid_convolution.h
index 62e9f9123d..6a9a2accd8 100644
--- a/tensorflow/core/kernels/eigen_cuboid_convolution.h
+++ b/tensorflow/core/kernels/eigen_cuboid_convolution.h
@@ -21,6 +21,1412 @@ limitations under the License.
namespace Eigen {
+namespace internal {
+
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract volume patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelPlanes * kernelRows * kernelCols
+// 1: out_planes * out_height * out_width * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar_,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionInputMapper<
+ Scalar_, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef Scalar_ Scalar;
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper VectorMapper;
+ typedef SubMapper LinearMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(
+ const TensorEvaluator<
+ const TensorReshapingOp<
+ NewDimension,
+ const TensorVolumePatchOp<Planes, Rows, Cols, ArgType> >,
+ Device>& tensor,
+ const nocontract_t&, const nocontract_t&, const contract_t&,
+ const contract_t&)
+ : m_impl(tensor.impl().impl()) {
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_patch_depth = tensor.impl().dimensions()[0];
+ m_patch_planes = tensor.impl().dimensions()[1];
+ m_patch_rows = tensor.impl().dimensions()[2];
+ m_patch_cols = tensor.impl().dimensions()[3];
+ m_num_patches = tensor.impl().dimensions()[4];
+ } else {
+ const int NumDims = tensor.impl().dimensions().size();
+ m_patch_depth = tensor.impl().dimensions()[NumDims - 1];
+ m_patch_planes = tensor.impl().dimensions()[NumDims - 2];
+ m_patch_rows = tensor.impl().dimensions()[NumDims - 3];
+ m_patch_cols = tensor.impl().dimensions()[NumDims - 4];
+ m_num_patches = tensor.impl().dimensions()[NumDims - 5];
+ }
+
+ // Strides for navigating through the single patch.
+ m_patch_plane_stride = m_patch_depth;
+ m_patch_row_stride = m_patch_planes * m_patch_plane_stride;
+ m_patch_col_stride = m_patch_rows * m_patch_row_stride;
+
+ // Strides for the output tensor.
+ // IMPORTANT: These strides are used to locate an element in a patch at a
+ // depth zero (channel), which is not quite the same as "traditional"
+ // stride.
+ m_rowStride = m_patch_planes;
+ m_colStride = m_patch_rows * m_rowStride;
+ m_patchStride = m_colStride * m_patch_cols * m_patch_depth;
+ m_otherStride = m_patchStride * m_num_patches;
+
+ m_outputPlanes = tensor.impl().outputPlanes();
+ m_outputRows = tensor.impl().outputRows();
+ m_outputCols = tensor.impl().outputCols();
+
+ m_outputPlanesRows = m_outputPlanes * m_outputRows;
+
+ m_plane_strides = tensor.impl().userPlaneStride();
+ m_row_strides = tensor.impl().userRowStride();
+ m_col_strides = tensor.impl().userColStride();
+
+ m_in_plane_strides = tensor.impl().userInPlaneStride();
+ m_in_row_strides = tensor.impl().userInRowStride();
+ m_in_col_strides = tensor.impl().userInColStride();
+
+ m_patch_plane_inflate_strides = tensor.impl().planeInflateStride();
+ m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
+ m_patch_col_inflate_strides = tensor.impl().colInflateStride();
+
+ if (internal::traits<ArgType>::Layout == ColMajor) {
+ m_inputDepth = tensor.impl().impl().dimensions()[0];
+ m_inputPlanes = tensor.impl().impl().dimensions()[1];
+ m_inputRows = tensor.impl().impl().dimensions()[2];
+ m_inputCols = tensor.impl().impl().dimensions()[3];
+ } else {
+ const int NumDims = tensor.impl().impl().dimensions().size();
+ m_inputDepth = tensor.impl().impl().dimensions()[NumDims - 1];
+ m_inputPlanes = tensor.impl().impl().dimensions()[NumDims - 2];
+ m_inputRows = tensor.impl().impl().dimensions()[NumDims - 3];
+ m_inputCols = tensor.impl().impl().dimensions()[NumDims - 4];
+ }
+
+ // Strides for navigating through the input tensor.
+ m_planeInputStride = m_inputDepth;
+ m_rowInputStride = m_inputDepth * m_inputPlanes;
+ m_colInputStride = m_inputDepth * m_inputRows * m_inputPlanes;
+ m_patchInputStride =
+ m_inputDepth * m_inputRows * m_inputCols * m_inputPlanes;
+
+ m_planePaddingTop = tensor.impl().planePaddingTop();
+ m_rowPaddingTop = tensor.impl().rowPaddingTop();
+ m_colPaddingLeft = tensor.impl().colPaddingLeft();
+
+ m_fastNumPatches = internal::TensorIntDivisor<Index>(m_num_patches);
+
+ m_fastPatchPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_stride);
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
+
+ m_fastInputPlaneStride =
+ internal::TensorIntDivisor<Index>(m_patch_plane_inflate_strides);
+ m_fastInputRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
+ m_fastInputColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_inflate_strides);
+
+ m_fastRowStride = internal::TensorIntDivisor<Index>(m_rowStride);
+ m_fastColStride = internal::TensorIntDivisor<Index>(m_colStride);
+
+ m_fastDimZero = internal::TensorIntDivisor<Index>(m_patch_depth);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputPlanes = internal::TensorIntDivisor<Index>(m_outputPlanes);
+ m_fastOutputRows = internal::TensorIntDivisor<Index>(m_outputRows);
+ m_fastOutputCols = internal::TensorIntDivisor<Index>(m_outputCols);
+
+ m_fastOutputPlanesRows =
+ internal::TensorIntDivisor<Index>(m_outputPlanesRows);
+ }
+
+ EIGEN_DEVICE_FUNC
+ TensorContractionInputMapper(const TensorContractionInputMapper& base_mapper)
+ : m_impl(base_mapper.m_impl) {
+ m_patch_depth = base_mapper.m_patch_depth;
+ m_patch_planes = base_mapper.m_patch_planes;
+ m_patch_rows = base_mapper.m_patch_rows;
+ m_patch_cols = base_mapper.m_patch_cols;
+ m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_plane_stride = base_mapper.m_patch_plane_stride;
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
+ m_rowStride = base_mapper.m_rowStride;
+ m_colStride = base_mapper.m_colStride;
+ m_patchStride = base_mapper.m_patchStride;
+ m_otherStride = base_mapper.m_otherStride;
+
+ m_planeInputStride = base_mapper.m_planeInputStride;
+ m_rowInputStride = base_mapper.m_rowInputStride;
+ m_colInputStride = base_mapper.m_colInputStride;
+ m_patchInputStride = base_mapper.m_patchInputStride;
+ m_otherInputStride = base_mapper.m_otherInputStride;
+
+ m_inputDepth = base_mapper.m_inputDepth;
+ m_inputPlanes = base_mapper.m_inputPlanes;
+ m_inputRows = base_mapper.m_inputRows;
+ m_inputCols = base_mapper.m_inputCols;
+
+ m_outputPlanes = base_mapper.m_outputPlanes;
+ m_outputRows = base_mapper.m_outputRows;
+ m_outputCols = base_mapper.m_outputCols;
+
+ m_plane_strides = base_mapper.m_plane_strides;
+ m_row_strides = base_mapper.m_row_strides;
+ m_col_strides = base_mapper.m_col_strides;
+
+ m_in_plane_strides = base_mapper.m_in_plane_strides;
+ m_in_row_strides = base_mapper.m_in_row_strides;
+ m_in_col_strides = base_mapper.m_in_col_strides;
+
+ m_patch_plane_inflate_strides = base_mapper.m_patch_plane_inflate_strides;
+ m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
+ m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
+
+ m_planePaddingTop = base_mapper.m_planePaddingTop;
+ m_rowPaddingTop = base_mapper.m_rowPaddingTop;
+ m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+
+ m_outputPlanesRows = base_mapper.m_outputPlanesRows;
+
+ m_fastNumPatches = base_mapper.m_fastNumPatches;
+ m_fastPatchPlaneStride = base_mapper.m_fastPatchPlaneStride;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
+ m_fastInputPlaneStride = base_mapper.m_fastInputPlaneStride;
+ m_fastInputRowStride = base_mapper.m_fastInputRowStride;
+ m_fastInputColStride = base_mapper.m_fastInputColStride;
+ m_fastRowStride = base_mapper.m_fastRowStride;
+ m_fastColStride = base_mapper.m_fastColStride;
+ m_fastOutputPlanes = base_mapper.m_fastOutputPlanes;
+ m_fastOutputRows = base_mapper.m_fastOutputRows;
+ m_fastOutputCols = base_mapper.m_fastOutputCols;
+ m_fastDimZero = base_mapper.m_fastDimZero;
+ m_fastOutputPlanesRows = base_mapper.m_fastOutputPlanesRows;
+ }
+
+ // If true, turns off some optimizations for loading packets since the image
+ // patches are "non-standard" such as there are non-trivial strides or
+ // inflations in the input.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_in_plane_strides != 1 || m_in_row_strides != 1 ||
+ m_in_col_strides != 1 || m_patch_plane_inflate_strides != 1 ||
+ m_patch_row_inflate_strides != 1 || m_patch_col_inflate_strides != 1;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE SubMapper getSubMapper(Index i, Index j) const {
+ return SubMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE LinearMapper getLinearMapper(Index i, Index j) const {
+ return LinearMapper(*this, i, j);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Scalar operator()(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the coefficient at the patchIndex location instead of the usual
+ // m_rowIndex, m_colIndex, m_otherIndex. This is currently only used by the
+ // gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar operator()(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadCoeff(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(0, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ // Load the packet at the patchIndex location instead of the usual m_rowIndex,
+ // m_colIndex, m_otherIndex. This is currently only used by the gpu code.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index row, Index patchIndex) const {
+ Index planeIndex, rowIndex, colIndex, otherIndex;
+ computeBaseIndices(patchIndex, planeIndex, rowIndex, colIndex, otherIndex);
+ return loadPacket(row, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE const TensorEvaluator<ArgType, Device>& impl() const {
+ return m_impl;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const { return m_planeInputStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const { return m_rowStride; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const { return m_patch_rows; }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const { return m_patch_cols; }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ private:
+ friend class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>;
+
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index inputCol = colIndex + colOffset * m_in_col_strides;
+ const Index origInputCol =
+ (m_patch_col_inflate_strides == 1)
+ ? inputCol
+ : ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
+ const Index origInputRow =
+ (m_patch_row_inflate_strides == 1)
+ ? inputRow
+ : ((inputRow >= 0) ? (inputRow / m_fastInputRowStride) : 0);
+
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+ const Index inputPlane = planeIndex + planeOffset * m_in_plane_strides;
+ const Index origInputPlane =
+ (m_patch_plane_inflate_strides == 1)
+ ? inputPlane
+ : ((inputPlane >= 0) ? (inputPlane / m_fastInputPlaneStride) : 0);
+
+ if (origInputCol < 0 || origInputRow < 0 || origInputPlane < 0 ||
+ origInputCol >= m_inputCols || origInputRow >= m_inputRows ||
+ origInputPlane >= m_inputPlanes ||
+ (inputCol != origInputCol * m_patch_col_inflate_strides) ||
+ (inputRow != origInputRow * m_patch_row_inflate_strides) ||
+ (inputPlane != origInputPlane * m_patch_plane_inflate_strides)) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + origInputPlane * m_planeInputStride +
+ origInputRow * m_rowInputStride +
+ origInputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
+ EIGEN_DEVICE_FUNC
+ EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ eigen_assert(!nonStandardPatches());
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
+ inputRow >= m_inputRows || inputPlane < 0 ||
+ inputPlane >= m_inputPlanes) {
+ return Scalar(0);
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+
+ return m_impl.coeff(inputIndex);
+ }
+
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ if (nonStandardPatches()) {
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+ return loadPacketStandard(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketStandard(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+ eigen_assert(!nonStandardPatches());
+
+ if ((patchDepth() % packetSize) == 0) {
+ return loadPacketFast(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ } else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
+ const Index patchOffsets[2] = {
+ patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
+
+ const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
+ patchOffsets[1] / m_fastColStride};
+ eigen_assert(colOffsets[0] <= colOffsets[1]);
+
+ const Index inputCols[2] = {colIndex + colOffsets[0],
+ colIndex + colOffsets[1]};
+ if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputCols[0] == inputCols[1]) {
+ const Index rowOffsets[2] = {
+ (patchOffsets[0] - colOffsets[0] * m_colStride) / m_fastRowStride,
+ (patchOffsets[1] - colOffsets[1] * m_colStride) / m_fastRowStride};
+ eigen_assert(rowOffsets[0] <= rowOffsets[1]);
+ const Index inputRows[2] = {rowIndex + rowOffsets[0],
+ rowIndex + rowOffsets[1]};
+
+ if (inputRows[0] >= m_inputRows || inputRows[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputRows[0] == inputRows[1]) {
+ const Index planeOffsets[2] = {
+ patchOffsets[0] - colOffsets[0] * m_colStride -
+ rowOffsets[0] * m_rowStride,
+ patchOffsets[1] - colOffsets[1] * m_colStride -
+ rowOffsets[1] * m_rowStride};
+ eigen_assert(planeOffsets[0] <= planeOffsets[1]);
+ const Index inputPlanes[2] = {planeIndex + planeOffsets[0],
+ planeIndex + planeOffsets[1]};
+
+ if (inputPlanes[0] >= m_inputPlanes || inputPlanes[1] < 0) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ if (inputPlanes[0] >= 0 && inputPlanes[1] < m_inputPlanes) {
+ const Index depth = patchId - patchOffsets[0] * patchDepth();
+ const Index inputIndex =
+ depth + inputPlanes[0] * m_planeInputStride +
+ inputRows[0] * m_rowInputStride +
+ inputCols[0] * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+ }
+ }
+ }
+
+ return packetWithPossibleZero(patchId, planeIndex, rowIndex, colIndex,
+ otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index patchId, Index planeIndex,
+ Index rowIndex, Index colIndex,
+ Index otherIndex) const {
+ const Index packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_STATIC_ASSERT(packetSize > 1, YOU_MADE_A_PROGRAMMING_MISTAKE)
+ eigen_assert(patchId <
+ patchDepth() * patchPlanes() * patchRows() * patchCols());
+
+ eigen_assert(!nonStandardPatches());
+ eigen_assert((patchDepth() % packetSize) == 0);
+
+ // Find the offset of the element wrt the location of the first element.
+ const Index patchOffset = patchId / m_fastDimZero;
+ eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
+
+ const Index colOffset = patchOffset / m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_colStride) / m_fastRowStride;
+ const Index planeOffset =
+ patchOffset - colOffset * m_colStride - rowOffset * m_rowStride;
+
+ const Index inputCol = colIndex + colOffset;
+ const Index inputRow = rowIndex + rowOffset;
+ const Index inputPlane = planeIndex + planeOffset;
+
+ if (inputCol < 0 || inputRow < 0 || inputPlane < 0 ||
+ inputCol >= m_inputCols || inputRow >= m_inputRows ||
+ inputPlane >= m_inputPlanes) {
+ return internal::pset1<Packet>(Scalar(0));
+ }
+
+ const Index depth = patchId - patchOffset * patchDepth();
+ const Index inputIndex = depth + inputPlane * m_planeInputStride +
+ inputRow * m_rowInputStride +
+ inputCol * m_colInputStride + otherIndex;
+ return m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ packetWithPossibleZero(Index patchId, Index planeIndex, Index rowIndex,
+ Index colIndex, Index otherIndex) const {
+ const int packetSize = internal::unpacket_traits<Packet>::size;
+ EIGEN_ALIGN_MAX
+ typename internal::remove_const<Scalar>::type values[packetSize];
+ for (int i = 0; i < packetSize; ++i) {
+ values[i] =
+ loadCoeff(patchId + i, planeIndex, rowIndex, colIndex, otherIndex);
+ }
+ Packet rslt = internal::pload<Packet>(values);
+ return rslt;
+ }
+
+ // Precompute the indices (plane, row, col, other) of the first element of
+ // the given patch index, within the output tensor of the TensorVolumePatchOp.
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
+ Index patchIndex, Index& planeIndex, Index& rowIndex, Index& colIndex,
+ Index& otherIndex) const {
+ const size_t NumInputDims = array_size<
+ typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
+
+ // Check if patchIndex might contain batch and other dimensions.
+ otherIndex = (NumInputDims == 4) ? 0 : patchIndex / m_fastNumPatches;
+
+ // Compute index of the patch within the batch (and other dimensions).
+ const Index patch3DIndex = (NumInputDims == 4)
+ ? patchIndex
+ : (patchIndex - otherIndex * m_num_patches);
+
+ otherIndex *= m_patchInputStride;
+
+ colIndex = patch3DIndex / m_fastOutputPlanesRows;
+ rowIndex =
+ (patch3DIndex - colIndex * m_outputPlanesRows) / m_fastOutputPlanes;
+ planeIndex =
+ patch3DIndex - (colIndex * m_outputRows + rowIndex) * m_outputPlanes;
+
+ colIndex = colIndex * m_col_strides - m_colPaddingLeft;
+ rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
+ planeIndex = planeIndex * m_plane_strides - m_planePaddingTop;
+ }
+
+ Index m_patch_depth; // number of channels in the patch
+ Index m_patch_planes; // number of planes in the patch
+ Index m_patch_rows; // number of rows in the patch
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract
+
+ // Strides for navigating through the single patch.
+ Index m_patch_plane_stride;
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+
+ // Strides for the output tensor (depth is not the part of the stride).
+ Index m_rowStride;
+ Index m_colStride;
+ Index m_patchStride;
+ Index m_otherStride;
+
+ Index m_planeInputStride; // Plane stride in the input tensor
+ Index m_rowInputStride; // Row stride in the input tensor
+ Index m_colInputStride; // Col stride in the input tensor
+ Index m_patchInputStride; // Patch stride in the input tensor
+ Index m_otherInputStride;
+
+ Index m_inputDepth; // Depth of the input tensor
+ Index m_inputPlanes; // Number of planes in the input tensor
+ Index m_inputRows; // Number of rows in the input tensor
+ Index m_inputCols; // Number of cols in the input tensor
+
+ Index m_outputPlanes; // Number of output planes
+ Index m_outputRows; // Number of output rows
+ Index m_outputCols; // Number of output cols
+ Index m_outputPlanesRows; // Cached outputPlanes * outputRows.
+
+ Index m_plane_strides; // User specified plane stride
+ Index m_row_strides; // User specified row stride
+ Index m_col_strides; // User specified col stride
+
+ // User specified plane/row/col atrous convolution strides.
+ Index m_in_plane_strides;
+ Index m_in_row_strides;
+ Index m_in_col_strides;
+
+ // User specified plane/row/col inflation strides in the image patch.
+ Index m_patch_plane_inflate_strides;
+ Index m_patch_row_inflate_strides;
+ Index m_patch_col_inflate_strides;
+
+ Index m_planePaddingTop; // Plane padding
+ Index m_rowPaddingTop; // Row padding
+ Index m_colPaddingLeft; // Column padding
+
+ // Fast representation of various divisors.
+ internal::TensorIntDivisor<Index> m_fastNumPatches;
+
+ internal::TensorIntDivisor<Index> m_fastPatchPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
+ internal::TensorIntDivisor<Index> m_fastInputPlaneStride;
+ internal::TensorIntDivisor<Index> m_fastInputRowStride;
+ internal::TensorIntDivisor<Index> m_fastInputColStride;
+
+ internal::TensorIntDivisor<Index> m_fastRowStride;
+ internal::TensorIntDivisor<Index> m_fastColStride;
+
+ internal::TensorIntDivisor<Index> m_fastDimZero; // aka output depth
+ internal::TensorIntDivisor<Index> m_fastOutputPlanes;
+ internal::TensorIntDivisor<Index> m_fastOutputRows;
+ internal::TensorIntDivisor<Index> m_fastOutputCols;
+ internal::TensorIntDivisor<Index> m_fastOutputPlanesRows;
+
+ const TensorEvaluator<ArgType, Device> m_impl;
+};
+
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t, int Side,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment>
+class TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<NewDimension,
+ const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment> {
+ public:
+ typedef typename packet_traits<Scalar>::type Packet;
+ typedef typename packet_traits<Scalar>::half HalfPacket;
+
+ typedef TensorContractionInputMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ ParentMapper;
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Side,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ Self;
+ typedef Self LinearMapper;
+
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const ParentMapper& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper),
+ m_depth_offset(vert_offset),
+ m_col_offset(horiz_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE TensorContractionSubMapper(
+ const Self& base_mapper, Index vert_offset, Index horiz_offset)
+ : m_base_mapper(base_mapper.m_base_mapper),
+ m_depth_offset(vert_offset + base_mapper.m_depth_offset),
+ m_col_offset(horiz_offset + base_mapper.m_col_offset) {
+ m_base_mapper.computeBaseIndices(m_col_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i) const {
+ return m_base_mapper.loadCoeff(i + m_depth_offset, m_planeIndex, m_rowIndex,
+ m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar operator()(Index i,
+ Index j) const {
+ return m_base_mapper(i + m_depth_offset, j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i) const {
+ return m_base_mapper.loadPacket(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacket(Index i,
+ Index j) const {
+ return m_base_mapper.template loadPacket<Alignment>(i + m_depth_offset,
+ j + m_col_offset);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar
+ loadCoeffStandard(Index i) const {
+ return m_base_mapper.loadCoeffStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet loadPacketFast(Index i) const {
+ return m_base_mapper.loadPacketFast(i + m_depth_offset, m_planeIndex,
+ m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet
+ loadPacketStandard(Index i) const {
+ return m_base_mapper.loadPacketStandard(
+ i + m_depth_offset, m_planeIndex, m_rowIndex, m_colIndex, m_otherIndex);
+ }
+ template <typename Packet>
+ EIGEN_DEVICE_FUNC bool aligned(Index) const {
+ return false;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool nonStandardPatches() const {
+ return m_base_mapper.nonStandardPatches();
+ }
+
+ // Max(Col|Row|Plane|Depth): compute the upper limit for the column, row,
+ // plane and depth index respectively that fits into the peeled_k elements
+ // starting at m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxPlane(const Index peeled_k, const Index col,
+ const Index row) const {
+ const Index max_plane = fastPatchPlaneStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride() -
+ row * patchRowStride());
+ return std::min<Index>(1 + max_plane, patchPlanes());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input plane
+ // stride. Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchDepth() const {
+ eigen_assert(m_base_mapper.m_patch_depth ==
+ m_base_mapper.m_planeInputStride &&
+ "Patch depth must be equal to plane input stride.");
+ return m_base_mapper.m_planeInputStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlanes() const {
+ eigen_assert(m_base_mapper.m_patch_planes == m_base_mapper.m_rowStride &&
+ "Patch planes must be equal to row stride.");
+ return m_base_mapper.m_rowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRows() const {
+ return m_base_mapper.m_patch_rows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchCols() const {
+ return m_base_mapper.m_patch_cols;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ return m_base_mapper.m_patch_row_stride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchPlaneStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_plane_stride &&
+ "Patch depth must be equal to patch plane stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ return m_base_mapper.m_fastPatchRowStride;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
+ const Index baseIndex) const {
+ const Index inputIndex = depth + baseIndex;
+ return m_base_mapper.m_impl.template packet<Unaligned>(inputIndex);
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padPlane(const Index plane) const {
+ const Index p = m_planeIndex + plane;
+ return p < 0 || p >= m_base_mapper.m_inputPlanes;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padRow(const Index row) const {
+ const Index r = m_rowIndex + row;
+ return r < 0 || r >= m_base_mapper.m_inputRows;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE bool padCol(const Index col) const {
+ const Index c = m_colIndex + col;
+ return c < 0 || c >= m_base_mapper.m_inputCols;
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index baseIndex(const Index plane, const Index row,
+ const Index col) const {
+ const Index p = m_planeIndex + plane;
+ const Index r = m_rowIndex + row;
+ const Index c = m_colIndex + col;
+ return p * m_base_mapper.m_planeInputStride +
+ r * m_base_mapper.m_rowInputStride +
+ c * m_base_mapper.m_colInputStride + m_otherIndex;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index planeOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ const Index planeOffset = patchOffset -
+ colOffset * m_base_mapper.m_colStride -
+ rowOffset * m_base_mapper.m_rowStride;
+ return planeOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index rowOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ const Index rowOffset =
+ (patchOffset - colOffset * m_base_mapper.m_colStride) /
+ m_base_mapper.m_fastRowStride;
+ return rowOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index colOffset() const {
+ const Index patchOffset = m_depth_offset / m_base_mapper.m_fastDimZero;
+ const Index colOffset = patchOffset / m_base_mapper.m_fastColStride;
+ return colOffset;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index depthOffset() const {
+ return m_depth_offset % patchDepth();
+ }
+
+ EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
+ getLinearMapper(Index i, Index j) const {
+ return LinearMapper(m_base_mapper, i + m_depth_offset, j + m_col_offset);
+ }
+
+ private:
+ const ParentMapper& m_base_mapper;
+ Index m_depth_offset; // First row in the input matrix
+ Index m_col_offset; // First col in the input matrix
+
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_planeIndex;
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
+};
+
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted volume patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ int packet_size, bool inner_dim_contiguous, bool inner_dim_reordered,
+ int Alignment, int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, packet_size, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+
+ typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if ((packet_size % 4) == 0 && !non_standard_patches) {
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
+
+ const bool pad_row0 = pad_col0 || dm0.padRow(r);
+ const bool pad_row1 = pad_col1 || dm1.padRow(r);
+ const bool pad_row2 = pad_col2 || dm2.padRow(r);
+ const bool pad_row3 = pad_col3 || dm3.padRow(r);
+
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
+
+ const bool pad0 = pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel.packet[2] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel.packet[3] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
+ } else {
+ // Packet can span multiple planes, rows or columns, so we have to go
+ // though the slower "standard" path.
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 4> kernel;
+ kernel.packet[0] = dm0.loadPacketStandard(k);
+ kernel.packet[1] = dm1.loadPacketStandard(k);
+ kernel.packet[2] = dm2.loadPacketStandard(k);
+ kernel.packet[3] = dm3.loadPacketStandard(k);
+ ptranspose(kernel);
+ pstoreu(block + 0 * packet_size, kernel.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel.packet[1]);
+ pstoreu(block + 2 * packet_size, kernel.packet[2]);
+ pstoreu(block + 3 * packet_size, kernel.packet[3]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Template specialization for packet_size = 2. We must special-case packet
+// blocks with nr > packet_size, e.g. PacketBlock<Packet2d, 4>.
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 2, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const int packet_size = 2;
+
+ const Index packet_cols4 = (cols / 4) * 4;
+ const Index peeled_k = (depth / packet_size) * packet_size;
+ const bool non_standard_patches = rhs.nonStandardPatches();
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ Index k = 0;
+ if (!non_standard_patches) {
+ // FAST PATH:
+ // Iterate over patch columns, rows and planes if we know that a single
+ // packet do not span across multiple planes, rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
+
+ const bool pad_col0 = dm0.padCol(c);
+ const bool pad_col1 = dm1.padCol(c);
+ const bool pad_col2 = dm2.padCol(c);
+ const bool pad_col3 = dm3.padCol(c);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_plane = ((c == start_col) && (r == start_row))
+ ? rhs.planeOffset()
+ : 0;
+ const Index max_plane = rhs.maxPlane(peeled_k, c, r);
+
+ const bool pad_row0 = dm0.padRow(r);
+ const bool pad_row1 = dm1.padRow(r);
+ const bool pad_row2 = dm2.padRow(r);
+ const bool pad_row3 = dm3.padRow(r);
+
+ for (Index p = start_plane; p < max_plane; ++p) {
+ eigen_assert(k <= peeled_k);
+
+ const bool pad0 = pad_col0 || pad_row0 || dm0.padPlane(p);
+ const bool pad1 = pad_col1 || pad_row1 || dm1.padPlane(p);
+ const bool pad2 = pad_col2 || pad_row2 || dm2.padPlane(p);
+ const bool pad3 = pad_col3 || pad_row3 || dm3.padPlane(p);
+
+ const Index idx0 = dm0.baseIndex(p, r, c);
+ const Index idx1 = dm1.baseIndex(p, r, c);
+ const Index idx2 = dm2.baseIndex(p, r, c);
+ const Index idx3 = dm3.baseIndex(p, r, c);
+
+ const Index start_depth =
+ ((c == start_col) && (r == start_row) && (p == start_plane))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
+ eigen_assert(k < peeled_k);
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx0);
+ kernel0.packet[1] = pad1 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx1);
+ kernel1.packet[0] = pad2 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx2);
+ kernel1.packet[1] = pad3 ? pset1<Packet>(Scalar(0))
+ : rhs.packetNoPadding(d, idx3);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ k += packet_size;
+ }
+ }
+ }
+ }
+
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
+ } else {
+ for (; k < peeled_k; k += packet_size) {
+ PacketBlock<Packet, 2> kernel0;
+ PacketBlock<Packet, 2> kernel1;
+ kernel0.packet[0] = dm0.loadPacketStandard(k);
+ kernel0.packet[1] = dm1.loadPacketStandard(k);
+ kernel1.packet[0] = dm2.loadPacketStandard(k);
+ kernel1.packet[1] = dm3.loadPacketStandard(k);
+ ptranspose(kernel0);
+ ptranspose(kernel1);
+ pstoreu(block + 0 * packet_size, kernel0.packet[0]);
+ pstoreu(block + 1 * packet_size, kernel1.packet[0]);
+ pstoreu(block + 2 * packet_size, kernel0.packet[1]);
+ pstoreu(block + 3 * packet_size, kernel1.packet[1]);
+ block += 4 * packet_size;
+ }
+ }
+ }
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!rhs.nonStandardPatches()) {
+ for (; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+// Special case for non-vectorized types such as float16 (packet_size = 1).
+template <typename NewDimension, DenseIndex Planes, DenseIndex Rows,
+ DenseIndex Cols, typename ArgType, typename Device, typename Scalar,
+ typename Index, typename nocontract_t, typename contract_t,
+ bool inner_dim_contiguous, bool inner_dim_reordered, int Alignment,
+ int nr>
+struct gemm_pack_rhs<
+ Scalar, Index,
+ TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, /*packet_size*/ 1, inner_dim_contiguous,
+ inner_dim_reordered, Alignment>,
+ nr, ColMajor, false, false> {
+ typedef TensorContractionSubMapper<
+ Scalar, Index, Rhs,
+ TensorEvaluator<const TensorReshapingOp<
+ NewDimension, const TensorVolumePatchOp<
+ Planes, Rows, Cols, ArgType> >,
+ Device>,
+ nocontract_t, contract_t, 1, inner_dim_contiguous, inner_dim_reordered,
+ Alignment>
+ SubMapper;
+ typedef SubMapper DataMapper;
+
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
+ Index depth, Index cols, Index stride = 0,
+ Index offset = 0) const {
+ eigen_assert(stride == 0);
+ eigen_assert(offset == 0);
+
+ const Index packet_cols4 = (cols / 4) * 4;
+
+ for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2 + 0);
+ const SubMapper dm1 = rhs.getLinearMapper(0, j2 + 1);
+ const SubMapper dm2 = rhs.getLinearMapper(0, j2 + 2);
+ const SubMapper dm3 = rhs.getLinearMapper(0, j2 + 3);
+
+ if (!rhs.nonStandardPatches()) {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0.loadCoeffStandard(k);
+ block[1] = dm1.loadCoeffStandard(k);
+ block[2] = dm2.loadCoeffStandard(k);
+ block[3] = dm3.loadCoeffStandard(k);
+ block += 4;
+ }
+ } else {
+ for (Index k = 0; k < depth; k++) {
+ block[0] = dm0(k);
+ block[1] = dm1(k);
+ block[2] = dm2(k);
+ block[3] = dm3(k);
+ block += 4;
+ }
+ }
+ }
+
+ // Copy the remaining columns one at a time (nr==1).
+ for (Index j2 = packet_cols4; j2 < cols; ++j2) {
+ const SubMapper dm0 = rhs.getLinearMapper(0, j2);
+ for (Index k = 0; k < depth; k++) {
+ *block = dm0(k);
+ block += 1;
+ }
+ }
+ }
+};
+
+} // namespace internal
+
/** CuboidConvolution
* \ingroup CXX11_NeuralNetworks_Module
*
@@ -98,7 +1504,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
isColMajor ? kern.dimensions()[1] : kern.dimensions()[3];
// Spatial size of the kernel.
- const TensorIndex kernelDepth =
+ const TensorIndex kernelPlanes =
isColMajor ? kern.dimensions()[2] : kern.dimensions()[2];
const TensorIndex kernelRows =
isColMajor ? kern.dimensions()[3] : kern.dimensions()[1];
@@ -118,27 +1524,27 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
const TensorIndex inputCols =
isColMajor ? in.dimension(3) : in.dimension(NumDims - 4);
- TensorIndex out_depth;
+ TensorIndex out_planes;
TensorIndex out_height;
TensorIndex out_width;
switch (padding_type) {
case PADDING_VALID:
- out_depth = Eigen::divup(inputPlanes - kernelDepth + 1,
- static_cast<TensorIndex>(stridePlanes));
+ out_planes = Eigen::divup(inputPlanes - kernelPlanes + 1,
+ static_cast<TensorIndex>(stridePlanes));
out_height = Eigen::divup(inputRows - kernelRows + 1,
static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols - kernelCols + 1,
static_cast<TensorIndex>(strideCols));
break;
case PADDING_SAME:
- out_depth =
+ out_planes =
Eigen::divup(inputPlanes, static_cast<TensorIndex>(stridePlanes));
out_height =
Eigen::divup(inputRows, static_cast<TensorIndex>(strideRows));
out_width = Eigen::divup(inputCols, static_cast<TensorIndex>(strideCols));
break;
default:
- out_depth = 0;
+ out_planes = 0;
out_height = 0;
out_width = 0;
eigen_assert(false && "unexpected padding");
@@ -147,9 +1553,9 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> kernel_dims;
if (isColMajor) {
kernel_dims[0] = kernelFilters;
- kernel_dims[1] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[1] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
} else {
- kernel_dims[0] = kernelChannels * kernelDepth * kernelRows * kernelCols;
+ kernel_dims[0] = kernelChannels * kernelPlanes * kernelRows * kernelCols;
kernel_dims[1] = kernelFilters;
}
@@ -160,15 +1566,15 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, 2> pre_contract_dims;
if (isColMajor) {
pre_contract_dims[0] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[1] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[1] = out_planes * out_height * out_width;
for (int i = 4; i < NumDims; ++i) {
pre_contract_dims[1] *= in.dimension(i);
}
} else {
pre_contract_dims[1] =
- kernelChannels * kernelDepth * kernelRows * kernelCols;
- pre_contract_dims[0] = out_depth * out_height * out_width;
+ kernelChannels * kernelPlanes * kernelRows * kernelCols;
+ pre_contract_dims[0] = out_planes * out_height * out_width;
for (int i = 0; i < NumDims - 4; ++i) {
pre_contract_dims[0] *= in.dimension(i);
}
@@ -187,7 +1593,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
DSizes<TensorIndex, NumDims> post_contract_dims;
if (isColMajor) {
post_contract_dims[0] = kernelFilters;
- post_contract_dims[1] = out_depth;
+ post_contract_dims[1] = out_planes;
post_contract_dims[2] = out_height;
post_contract_dims[3] = out_width;
for (int i = 4; i < NumDims; ++i) {
@@ -195,7 +1601,7 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
}
} else {
post_contract_dims[NumDims - 1] = kernelFilters;
- post_contract_dims[NumDims - 2] = out_depth;
+ post_contract_dims[NumDims - 2] = out_planes;
post_contract_dims[NumDims - 3] = out_height;
post_contract_dims[NumDims - 4] = out_width;
for (int i = 0; i < NumDims - 4; ++i) {
@@ -208,13 +1614,13 @@ CuboidConvolution(const Input& input, const Kernel& kernel,
kernel.reshape(kernel_dims)
.contract(input
.extract_volume_patches(
- kernelDepth, kernelRows, kernelCols, stridePlanes,
+ kernelPlanes, kernelRows, kernelCols, stridePlanes,
strideRows, strideCols, padding_type)
.reshape(pre_contract_dims),
contract_dims)
.reshape(post_contract_dims),
input
- .extract_volume_patches(kernelDepth, kernelRows, kernelCols,
+ .extract_volume_patches(kernelPlanes, kernelRows, kernelCols,
stridePlanes, strideRows, strideCols,
padding_type)
.reshape(pre_contract_dims)
diff --git a/tensorflow/core/kernels/eigen_spatial_convolutions.h b/tensorflow/core/kernels/eigen_spatial_convolutions.h
index a4dff4b91c..e926d73f87 100644
--- a/tensorflow/core/kernels/eigen_spatial_convolutions.h
+++ b/tensorflow/core/kernels/eigen_spatial_convolutions.h
@@ -22,8 +22,36 @@ namespace Eigen {
namespace internal {
-// TODO: Consolidate this part of the code with the image patch extraction code
-// since they are both very similar.
+// WARNING: Most of the code here implicitly assumes that the matrix is in
+// ColMajor layout. This is guaranteed by the tensor contraction (see
+// TensorContraction.h).
+//
+// Inside Eigen a tensor contraction is represented by a matrix multiplication.
+// We don't want to actually extract image patches and reshape the result into
+// a matrix (this involves allocating huge extra memory), so the patch
+// extraction and reshape operations are implicit.
+//
+// TensorContractionInputMapper takes a matrix index and returns the coefficient
+// (or the packet) of the "virtual tensor", that would be at that index if we
+// were to actually reshape the result of patch extraction.
+//
+// TensorContractionSubMapper provides a similar view into the "virtual matrix"
+// at the given vertical and horizontal offsets.
+//
+// "Virtual matrix" dimensions:
+// *0: kernelChannels * kernelRows * kernelCols;
+// 1: out_height * out_width; * OTHERS (e.g batches, etc...)
+//
+// *) extracted patches are continuous in memory (innermost dimension assuming
+// col major layout)
+//
+// With this dimensions:
+// row - offset within a single patch (in code: patchId)
+// col - index of the extracted patch (in code: patchIndex)
+// patchIndex ∈ [0..num_patches * OTHERS] (batch and other dimensions)
+//
+// TODO(ezhulenev): Consolidate this part of the code with the image patch
+// extraction code since they are both very similar.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar_, typename Index,
typename nocontract_t, typename contract_t, int Side, int packet_size,
@@ -77,12 +105,17 @@ class TensorContractionInputMapper<
m_patch_cols = tensor.impl().dimensions()[2];
m_num_patches = tensor.impl().dimensions()[3];
} else {
- const int NumDims = tensor.impl().dimensions().size();
+ const size_t NumDims = tensor.impl().dimensions().size();
patch_depth = tensor.impl().dimensions()[NumDims - 1];
patch_rows = tensor.impl().dimensions()[NumDims - 2];
m_patch_cols = tensor.impl().dimensions()[NumDims - 3];
m_num_patches = tensor.impl().dimensions()[NumDims - 4];
}
+
+ // Strides for navigating through the single patch.
+ m_patch_row_stride = patch_depth;
+ m_patch_col_stride = patch_rows * m_patch_row_stride;
+
m_patch_row_inflate_strides = tensor.impl().rowInflateStride();
m_patch_col_inflate_strides = tensor.impl().colInflateStride();
@@ -111,6 +144,10 @@ class TensorContractionInputMapper<
m_rowPaddingTop = tensor.impl().rowPaddingTop();
m_colPaddingLeft = tensor.impl().colPaddingLeft();
+ m_fastPatchRowStride =
+ internal::TensorIntDivisor<Index>(m_patch_row_stride);
+ m_fastPatchColStride =
+ internal::TensorIntDivisor<Index>(m_patch_col_stride);
m_fastInputRowStride =
internal::TensorIntDivisor<Index>(m_patch_row_inflate_strides);
m_fastInputColStride =
@@ -126,6 +163,10 @@ class TensorContractionInputMapper<
: m_impl(base_mapper.m_impl) {
m_patch_cols = base_mapper.m_patch_cols;
m_num_patches = base_mapper.m_num_patches;
+
+ m_patch_row_stride = base_mapper.m_patch_row_stride;
+ m_patch_col_stride = base_mapper.m_patch_col_stride;
+
m_patch_row_inflate_strides = base_mapper.m_patch_row_inflate_strides;
m_patch_col_inflate_strides = base_mapper.m_patch_col_inflate_strides;
@@ -148,6 +189,8 @@ class TensorContractionInputMapper<
m_rowPaddingTop = base_mapper.m_rowPaddingTop;
m_colPaddingLeft = base_mapper.m_colPaddingLeft;
+ m_fastPatchRowStride = base_mapper.m_fastPatchRowStride;
+ m_fastPatchColStride = base_mapper.m_fastPatchColStride;
m_fastInputRowStride = base_mapper.m_fastInputRowStride;
m_fastInputColStride = base_mapper.m_fastInputColStride;
m_fastNumPatches = base_mapper.m_fastNumPatches;
@@ -238,6 +281,8 @@ class TensorContractionInputMapper<
nocontract_t, contract_t, packet_size, inner_dim_contiguous,
inner_dim_reordered, Alignment>;
+ // Load coefficient from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeff(Index patchId, Index rowIndex,
Index colIndex, Index otherIndex) const {
@@ -250,6 +295,7 @@ class TensorContractionInputMapper<
(m_patch_col_inflate_strides == 1)
? inputCol
: ((inputCol >= 0) ? (inputCol / m_fastInputColStride) : 0);
+
const Index rowOffset = patchOffset - colOffset * m_colStride;
const Index inputRow = rowIndex + rowOffset * m_in_row_strides;
const Index origInputRow =
@@ -268,6 +314,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // This is the same as loadCoeff(...), but optimized for all `inflate_strides`
+ // and `in_strides` equal to 1 (template specialization without templates).
EIGEN_DEVICE_FUNC
EIGEN_STRONG_INLINE Scalar loadCoeffStandard(Index patchId, Index rowIndex,
Index colIndex,
@@ -276,10 +324,9 @@ class TensorContractionInputMapper<
// Find the offset of the element wrt the location of the first element.
const Index patchOffset = patchId / m_fastDimZero;
-
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputCol >= m_inputCols || inputRow < 0 ||
inputRow >= m_inputRows) {
@@ -291,6 +338,8 @@ class TensorContractionInputMapper<
return m_impl.coeff(inputIndex);
}
+ // Load packet from a patch specified by the "within patch offset"
+ // (patchId) and the precomputed indices of the first element of the patch.
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet loadPacket(Index patchId, Index rowIndex,
Index colIndex,
@@ -318,12 +367,14 @@ class TensorContractionInputMapper<
if ((patchDepth() % packetSize) == 0) {
return loadPacketFast(patchId, rowIndex, colIndex, otherIndex);
} else {
+ // Offsets and input calculation here are identical to
+ // loadCoeffStandard(...), but repeated twice.
+
const Index patchOffsets[2] = {
patchId / m_fastDimZero, (patchId + packetSize - 1) / m_fastDimZero};
const Index colOffsets[2] = {patchOffsets[0] / m_fastColStride,
patchOffsets[1] / m_fastColStride};
-
const Index inputCols[2] = {colIndex + colOffsets[0],
colIndex + colOffsets[1]};
if (inputCols[0] >= m_inputCols || inputCols[1] < 0) {
@@ -371,8 +422,8 @@ class TensorContractionInputMapper<
eigen_assert((patchId + packetSize - 1) / m_fastDimZero == patchOffset);
const Index colOffset = patchOffset / m_fastColStride;
- const Index inputCol = colIndex + colOffset;
const Index rowOffset = patchOffset - colOffset * m_colStride;
+ const Index inputCol = colIndex + colOffset;
const Index inputRow = rowIndex + rowOffset;
if (inputCol < 0 || inputRow < 0 || inputCol >= m_inputCols ||
inputRow >= m_inputRows) {
@@ -401,7 +452,7 @@ class TensorContractionInputMapper<
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void computeBaseIndices(
Index patchIndex, Index& rowIndex, Index& colIndex,
Index& otherIndex) const {
- const int NumInputDims = array_size<
+ const size_t NumInputDims = array_size<
typename TensorEvaluator<ArgType, Device>::Dimensions>::value;
otherIndex = (NumInputDims == 3) ? 0 : patchIndex / m_fastNumPatches;
const Index patch2DIndex = (NumInputDims == 3)
@@ -414,8 +465,15 @@ class TensorContractionInputMapper<
rowIndex = rowIndex * m_row_strides - m_rowPaddingTop;
}
- Index m_patch_cols; // number of colums in the patch
- Index m_num_patches; // number of patches to extract.
+ Index m_patch_cols; // number of columns in the patch
+ Index m_num_patches; // number of patches to extract.
+
+ // Strides for navigating through the single patch.
+ Index m_patch_row_stride;
+ Index m_patch_col_stride;
+ internal::TensorIntDivisor<Index> m_fastPatchRowStride;
+ internal::TensorIntDivisor<Index> m_fastPatchColStride;
+
Index m_patch_row_inflate_strides; // the strides for row inflation in the
// image patch
Index m_patch_col_inflate_strides; // the strides for col inflation in the
@@ -549,6 +607,40 @@ class TensorContractionSubMapper<
return m_base_mapper.nonStandardPatches();
}
+ // Max(Col|Row|Depth): compute the upper limit for the column, row and depth
+ // index respectively that fits into the peeled_k elements starting at
+ // m_depth_offset.
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxCol(const Index peeled_k) const {
+ const Index max_col =
+ fastPatchColStride().divide(m_depth_offset + peeled_k);
+ return std::min<Index>(1 + max_col, patchCols());
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxRow(const Index peeled_k,
+ const Index col) const {
+ const Index max_row = fastPatchRowStride().divide(
+ m_depth_offset + peeled_k - col * patchColStride());
+ return std::min<Index>(1 + max_row, patchRows());
+ }
+
+ // MaxDepth uses only the remaining number of elements in the peeled_k.
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index maxDepth(const Index num_elements,
+ const Index start_depth) const {
+ return std::min<Index>(start_depth + num_elements, patchDepth());
+ }
+
+ // Every register matters in this code, so sometimes to prevent register
+ // spilling, instead of the variable that you would expect to see, we use
+ // another one, that is guaranteed to have the same value. E.g. patch depth is
+ // always the same as input depth, and it's also the same as input row stride.
+ // Bunch of other parameters have similar relations.
+
+ typedef internal::TensorIntDivisor<Index> IndexDivisor;
+
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index patchDepth() const {
return m_base_mapper.m_rowInputStride;
@@ -563,6 +655,28 @@ class TensorContractionSubMapper<
}
EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return patchDepth();
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE Index patchColStride() const {
+ return m_base_mapper.m_patch_col_stride;
+ }
+
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchRowStride() const {
+ eigen_assert(patchDepth() == m_base_mapper.m_patch_row_stride &&
+ "Patch depth must be equal to patch row stride.");
+ return m_base_mapper.m_fastDimZero; // patch_depth
+ }
+ EIGEN_DEVICE_FUNC
+ EIGEN_ALWAYS_INLINE IndexDivisor fastPatchColStride() const {
+ return m_base_mapper.m_fastPatchColStride;
+ }
+
+ EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Packet packetNoPadding(const Index depth,
const Index baseIndex) const {
const Index inputIndex = depth + baseIndex;
@@ -603,8 +717,7 @@ class TensorContractionSubMapper<
EIGEN_DEVICE_FUNC
EIGEN_ALWAYS_INLINE Index depthOffset() const {
- const Index patchOffset = m_depth_offset % m_base_mapper.patchDepth();
- return patchOffset;
+ return m_depth_offset % patchDepth();
}
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE LinearMapper
@@ -617,12 +730,44 @@ class TensorContractionSubMapper<
Index m_depth_offset; // First row in the input matrix
Index m_col_offset; // First col in the input matrix
- Index m_rowIndex; // precomputed row index corresponding to the col offset
- Index m_colIndex; // precomputed col index corresponding to the col offset
- Index
- m_otherIndex; // precomputed other index corresponding to the col offset
+ // Knowing that: col_offset == patchIndex * OTHERS, we keep precomputed base
+ // indices for the first element in a patch specified by col_offset
+ // (see computeBaseIndices(...) for details).
+ Index m_rowIndex;
+ Index m_colIndex;
+ Index m_otherIndex;
};
+// Arrange a block of the right input matrix (in our case it's always a "virtual
+// matrix" constructed from extracted image patches) in contiguous memory.
+//
+// Given column major input (A0 beside A1 in memory):
+// A0 B0 C0 D0 E0 F0 G0 H0 ... Z0
+// A1 B1 C1 D1 E1 F1 G1 H1 ... Z1
+// A2 B2 C2 D2 E2 F2 G2 H2 ... Z2
+// A3 B3 C3 D3 E3 F3 G3 H3 ... Z3
+// A4 B4 C4 D4 E4 F4 G4 H4 ... Z4
+// A5 B5 C5 D5 E5 F5 G5 H5 ... Z5
+// A6 B6 C6 D6 E6 F6 G6 H6 ... Z6
+// A7 B7 C7 D7 E7 F7 G7 H7 ... Z7
+// A8 ...
+// ...
+//
+// *) A, B, C, ... - patches extracted from the original input.
+// *) A0, A1, A2 ... - values from the same patch at different offsets.
+//
+// The traversal (packed rhs memory) order (B0 besides A0 in memory):
+// A0 B0 C0 D0 A1 B1 C1 D1 ...
+// E0 F0 G0 H0 E1 F1 G1 H1 ...
+// ...
+// Z0 Z1 Z2 Z3 Z4 Z5 Z6 Z7 ... <- doesn't belong to any block (nr = 4)
+//
+// This traversal order must be the same as in default gemm_pack_rhs defined in
+// GeneralBlockPanelKernel.h.
+//
+// *) nr - number of registers along the 'n' dimension.
+// See GeneralBlockPanelKernel.h and "Anatomy of High-Performance Matrix
+// Multiplication" paper.
template <typename NewDimension, DenseIndex Rows, DenseIndex Cols,
typename ArgType, typename Device, typename Scalar, typename Index,
typename nocontract_t, typename contract_t, int packet_size,
@@ -649,9 +794,9 @@ struct gemm_pack_rhs<
inner_dim_reordered, Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -660,9 +805,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
const bool non_standard_patches = rhs.nonStandardPatches();
@@ -675,30 +817,27 @@ struct gemm_pack_rhs<
Index k = 0;
if ((packet_size % 4) == 0 && !non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows, if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -709,14 +848,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 4> kernel;
kernel.packet[0] = pad0 ? pset1<Packet>(Scalar(0))
@@ -738,19 +876,9 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 4> kernel;
- kernel.packet[0] = dm0.loadPacketFast(k);
- kernel.packet[1] = dm1.loadPacketFast(k);
- kernel.packet[2] = dm2.loadPacketFast(k);
- kernel.packet[3] = dm3.loadPacketFast(k);
- ptranspose(kernel);
- pstoreu(block + 0 * packet_size, kernel.packet[0]);
- pstoreu(block + 1 * packet_size, kernel.packet[1]);
- pstoreu(block + 2 * packet_size, kernel.packet[2]);
- pstoreu(block + 3 * packet_size, kernel.packet[3]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 4> kernel;
@@ -767,6 +895,8 @@ struct gemm_pack_rhs<
}
}
}
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
if (!rhs.nonStandardPatches()) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
@@ -824,9 +954,9 @@ struct gemm_pack_rhs<
Alignment>
SubMapper;
typedef SubMapper DataMapper;
+ typedef typename packet_traits<Scalar>::type Packet;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -835,9 +965,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
- typedef typename packet_traits<Scalar>::type Packet;
-
const int packet_size = 2;
const Index packet_cols4 = (cols / 4) * 4;
const Index peeled_k = (depth / packet_size) * packet_size;
@@ -851,30 +978,27 @@ struct gemm_pack_rhs<
Index k = 0;
if (!non_standard_patches) {
- const Index patch_depth = rhs.patchDepth();
- if ((patch_depth % packet_size) == 0) {
- const Index patch_cols = rhs.patchCols();
- const Index patch_rows = rhs.patchRows();
-
- const Index startCol = rhs.colOffset();
- const Index max_cols = std::min<Index>(
- ceil_div(peeled_k, patch_rows * patch_depth) + startCol,
- patch_cols);
-
- for (Index c = startCol; c < max_cols; ++c) {
- eigen_assert(k < peeled_k);
- const Index startRow = (c == startCol) ? rhs.rowOffset() : 0;
- const Index max_rows = std::min<Index>(
- ceil_div(peeled_k - c * patch_rows * patch_depth, patch_depth) +
- startRow,
- patch_rows);
+ // FAST PATH:
+ // Iterate over patch columns and rows if we know that a single
+ // packet do not span across multiple rows or columns.
+ if ((rhs.patchDepth() % packet_size) == 0) {
+ const Index start_col = rhs.colOffset();
+ const Index max_col = rhs.maxCol(peeled_k);
+
+ for (Index c = start_col; c < max_col; ++c) {
+ eigen_assert(k <= peeled_k);
+
+ const Index start_row = (c == start_col) ? rhs.rowOffset() : 0;
+ const Index max_row = rhs.maxRow(peeled_k, c);
const bool pad_col0 = dm0.padCol(c);
const bool pad_col1 = dm1.padCol(c);
const bool pad_col2 = dm2.padCol(c);
const bool pad_col3 = dm3.padCol(c);
- for (Index r = startRow; r < max_rows; ++r) {
- eigen_assert(k < peeled_k);
+
+ for (Index r = start_row; r < max_row; ++r) {
+ eigen_assert(k <= peeled_k);
+
const bool pad0 = pad_col0 || dm0.padRow(r);
const bool pad1 = pad_col1 || dm1.padRow(r);
const bool pad2 = pad_col2 || dm2.padRow(r);
@@ -885,14 +1009,13 @@ struct gemm_pack_rhs<
const Index idx2 = dm2.baseIndex(r, c);
const Index idx3 = dm3.baseIndex(r, c);
- const Index startDepth =
- ((c == startCol) && (r == startRow)) ? rhs.depthOffset() : 0;
- const Index max_depth =
- std::min<Index>(peeled_k - c * patch_rows * patch_depth -
- r * patch_depth + startDepth,
- patch_depth);
- eigen_assert((max_depth - startDepth) % packet_size == 0);
- for (Index d = startDepth; d < max_depth; d += packet_size) {
+ const Index start_depth = ((c == start_col) && (r == start_row))
+ ? rhs.depthOffset()
+ : 0;
+ const Index max_depth = rhs.maxDepth(peeled_k - k, start_depth);
+ eigen_assert((max_depth - start_depth) % packet_size == 0);
+
+ for (Index d = start_depth; d < max_depth; d += packet_size) {
eigen_assert(k < peeled_k);
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -916,22 +1039,12 @@ struct gemm_pack_rhs<
}
}
- for (; k < peeled_k; k += packet_size) {
- PacketBlock<Packet, 2> kernel0;
- PacketBlock<Packet, 2> kernel1;
- kernel0.packet[0] = dm0.loadPacketFast(k);
- kernel0.packet[1] = dm1.loadPacketFast(k);
- kernel1.packet[0] = dm2.loadPacketFast(k);
- kernel1.packet[1] = dm3.loadPacketFast(k);
- ptranspose(kernel0);
- ptranspose(kernel1);
- pstoreu(block + 0 * packet_size, kernel0.packet[0]);
- pstoreu(block + 1 * packet_size, kernel1.packet[0]);
- pstoreu(block + 2 * packet_size, kernel0.packet[1]);
- pstoreu(block + 3 * packet_size, kernel1.packet[1]);
- block += 4 * packet_size;
- }
+ // The loop above should fill peeled_k elements.
+ eigen_assert(peeled_k == k);
+
} else {
+ // Packet can span multiple rows or columns, so we have to go
+ // though the slower "standard" path.
for (; k < peeled_k; k += packet_size) {
PacketBlock<Packet, 2> kernel0;
PacketBlock<Packet, 2> kernel1;
@@ -949,7 +1062,9 @@ struct gemm_pack_rhs<
}
}
}
- if (!rhs.nonStandardPatches()) {
+
+ // Copy the remaining coefficients of the column block after the peeled_k.
+ if (!non_standard_patches) {
for (; k < depth; k++) {
block[0] = dm0.loadCoeffStandard(k);
block[1] = dm1.loadCoeffStandard(k);
@@ -968,7 +1083,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
@@ -1006,8 +1121,7 @@ struct gemm_pack_rhs<
SubMapper;
typedef SubMapper DataMapper;
- EIGEN_DEVICE_FUNC
- static inline Index ceil_div(Index a, Index b) { return (a + b - 1) / b; }
+ EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
EIGEN_DEVICE_FUNC
EIGEN_DONT_INLINE void operator()(Scalar* block, const DataMapper& rhs,
@@ -1016,8 +1130,6 @@ struct gemm_pack_rhs<
eigen_assert(stride == 0);
eigen_assert(offset == 0);
- EIGEN_STATIC_ASSERT((nr == 4), YOU_MADE_A_PROGRAMMING_MISTAKE);
-
const Index packet_cols4 = (cols / 4) * 4;
for (Index j2 = 0; j2 < packet_cols4; j2 += 4) {
@@ -1045,7 +1157,7 @@ struct gemm_pack_rhs<
}
}
- // copy the remaining columns one at a time (nr==1)
+ // Copy the remaining columns one at a time (nr==1).
for (Index j2 = packet_cols4; j2 < cols; ++j2) {
const SubMapper dm0 = rhs.getLinearMapper(0, j2);
for (Index k = 0; k < depth; k++) {
diff --git a/tensorflow/core/kernels/eigen_volume_patch.h b/tensorflow/core/kernels/eigen_volume_patch.h
index a3d795813d..80ab745bfe 100644
--- a/tensorflow/core/kernels/eigen_volume_patch.h
+++ b/tensorflow/core/kernels/eigen_volume_patch.h
@@ -43,6 +43,7 @@ struct CustomTensorEvaluator {
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = NumDims == 6,
RawAccess = false
diff --git a/tensorflow/core/kernels/example_parsing_ops.cc b/tensorflow/core/kernels/example_parsing_ops.cc
index 83cd0e9b47..528b3c6bf0 100644
--- a/tensorflow/core/kernels/example_parsing_ops.cc
+++ b/tensorflow/core/kernels/example_parsing_ops.cc
@@ -264,9 +264,168 @@ class ParseSingleExampleOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
ParseSingleExampleOp);
-class SingleSequenceExampleParserOp : public OpKernel {
+class ParseSequenceExampleOp : public OpKernel {
public:
- explicit SingleSequenceExampleParserOp(OpKernelConstruction* ctx)
+ explicit ParseSequenceExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* debug_name;
+ const Tensor* serialized;
+ OpInputList context_dense_defaults;
+
+ OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
+ OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
+ OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
+ &context_dense_defaults));
+
+ bool has_debug_name = (debug_name->NumElements() > 0);
+ if (has_debug_name) {
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(debug_name->shape()),
+ errors::InvalidArgument(
+ "Expected debug_name to be a vector, got shape: ",
+ debug_name->shape().DebugString()));
+ }
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsVector(serialized->shape()),
+ errors::InvalidArgument(
+ "Expected serialized to be a vector, got shape: ",
+ serialized->shape().DebugString()));
+
+ OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
+ errors::InvalidArgument("Expected len(context_dense_defaults) "
+ "== len(context_dense_keys) but got: ",
+ context_dense_defaults.size(), " vs. ",
+ attrs_.num_context_dense));
+
+ std::vector<bool> required(attrs_.num_context_dense);
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ const Tensor& def_value = context_dense_defaults[d];
+ required[d] = (def_value.NumElements() == 0); // No default provided.
+
+ if (def_value.NumElements() > 0) {
+ OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
+ errors::InvalidArgument(
+ "default_value[", d,
+ "].shape() == ", def_value.shape().DebugString(),
+ " != context_dense_shapes[", d,
+ "] == ", attrs_.context_dense_shapes[d].DebugString()));
+ OP_REQUIRES(
+ ctx, def_value.dtype() == attrs_.context_dense_types[d],
+ errors::InvalidArgument(
+ "context_dense_defaults[", d, "].dtype() == ",
+ DataTypeString(def_value.dtype()), " != context_dense_types[",
+ d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
+ }
+ }
+
+ example::Result context_result, feature_list_result;
+ std::vector<Tensor> dense_feature_lengths;
+
+ example::FastParseExampleConfig context_config;
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ context_config.dense.push_back(
+ {attrs_.context_dense_keys[d], attrs_.context_dense_types[d],
+ attrs_.context_dense_shapes[d], context_dense_defaults[d],
+ false /* attrs_.context_variable_length[d] */,
+ 0 /*attrs_.context_elements_per_stride[d] */});
+ }
+ for (int d = 0; d < attrs_.num_context_sparse; ++d) {
+ context_config.sparse.push_back(
+ {attrs_.context_sparse_keys[d], attrs_.context_sparse_types[d]});
+ }
+ example::FastParseExampleConfig feature_list_config;
+ for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
+ DataType dtype = attrs_.feature_list_dense_types[d];
+ Tensor default_value = Tensor(dtype, TensorShape({}));
+ feature_list_config.dense.push_back(
+ {attrs_.feature_list_dense_keys[d], dtype,
+ attrs_.feature_list_dense_shapes[d], default_value,
+ (attrs_.feature_list_dense_missing_assumed_empty.count(
+ attrs_.feature_list_dense_keys[d]) > 0),
+ 0 /*attrs_.context_elements_per_stride[d] */});
+ }
+ for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
+ feature_list_config.sparse.push_back(
+ {attrs_.feature_list_sparse_keys[d],
+ attrs_.feature_list_sparse_types[d]});
+ }
+
+ auto serialized_t = serialized->flat<string>();
+ auto debug_name_t = debug_name->flat<string>();
+ gtl::ArraySlice<string> slice(serialized_t.data(), serialized_t.size());
+ gtl::ArraySlice<string> names_slice(debug_name_t.data(),
+ debug_name_t.size());
+
+ OP_REQUIRES_OK(
+ ctx,
+ FastParseSequenceExample(
+ context_config, feature_list_config, slice, names_slice,
+ ctx->device()->tensorflow_cpu_worker_threads()->workers,
+ &context_result, &feature_list_result, &dense_feature_lengths));
+
+ OpOutputList context_sparse_indices;
+ OpOutputList context_sparse_values;
+ OpOutputList context_sparse_shapes;
+ OpOutputList context_dense_values;
+ OpOutputList feature_list_sparse_indices;
+ OpOutputList feature_list_sparse_values;
+ OpOutputList feature_list_sparse_shapes;
+ OpOutputList feature_list_dense_values;
+ OpOutputList feature_list_dense_lengths;
+
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
+ OP_REQUIRES_OK(
+ ctx, ctx->output_list("context_dense_values", &context_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
+ &context_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
+ &feature_list_sparse_indices));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
+ &feature_list_sparse_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
+ &feature_list_sparse_shapes));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
+ &feature_list_dense_values));
+ OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_lengths",
+ &feature_list_dense_lengths));
+ for (int d = 0; d < attrs_.num_context_dense; ++d) {
+ context_dense_values.set(d, context_result.dense_values[d]);
+ }
+ TensorShape lengths_shape;
+ lengths_shape.AddDim(serialized_t.size());
+ for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
+ feature_list_dense_values.set(d, feature_list_result.dense_values[d]);
+ feature_list_dense_lengths.set(d, dense_feature_lengths[d]);
+ }
+ for (int d = 0; d < attrs_.num_context_sparse; ++d) {
+ context_sparse_indices.set(d, context_result.sparse_indices[d]);
+ context_sparse_values.set(d, context_result.sparse_values[d]);
+ context_sparse_shapes.set(d, context_result.sparse_shapes[d]);
+ }
+ for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
+ feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]);
+ feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]);
+ feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]);
+ }
+ }
+
+ protected:
+ ParseSequenceExampleAttrs attrs_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU),
+ ParseSequenceExampleOp);
+
+class ParseSingleSequenceExampleOp : public OpKernel {
+ public:
+ explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx)
: OpKernel(ctx) {
OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
}
@@ -658,7 +817,7 @@ class SingleSequenceExampleParserOp : public OpKernel {
};
REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
- SingleSequenceExampleParserOp);
+ ParseSingleSequenceExampleOp);
#ifndef IS_MOBILE_PLATFORM
// when using lite protos on mobile, decoding JSON is not available.
diff --git a/tensorflow/core/kernels/extract_image_patches_op.h b/tensorflow/core/kernels/extract_image_patches_op.h
index e430a23d20..64b8c0338b 100644
--- a/tensorflow/core/kernels/extract_image_patches_op.h
+++ b/tensorflow/core/kernels/extract_image_patches_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
-#define TENSORFLOW_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
+#define TENSORFLOW_CORE_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -53,4 +53,4 @@ struct ExtractImagePatchesForward {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_EXTRACT_IMAGE_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.cc b/tensorflow/core/kernels/extract_volume_patches_op.cc
new file mode 100644
index 0000000000..52cd078a35
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.cc
@@ -0,0 +1,197 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+/*
+See extract_image_patches_op* files and docs for extract_image_patches in
+../ops/image_ops.cc.
+
+Rates are not supported as of now, but the comments hint how to edit the code
+when rates are to be added.
+*/
+
+#define USE_EIGEN_TENSOR
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include <vector>
+#include "tensorflow/core/framework/numeric_op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/util/tensor_format.h"
+
+namespace tensorflow {
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+static inline void ParseAttributeVec5(OpKernelConstruction* context,
+ const string& attr_name,
+ std::vector<int32>* attr) {
+ OP_REQUIRES_OK(context, context->GetAttr(attr_name, attr));
+ OP_REQUIRES(
+ context, (*attr)[0] == 1 && (*attr)[4] == 1,
+ errors::Unimplemented("Only support ", attr_name, " across space."));
+ OP_REQUIRES(context, (*attr)[1] >= 1 && (*attr)[2] >= 1 && (*attr)[3] >= 1,
+ errors::OutOfRange(attr_name, " is out of range."));
+}
+
+template <typename Device, typename T>
+class ExtractVolumePatchesOp : public UnaryOp<T> {
+ public:
+ explicit ExtractVolumePatchesOp(OpKernelConstruction* context)
+ : UnaryOp<T>(context) {
+ ParseAttributeVec5(context, "ksizes", &ksizes_);
+ ParseAttributeVec5(context, "strides", &strides_);
+ // ParseAttributeVec5(context, "rates", &rates_);
+ OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ // Input tensor is of the following dimensions:
+ // [ batch, in_planes, in_rows, in_cols, channels ]
+ const Tensor& input = context->input(0);
+ OP_REQUIRES(context, input.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input.shape().DebugString()));
+
+ const int batch = input.dim_size(0);
+ const int in_planes = input.dim_size(1);
+ const int in_rows = input.dim_size(2);
+ const int in_cols = input.dim_size(3);
+ const int depth = input.dim_size(4);
+
+ const int ksize_planes = ksizes_[1];
+ const int ksize_rows = ksizes_[2];
+ const int ksize_cols = ksizes_[3];
+
+ const int stride_planes = strides_[1];
+ const int stride_rows = strides_[2];
+ const int stride_cols = strides_[3];
+
+ /*
+ // TODO(hsgkim): enable rates
+ // Rates are disabled as of now due to Eigen's definitions of
+ // `extract_volume_patch` functions; none of them accept rates
+ // as its argument and rates are fixed to (1, 1, 1, 1, 1). A
+ // workaround has to be found for this.
+ // In order to enable rates, uncomment the following lines and use
+ // ksize_*_eff instead of ksize_* for the second argument of
+ // GetWindowedOutputSize calls.
+
+ const int rate_planes = rates_[1];
+ const int rate_rows = rates_[2];
+ const int rate_cols = rates_[3];
+
+ const int ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ const int ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ const int ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ int64 out_planes = 0, out_rows = 0, out_cols = 0;
+ int64 pad_planes = 0, pad_rows = 0, pad_cols = 0;
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_planes, ksize_planes, stride_planes,
+ padding_, &out_planes, &pad_planes));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_rows, ksize_rows, stride_rows,
+ padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context,
+ GetWindowedOutputSize(in_cols, ksize_cols, stride_cols,
+ padding_, &out_cols, &pad_cols));
+
+ const std::vector<int64> out_sizes = {
+ batch, out_planes, out_rows, out_cols,
+ ksize_planes * ksize_rows * ksize_cols * depth};
+ TensorShape out_shape(out_sizes);
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ functor::ExtractVolumePatchesForward<Device, T>()(
+ context->eigen_device<Device>(), input.tensor<T, 5>(), ksize_planes,
+ ksize_rows, ksize_cols, stride_planes, stride_rows, stride_cols,
+ /* rate_planes, rate_rows, rate_cols, */
+ BrainPadding2EigenPadding(padding_), output->tensor<T, 5>());
+ }
+
+ private:
+ std::vector<int32> ksizes_;
+ std::vector<int32> strides_;
+ // std::vector<int32> rates_;
+
+ Padding padding_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(ExtractVolumePatchesOp);
+};
+
+// Registration of the CPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<CPUDevice, T>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#if GOOGLE_CUDA
+
+// Forward declarations of the functor specializations for GPU.
+namespace functor {
+
+// clang-format off
+#define DECLARE_GPU_SPEC(T) \
+ template <> \
+ void ExtractVolumePatchesForward<GPUDevice, T>::operator()( \
+ const GPUDevice& d, typename TTypes<T, 5>::ConstTensor input, \
+ int patch_planes, int patch_rows, int patch_cols, \
+ int stride_planes, int stride_rows, int stride_cols, \
+ /* int rate_planes, int rate_rows, int rate_cols, */ \
+ const Eigen::PaddingType& padding, \
+ typename TTypes<T, 5>::Tensor output); \
+ extern template struct ExtractVolumePatchesForward<GPUDevice, T>;
+// clang-format on
+
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+
+#undef DECLARE_GPU_SPEC
+
+} // namespace functor
+
+// Registration of the GPU implementations.
+#define REGISTER(T) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ExtractVolumePatches").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
+ ExtractVolumePatchesOp<GPUDevice, T>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/extract_volume_patches_op.h b/tensorflow/core/kernels/extract_volume_patches_op.h
new file mode 100644
index 0000000000..7e0502b770
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+#define TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
+
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/kernels/eigen_volume_patch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T>
+struct ExtractVolumePatchesForward {
+ void operator()(const Device& d, typename TTypes<T, 5>::ConstTensor input,
+ int patch_planes, int patch_rows, int patch_cols,
+ int stride_planes, int stride_rows, int stride_cols,
+ /* int rate_planes, int rate_rows, int rate_cols, */
+ const Eigen::PaddingType& padding,
+ typename TTypes<T, 5>::Tensor output) {
+ const int64 N = std::max(input.size(), output.size());
+ if (N <= std::numeric_limits<Index32>::max()) {
+ auto output_32bit = To32Bit(output);
+ output_32bit.device(d) =
+ To32Bit(input)
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output_32bit.dimensions());
+ } else {
+ output.device(d) =
+ input
+ .extract_volume_patches(patch_cols, patch_rows, patch_planes,
+ stride_cols, stride_rows, stride_planes,
+ padding)
+ .reshape(output.dimensions());
+ }
+ }
+};
+
+} // namespace functor
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_EXTRACT_VOLUME_PATCHES_OP_H_
diff --git a/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
new file mode 100644
index 0000000000..c636493602
--- /dev/null
+++ b/tensorflow/core/kernels/extract_volume_patches_op_gpu.cu.cc
@@ -0,0 +1,38 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/extract_volume_patches_op.h"
+#include "tensorflow/core/framework/register_types.h"
+
+namespace tensorflow {
+
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+
+#define REGISTER(T) template struct ExtractVolumePatchesForward<GPUDevice, T>;
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER);
+
+#undef REGISTER
+
+} // end namespace functor
+} // end namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/fake_quant_ops_functor.h b/tensorflow/core/kernels/fake_quant_ops_functor.h
index d51acc38ef..045a96ac1e 100644
--- a/tensorflow/core/kernels/fake_quant_ops_functor.h
+++ b/tensorflow/core/kernels/fake_quant_ops_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
-#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_
#include <tuple>
@@ -277,4 +277,4 @@ struct FakeQuantWithMinMaxVarsPerChannelGradientFunctor {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_FAKE_QUANT_OPS_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/fill_functor.h b/tensorflow/core/kernels/fill_functor.h
index 4c8b3f01a7..46bffa5173 100644
--- a/tensorflow/core/kernels/fill_functor.h
+++ b/tensorflow/core/kernels/fill_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_
#define EIGEN_USE_THREADS
@@ -89,4 +89,4 @@ struct SetOneFunctor<Eigen::ThreadPoolDevice, string> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FILL_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_FILL_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/fractional_pool_common.h b/tensorflow/core/kernels/fractional_pool_common.h
index 2d7a230fc0..55a959f3c3 100644
--- a/tensorflow/core/kernels/fractional_pool_common.h
+++ b/tensorflow/core/kernels/fractional_pool_common.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
-#define TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_
#include <algorithm>
#include <vector>
@@ -75,4 +75,4 @@ std::vector<int64> GeneratePoolingSequence(int input_length, int output_length,
bool pseudo_random);
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FRACTIONAL_POOL_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_FRACTIONAL_POOL_COMMON_H_
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.h b/tensorflow/core/kernels/fused_batch_norm_op.h
index d6c68df986..c45b6f79e3 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.h
+++ b/tensorflow/core/kernels/fused_batch_norm_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
-#define TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
@@ -128,4 +128,4 @@ struct FusedBatchNormFreezeGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_FUSED_BATCH_NORM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_FUSED_BATCH_NORM_OP_H_
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index 8bfa40304e..f2e0b2558f 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -43,4 +43,6 @@ tf_ops_fuzz_target_lib("example_proto_fast_parsing")
tf_ops_fuzz_target_lib("parse_tensor_op")
+tf_ops_fuzz_target_lib("decode_compressed")
+
tf_ops_fuzz_target_lib("decode_json_example")
diff --git a/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
new file mode 100644
index 0000000000..0a56f4b63f
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_compressed_fuzz.cc
@@ -0,0 +1,45 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeCompressed : public FuzzStringInputOp {
+ void BuildGraph(const Scope& scope) override {
+ auto input =
+ tensorflow::ops::Placeholder(scope.WithOpName("input1"), DT_STRING);
+ auto d1 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d1"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType(""));
+ auto d2 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d2"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("ZLIB"));
+ auto d3 = tensorflow::ops::DecodeCompressed(
+ scope.WithOpName("d3"), input,
+ tensorflow::ops::DecodeCompressed::CompressionType("GZIP"));
+ Scope grouper =
+ scope.WithControlDependencies(std::vector<tensorflow::Operation>{
+ d1.output.op(), d2.output.op(), d3.output.op()});
+ (void)tensorflow::ops::NoOp(grouper.WithOpName("output"));
+ }
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeCompressed);
+
+} // namespace fuzzing
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
index c90ad2cfeb..ada1235449 100644
--- a/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
+++ b/tensorflow/core/kernels/fuzzing/parse_tensor_op_fuzz.cc
@@ -31,9 +31,37 @@ class FuzzParseTensor : public FuzzSession {
}
void FuzzImpl(const uint8_t* data, size_t size) final {
+ // We need to be sure that we don't request too many elements (i.e., we
+ // don't make ASAN OOM). In theory, a tensor shape can have arbitrary large
+ // number of elements, up to the limit of the memory available to the OS.
+ // However, due to the tracing done in ASAN, after 2^32 bytes of requested
+ // memory we would get a crash in the fuzzer (see b/34190148). Hence, let's
+ // try parsing the proto here, check that the size (if valid) is below a
+ // maximum threshold (using 2^20 for convenience), and then run the
+ // remainder of the fuzzer testing. Of course, this duplicates some work
+ // but it's better than repeating the investigation whenever Autofuzz
+ // detects another similar OOM.
+ string as_string = string(reinterpret_cast<const char*>(data), size);
+ TensorProto proto;
+ if (!ParseProtoUnlimited(&proto, as_string)) {
+ LOG(WARNING) << "Unable to parse proto of tensor\n";
+ return;
+ }
+ if (!TensorShape::IsValid(proto.tensor_shape())) {
+ LOG(WARNING) << "Invalid tensor shape\n";
+ return;
+ }
+ TensorShape shape(proto.tensor_shape());
+ const int64 num_elements = shape.num_elements();
+ const int64 max_num_elements = 1 << 20;
+ if (num_elements > max_num_elements) {
+ LOG(WARNING) << "Requiring a tensor with too many elements\n";
+ return;
+ }
+
+ // Now we can do the actual fuzz implementation
Tensor input_tensor(tensorflow::DT_STRING, TensorShape({}));
- input_tensor.scalar<string>()() =
- string(reinterpret_cast<const char*>(data), size);
+ input_tensor.scalar<string>()() = as_string;
// TODO(b/32704451): Don't just ignore the ::tensorflow::Status object!
RunOneInput(input_tensor).IgnoreError();
}
diff --git a/tensorflow/core/kernels/gather_functor.h b/tensorflow/core/kernels/gather_functor.h
index 2c6e8bf3bc..7710cf93d6 100644
--- a/tensorflow/core/kernels/gather_functor.h
+++ b/tensorflow/core/kernels/gather_functor.h
@@ -13,14 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/type_traits.h"
+#include "tensorflow/core/framework/variant.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/platform/prefetch.h"
#include "tensorflow/core/platform/types.h"
@@ -176,4 +177,4 @@ struct GatherFunctor<GPUDevice, Variant, Index> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GATHER_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/gather_nd_op.h b/tensorflow/core/kernels/gather_nd_op.h
index 60780fb50c..003badb74d 100644
--- a/tensorflow/core/kernels/gather_nd_op.h
+++ b/tensorflow/core/kernels/gather_nd_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_H_
-#define TENSORFLOW_KERNELS_GATHER_ND_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
// Functor definition for GatherOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -47,4 +47,4 @@ Status DoGatherNd(OpKernelContext* c, const Tensor& params,
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_H_
diff --git a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
index dc028c2f1e..1c78de253e 100644
--- a/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
+++ b/tensorflow/core/kernels/gather_nd_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
-#define TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
// Specialization of GatherNdSlice to CPU
@@ -113,10 +113,25 @@ struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
#endif
generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
slice_size, Tindices, Tparams, Tout, &error_loc);
+
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
+// Eigen implementation below is not highly performant. gather_nd_generator
+// does not seem to be called in parallel, leading to very poor performance.
+// Additionally, since it uses scalar (Tscratch) to invoke 'generate', it
+// needs to go through redundant operations like 'reshape', 'broadcast' and
+// 'sum'. OpenMP loop below essentially does same thing as Eigen code, but
+// is considerably more efficient.
+#pragma omp parallel for
+ for (Eigen::DenseIndex i = 0; i < batch_size; i++) {
+ const Eigen::array<Eigen::DenseIndex, 1> loc{i};
+ gather_nd_generator(loc);
+ }
+#else // INTEL_MKL && ENABLE_MKL
Tscratch.device(d) = Tscratch.reshape(reshape_dims)
.broadcast(broadcast_dims)
.generate(gather_nd_generator)
.sum();
+#endif // INTEL_MKL && ENABLE_MKL
// error_loc() returns -1 if there's no out-of-bounds index,
// otherwise it returns the location of an OOB index in Tindices.
@@ -142,4 +157,4 @@ TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/gemm_functors.h b/tensorflow/core/kernels/gemm_functors.h
index 4b30c1f17f..1c80844085 100644
--- a/tensorflow/core/kernels/gemm_functors.h
+++ b/tensorflow/core/kernels/gemm_functors.h
@@ -24,6 +24,9 @@ limitations under the License.
#error "EIGEN_USE_THREADS must be enabled by all .cc files including this."
#endif // EIGEN_USE_THREADS
+#ifndef TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
+#define TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
+
#include <string.h>
#include <map>
#include <vector>
@@ -116,3 +119,5 @@ class FastGemmFunctor<float, float, float> {
}
};
#endif // USE_CBLAS_GEMM
+
+#endif // TENSORFLOW_CORE_KERNELS_GEMM_FUNCTORS_H_
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index c7dbefa0b4..86146f75f4 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -123,8 +123,7 @@ class AutoTuneMap {
string GetActionSummary(StringPiece action, const Parameters& params,
const Config& config) {
return strings::Printf("autotune_map %s %s: %s -> (%s)", name_.c_str(),
- std::string(action).c_str(),
- params.ToString().c_str(),
+ string(action).c_str(), params.ToString().c_str(),
config.ToString().c_str());
}
diff --git a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
index ada96ae4ea..d0d5c3e018 100644
--- a/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
+++ b/tensorflow/core/kernels/hexagon/graph_transfer_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HEXAGON_GRAPH_TRANSFER_UTILS_H_
-#define TENSORFLOW_PLATFORM_HEXAGON_GRAPH_TRANSFER_UTILS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
#include <queue>
#include <utility>
@@ -56,4 +56,4 @@ class GraphTransferUtils {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_HEXAGON_GRAPH_TRANSFER_UTILS_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.cc b/tensorflow/core/kernels/hexagon/graph_transferer.cc
index e05de3fe8e..477e729dcb 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.cc
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.cc
@@ -161,7 +161,7 @@ Status GraphTransferer::LoadGraphFromProto(
for (const string& output_node_name : output_node_names) {
const TensorId tid = ParseTensorName(output_node_name);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
const int node_id = node_name_to_id_cache_map_.at(node_name);
const Node* node = node_name_cache_list_.at(node_id);
diff --git a/tensorflow/core/kernels/hexagon/graph_transferer.h b/tensorflow/core/kernels/hexagon/graph_transferer.h
index 86c1c5625f..4328d51916 100644
--- a/tensorflow/core/kernels/hexagon/graph_transferer.h
+++ b/tensorflow/core/kernels/hexagon/graph_transferer.h
@@ -228,4 +228,4 @@ class GraphTransferer {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
index 1580b72605..cc469f6dba 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
@@ -168,7 +168,7 @@ bool HexagonControlWrapper::SetupGraph() {
new_output_node_info.set_output_count(0);
const TensorId tid = ParseTensorName(graph_output.name());
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
// Register node input for the new output node
const GraphTransferNodeInfo* node_info =
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
index 132cfde2db..1b382996f8 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
-#define TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
#include <unordered_map>
#include <vector>
@@ -88,4 +88,4 @@ class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_CONTROL_WRAPPER_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
diff --git a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
index b9328c8e0e..270d697e96 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
+++ b/tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h
@@ -55,4 +55,4 @@ class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
diff --git a/tensorflow/core/kernels/hexagon/soc_interface.h b/tensorflow/core/kernels/hexagon/soc_interface.h
index 062103ed98..d1a41d47c8 100644
--- a/tensorflow/core/kernels/hexagon/soc_interface.h
+++ b/tensorflow/core/kernels/hexagon/soc_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
-#define TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
+#define TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
// Declaration of APIs provided by hexagon shared library. This header is shared
// with both hexagon library built with qualcomm SDK and tensorflow.
@@ -111,4 +111,4 @@ void soc_interface_SetDebugFlag(uint64_t flag);
}
#endif // __cplusplus
-#endif // TENSORFLOW_PLATFORM_HEXAGON_SOC_INTERFACE_H_
+#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_SOC_INTERFACE_H_
diff --git a/tensorflow/core/kernels/hinge-loss.h b/tensorflow/core/kernels/hinge-loss.h
index d303e9c877..b12910d27d 100644
--- a/tensorflow/core/kernels/hinge-loss.h
+++ b/tensorflow/core/kernels/hinge-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_HINGE_LOSS_H_
-#define TENSORFLOW_KERNELS_HINGE_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
#include <algorithm>
#include <limits>
@@ -123,4 +123,4 @@ class HingeLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_HINGE_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_HINGE_LOSS_H_
diff --git a/tensorflow/core/kernels/histogram_op.h b/tensorflow/core/kernels/histogram_op.h
index 1b253f7fed..b14fc2bee3 100644
--- a/tensorflow/core/kernels/histogram_op.h
+++ b/tensorflow/core/kernels/histogram_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_HISTOGRAM_OP_H_
-#define TENSORFLOW_HISTOGRAM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -35,4 +35,4 @@ struct HistogramFixedWidthFunctor {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_HISTOGRAM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_HISTOGRAM_OP_H_
diff --git a/tensorflow/core/kernels/histogram_op_gpu.cu.cc b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
index a88e9b0ddc..374a05850e 100644
--- a/tensorflow/core/kernels/histogram_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/histogram_op_gpu.cu.cc
@@ -18,7 +18,7 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_histogram.cuh"
+#include "third_party/cub/device/device_histogram.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/host_constant_op.cc b/tensorflow/core/kernels/host_constant_op.cc
new file mode 100644
index 0000000000..d08a7c9bd2
--- /dev/null
+++ b/tensorflow/core/kernels/host_constant_op.cc
@@ -0,0 +1,78 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/kernels/host_constant_op.h"
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+_HostConstantOp::_HostConstantOp(OpKernelConstruction* ctx)
+ : OpKernel(ctx), tensor_(ctx->output_type(0)) {
+ const TensorProto* proto = nullptr;
+ AllocatorAttributes alloc_attr;
+ alloc_attr.set_on_host(true);
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("value", &proto));
+ OP_REQUIRES_OK(
+ ctx, ctx->device()->MakeTensorFromProto(*proto, alloc_attr, &tensor_));
+ OP_REQUIRES(
+ ctx, ctx->output_type(0) == tensor_.dtype(),
+ errors::InvalidArgument("Type mismatch between value (",
+ DataTypeString(tensor_.dtype()), ") and dtype (",
+ DataTypeString(ctx->output_type(0)), ")"));
+}
+
+void _HostConstantOp::Compute(OpKernelContext* ctx) {
+ ctx->set_output(0, tensor_);
+}
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+REGISTER_KERNEL_BUILDER(Name("Const")
+ .Device(DEVICE_GPU)
+ .HostMemory("output")
+ .TypeConstraint<int32>("dtype"),
+ _HostConstantOp);
+#endif
+
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("Const")
+ .Device(DEVICE_SYCL)
+ .HostMemory("output")
+ .TypeConstraint<int32>("dtype"),
+ _HostConstantOp);
+#endif // TENSORFLOW_USE_SYCL
+
+// HostConst: forced to generate output on the host.
+// Only used in tests; no op is registered for this kernel
+// externally (i.e., in array_ops.cc)
+REGISTER_KERNEL_BUILDER(Name("HostConst").Device(DEVICE_CPU), _HostConstantOp);
+REGISTER_KERNEL_BUILDER(
+ Name("HostConst").Device(DEVICE_GPU).HostMemory("output"), _HostConstantOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("HostConst").Device(DEVICE_SYCL).HostMemory("output"),
+ _HostConstantOp);
+#endif // TENSORFLOW_USE_SYCL
+
+} // end namespace tensorflow
+
diff --git a/tensorflow/core/kernels/host_constant_op.h b/tensorflow/core/kernels/host_constant_op.h
new file mode 100644
index 0000000000..1b887ea1aa
--- /dev/null
+++ b/tensorflow/core/kernels/host_constant_op.h
@@ -0,0 +1,42 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace tensorflow {
+
+// HostConstantOp differs from ConstantOp in that its output is always
+// in host memory.
+class _HostConstantOp : public OpKernel {
+ public:
+ explicit _HostConstantOp(OpKernelConstruction* ctx);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+ ~_HostConstantOp() override {}
+
+ private:
+ Tensor tensor_;
+ TF_DISALLOW_COPY_AND_ASSIGN(_HostConstantOp);
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_HOST_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
index 6072412689..b2329f4b61 100644
--- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h
+++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
-#define TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_EXECUTOR_H_
+#define TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_EXECUTOR_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
@@ -74,4 +74,4 @@ class IRemoteFusedGraphExecutor {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_I_REMOTE_FUSED_GRAPH_EXECUTOR_H_
diff --git a/tensorflow/core/kernels/identity_n_op.h b/tensorflow/core/kernels/identity_n_op.h
index 490bbf456c..7339cbbe29 100644
--- a/tensorflow/core/kernels/identity_n_op.h
+++ b/tensorflow/core/kernels/identity_n_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_IDENTITY_N_OP_H_
-#define TENSORFLOW_KERNELS_IDENTITY_N_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_
+#define TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -41,4 +41,4 @@ class IdentityNOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IDENTITY_N_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_IDENTITY_N_OP_H_
diff --git a/tensorflow/core/kernels/identity_op.h b/tensorflow/core/kernels/identity_op.h
index f8856a1b9b..6b74868ad4 100644
--- a/tensorflow/core/kernels/identity_op.h
+++ b/tensorflow/core/kernels/identity_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_IDENTITY_OP_H_
-#define TENSORFLOW_KERNELS_IDENTITY_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_
+#define TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -37,4 +37,4 @@ class IdentityOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IDENTITY_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_IDENTITY_OP_H_
diff --git a/tensorflow/core/kernels/image_resizer_state.h b/tensorflow/core/kernels/image_resizer_state.h
index 8dcb5977c6..1d4fa1a7db 100644
--- a/tensorflow/core/kernels/image_resizer_state.h
+++ b/tensorflow/core/kernels/image_resizer_state.h
@@ -18,8 +18,8 @@ limitations under the License.
// reduce code duplication and ensure consistency across the different
// resizers, it performs the input validation.
-#ifndef TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
-#define TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IMAGE_RESIZER_STATE_H_
+#define TENSORFLOW_CORE_KERNELS_IMAGE_RESIZER_STATE_H_
#define EIGEN_USE_THREADS
@@ -191,4 +191,4 @@ struct ImageResizerGradientState {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IMAGE_RESIZER_STATE_H_
+#endif // TENSORFLOW_CORE_KERNELS_IMAGE_RESIZER_STATE_H_
diff --git a/tensorflow/core/kernels/immutable_constant_op.h b/tensorflow/core/kernels/immutable_constant_op.h
index 795331b4b2..97af8c7dc5 100644
--- a/tensorflow/core/kernels/immutable_constant_op.h
+++ b/tensorflow/core/kernels/immutable_constant_op.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_IMMUTABLE_CONSTANT_OP_H_
-#define TENSORFLOW_KERNELS_IMMUTABLE_CONSTANT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_
#include <memory>
@@ -46,4 +46,4 @@ class ImmutableConstantOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_IMMUTABLE_CONSTANT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_IMMUTABLE_CONSTANT_OP_H_
diff --git a/tensorflow/core/kernels/initializable_lookup_table.cc b/tensorflow/core/kernels/initializable_lookup_table.cc
index 06d53eba30..fcf468f5a8 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.cc
+++ b/tensorflow/core/kernels/initializable_lookup_table.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/initializable_lookup_table.h"
-
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -32,6 +31,13 @@ Status InitializableLookupTable::Find(OpKernelContext* ctx, const Tensor& keys,
return DoFind(keys, values, default_value);
}
+Status InitializableLookupTable::ImportValues(OpKernelContext* ctx,
+ const Tensor& keys,
+ const Tensor& values) {
+ lookup::KeyValueTensorIterator iter(&keys, &values);
+ return Initialize(iter);
+}
+
Status InitializableLookupTable::Initialize(InitTableIterator& iter) {
if (!iter.Valid()) {
return iter.status();
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index b4f81d9a70..424fe5df3c 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
-#define TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#define TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/platform/macros.h"
@@ -58,11 +58,7 @@ class InitializableLookupTable : public LookupInterface {
}
Status ImportValues(OpKernelContext* ctx, const Tensor& keys,
- const Tensor& values) final {
- return errors::Unimplemented(
- "ImportValues not supported by InitializableLookupTable "
- "implementations");
- }
+ const Tensor& values) final;
TensorShape key_shape() const final { return TensorShape(); }
@@ -155,7 +151,58 @@ class InitializableLookupTable : public LookupInterface {
bool is_initialized_ = false;
};
+// Iterator to initialize tables given 'keys' and 'values' tensors.
+//
+// The two tensors are returned in the first iteration. It doesn't loop
+// over each element of the tensor since insertions in the lookup table can
+// process batches.
+class KeyValueTensorIterator
+ : public InitializableLookupTable::InitTableIterator {
+ public:
+ // keys and values are not owned by the iterator.
+ explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
+ : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
+ TensorShape key_shape = keys_->shape();
+ if (!key_shape.IsSameSize(values_->shape())) {
+ valid_ = false;
+ status_ = errors::InvalidArgument(
+ "keys and values should have the same dimension.",
+ key_shape.DebugString(), " vs ", values_->shape().DebugString());
+ }
+ if (key_shape.num_elements() == 0) {
+ valid_ = false;
+ status_ =
+ errors::InvalidArgument("keys and values cannot be empty tensors.");
+ }
+ }
+
+ bool Valid() const override { return valid_; }
+
+ void Next() override {
+ valid_ = false;
+ status_ = errors::OutOfRange("No more data.");
+ }
+
+ const Tensor& keys() const override { return *keys_; }
+
+ const Tensor& values() const override { return *values_; }
+
+ Status status() const override { return status_; }
+
+ int64 total_size() const override {
+ return keys_ == nullptr ? -1 : keys_->NumElements();
+ }
+
+ private:
+ TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
+
+ const Tensor* keys_; // Doesn't own it.
+ const Tensor* values_; // Doesn't own it.
+ bool valid_; // true if the iterator points to an existing range.
+ Status status_;
+};
+
} // namespace lookup
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
+#endif // TENSORFLOW_CORE_KERNELS_INITIALIZABLE_LOOKUP_TABLE_H_
diff --git a/tensorflow/core/kernels/inplace_ops_functor.h b/tensorflow/core/kernels/inplace_ops_functor.h
index b806787e91..2023869f49 100644
--- a/tensorflow/core/kernels/inplace_ops_functor.h
+++ b/tensorflow/core/kernels/inplace_ops_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
@@ -46,4 +46,4 @@ Status DoCopy(const Device& device, const Tensor& x, Tensor* y);
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_INPLACE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_INPLACE_OPS_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/l2loss_op.h b/tensorflow/core/kernels/l2loss_op.h
index 4953aa237c..465ef96a51 100644
--- a/tensorflow/core/kernels/l2loss_op.h
+++ b/tensorflow/core/kernels/l2loss_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_L2LOSS_OP_H_
-#define TENSORFLOW_KERNELS_L2LOSS_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -30,4 +30,4 @@ struct L2LossOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_L2LOSS_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_L2LOSS_OP_H_
diff --git a/tensorflow/core/kernels/linalg_ops_common.h b/tensorflow/core/kernels/linalg_ops_common.h
index f7c3f1950b..692f916439 100644
--- a/tensorflow/core/kernels/linalg_ops_common.h
+++ b/tensorflow/core/kernels/linalg_ops_common.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
// Classes to support linear algebra functionality, similar to the numpy.linalg
// module. Supports batch computation on several matrices at once, sharding the
@@ -194,4 +194,4 @@ extern template class LinearAlgebraOp<complex128>;
#define REGISTER_LINALG_OP(OpName, OpClass, Scalar) \
REGISTER_LINALG_OP_CPU(OpName, OpClass, Scalar)
-#endif // TENSORFLOW_KERNELS_LINALG_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_LINALG_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/list_kernels.cc b/tensorflow/core/kernels/list_kernels.cc
index 84fa63fc00..2088c13586 100644
--- a/tensorflow/core/kernels/list_kernels.cc
+++ b/tensorflow/core/kernels/list_kernels.cc
@@ -77,9 +77,9 @@ static Status TensorListDeviceCopy(
return Status::OK();
}
-#define REGISTER_LIST_COPY(DIRECTION) \
- INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
- TensorList, DIRECTION, TensorList::kTypeName, TensorListDeviceCopy)
+#define REGISTER_LIST_COPY(DIRECTION) \
+ INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(TensorList, DIRECTION, \
+ TensorListDeviceCopy)
REGISTER_LIST_COPY(VariantDeviceCopyDirection::HOST_TO_DEVICE);
REGISTER_LIST_COPY(VariantDeviceCopyDirection::DEVICE_TO_HOST);
@@ -92,8 +92,7 @@ Status TensorListShape(const TensorList& t, TensorShape* s) {
return Status::OK();
}
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorList::kTypeName,
- TensorListShape);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(TensorList, TensorListShape);
bool TensorList::Decode(const VariantTensorData& data) {
tensors = data.tensors();
@@ -588,7 +587,11 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListStack<CPUDevice, T>)
+ TensorListStack<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListGather<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_STACK_CPU);
REGISTER_TENSOR_LIST_STACK_CPU(quint8);
@@ -604,7 +607,11 @@ REGISTER_TENSOR_LIST_STACK_CPU(bfloat16);
REGISTER_KERNEL_BUILDER(Name("TensorListFromTensor") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_CPU), \
- TensorListFromTensor<CPUDevice, T>)
+ TensorListFromTensor<CPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_CPU), \
+ TensorListScatter<CPUDevice, T>)
TF_CALL_POD_STRING_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_CPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(quint8);
@@ -617,12 +624,11 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_CPU(bfloat16);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_CPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<CPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_CPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<CPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.cu.cc b/tensorflow/core/kernels/list_kernels.cu.cc
index 0ea9362cbe..a00bf700ca 100644
--- a/tensorflow/core/kernels/list_kernels.cu.cc
+++ b/tensorflow/core/kernels/list_kernels.cu.cc
@@ -40,7 +40,12 @@ typedef Eigen::GpuDevice GPUDevice;
REGISTER_KERNEL_BUILDER(Name("TensorListStack") \
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU), \
- TensorListStack<GPUDevice, T>)
+ TensorListStack<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListGather") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("indices"), \
+ TensorListGather<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_STACK_GPU);
REGISTER_TENSOR_LIST_STACK_GPU(bfloat16);
@@ -71,7 +76,13 @@ REGISTER_TENSOR_LIST_PUSH_BACK_BATCH_GPU(bool);
.TypeConstraint<T>("element_dtype") \
.Device(DEVICE_GPU) \
.HostMemory("element_shape"), \
- TensorListFromTensor<GPUDevice, T>)
+ TensorListFromTensor<GPUDevice, T>) \
+ REGISTER_KERNEL_BUILDER(Name("TensorListScatter") \
+ .TypeConstraint<T>("element_dtype") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("element_shape") \
+ .HostMemory("indices"), \
+ TensorListScatter<GPUDevice, T>)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_TENSOR_LIST_FROM_TENSOR_GPU);
REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bfloat16);
@@ -83,11 +94,10 @@ REGISTER_TENSOR_LIST_FROM_TENSOR_GPU(bool);
#undef REGISTER_TENSOR_LIST_FROM_TENSOR_GPU
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_GPU,
- TensorList, TensorList::kTypeName,
+ TensorList,
TensorListBinaryAdd<GPUDevice>);
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP,
DEVICE_GPU, TensorList,
- TensorList::kTypeName,
TensorListZerosLike<GPUDevice>);
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/list_kernels.h b/tensorflow/core/kernels/list_kernels.h
index 42871c6113..72581c9293 100644
--- a/tensorflow/core/kernels/list_kernels.h
+++ b/tensorflow/core/kernels/list_kernels.h
@@ -134,6 +134,74 @@ class TensorListStack : public OpKernel {
};
template <typename Device, typename T>
+class TensorListGather : public OpKernel {
+ public:
+ typedef std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>
+ ConstMatrixVector;
+ explicit TensorListGather(OpKernelConstruction* c) : OpKernel(c) {
+ OP_REQUIRES_OK(c, c->GetAttr("element_dtype", &element_dtype_));
+ }
+
+ void Compute(OpKernelContext* c) override {
+ const TensorList* l = c->input(0).scalar<Variant>()().get<TensorList>();
+ OP_REQUIRES(c, l != nullptr,
+ errors::InvalidArgument(
+ "Input handle is not a list. Saw: '",
+ c->input(0).scalar<Variant>()().DebugString(), "'"));
+ OP_REQUIRES(c, element_dtype_ == l->element_dtype,
+ errors::InvalidArgument("Invalid data types; op elements ",
+ DataTypeString(element_dtype_),
+ " but list elements ",
+ DataTypeString(l->element_dtype)));
+ OP_REQUIRES(c, l->element_shape.IsFullyDefined(),
+ errors::InvalidArgument("Tried to stack elements from a list "
+ "with non-fully-defined shape: ",
+ l->element_shape.DebugString()));
+ Tensor indices = c->input(1);
+ TensorShape resulting_shape;
+ resulting_shape.AddDim(indices.NumElements());
+ for (TensorShapeDim s : l->element_shape) {
+ resulting_shape.AddDim(s.size);
+ }
+ Tensor* output;
+ OP_REQUIRES_OK(c, c->allocate_output(0, resulting_shape, &output));
+ if (output->NumElements() == 0) {
+ return;
+ }
+
+ ConstMatrixVector inputs_flat;
+ inputs_flat.reserve(l->tensors.size());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(
+ c, i < l->tensors.size(),
+ errors::InvalidArgument("Index ", i, " out o range; list only has ",
+ l->tensors.size(), " elements."));
+ const Tensor& t = l->tensors[i];
+ OP_REQUIRES(c, l->element_shape.IsCompatibleWith(t.shape()),
+ errors::InvalidArgument(
+ "Tensor with invalid shape in list. List element shape: ",
+ l->element_shape.DebugString(),
+ " and tensor shape: ", t.shape().DebugString()));
+ inputs_flat.emplace_back(new typename TTypes<T, 2>::ConstMatrix(
+ t.shaped<T, 2>({1, t.NumElements()})));
+ }
+ auto output_flat = output->shaped<T, 2>({1, output->NumElements()});
+
+#if GOOGLE_CUDA
+ if (std::is_same<Device, Eigen::GpuDevice>::value) {
+ ConcatGPU<T>(c, inputs_flat, output, &output_flat);
+ return;
+ }
+#endif // GOOGLE_CUDA
+ ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
+ }
+
+ private:
+ DataType element_dtype_;
+};
+
+template <typename Device, typename T>
class TensorListFromTensor : public OpKernel {
public:
TensorListFromTensor(OpKernelConstruction* c) : OpKernel(c) {}
@@ -178,6 +246,59 @@ class TensorListFromTensor : public OpKernel {
}
};
+template <typename Device, typename T>
+class TensorListScatter : public OpKernel {
+ public:
+ TensorListScatter(OpKernelConstruction* c) : OpKernel(c) {}
+
+ void Compute(OpKernelContext* c) override {
+ Tensor* output_tensor;
+ AllocatorAttributes attr;
+ attr.set_on_host(true);
+ OP_REQUIRES_OK(c, c->allocate_output(0, {}, &output_tensor, attr));
+ Tensor indices = c->input(1);
+ PartialTensorShape element_shape;
+ OP_REQUIRES_OK(c, TensorShapeFromTensor(c->input(2), &element_shape));
+ TensorList output_list;
+ const Tensor& t = c->input(0);
+ output_list.element_dtype = t.dtype();
+ OP_REQUIRES(c, TensorShapeUtils::IsVectorOrHigher(t.shape()),
+ errors::InvalidArgument(
+ "Tensor must be at least a vector, but saw shape: ",
+ t.shape().DebugString()));
+ TensorShape output_shape(t.shape());
+ output_shape.RemoveDim(0);
+ OP_REQUIRES(c, element_shape.IsCompatibleWith(output_shape),
+ errors::InvalidArgument(
+ "Specified a list with shape ", element_shape.DebugString(),
+ " from a tensor with shape ", output_shape.DebugString()));
+ output_list.element_shape = element_shape;
+ output_list.tensors.reserve(indices.NumElements());
+ for (int index = 0; index < indices.NumElements(); ++index) {
+ const int i = indices.flat<int32>()(index);
+ OP_REQUIRES(c, i < t.shape().dim_size(0),
+ errors::InvalidArgument("Trying to scatter index ", i,
+ " from tensor with ",
+ t.shape().dim_size(0), " rows."));
+ Tensor tmp = t.Slice(i, i + 1);
+ TensorShape tmp_shape = tmp.shape();
+ tmp_shape.RemoveDim(0);
+ OP_REQUIRES(c, tmp.CopyFrom(tmp, tmp_shape),
+ errors::Unknown("Unexpected shape error."));
+ // TODO(apassos) maybe not always align; but weird compiler bugs seem to
+ // prevent this.
+ Tensor aligned;
+ OP_REQUIRES_OK(c, c->allocate_temp(tmp.dtype(), tmp.shape(), &aligned));
+ // TODO(apassos) do all slices in a single kernel invocation instead of
+ // many small ondes.
+ aligned.flat<T>().device(c->eigen_device<Device>()) =
+ tmp.unaligned_flat<T>();
+ output_list.tensors.push_back(aligned);
+ }
+ output_tensor->scalar<Variant>()() = std::move(output_list);
+ }
+};
+
template <typename Device>
Status TensorListBinaryAdd(OpKernelContext* c, const TensorList& a,
const TensorList& b, TensorList* out) {
@@ -253,7 +374,12 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
y->tensors.reserve(x.tensors.size());
for (const Tensor& t : x.tensors) {
Tensor out_tensor;
- TF_RETURN_IF_ERROR(c->allocate_temp(t.dtype(), t.shape(), &out_tensor));
+ AllocatorAttributes attr;
+ if (t.dtype() == DT_VARIANT) {
+ attr.set_on_host(true);
+ }
+ TF_RETURN_IF_ERROR(
+ c->allocate_temp(t.dtype(), t.shape(), &out_tensor, attr));
switch (out_tensor.dtype()) {
#define DTYPE_CASE(dtype) \
case DataTypeToEnum<dtype>::value: \
@@ -261,14 +387,29 @@ Status TensorListZerosLike(OpKernelContext* c, const TensorList& x,
out_tensor.flat<dtype>().constant(dtype(0)); \
break;
- TF_CALL_NUMBER_TYPES(DTYPE_CASE)
+ TF_CALL_POD_TYPES(DTYPE_CASE)
#undef DTYPE_CASE
+
+ case DataTypeToEnum<Variant>::value: {
+ const TensorList* inner_x = t.scalar<Variant>()().get<TensorList>();
+ if (inner_x == nullptr) {
+ return errors::InvalidArgument("Input handle is not a list. Saw: '",
+ t.scalar<Variant>()().DebugString(),
+ "'");
+ }
+ TensorList inner_y;
+ TF_RETURN_IF_ERROR(TensorListZerosLike<Device>(c, *inner_x, &inner_y));
+ out_tensor.scalar<Variant>()() = std::move(inner_y);
+ break;
+ }
+
default:
return errors::InvalidArgument(
- "Trying to compute zeros_like for unsupported dtype",
- out_tensor.dtype());
+ "Trying to compute zeros_like for unsupported dtype ",
+ DataTypeString(out_tensor.dtype()));
}
+ y->tensors.emplace_back(out_tensor);
}
return Status::OK();
}
diff --git a/tensorflow/core/kernels/logging_ops.cc b/tensorflow/core/kernels/logging_ops.cc
index 6b6a14e9a7..1ded012f3c 100644
--- a/tensorflow/core/kernels/logging_ops.cc
+++ b/tensorflow/core/kernels/logging_ops.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <iostream>
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -90,6 +91,59 @@ class PrintOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("Print").Device(DEVICE_CPU), PrintOp);
+class PrintV2Op : public OpKernel {
+ public:
+ explicit PrintV2Op(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("output_stream", &output_stream_));
+
+ auto output_stream_index =
+ std::find(std::begin(valid_output_streams_),
+ std::end(valid_output_streams_), output_stream_);
+
+ if (output_stream_index == std::end(valid_output_streams_)) {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_));
+ const string& msg = input_->scalar<string>()();
+
+ if (output_stream_ == "stdout") {
+ std::cout << msg << std::endl;
+ } else if (output_stream_ == "stderr") {
+ std::cerr << msg << std::endl;
+ } else if (output_stream_ == "log(info)") {
+ LOG(INFO) << msg << std::endl;
+ } else if (output_stream_ == "log(warning)") {
+ LOG(WARNING) << msg << std::endl;
+ } else if (output_stream_ == "log(error)") {
+ LOG(ERROR) << msg << std::endl;
+ } else {
+ string error_msg = strings::StrCat(
+ "Unknown output stream: ", output_stream_, ", Valid streams are:");
+ for (auto valid_stream : valid_output_streams_) {
+ strings::StrAppend(&error_msg, " ", valid_stream);
+ }
+ OP_REQUIRES(ctx, false, errors::InvalidArgument(error_msg));
+ }
+ }
+
+ const char* valid_output_streams_[6] = {"stdout", "stderr", "log(info)",
+ "log(warning)", "log(error)"};
+
+ private:
+ string output_stream_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("PrintV2").Device(DEVICE_CPU), PrintV2Op);
+
class TimestampOp : public OpKernel {
public:
explicit TimestampOp(OpKernelConstruction* context) : OpKernel(context) {}
diff --git a/tensorflow/core/kernels/logging_ops_test.cc b/tensorflow/core/kernels/logging_ops_test.cc
index 5e6958f364..a259d995fa 100644
--- a/tensorflow/core/kernels/logging_ops_test.cc
+++ b/tensorflow/core/kernels/logging_ops_test.cc
@@ -23,11 +23,33 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
namespace tensorflow {
namespace {
+class PrintingV2GraphTest : public OpsTestBase {
+ protected:
+ Status Init(const string& output_stream = "log(warning)") {
+ TF_CHECK_OK(NodeDefBuilder("op", "PrintV2")
+ .Input(FakeInput(DT_STRING))
+ .Attr("output_stream", output_stream)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(PrintingV2GraphTest, StringSuccess) {
+ TF_ASSERT_OK(Init());
+ AddInputFromArray<string>(TensorShape({}), {"bar"});
+ TF_ASSERT_OK(RunOpKernel());
+}
+
+TEST_F(PrintingV2GraphTest, InvalidOutputStream) {
+ ASSERT_NE(::tensorflow::Status::OK(), (Init("invalid_output_stream")));
+}
+
class PrintingGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type1, DataType input_type2, string msg = "",
diff --git a/tensorflow/core/kernels/logistic-loss.h b/tensorflow/core/kernels/logistic-loss.h
index 6479e6f5dc..9198a98e47 100644
--- a/tensorflow/core/kernels/logistic-loss.h
+++ b/tensorflow/core/kernels/logistic-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOGISTIC_LOSS_H_
-#define TENSORFLOW_KERNELS_LOGISTIC_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
#include <cmath>
@@ -86,7 +86,7 @@ class LogisticLossUpdater : public DualLossUpdater {
} else {
inverse_exp_term = 1 / (1 + exp(label * wx));
}
- return inverse_exp_term * label * example_weight;
+ return -inverse_exp_term * label * example_weight;
}
// The smoothness constant is 4 since the derivative of logistic loss, which
@@ -131,4 +131,4 @@ class LogisticLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOGISTIC_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
diff --git a/tensorflow/core/kernels/lookup_table_init_op.cc b/tensorflow/core/kernels/lookup_table_init_op.cc
index b352dd257c..6e77e1ee01 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.cc
+++ b/tensorflow/core/kernels/lookup_table_init_op.cc
@@ -74,13 +74,11 @@ class InitializeTableOp : public OpKernel {
"Keys and values must have the same size ",
keys.NumElements(), " vs ", values.NumElements()));
- lookup::KeyValueTensorIterator iter(&keys, &values);
-
int memory_used_before = 0;
if (ctx->track_allocations()) {
memory_used_before = table->MemoryUsed();
}
- OP_REQUIRES_OK(ctx, table->Initialize(iter));
+ OP_REQUIRES_OK(ctx, table->ImportValues(ctx, keys, values));
if (ctx->track_allocations()) {
ctx->record_persistent_memory_allocation(table->MemoryUsed() -
memory_used_before);
diff --git a/tensorflow/core/kernels/lookup_table_init_op.h b/tensorflow/core/kernels/lookup_table_init_op.h
index 177a26daa8..101e528659 100644
--- a/tensorflow/core/kernels/lookup_table_init_op.h
+++ b/tensorflow/core/kernels/lookup_table_init_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_INIT_OP_H_
-#define TENSORFLOW_KERNELS_LOOKUP_TABLE_INIT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_
#include "tensorflow/core/kernels/initializable_lookup_table.h"
@@ -30,4 +30,4 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
} // namespace lookup
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_INIT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_INIT_OP_H_
diff --git a/tensorflow/core/kernels/lookup_table_op.cc b/tensorflow/core/kernels/lookup_table_op.cc
index 07e754a6ef..a495758861 100644
--- a/tensorflow/core/kernels/lookup_table_op.cc
+++ b/tensorflow/core/kernels/lookup_table_op.cc
@@ -50,7 +50,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
MutableHashTableOfScalars(OpKernelContext* ctx, OpKernel* kernel) {}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -60,7 +60,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
const auto key_values = key.flat<K>();
auto value_values = value->flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
value_values(i) = gtl::FindWithDefault(
table_, SubtleMustCopyIfIntegral(key_values(i)), default_val);
@@ -95,7 +95,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
Tensor* keys;
@@ -125,7 +125,7 @@ class MutableHashTableOfScalars final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -138,7 +138,6 @@ class MutableHashTableOfScalars final : public LookupInterface {
}
private:
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
std::unordered_map<K, V> table_ GUARDED_BY(mu_);
};
@@ -158,7 +157,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
size_t size() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return table_.size();
}
@@ -169,7 +168,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
auto value_values = value->flat_inner_dims<V, 2>();
int64 value_dim = value_shape_.dim_size(0);
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (int64 i = 0; i < key_values.size(); ++i) {
ValueArray* value_vec =
gtl::FindOrNull(table_, SubtleMustCopyIfIntegral(key_values(i)));
@@ -219,7 +218,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
int64 size = table_.size();
int64 value_dim = value_shape_.dim_size(0);
@@ -254,7 +253,7 @@ class MutableHashTableOfTensors final : public LookupInterface {
int64 MemoryUsed() const override {
int64 ret = 0;
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
for (unsigned i = 0; i < table_.bucket_count(); ++i) {
size_t bucket_size = table_.bucket_size(i);
if (bucket_size == 0) {
@@ -268,7 +267,6 @@ class MutableHashTableOfTensors final : public LookupInterface {
private:
TensorShape value_shape_;
- // TODO(andreasst): consider using a read/write lock or a concurrent map
mutable mutex mu_;
typedef gtl::InlinedVector<V, 4> ValueArray;
std::unordered_map<K, ValueArray> table_ GUARDED_BY(mu_);
@@ -335,13 +333,13 @@ class MutableDenseHashTable final : public LookupInterface {
}
size_t size() const override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return num_entries_;
}
Status Find(OpKernelContext* ctx, const Tensor& key, Tensor* value,
const Tensor& default_value) override LOCKS_EXCLUDED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 key_size = key_shape_.num_elements();
const int64 value_size = value_shape_.num_elements();
if (key.NumElements() != num_elements * key_size) {
@@ -355,7 +353,7 @@ class MutableDenseHashTable final : public LookupInterface {
auto value_matrix = value->shaped<V, 2>({num_elements, value_size});
const auto default_flat = default_value.flat<V>();
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
const auto key_buckets_matrix =
key_buckets_.AccessTensor(ctx)->template matrix<K>();
const auto value_buckets_matrix =
@@ -403,8 +401,9 @@ class MutableDenseHashTable final : public LookupInterface {
Status Insert(OpKernelContext* ctx, const Tensor& key,
const Tensor& value) override LOCKS_EXCLUDED(mu_) {
- if (key.NumElements() != key.dim_size(0) * key_shape_.num_elements()) {
- TensorShape expected_shape({key.dim_size(0)});
+ const int64 batch_size = (key.dims() == 0) ? 1 : key.dim_size(0);
+ if (key.NumElements() != batch_size * key_shape_.num_elements()) {
+ TensorShape expected_shape({batch_size});
expected_shape.AppendShape(key_shape_);
return errors::InvalidArgument("Expected key shape ",
expected_shape.DebugString(), " got ",
@@ -415,7 +414,7 @@ class MutableDenseHashTable final : public LookupInterface {
// rather than updates. That means we may grow the table even though we
// don't need to. As long as the number of keys inserted in one call is
// small compared to the size of the map, the impact of this is minimal.
- const int64 pending_num_entries = num_entries_ + key.dim_size(0);
+ const int64 pending_num_entries = num_entries_ + batch_size;
if (pending_num_entries > num_buckets_ * max_load_factor_) {
int64 new_num_buckets = num_buckets_;
do {
@@ -450,7 +449,7 @@ class MutableDenseHashTable final : public LookupInterface {
}
Status ExportValues(OpKernelContext* ctx) override LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
Tensor key_buckets_tensor = *key_buckets_.AccessTensor(ctx);
Tensor value_buckets_tensor = *value_buckets_.AccessTensor(ctx);
TF_RETURN_IF_ERROR(ctx->set_output("keys", key_buckets_tensor));
@@ -492,7 +491,7 @@ class MutableDenseHashTable final : public LookupInterface {
TensorShape value_shape() const override { return value_shape_; }
int64 MemoryUsed() const override {
- mutex_lock l(mu_);
+ tf_shared_lock l(mu_);
return sizeof(MutableDenseHashTable) + key_buckets_.AllocatedBytes() +
value_buckets_.AllocatedBytes() + empty_key_.AllocatedBytes();
}
@@ -500,7 +499,7 @@ class MutableDenseHashTable final : public LookupInterface {
private:
Status DoInsert(OpKernelContext* ctx, const Tensor& key, const Tensor& value,
bool ignore_empty_key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
- const int64 num_elements = key.dim_size(0);
+ const int64 num_elements = (key.dims() == 0) ? 1 : key.dim_size(0);
const int64 value_size = value_shape_.num_elements();
const int64 key_size = key_shape_.num_elements();
const auto key_matrix = key.shaped<K, 2>({num_elements, key_size});
@@ -812,17 +811,21 @@ REGISTER_KERNEL_BUILDER(Name("LookupTableImportV2").Device(DEVICE_CPU),
LookupTableOp<lookup::HashTable<key_dtype, value_dtype>, key_dtype, \
value_dtype>)
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int32, string);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
+REGISTER_KERNEL(string, bool);
REGISTER_KERNEL(string, double);
REGISTER_KERNEL(string, float);
REGISTER_KERNEL(string, int32);
REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
REGISTER_KERNEL(string, string);
-REGISTER_KERNEL(string, bool);
-REGISTER_KERNEL(int32, int32);
-REGISTER_KERNEL(int32, string);
#undef REGISTER_KERNEL
@@ -843,12 +846,20 @@ REGISTER_KERNEL(int32, string);
LookupTableOp<lookup::MutableHashTableOfScalars<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
-REGISTER_KERNEL(int64, string);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
+REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -869,10 +880,19 @@ REGISTER_KERNEL(int64, Variant);
LookupTableOp<lookup::MutableHashTableOfTensors<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, int64);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, string);
REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
@@ -893,13 +913,20 @@ REGISTER_KERNEL(string, bool);
LookupTableOp<lookup::MutableDenseHashTable<key_dtype, value_dtype>, \
key_dtype, value_dtype>)
-REGISTER_KERNEL(int64, int64);
-REGISTER_KERNEL(int64, float);
-REGISTER_KERNEL(int64, double);
-REGISTER_KERNEL(string, float);
-REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(int32, double);
+REGISTER_KERNEL(int32, float);
+REGISTER_KERNEL(int32, int32);
REGISTER_KERNEL(int64, bool);
+REGISTER_KERNEL(int64, double);
+REGISTER_KERNEL(int64, float);
+REGISTER_KERNEL(int64, int32);
+REGISTER_KERNEL(int64, int64);
REGISTER_KERNEL(int64, Variant);
+REGISTER_KERNEL(string, bool);
+REGISTER_KERNEL(string, double);
+REGISTER_KERNEL(string, float);
+REGISTER_KERNEL(string, int32);
+REGISTER_KERNEL(string, int64);
#undef REGISTER_KERNEL
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 3977f16299..9451247f26 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
-#define TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
#include "tensorflow/core/framework/lookup_interface.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -102,9 +102,12 @@ class LookupTableOp : public OpKernel {
~LookupTableOp() override {
// If the table object was not shared, delete it.
if (table_handle_set_ && cinfo_.resource_is_private_to_kernel()) {
- TF_CHECK_OK(
- cinfo_.resource_manager()->template Delete<lookup::LookupInterface>(
- cinfo_.container(), cinfo_.name()));
+ if (!cinfo_.resource_manager()
+ ->template Delete<lookup::LookupInterface>(cinfo_.container(),
+ cinfo_.name())
+ .ok()) {
+ // Do nothing; the resource can have been deleted by session resets.
+ }
}
}
@@ -272,4 +275,4 @@ class HashTable : public InitializableLookupTable {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOOKUP_TABLE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOOKUP_TABLE_OP_H_
diff --git a/tensorflow/core/kernels/lookup_util.h b/tensorflow/core/kernels/lookup_util.h
index 894769960a..ec28cf9fa7 100644
--- a/tensorflow/core/kernels/lookup_util.h
+++ b/tensorflow/core/kernels/lookup_util.h
@@ -46,57 +46,6 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
int32 value_index, Env* env,
InitializableLookupTable* table);
-// Iterator to initialize tables given 'keys' and 'values' tensors.
-//
-// The two tensors are returned in the first iteration. It doesn't loop
-// over each element of the tensor since insertions in the lookup table can
-// process batches.
-class KeyValueTensorIterator
- : public InitializableLookupTable::InitTableIterator {
- public:
- // keys and values are not owned by the iterator.
- explicit KeyValueTensorIterator(const Tensor* keys, const Tensor* values)
- : keys_(keys), values_(values), valid_(true), status_(Status::OK()) {
- TensorShape key_shape = keys_->shape();
- if (!key_shape.IsSameSize(values_->shape())) {
- valid_ = false;
- status_ = errors::InvalidArgument(
- "keys and values should have the same dimension.",
- key_shape.DebugString(), " vs ", values_->shape().DebugString());
- }
- if (key_shape.num_elements() == 0) {
- valid_ = false;
- status_ =
- errors::InvalidArgument("keys and values cannot be empty tensors.");
- }
- }
-
- bool Valid() const override { return valid_; }
-
- void Next() override {
- valid_ = false;
- status_ = errors::OutOfRange("No more data.");
- }
-
- const Tensor& keys() const override { return *keys_; }
-
- const Tensor& values() const override { return *values_; }
-
- Status status() const override { return status_; }
-
- int64 total_size() const override {
- return keys_ == nullptr ? -1 : keys_->NumElements();
- }
-
- private:
- TF_DISALLOW_COPY_AND_ASSIGN(KeyValueTensorIterator);
-
- const Tensor* keys_; // Doesn't own it.
- const Tensor* values_; // Doesn't own it.
- bool valid_; // true if the iterator points to an existing range.
- Status status_;
-};
-
} // namespace lookup
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/loss.h b/tensorflow/core/kernels/loss.h
index a77aa7587b..7db348800e 100644
--- a/tensorflow/core/kernels/loss.h
+++ b/tensorflow/core/kernels/loss.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_LOSS_H_
-#define TENSORFLOW_KERNELS_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_LOSS_H_
#include "tensorflow/core/lib/core/status.h"
@@ -56,4 +56,4 @@ class DualLossUpdater {
};
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_LOSS_H_
diff --git a/tensorflow/core/kernels/loss_test.cc b/tensorflow/core/kernels/loss_test.cc
index 460d65c5c2..9209ed2ab7 100644
--- a/tensorflow/core/kernels/loss_test.cc
+++ b/tensorflow/core/kernels/loss_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -30,6 +31,24 @@ namespace {
// TODO(sibyl-Aix6ihai): add a test to show the improvements of the Newton
// modification detailed in readme.md
+// This test checks that the dual value after update is optimal.
+// At the optimum the dual value should be the opposite of the primal gradient.
+// This does not hold at a point where the primal is not differentiable.
+void TestComputeUpdatedDual(const DualLossUpdater &loss_updater,
+ const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) {
+ double new_dual = loss_updater.ComputeUpdatedDual(
+ num_loss_partitions, label, example_weight, current_dual, wx,
+ weighted_example_norm);
+ // The primal gradient needs to be computed after the weight update.
+ double new_wx = wx + (new_dual - current_dual) * num_loss_partitions *
+ weighted_example_norm * example_weight;
+ EXPECT_NEAR(new_dual, -loss_updater.PrimalLossDerivative(new_wx, label, 1.0),
+ 1e-5);
+}
+
TEST(LogisticLoss, ComputePrimalLoss) {
LogisticLossUpdater loss_updater;
EXPECT_NEAR(0.693147,
@@ -65,19 +84,12 @@ TEST(LogisticLoss, ComputeDualLoss) {
TEST(LogisticLoss, ComputeUpdatedDual) {
LogisticLossUpdater loss_updater;
- EXPECT_NEAR(0.479,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.031,
- loss_updater.ComputeUpdatedDual(
- 2 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, 0.1 /* current_dual */,
- -0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, 0.1 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SquaredLoss, ComputePrimalLoss) {
@@ -126,19 +138,12 @@ TEST(SquaredLoss, ComputeDualLoss) {
TEST(SquaredLoss, ComputeUpdatedDual) {
SquaredLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(HingeLoss, ComputePrimalLoss) {
@@ -207,48 +212,27 @@ TEST(HingeLoss, ConvertLabel) {
TEST(HingeLoss, ComputeUpdatedDual) {
HingeLossUpdater loss_updater;
- // When label=1.0, example_weight=1.0, current_dual=0.5, wx=0.3 and
- // weighted_example_norm=100.0, it turns out that the optimal value to update
- // the dual to is 0.507 which is within the permitted range and thus should be
- // the value returned.
+ // For the two tests belows, y*wx=1 after the update which is a
+ // non-differetiable point of the hinge loss and TestComputeUpdatedDual
+ // cannot be used. Check value of the dual variable instead.
EXPECT_NEAR(0.507,
loss_updater.ComputeUpdatedDual(
1 /* num partitions */, 1.0 /* label */,
1.0 /* example weight */, 0.5 /* current_dual */,
0.3 /* wx */, 100.0 /* weighted_example_norm */),
1e-3);
- // When label=-1.0, example_weight=1.0, current_dual=0.4, wx=0.6,
- // weighted_example_norm=10.0 and num_loss_partitions=10, it turns out that
- // the optimal value to update the dual to is 0.384 which is within the
- // permitted range and thus should be the value returned.
EXPECT_NEAR(-0.416,
loss_updater.ComputeUpdatedDual(
10 /* num partitions */, -1.0 /* label */,
1.0 /* example weight */, -0.4 /* current_dual */,
0.6 /* wx */, 10.0 /* weighted_example_norm */),
1e-3);
- // When label=1.0, example_weight=1.0, current_dual=-0.5, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -0.43. However, this is outside the allowed [0.0, 1.0] range
- // and hence the closest permitted value (0.0) should be returned instead.
- EXPECT_NEAR(0.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, -0.5 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
-
- // When label=-1.0, example_weight=2.0, current_dual=-1.0, wx=0.3 and
- // weighted_example_norm=10.0, it turns out that the optimal value to update
- // the dual to is -1.065. However, this is outside the allowed [-1.0, 0.0]
- // range and hence the closest permitted value (-1.0) should be returned
- // instead.
- EXPECT_NEAR(-1.0,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, -1.0 /* label */,
- 2.0 /* example weight */, -1.0 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, -0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, -1.0 /* label */,
+ 2.0 /* example weight */, -1.0 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
}
TEST(SmoothHingeLoss, ComputePrimalLoss) {
@@ -297,19 +281,75 @@ TEST(SmoothHingeLoss, ComputeDualLoss) {
TEST(SmoothHingeLoss, ComputeUpdatedDual) {
SmoothHingeLossUpdater loss_updater;
- EXPECT_NEAR(0.336,
- loss_updater.ComputeUpdatedDual(
- 1 /* num partitions */, 1.0 /* label */,
- 1.0 /* example weight */, 0.3 /* current_dual */,
- 0.3 /* wx */, 10.0 /* weighted_example_norm */),
- 1e-3);
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 1.0 /* label */,
+ 1.0 /* example weight */, 0.3 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 5 /* num partitions */, -1.0 /* label */,
+ 1.0 /* example weight */, -0.4 /* current_dual */,
+ 0.8 /* wx */, 10.0 /* weighted_example_norm */);
+}
- EXPECT_NEAR(-0.427,
- loss_updater.ComputeUpdatedDual(
- 5 /* num partitions */, -1.0 /* label */,
- 1.0 /* example weight */, -0.4 /* current_dual */,
- 0.8 /* wx */, 10.0 /* weighted_example_norm */),
+TEST(PoissonLoss, ComputePrimalLoss) {
+ PoissonLossUpdater loss_updater;
+ EXPECT_NEAR(1.0,
+ loss_updater.ComputePrimalLoss(0.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
1e-3);
+ EXPECT_NEAR(21996.0,
+ loss_updater.ComputePrimalLoss(10.0 /* wx */, 3.0 /* label */,
+ 1.0 /* example weight */),
+ 1.0);
+ EXPECT_NEAR(0.606,
+ loss_updater.ComputePrimalLoss(-0.5 /* wx */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(6.64,
+ loss_updater.ComputePrimalLoss(1.2 /* wx */, 0.0 /* label */,
+ 2.0 /* example weight */),
+ 1e-2);
+}
+
+TEST(PoissonLoss, ComputeDualLoss) {
+ PoissonLossUpdater loss_updater;
+ // Dual is undefined.
+ EXPECT_NEAR(
+ std::numeric_limits<double>::max(),
+ loss_updater.ComputeDualLoss(1.0 /* current dual */, 0.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ 0.0,
+ loss_updater.ComputeDualLoss(0.0 /* current dual */, 0.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -0.847,
+ loss_updater.ComputeDualLoss(1.5 /* current dual */, 2.0 /* label */,
+ 1.0 /* example weight */),
+ 1e-3);
+ EXPECT_NEAR(
+ -2.675,
+ loss_updater.ComputeDualLoss(0.5 /* current dual */, 2.0 /* label */,
+ 3.0 /* example weight */),
+ 1e-3);
+}
+
+TEST(PoissonLoss, ConvertLabel) {
+ PoissonLossUpdater loss_updater;
+ float example_label = -1.0;
+ // Negative label should throw an error.
+ Status status = loss_updater.ConvertLabel(&example_label);
+ EXPECT_FALSE(status.ok());
+}
+
+TEST(PoissonLoss, ComputeUpdatedDual) {
+ PoissonLossUpdater loss_updater;
+ TestComputeUpdatedDual(loss_updater, 1 /* num partitions */, 2.0 /* label */,
+ 1.0 /* example weight */, 0.5 /* current_dual */,
+ 0.3 /* wx */, 10.0 /* weighted_example_norm */);
+ TestComputeUpdatedDual(loss_updater, 2 /* num partitions */, 0.0 /* label */,
+ 1.0 /* example weight */, 0.0 /* current_dual */,
+ -0.8 /* wx */, 10.0 /* weighted_example_norm */);
}
} // namespace
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index bdc3b5778f..dd89597369 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -410,8 +410,9 @@ class StagingMap : public ResourceBase {
copy_or_move_tensors(&it->second, *key, *indices, tuple));
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
@@ -444,8 +445,9 @@ class StagingMap : public ResourceBase {
*key = it->first;
// Remove entry if all the values have been consumed
- if (!std::any_of(it->second.begin(), it->second.end(),
- std::mem_fn(&OptionalTensor::has_value))) {
+ if (!std::any_of(
+ it->second.begin(), it->second.end(),
+ [](const OptionalTensor& tensor) { return tensor.has_value(); })) {
map_.erase(it);
}
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 79967aab38..4ad390a411 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -578,7 +578,7 @@ struct MatMulFunctor<SYCLDevice, T> {
.Label("cublas"), \
MatMulOp<GPUDevice, T, true /* cublas */>)
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
// MKL does not support half, bfloat16 and int32 types for
// matrix-multiplication, so register the kernel to use default Eigen based
@@ -606,9 +606,9 @@ TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU_EIGEN);
TF_CALL_complex128(REGISTER_CPU_EIGEN);
TF_CALL_double(REGISTER_CPU_EIGEN);
-#endif
+#endif // INTEL_MKL_DNN_ONLY
-#else // INTEL MKL
+#else // INTEL_MKL && ENABLE_MKL
TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_half(REGISTER_CPU);
@@ -616,7 +616,7 @@ TF_CALL_bfloat16(REGISTER_CPU);
TF_CALL_int32(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
#if GOOGLE_CUDA
TF_CALL_float(REGISTER_GPU);
diff --git a/tensorflow/core/kernels/matmul_op.h b/tensorflow/core/kernels/matmul_op.h
index 628895ca86..4b74a64025 100644
--- a/tensorflow/core/kernels/matmul_op.h
+++ b/tensorflow/core/kernels/matmul_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATMUL_OP_H_
-#define TENSORFLOW_KERNELS_MATMUL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
@@ -117,4 +117,4 @@ typedef Eigen::GpuDevice GPUDevice;
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATMUL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/matrix_band_part_op.h b/tensorflow/core/kernels/matrix_band_part_op.h
index 97cc950793..b04e36db8e 100644
--- a/tensorflow/core/kernels/matrix_band_part_op.h
+++ b/tensorflow/core/kernels/matrix_band_part_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct MatrixBandPartFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_BAND_PART_OP_H_
diff --git a/tensorflow/core/kernels/matrix_diag_op.h b/tensorflow/core/kernels/matrix_diag_op.h
index 14095845b8..108ba0f56b 100644
--- a/tensorflow/core/kernels/matrix_diag_op.h
+++ b/tensorflow/core/kernels/matrix_diag_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
// Generator definition for MatrixDiagOp, must be compilable by nvcc.
@@ -91,4 +91,4 @@ struct MatrixDiag {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_exponential_op.cc b/tensorflow/core/kernels/matrix_exponential_op.cc
index 99db898301..01d4894438 100644
--- a/tensorflow/core/kernels/matrix_exponential_op.cc
+++ b/tensorflow/core/kernels/matrix_exponential_op.cc
@@ -49,6 +49,7 @@ class MatrixExponentialOp : public LinearAlgebraOp<Scalar> {
TF_DISALLOW_COPY_AND_ASSIGN(MatrixExponentialOp);
};
+// Deprecated kernels (2018/08/21).
REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<float>), float);
REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<double>), double);
REGISTER_LINALG_OP("MatrixExponential", (MatrixExponentialOp<complex64>),
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.h b/tensorflow/core/kernels/matrix_set_diag_op.h
index aeb144559f..341ef12e97 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.h
+++ b/tensorflow/core/kernels/matrix_set_diag_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
-#define TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct MatrixSetDiag {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MATRIX_SET_DIAG_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SET_DIAG_OP_H_
diff --git a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
index 0e09078365..00a05a87a3 100644
--- a/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
+++ b/tensorflow/core/kernels/matrix_solve_ls_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Cholesky"
@@ -159,3 +162,5 @@ class MatrixSolveLsOp : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_MATRIX_SOLVE_LS_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/maxpooling_op.h b/tensorflow/core/kernels/maxpooling_op.h
index f82e57d44c..2adb8081ce 100644
--- a/tensorflow/core/kernels/maxpooling_op.h
+++ b/tensorflow/core/kernels/maxpooling_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
-#define TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_
// Functor definition for MaxPoolingOp, must be compilable by nvcc.
#include "tensorflow/core/framework/numeric_types.h"
@@ -51,4 +51,4 @@ struct SpatialMaxPooling<Device, qint8> {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MAXPOOLING_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_H_
diff --git a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
index 10e468ce46..693ed8a8f0 100644
--- a/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
+++ b/tensorflow/core/kernels/merge_v2_checkpoints_op_test.cc
@@ -114,9 +114,7 @@ class MergeV2CheckpointsOpTest : public OpsTestBase {
// Exercises "delete_old_dirs".
for (int i = 0; i < 2; ++i) {
int directory_found =
- Env::Default()
- ->IsDirectory(std::string(io::Dirname(prefixes[i])))
- .code();
+ Env::Default()->IsDirectory(string(io::Dirname(prefixes[i]))).code();
if (delete_old_dirs) {
EXPECT_EQ(error::NOT_FOUND, directory_found);
} else {
diff --git a/tensorflow/core/kernels/mirror_pad_op.h b/tensorflow/core/kernels/mirror_pad_op.h
index 81150a9e79..62aa7d5c29 100644
--- a/tensorflow/core/kernels/mirror_pad_op.h
+++ b/tensorflow/core/kernels/mirror_pad_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MIRROR_PAD_OP_H_
-#define TENSORFLOW_KERNELS_MIRROR_PAD_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -103,6 +103,7 @@ struct TensorEvaluator<const TensorMirrorPadOp<PaddingDimensions, ArgType>,
IsAligned = false,
PacketAccess = TensorEvaluator<ArgType, Device>::PacketAccess,
BlockAccess = false,
+ PreferBlockAccess = false,
Layout = TensorEvaluator<ArgType, Device>::Layout,
CoordAccess = true,
RawAccess = false
@@ -437,4 +438,4 @@ struct MirrorPadGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MIRROR_PAD_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_H_
diff --git a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
index f27ca139c9..98e3be082d 100644
--- a/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
+++ b/tensorflow/core/kernels/mirror_pad_op_cpu_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
-#define TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_CPU_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_CPU_IMPL_H_
#define EIGEN_USE_THREADS
@@ -42,4 +42,4 @@ TF_CALL_NUMBER_TYPES(DEFINE_CPU_SPECS);
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_MIRROR_PAD_OP_CPU_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_MIRROR_PAD_OP_CPU_IMPL_H_
diff --git a/tensorflow/core/kernels/mkl_aggregate_ops.cc b/tensorflow/core/kernels/mkl_aggregate_ops.cc
index 28edf51546..20aa1f7ea1 100644
--- a/tensorflow/core/kernels/mkl_aggregate_ops.cc
+++ b/tensorflow/core/kernels/mkl_aggregate_ops.cc
@@ -392,16 +392,28 @@ class MklAddNOp : public OpKernel {
memory::format src1_mkl_data_format = src1_mkl_shape.GetTfDataFormat();
auto src1_tf_data_format =
MklDnnDataFormatToTFDataFormat(src1_mkl_data_format);
- auto src2_dims =
- TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(), src1_tf_data_format);
+ memory::dims src2_dims;
+ if (src2_tensor.dims() == 4) {
+ src2_dims = TFShapeToMklDnnDimsInNCHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ } else {
+ src2_dims = TFShapeToMklDnnDimsInNCDHW(src2_tensor.shape(),
+ src1_tf_data_format);
+ }
md2 = memory::desc(src2_dims, MklDnnType<T>(), src1_mkl_data_format);
} else if (input2_in_mkl_format && !input1_in_mkl_format) {
// Same comment as above.
memory::format src2_mkl_data_format = src2_mkl_shape.GetTfDataFormat();
auto src2_tf_data_format =
MklDnnDataFormatToTFDataFormat(src2_mkl_data_format);
- auto src1_dims =
- TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(), src2_tf_data_format);
+ memory::dims src1_dims;
+ if (src1_tensor.dims() == 4) {
+ src1_dims = TFShapeToMklDnnDimsInNCHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ } else {
+ src1_dims = TFShapeToMklDnnDimsInNCDHW(src1_tensor.shape(),
+ src2_tf_data_format);
+ }
md1 = memory::desc(src1_dims, MklDnnType<T>(), src2_mkl_data_format);
md2 = src2_mkl_shape.GetMklLayout();
diff --git a/tensorflow/core/kernels/mkl_avgpooling_op.cc b/tensorflow/core/kernels/mkl_avgpooling_op.cc
index 969baecc51..2409f7e9dc 100644
--- a/tensorflow/core/kernels/mkl_avgpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_avgpooling_op.cc
@@ -453,6 +453,8 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
+ // check whether pooling is 2D or 3D
+ bool is_pool2d = (this->ksize_.size() == 4);
// Get the input tensor and initialize the pooling parameters
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
@@ -473,23 +475,22 @@ class MklAvgPoolingOp : public MklPoolingForwardOpBase<T> {
}
memory::dims filter_dims, strides, padding_left, padding_right;
+ // Get src/filter/stride/padding information
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
// Get the input memory descriptor
- memory::desc input_md =
- dnn_shape_input.IsMklTensor()
- ? dnn_shape_input.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
-
- // Get src/filter/stride/padding information
memory::dims src_dims =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
+ this->data_format_tf_);
+ memory::desc input_md = dnn_shape_input.IsMklTensor()
+ ? dnn_shape_input.GetMklLayout()
+ : memory::desc(src_dims, MklDnnType<T>(),
+ this->data_format_mkldnn_);
// Get an average pooling primitive from the op pool
MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
@@ -562,24 +563,30 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
for (int i = 0; i < orig_input_tensor.NumElements(); i++) {
orig_input_shape.AddDim(shape_vec(i));
}
+
+ bool is_pool2d = (this->ksize_.size() == 4);
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
+ this->data_format_tf_);
memory::dims diff_dst_dims =
grad_mkl_shape.IsMklTensor()
? grad_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
+ this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
@@ -664,6 +671,18 @@ class MklAvgPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklAvgPoolingGradOp
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklAvgPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("_MklAvgPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklAvgPoolingGradOp<CPUDevice, float>);
+
#endif // INTEL_MKL_ML_ONLY
REGISTER_KERNEL_BUILDER(Name("_MklAvgPool")
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 0841395dc3..bc135de11e 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -223,10 +223,12 @@ class BatchMatMulMkl : public OpKernel {
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
+#ifdef ENABLE_MKL
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
+#endif // ENABLE_MKL
} // end namespace tensorflow
#endif
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 50c25e1da7..f406ad2ab5 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -82,11 +82,11 @@ struct MklConvBwdFilterParams {
};
template <typename T>
-class MklConv2DBwdFilterPrimitive : public MklPrimitive {
+class MklConvBwdFilterPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdFilterPrimitive(
- const MklConvBwdFilterParams& convBwdFilterDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdFilterPrimitive(
+ const MklConvBwdFilterParams& convBwdFilterDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_filter_stream.reset(new stream(stream::kind::eager));
// create conv primitive
if (context_.conv_bwd_filter == nullptr) {
@@ -94,7 +94,7 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
}
}
- ~MklConv2DBwdFilterPrimitive() {}
+ ~MklConvBwdFilterPrimitive() {}
// Convolution backward weights with bias
// src_data: input data buffer of src
@@ -297,38 +297,41 @@ class MklConv2DBwdFilterPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DBwdFilterPrimitive<T>* Get(
- const MklConvBwdFilterParams& convBwdFilterDims) {
- MklConv2DBwdFilterPrimitive<T>* conv2d_bwd_filter = nullptr;
-
- // look into the pool for reusable primitive
- conv2d_bwd_filter = dynamic_cast<MklConv2DBwdFilterPrimitive<T>*> (
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().GetConv2dBwdFilter(
- convBwdFilterDims));
-
- if (conv2d_bwd_filter == nullptr) {
- conv2d_bwd_filter = new MklConv2DBwdFilterPrimitive<T>(
- convBwdFilterDims);
- MklConv2DBwdFilterPrimitiveFactory<T>::GetInstance().SetConv2dBwdFilter(
- convBwdFilterDims, conv2d_bwd_filter);
+ static MklConvBwdFilterPrimitive<T>* Get(
+ const MklConvBwdFilterParams& convBwdFilterDims, bool do_not_cache) {
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
+
+ if (do_not_cache) { /* Create new primitive always */
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_filter = dynamic_cast<MklConvBwdFilterPrimitive<T>*> (
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().GetConvBwdFilter(
+ convBwdFilterDims));
+
+ if (conv_bwd_filter == nullptr) {
+ conv_bwd_filter = new MklConvBwdFilterPrimitive<T>(convBwdFilterDims);
+ MklConvBwdFilterPrimitiveFactory<T>::GetInstance().SetConvBwdFilter(
+ convBwdFilterDims, conv_bwd_filter);
+ }
}
- return conv2d_bwd_filter;
- }
+ return conv_bwd_filter;
+ }
private:
- MklConv2DBwdFilterPrimitiveFactory() {}
- ~MklConv2DBwdFilterPrimitiveFactory() {}
+ MklConvBwdFilterPrimitiveFactory() {}
+ ~MklConvBwdFilterPrimitiveFactory() {}
- static MklConv2DBwdFilterPrimitiveFactory& GetInstance() {
- static MklConv2DBwdFilterPrimitiveFactory instance_;
+ static MklConvBwdFilterPrimitiveFactory& GetInstance() {
+ static MklConvBwdFilterPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdFilterParams& convBwdFilterDims) {
- string prefix = "conv2d_bwd_filter";
+ string prefix = "conv_bwd_filter";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdFilterDims.src_dims);
@@ -342,14 +345,14 @@ class MklConv2DBwdFilterPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdFilter(
+ MklPrimitive* GetConvBwdFilter(
const MklConvBwdFilterParams& convBwdFilterDims) {
string key = CreateKey(convBwdFilterDims);
return this->GetOp(key);
}
- void SetConv2dBwdFilter(
- const MklConvBwdFilterParams& convBwdFilterDims, MklPrimitive* op) {
+ void SetConvBwdFilter(const MklConvBwdFilterParams& convBwdFilterDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdFilterDims);
this->SetOp(key, op);
}
@@ -738,14 +741,13 @@ TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#else
template <typename Device, class T, bool biasEnabled>
-class MklConv2DCustomBackpropFilterOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropFilterOp
+ : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropFilterOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropFilterOp() {}
+ ~MklConvCustomBackpropFilterOp() {}
void Compute(OpKernelContext* context) {
try {
@@ -753,6 +755,9 @@ class MklConv2DCustomBackpropFilterOp
MklDnnData<T> diff_dst(&cpu_engine_);
MklDnnData<T> diff_filter(&cpu_engine_); // output
+ // This flag indicates Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -813,7 +818,10 @@ class MklConv2DCustomBackpropFilterOp
&fwd_dst_dims, &padding_left, &padding_right);
if (!context->status().ok()) return;
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
+
auto fwd_src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
@@ -832,21 +840,24 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
TensorShape obp_tf_shape = GetTfShape(context, 2);
depth = (this->data_format_ == FORMAT_NCHW)
- ? obp_tf_shape.dim_size(1)
- : obp_tf_shape.dim_size(3);
+ ? obp_tf_shape.dim_size(1)
+ : obp_tf_shape.dim_size(isConv2D ? 3 : 4);
diff_bias_dims = {static_cast<int>(depth)};
}
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdFilterPrimitive<T> *conv2d_bwd_filter = nullptr;
+ MklConvBwdFilterPrimitive<T>* conv_bwd_filter = nullptr;
MklConvBwdFilterParams convBwdFilterDims(fwd_src_dims, fwd_filter_dims,
diff_bias_dims, diff_dst_dims, strides, dilations, padding_left,
padding_right, TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_filter = MklConv2DBwdFilterPrimitiveFactory<T>::Get(
- convBwdFilterDims);
- auto bwd_filter_pd = conv2d_bwd_filter->GetPrimitiveDesc();
+
+ // MKL DNN allocates large buffers when a conv gradient filter primtive is
+ // created. So we don't cache conv backward primitives when the env
+ // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
+ conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
+ convBwdFilterDims, do_not_cache);
+ auto bwd_filter_pd = conv_bwd_filter->GetPrimitiveDesc();
// allocate output tensors: diff_fitler and diff_bias (w bias)
auto bwd_output_dims = GetOutputDims(fwd_src_dims, fwd_filter_dims);
@@ -854,14 +865,26 @@ class MklConv2DCustomBackpropFilterOp
// diff_filter
MklDnnShape diff_filter_mkl_shape;
diff_filter_mkl_shape.SetMklTensor(false);
- // output_dims_mkl_order is in OIHW format.
- TensorShape diff_filter_tf_shape(
- {bwd_output_dims[MklDnnDims::Dim_H],
- bwd_output_dims[MklDnnDims::Dim_W],
- bwd_output_dims[MklDnnDims::Dim_I],
- bwd_output_dims[MklDnnDims::Dim_O]});
- AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
- diff_filter_tf_shape, diff_filter_mkl_shape);
+
+ if (isConv2D) {
+ // Conv2D: output_dims_mkl_order is in OIHW format.
+ TensorShape diff_filter_tf_shape({bwd_output_dims[MklDnnDims::Dim_H],
+ bwd_output_dims[MklDnnDims::Dim_W],
+ bwd_output_dims[MklDnnDims::Dim_I],
+ bwd_output_dims[MklDnnDims::Dim_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ } else {
+ // Conv3D: output_dims_mkl_order is in OIDHW format.
+ TensorShape diff_filter_tf_shape(
+ {bwd_output_dims[MklDnnDims3D::Dim3d_D],
+ bwd_output_dims[MklDnnDims3D::Dim3d_H],
+ bwd_output_dims[MklDnnDims3D::Dim3d_W],
+ bwd_output_dims[MklDnnDims3D::Dim3d_I],
+ bwd_output_dims[MklDnnDims3D::Dim3d_O]});
+ AllocateOutputSetMklShape(context, 0, &diff_filter_tensor,
+ diff_filter_tf_shape, diff_filter_mkl_shape);
+ }
Tensor* diff_bias_tensor = nullptr;
if (biasEnabled) {
@@ -871,7 +894,7 @@ class MklConv2DCustomBackpropFilterOp
// check if src and diff_dst need reorder
T *src_data = nullptr;
- if (fwd_src_md.data.format != conv2d_bwd_filter->GetSrcMemoryFormat()) {
+ if (fwd_src_md.data.format != conv_bwd_filter->GetSrcMemoryFormat()) {
src.SetUsrMem(fwd_src_md, &src_tensor);
src.CheckReorderToOpMem(bwd_filter_pd->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -882,7 +905,7 @@ class MklConv2DCustomBackpropFilterOp
T *diff_dst_data = nullptr;
if (diff_dst_md.data.format !=
- conv2d_bwd_filter->GetDiffDstMemoryFormat()) {
+ conv_bwd_filter->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_filter_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -897,7 +920,7 @@ class MklConv2DCustomBackpropFilterOp
bool diff_filter_reorder_required = false;
T *diff_filter_data = nullptr;
if (GetOutputFormat(tf_fmt) !=
- conv2d_bwd_filter->GetDiffFilterMemoryFormat()) {
+ conv_bwd_filter->GetDiffFilterMemoryFormat()) {
// Allocate diff filter tensor as Tensorflow layout
diff_filter.SetUsrMem(bwd_output_dims, GetOutputFormat(tf_fmt),
diff_filter_tensor);
@@ -915,16 +938,19 @@ class MklConv2DCustomBackpropFilterOp
if (biasEnabled) {
T* diff_bias_data = static_cast<T*>(const_cast<T*>(
diff_bias_tensor->flat<T>().data()));
- conv2d_bwd_filter->Execute(src_data, diff_filter_data,
- diff_bias_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_bias_data,
+ diff_dst_data);
} else {
- conv2d_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
+ conv_bwd_filter->Execute(src_data, diff_filter_data, diff_dst_data);
}
// Reorder diff_filter back to Tensorflow layout if necessary
if (diff_filter_reorder_required) {
diff_filter.InsertReorderToUserMem();
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_bwd_filter;
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -947,7 +973,7 @@ class MklConv2DCustomBackpropFilterOp
const MklDnnShape& filter_mkl_shape,
const MklDnnShape& obp_mkl_shape) {
CHECK(!filter_mkl_shape.IsMklTensor())
- << "Conv2DBackpropFilter: filter should not be in MKL Layout";
+ << "ConvBackpropFilter: filter should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -983,9 +1009,11 @@ class MklConv2DCustomBackpropFilterOp
return fwd_filter_dims;
}
- // Output layout is Tensorflow's filter layout (HWIO).
+ // Output layout is Tensorflow's filter layout
+ // Conv2D: HWIO; Conv3D: DHWIO
memory::format GetOutputFormat(const memory::format data_format) {
- return memory::format::hwio;
+ return (this->strides_.size() == 4) ? memory::format::hwio
+ : memory::format::dhwio;
}
// Allocate output tensor.
@@ -1027,24 +1055,27 @@ class MklConv2DCustomBackpropFilterOp
}
};
-#define REGISTER_MKL_FILTER_KERNELS(T) \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilter") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, false>); \
- REGISTER_KERNEL_BUILDER( \
- Name("_MklConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropFilterOp<CPUDevice, T, true>); \
- REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklDummyOp<CPUDevice, T>);
+#define REGISTER_MKL_FILTER_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, true>); \
+ REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DBackpropFilterWithBias") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDummyOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropFilterV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropFilterOp<CPUDevice, T, false>);
TF_CALL_float(REGISTER_MKL_FILTER_KERNELS);
#undef REGISTER_MKL_FILTER_KERNELS
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index 38e014d68e..a501ce2c93 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -59,7 +59,7 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_ML_ONLY
-/// utility classes enabling primitive reuse for backward conv2d ops.
+/// utility classes enabling primitive reuse for backward conv ops.
struct MklConvBwdInputParams {
memory::dims diff_src_dims;
memory::dims filter_dims;
@@ -83,11 +83,11 @@ struct MklConvBwdInputParams {
};
template <typename T>
-class MklConv2DBwdInputPrimitive : public MklPrimitive {
+class MklConvBwdInputPrimitive : public MklPrimitive {
public:
- explicit MklConv2DBwdInputPrimitive(
- const MklConvBwdInputParams& convBwdInputDims) :
- cpu_engine_(engine::cpu, 0) {
+ explicit MklConvBwdInputPrimitive(
+ const MklConvBwdInputParams& convBwdInputDims)
+ : cpu_engine_(engine::cpu, 0) {
context_.bwd_input_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -95,7 +95,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
Setup(convBwdInputDims);
}
}
- ~MklConv2DBwdInputPrimitive() {}
+ ~MklConvBwdInputPrimitive() {}
// Convolution backward filter (weights)
// diff_src_data: output data buffer of diff_src
@@ -134,7 +134,7 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
private:
- // Primitive reuse context for Conv2D Bwd Input op
+ // Primitive reuse context for Conv Bwd Input op
struct ConvBwdInputContext {
// expected memory format for this primitive instance
memory::format filter_fmt;
@@ -174,7 +174,6 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
}
};
-
void Setup(const MklConvBwdInputParams& convBwdInputDims) {
// create memory descriptors for convolution data w/ no specified format
context_.diff_src_md.reset(new memory::desc(
@@ -235,38 +234,41 @@ class MklConv2DBwdInputPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
private:
- MklConv2DBwdInputPrimitiveFactory() {}
- ~MklConv2DBwdInputPrimitiveFactory() {}
+ MklConvBwdInputPrimitiveFactory() {}
+ ~MklConvBwdInputPrimitiveFactory() {}
public:
- static MklConv2DBwdInputPrimitive<T>* Get(
- const MklConvBwdInputParams& convBwdInputDims) {
- MklConv2DBwdInputPrimitive<T>* conv2d_bwd_input = nullptr;
-
- // look into the pool for reusable primitive
- conv2d_bwd_input = dynamic_cast<MklConv2DBwdInputPrimitive<T>*> (
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().GetConv2dBwdInput(
- convBwdInputDims));
-
- if (conv2d_bwd_input == nullptr) {
- conv2d_bwd_input = new MklConv2DBwdInputPrimitive<T>(
- convBwdInputDims);
- MklConv2DBwdInputPrimitiveFactory<T>::GetInstance().SetConv2dBwdInput(
- convBwdInputDims, conv2d_bwd_input);
+ static MklConvBwdInputPrimitive<T>* Get(
+ const MklConvBwdInputParams& convBwdInputDims, bool do_not_cache) {
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
+
+ if (do_not_cache) { /* Always allocate primitive */
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ } else {
+ // look into the pool for reusable primitive
+ conv_bwd_input = dynamic_cast<MklConvBwdInputPrimitive<T>*>(
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().GetConvBwdInput(
+ convBwdInputDims));
+ if (conv_bwd_input == nullptr) {
+ conv_bwd_input = new MklConvBwdInputPrimitive<T>(convBwdInputDims);
+ MklConvBwdInputPrimitiveFactory<T>::GetInstance().SetConvBwdInput(
+ convBwdInputDims, conv_bwd_input);
+ }
}
- return conv2d_bwd_input;
+
+ return conv_bwd_input;
}
private:
- static MklConv2DBwdInputPrimitiveFactory& GetInstance() {
- static MklConv2DBwdInputPrimitiveFactory instance_;
+ static MklConvBwdInputPrimitiveFactory& GetInstance() {
+ static MklConvBwdInputPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvBwdInputParams& convBwdInputDims) {
- string prefix = "conv2d_bwd_input";
+ string prefix = "conv_bwd_input";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convBwdInputDims.diff_src_dims);
@@ -279,14 +281,13 @@ class MklConv2DBwdInputPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims) {
+ MklPrimitive* GetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims) {
string key = CreateKey(convBwdInputDims);
return this->GetOp(key);
}
- void SetConv2dBwdInput(
- const MklConvBwdInputParams& convBwdInputDims, MklPrimitive *op) {
+ void SetConvBwdInput(const MklConvBwdInputParams& convBwdInputDims,
+ MklPrimitive* op) {
string key = CreateKey(convBwdInputDims);
this->SetOp(key, op);
}
@@ -594,23 +595,34 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format;
};
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+
+TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
+#undef REGISTER_MKL_CPU_KERNELS
+
#else
template <typename Device, class T>
-class MklConv2DCustomBackpropInputOp
- : public MklConv2DBackpropCommonOp<Device, T> {
+class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
public:
- explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
- : MklConv2DBackpropCommonOp<Device, T>(context) {
- }
+ explicit MklConvCustomBackpropInputOp(OpKernelConstruction* context)
+ : MklConvBackpropCommonOp<Device, T>(context) {}
- ~MklConv2DCustomBackpropInputOp() {}
+ ~MklConvCustomBackpropInputOp() {}
void Compute(OpKernelContext* context) {
try {
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
+ // This flag indicate Conv2D or Conv3D
+ bool isConv2D = (this->strides_.size() == 4);
+
// Input tensors
const int kInputIdx = 0, kFilterIdx = 1, kOutbpropIdx = 2;
const Tensor& src_tensor = MklGetInput(context, kInputIdx);
@@ -626,7 +638,7 @@ class MklConv2DCustomBackpropInputOp
diff_dst_mkl_shape);
// Allow operator-specific generation of shapes.
- // E.g., Conv2DBackpropFilter gets filter as filter_sizes. It is a
+ // E.g., ConvBackpropFilter gets filter as filter_sizes. It is a
// tensor containing shape of filter. So filter.shape() is not
// a correct way to get filter shape. These operator-specific calls
// allow this class to handle this case.
@@ -655,6 +667,7 @@ class MklConv2DCustomBackpropInputOp
}
return;
}
+
// By default, all dims are in MKL order. Only dims in TF order
// are those with postfix tf_order.
memory::dims diff_dst_dims, fwd_src_dims, fwd_filter_dims;
@@ -673,15 +686,18 @@ class MklConv2DCustomBackpropInputOp
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(this->data_format_);
+ auto tf_fmt = isConv2D
+ ? TFDataFormatToMklDnnDataFormat(this->data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_);
// If filter is in MKL layout, then simply grab filter layout;
// otherwise, construct filter in TF layout.
// For TF layout, filter is in HWIO format.
auto fwd_filter_md = filter_mkl_shape.IsMklTensor()
- ? filter_mkl_shape.GetMklLayout()
- : memory::desc(fwd_filter_dims, MklDnnType<T>(),
- memory::format::hwio);
+ ? filter_mkl_shape.GetMklLayout()
+ : memory::desc(fwd_filter_dims, MklDnnType<T>(),
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
if (!context->status().ok()) return;
@@ -689,18 +705,25 @@ class MklConv2DCustomBackpropInputOp
? diff_dst_mkl_shape.GetMklLayout()
: memory::desc(diff_dst_dims,
MklDnnType<T>(), tf_fmt);
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
-
- MklConv2DBwdInputPrimitive<T> *conv2d_bwd_input = nullptr;
- conv_utl.GetInputSizeInMklOrder(diff_dst_tf_shape, &diff_dst_dims);
+ MklConvBwdInputPrimitive<T>* conv_bwd_input = nullptr;
MklConvBwdInputParams convBwdInputDims(fwd_src_dims, fwd_filter_dims,
diff_dst_dims, strides, dilations, padding_left, padding_right,
TFPaddingToMklDnnPadding(this->padding_));
- conv2d_bwd_input = MklConv2DBwdInputPrimitiveFactory<T>::Get(
- convBwdInputDims);
- auto bwd_input_pd = conv2d_bwd_input->GetPrimitiveDesc();
+
+ // We don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
+ // includes potentialy large buffers. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(fwd_filter_dims, strides));
+ conv_bwd_input = MklConvBwdInputPrimitiveFactory<T>::Get(convBwdInputDims,
+ do_not_cache);
+ auto bwd_input_pd = conv_bwd_input->GetPrimitiveDesc();
// allocate output tensor
auto diff_src_pd = bwd_input_pd->diff_src_primitive_desc();
@@ -723,7 +746,7 @@ class MklConv2DCustomBackpropInputOp
// check if filter and diff_dst need reorder
T* filter_data = nullptr;
if (fwd_filter_md.data.format !=
- conv2d_bwd_input->GetFilterMemoryFormat()) {
+ conv_bwd_input->GetFilterMemoryFormat()) {
filter.SetUsrMem(fwd_filter_md, &filter_tensor);
filter.CheckReorderToOpMem(bwd_input_pd->weights_primitive_desc());
filter_data = static_cast<T*>(filter.GetOpMem().get_data_handle());
@@ -733,8 +756,7 @@ class MklConv2DCustomBackpropInputOp
}
T* diff_dst_data = nullptr;
- if (diff_dst_md.data.format !=
- conv2d_bwd_input->GetDiffDstMemoryFormat()) {
+ if (diff_dst_md.data.format != conv_bwd_input->GetDiffDstMemoryFormat()) {
diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
diff_dst.CheckReorderToOpMem(bwd_input_pd->diff_dst_primitive_desc());
diff_dst_data = static_cast<T*>(
@@ -745,7 +767,12 @@ class MklConv2DCustomBackpropInputOp
}
// execute convolution input bwd
- conv2d_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+ conv_bwd_input->Execute(diff_src_data, filter_data, diff_dst_data);
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) {
+ delete conv_bwd_input;
+ }
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -770,7 +797,7 @@ class MklConv2DCustomBackpropInputOp
// of the Tensor and never an actual tensor. So it will never be in MKL
// layout.
CHECK(!input_mkl_shape.IsMklTensor())
- << "Conv2DBackpropInput: input should not be in MKL Layout";
+ << "ConvBackpropInput: input should not be in MKL Layout";
}
// Get TensorFlow shape of input tensor.
@@ -778,10 +805,10 @@ class MklConv2DCustomBackpropInputOp
const Tensor& input_tensor) {
TensorShape input_tf_shape;
CHECK_EQ(TensorShapeUtils::IsVector(input_tensor.shape()), true);
- CHECK_EQ(
- TensorShapeUtils::MakeShape(input_tensor.vec<int32>(), &input_tf_shape)
- .ok(),
- true);
+ // Conv[2D|3D]BackpropInputV2 supports both DT_INT32 and DT_INT64
+ // output_shape MakeShape is able to handle both DT_INT32 and DT_INT64 for
+ // input_tensor.
+ CHECK_EQ(this->MakeShape(input_tensor, &input_tf_shape).ok(), true);
return input_tf_shape;
}
@@ -792,7 +819,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
TensorShape GetOutputTfShape(const TensorShape& input_shape,
const TensorShape& filter_shape,
const TensorShape& outbprop_shape) {
@@ -800,7 +827,7 @@ class MklConv2DCustomBackpropInputOp
}
// Get the Tensorflow shape of Output (diff_src),
- // which is same as shape of Conv2D 'input'.
+ // which is same as shape of Conv 'input'.
const memory::dims& GetOutputDims(const memory::dims& fwd_input_dims,
const memory::dims& fwd_filter_dims) {
return fwd_input_dims;
@@ -839,17 +866,22 @@ class MklConv2DCustomBackpropInputOp
}
};
-#endif // INTEL_MKL_ML_ONLY
-
-#define REGISTER_MKL_CPU_KERNELS(T) \
- REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<T>("T") \
- .Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DCustomBackpropInputOp<CPUDevice, T>);
+#define REGISTER_MKL_CPU_KERNELS(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>); \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3DBackpropInputV2") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvCustomBackpropInputOp<CPUDevice, T>);
TF_CALL_float(REGISTER_MKL_CPU_KERNELS);
#undef REGISTER_MKL_CPU_KERNELS
+#endif // INTEL_MKL_ML_ONLY
+
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index bca1aa21a8..b332edad0a 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -85,9 +85,9 @@ struct MklConvFwdParams {
};
template <typename T>
-class MklConv2DFwdPrimitive : public MklPrimitive {
+class MklConvFwdPrimitive : public MklPrimitive {
public:
- explicit MklConv2DFwdPrimitive(const MklConvFwdParams& convFwdDims)
+ explicit MklConvFwdPrimitive(const MklConvFwdParams& convFwdDims)
: cpu_engine_(engine::cpu, 0) {
context_.fwd_stream.reset(new stream(stream::kind::eager));
// create conv primitive
@@ -96,7 +96,7 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
}
}
- ~MklConv2DFwdPrimitive() {}
+ ~MklConvFwdPrimitive() {}
// Convolution forward execute with bias
// src_data: input data buffer of src
@@ -269,37 +269,41 @@ class MklConv2DFwdPrimitive : public MklPrimitive {
};
template <typename T>
-class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+class MklConvFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
public:
- static MklConv2DFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims) {
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
-
- // try to find a suitable one in pool
- conv2d_fwd = dynamic_cast<MklConv2DFwdPrimitive<T>*>(
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().GetConv2DFwd(
- convFwdDims));
-
- if (conv2d_fwd == nullptr) {
- conv2d_fwd = new MklConv2DFwdPrimitive<T>(convFwdDims);
- MklConv2DFwdPrimitiveFactory<T>::GetInstance().SetConv2DFwd(convFwdDims,
- conv2d_fwd);
+ static MklConvFwdPrimitive<T>* Get(const MklConvFwdParams& convFwdDims,
+ bool do_not_cache) {
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
+
+ if (do_not_cache) { /* Always create new primitive */
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ } else {
+ // try to find a suitable one in pool
+ conv_fwd = dynamic_cast<MklConvFwdPrimitive<T>*>(
+ MklConvFwdPrimitiveFactory<T>::GetInstance().GetConvFwd(convFwdDims));
+ if (conv_fwd == nullptr) {
+ conv_fwd = new MklConvFwdPrimitive<T>(convFwdDims);
+ MklConvFwdPrimitiveFactory<T>::GetInstance().SetConvFwd(convFwdDims,
+ conv_fwd);
+ }
}
- return conv2d_fwd;
+
+ return conv_fwd;
}
private:
- MklConv2DFwdPrimitiveFactory() {}
- ~MklConv2DFwdPrimitiveFactory() {}
+ MklConvFwdPrimitiveFactory() {}
+ ~MklConvFwdPrimitiveFactory() {}
static const int kDilationH = 0, kDilationW = 1;
- static MklConv2DFwdPrimitiveFactory& GetInstance() {
- static MklConv2DFwdPrimitiveFactory instance_;
+ static MklConvFwdPrimitiveFactory& GetInstance() {
+ static MklConvFwdPrimitiveFactory instance_;
return instance_;
}
static string CreateKey(const MklConvFwdParams& convFwdDims) {
- string prefix = "conv2d_fwd_";
+ string prefix = "conv_fwd_";
FactoryKeyCreator key_creator;
key_creator.AddAsKey(prefix);
key_creator.AddAsKey(convFwdDims.src_dims);
@@ -313,12 +317,12 @@ class MklConv2DFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
return key_creator.GetKey();
}
- MklPrimitive* GetConv2DFwd(const MklConvFwdParams& convFwdDims) {
+ MklPrimitive* GetConvFwd(const MklConvFwdParams& convFwdDims) {
string key = CreateKey(convFwdDims);
return this->GetOp(key);
}
- void SetConv2DFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
+ void SetConvFwd(const MklConvFwdParams& convFwdDims, MklPrimitive* op) {
string key = CreateKey(convFwdDims);
this->SetOp(key, op);
}
@@ -331,11 +335,11 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
// For now, MKL-ML is default. So making MKL-DNN not a default choice.
#ifdef INTEL_MKL_ML_ONLY
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -755,21 +759,22 @@ class MklConv2DOp : public OpKernel {
#else
+// Base class for convolution forward operations
template <typename Device, typename T, bool biasEnabled>
-class MklConv2DOp : public OpKernel {
+class MklConvOp : public OpKernel {
public:
- ~MklConv2DOp() {}
+ ~MklConvOp() {}
- explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
+ explicit MklConvOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
- OP_REQUIRES(context, strides_.size() == 4,
+ OP_REQUIRES(context, (strides_.size() == 4 || strides_.size() == 5),
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
@@ -778,20 +783,39 @@ class MklConv2DOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ } else if (strides_.size() == 5) {
+ OP_REQUIRES(context, dilations_.size() == 5,
+ errors::InvalidArgument("Dilation rates field must "
+ "specify 5 dimensions"));
+ OP_REQUIRES(context,
+ (GetTensorDim(dilations_, data_format_, 'N') == 1 &&
+ GetTensorDim(dilations_, data_format_, 'C') == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations rates in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context,
+ (GetTensorDim(dilations_, data_format_, '0') > 0 &&
+ GetTensorDim(dilations_, data_format_, '1') > 0 &&
+ GetTensorDim(dilations_, data_format_, '2') > 0),
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
}
void Compute(OpKernelContext* context) override {
@@ -837,7 +861,8 @@ class MklConv2DOp : public OpKernel {
AllocateOutputSetMklShape(context, kOutputIndex_Dst,
&dst_tensor, src_tf_shape, dst_mkl_shape);
- // MklConv2D also outputs converted filter as 2nd output of Conv2D.
+ // MklConv2D/3D also outputs converted filter
+ // as 2nd output of Conv2D/3D.
filter_mkl_shape.SetMklTensor(false);
Tensor* output_filter_tensor = nullptr;
AllocateOutputSetMklShape(context, kOutputIndex_Filter,
@@ -846,15 +871,20 @@ class MklConv2DOp : public OpKernel {
return;
}
+ bool isConv2D = (strides_.size() == 4);
+
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
- auto tf_fmt = TFDataFormatToMklDnnDataFormat(data_format_);
+ auto tf_fmt = isConv2D ? TFDataFormatToMklDnnDataFormat(data_format_)
+ : TFDataFormatToMklDnn3DDataFormat(data_format_);
// If input is in MKL layout, then simply grab input layout; otherwise,
// construct input Tf layout. For TF layout, although input shape
// (src_dims) required is in MKL-DNN order, the layout is Tensorflow's
- // layout (NHWC or NCHW depending on data format).
+ // layout depending on data format:
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
auto src_md = src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
: memory::desc(src_dims, MklDnnType<T>(), tf_fmt);
@@ -864,31 +894,43 @@ class MklConv2DOp : public OpKernel {
auto filter_md = filter_mkl_shape.IsMklTensor() // Should NEVER be true
? filter_mkl_shape.GetMklLayout()
: memory::desc(filter_dims, MklDnnType<T>(),
- memory::format::hwio);
-
+ isConv2D ? memory::format::hwio
+ : memory::format::dhwio);
// MKLDNN dilation starts from 0.
- dilations[kDilationH] -= 1;
- dilations[kDilationW] -= 1;
+ for (int i = 0; i < dilations.size(); i++) dilations[i] -= 1;
+
+ // In some cases, primitve descriptor includes potentialy large buffers,
+ // we don't cache those primitves if the env variable
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true. MKL DNN allocates buffers
+ // in the following cases
+ // 1. Legacy CPU without AVX512/AVX2, or
+ // 2. 1x1 convolution with stride != 1
+ bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled() &&
+ (src_dims[MklDnnDims::Dim_N] > kSmallBatchSize) &&
+ (MklPrimitiveFactory<T>::IsLegacyPlatform() ||
+ IsConv1x1StrideNot1(filter_dims, strides));
// get a conv2d fwd from primitive pool
- MklConv2DFwdPrimitive<T>* conv2d_fwd = nullptr;
+ MklConvFwdPrimitive<T>* conv_fwd = nullptr;
if (biasEnabled) {
memory::dims bias_dims = {};
conv_utl.GetBiasSizeInMklOrder(kInputIndex_Bias, &bias_dims);
MklConvFwdParams convFwdDims(src_dims, filter_dims, bias_dims,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
} else {
MklConvFwdParams convFwdDims(src_dims, filter_dims, NONE_DIMS,
dst_dims_mkl_order, strides, dilations,
padding_left, padding_right);
- conv2d_fwd = MklConv2DFwdPrimitiveFactory<T>::Get(convFwdDims);
+ conv_fwd = MklConvFwdPrimitiveFactory<T>::Get(
+ convFwdDims, do_not_cache);
}
// allocate output tensors output_tensor and filter_out_tensor
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_fwd_pd =
- conv2d_fwd->GetPrimitiveDesc();
+ conv_fwd->GetPrimitiveDesc();
AllocateOutputTensor(context, *conv_fwd_pd,
dst_dims_mkl_order, tf_fmt, &dst_tensor);
Tensor* filter_out_tensor = nullptr;
@@ -900,7 +942,7 @@ class MklConv2DOp : public OpKernel {
// check whether src/filter need reorder
T *src_data = nullptr;
- if (src_md.data.format != conv2d_fwd->GetSrcMemoryFormat()) {
+ if (src_md.data.format != conv_fwd->GetSrcMemoryFormat()) {
src.SetUsrMem(src_md, &src_tensor);
src.CheckReorderToOpMem(conv_fwd_pd.get()->src_primitive_desc());
src_data = static_cast<T*>(src.GetOpMem().get_data_handle());
@@ -908,7 +950,7 @@ class MklConv2DOp : public OpKernel {
src_data = static_cast<T*>(const_cast<T*>(src_tensor.flat<T>().data()));
}
T* filter_data = nullptr;
- if (filter_md.data.format != conv2d_fwd->GetFilterMemoryFormat()) {
+ if (filter_md.data.format != conv_fwd->GetFilterMemoryFormat()) {
filter.SetUsrMem(filter_md, &filter_tensor);
filter.CheckReorderToOpMem(conv_fwd_pd.get()->weights_primitive_desc(),
filter.GetTensorBuffer(filter_out_tensor));
@@ -918,17 +960,19 @@ class MklConv2DOp : public OpKernel {
static_cast<T*>(const_cast<T*>(filter_tensor.flat<T>().data()));
}
-
// execute convolution
if (biasEnabled) {
const Tensor& bias_tensor = MklGetInput(context, kInputIndex_Bias);
T* bias_data = static_cast<T*>(const_cast<T*>(
bias_tensor.flat<T>().data()));
- conv2d_fwd->Execute(src_data, filter_data, bias_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, bias_data, dst_data);
} else {
- conv2d_fwd->Execute(src_data, filter_data, dst_data);
+ conv_fwd->Execute(src_data, filter_data, dst_data);
}
+
+ // delete primitive since it is not cached.
+ if (do_not_cache) delete conv_fwd;
} catch (mkldnn::error &e) {
string error_msg = tensorflow::strings::StrCat(
"Status: ", e.status, ", message: ", string(e.message), ", in file ",
@@ -1038,24 +1082,34 @@ class MklConv2DOp : public OpKernel {
#endif
-#define REGISTER_MKL_CPU(T) \
+// Register 2D operations
+#define REGISTER_MKL_CPU_2D(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, false>); \
+ MklConvOp<CPUDevice, T, false>); \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
- MklConv2DOp<CPUDevice, T, true>); \
+ MklConvOp<CPUDevice, T, true>); \
REGISTER_KERNEL_BUILDER(Name("__MklDummyConv2DWithBias") \
.Device(DEVICE_CPU) \
.TypeConstraint<T>("T") \
.Label(mkl_op_registry::kMklOpLabel), \
MklDummyOp<CPUDevice, T>);
-TF_CALL_float(REGISTER_MKL_CPU);
+TF_CALL_float(REGISTER_MKL_CPU_2D);
+
+// Register 3D operations
+#define REGISTER_MKL_CPU_3D(T) \
+ REGISTER_KERNEL_BUILDER(Name("_MklConv3D") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<T>("T") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklConvOp<CPUDevice, T, false>);
+TF_CALL_float(REGISTER_MKL_CPU_3D);
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_conv_ops.h b/tensorflow/core/kernels/mkl_conv_ops.h
index 838c06f49d..01cc606f41 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.h
+++ b/tensorflow/core/kernels/mkl_conv_ops.h
@@ -79,9 +79,16 @@ class MklDnnConvUtil {
// For now we take the stride from the second and third dimensions only
// (we do not support striding on the batch or depth dimension).
CHECK_NOTNULL(strides);
- int stride_rows = GetTensorDim(strides_, data_format_, 'H');
- int stride_cols = GetTensorDim(strides_, data_format_, 'W');
- *strides = {stride_rows, stride_cols};
+ if (strides_.size() == 4) {
+ int stride_rows = GetTensorDim(strides_, data_format_, 'H');
+ int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ *strides = {stride_rows, stride_cols};
+ } else if (strides_.size() == 5) {
+ int stride_planes = GetTensorDim(strides_, data_format_, '0');
+ int stride_rows = GetTensorDim(strides_, data_format_, '1');
+ int stride_cols = GetTensorDim(strides_, data_format_, '2');
+ *strides = {stride_planes, stride_rows, stride_cols};
+ }
}
// Calculate Convolution dilations
@@ -89,13 +96,20 @@ class MklDnnConvUtil {
// For now we take the dilation from the second and third dimensions only
// (we do not support dilation on the batch or depth dimension).
CHECK_NOTNULL(dilations);
- int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
- int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
- *dilations = {dilations_rows, dilations_cols};
+ if (dilations_.size() == 4) {
+ int dilations_rows = GetTensorDim(dilations_, data_format_, 'H');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, 'W');
+ *dilations = {dilations_rows, dilations_cols};
+ } else if (dilations_.size() == 5) {
+ int dilations_planes = GetTensorDim(dilations_, data_format_, '0');
+ int dilations_rows = GetTensorDim(dilations_, data_format_, '1');
+ int dilations_cols = GetTensorDim(dilations_, data_format_, '2');
+ *dilations = {dilations_planes, dilations_rows, dilations_cols};
+ }
}
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
- // requires input in NCHW format. Function does not return anything.
+ // requires input in NCHW/NCDHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void GetInputSizeInMklOrder(const TensorShape& input_shape,
@@ -113,40 +127,62 @@ class MklDnnConvUtil {
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
int input_depth = static_cast<int>(input_depth_raw);
- // Input rows/height
- int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
- CHECK_BOUNDS(input_rows_raw, "Input rows too large");
- int input_rows = static_cast<int>(input_rows_raw);
-
- // Input columns/width
- int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
- CHECK_BOUNDS(input_cols_raw, "Input cols too large");
- int input_cols = static_cast<int>(input_cols_raw);
-
// Input batch
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
+ if (strides_.size() == 4) { // NCHW format for Conv2D
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCHW format Conv2D.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ } else if (strides_.size() == 5) { // NCDHW format for Conv3D
+ // Input planes/third-dimension
+ int64 input_planes_raw = GetTensorDim(input_shape, data_format_, '0');
+ CHECK_BOUNDS(input_planes_raw, "Input depth too large");
+ int input_planes = static_cast<int>(input_planes_raw);
+
+ // Input rows/height
+ int64 input_rows_raw = GetTensorDim(input_shape, data_format_, '1');
+ CHECK_BOUNDS(input_rows_raw, "Input rows too large");
+ int input_rows = static_cast<int>(input_rows_raw);
+
+ // Input columns/width
+ int64 input_cols_raw = GetTensorDim(input_shape, data_format_, '2');
+ CHECK_BOUNDS(input_cols_raw, "Input cols too large");
+ int input_cols = static_cast<int>(input_cols_raw);
+
+ // MKL-DNN always requires input in NCDHW format for Conv3D.
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_cols;
+
+ *input_dims = mkldnn_sizes;
+ }
#undef CHECK_BOUNDS
-
- // MKL-DNN always requires input in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = input_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = input_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = input_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = input_cols;
-
- *input_dims = mkldnn_sizes;
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
- //
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW (Conv3D) format.
+ // Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status. This function differs from GetConvFilterSizeInMklOrder in
// parameter for input - it accepts src_shape since Convolution Backward
@@ -159,11 +195,13 @@ class MklDnnConvUtil {
memory::dims* filter_dims) {
CHECK_NOTNULL(filter_dims);
- OP_REQUIRES(context_, filter_shape.dims() == 4,
- errors::InvalidArgument("filter must be 4-dimensional: ",
+ OP_REQUIRES(context_, filter_shape.dims() == strides_.size(),
+ errors::InvalidArgument((strides_.size() == 4)
+ ? "filter must be 4-dimensional: "
+ : "filter must be 5-dimensional: ",
filter_shape.DebugString()));
- for (int i = 0; i < 3; i++) {
+ for (int i = 0; i < ((strides_.size() == 4) ? 3 : 5); i++) {
OP_REQUIRES(context_,
FastBoundsCheck(filter_shape.dim_size(i),
std::numeric_limits<int>::max()),
@@ -172,32 +210,57 @@ class MklDnnConvUtil {
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
- OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
- errors::InvalidArgument(
- "input and filter must have the same depth: ", input_depth,
- " vs ", filter_shape.dim_size(2)));
-
- // TF filter is always in (rows, cols, in_depth, out_depth) order.
- int filter_rows = static_cast<int>(filter_shape.dim_size(0));
- int filter_cols = static_cast<int>(filter_shape.dim_size(1));
- int in_depth = static_cast<int>(filter_shape.dim_size(2));
- int out_depth = static_cast<int>(filter_shape.dim_size(3));
-
- // MKL-DNN always needs filter in OIHW format.
- // OIHW = (out_depth, in_depth, rows, cols)
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
- mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
-
- *filter_dims = mkldnn_sizes;
+ if (strides_.size() == 4) { // Conv2D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(2),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(2)));
+
+ // TF filter is always in (rows, cols, in_depth, out_depth) order.
+ int filter_rows = static_cast<int>(filter_shape.dim_size(0));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(1));
+ int in_depth = static_cast<int>(filter_shape.dim_size(2));
+ int out_depth = static_cast<int>(filter_shape.dim_size(3));
+
+ // MKL-DNN always needs filter in OIHW format.
+ // OIHW = (out_depth, in_depth, rows, cols)
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_O] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_I] = in_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims::Dim_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ } else { // Conv3D
+ OP_REQUIRES(context_, input_depth == filter_shape.dim_size(3),
+ errors::InvalidArgument(
+ "input and filter must have the same depth: ",
+ input_depth, " vs ", filter_shape.dim_size(3)));
+
+ // TF filter is always in (planes, rows, cols, in_depth, out_depth) order.
+ int filter_planes = static_cast<int>(filter_shape.dim_size(0));
+ int filter_rows = static_cast<int>(filter_shape.dim_size(1));
+ int filter_cols = static_cast<int>(filter_shape.dim_size(2));
+ int in_depth = static_cast<int>(filter_shape.dim_size(3));
+ int out_depth = static_cast<int>(filter_shape.dim_size(4));
+
+ // MKL-DNN always needs filter in OIDHW format.
+ // OIDHW = (out_depth, in_depth, planes, rows, cols)
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_O] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_I] = in_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = filter_planes;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = filter_rows;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = filter_cols;
+
+ *filter_dims = mkldnn_sizes;
+ }
}
- // Calculate Convolution filter size in MKL-DNN order. MKL-DNN
- // requires filter in OIHW format. Function does not return anything.
- // But errors arising from sanity checks are returned in context's
- // status.
+ // Calculate Convolution filter size in MKL-DNN order.
+ // MKL-DNN requires filter in OIHW (Conv2D) or OIDHW(Conv3D format.
+ // Function does not return anything. But errors arising from sanity
+ // checks are returned in context's status.
virtual inline void GetFilterSizeInMklOrder(size_t src_index,
size_t filter_index,
memory::dims* filter_dims) {
@@ -206,8 +269,8 @@ class MklDnnConvUtil {
GetTfShape(context_, filter_index), filter_dims);
}
- // Calculate Bias size for 2D Convolution. Function does not return
- // anything, but sets error in context status.
+ // Calculate Bias size for 2D or 3D Convolution. Function does not
+ // return anything, but may set an error in context status.
virtual inline void GetBiasSizeInMklOrder(size_t bias_index,
memory::dims* bias_dims) {
const Tensor& bias = MklGetInput(context_, bias_index);
@@ -218,73 +281,142 @@ class MklDnnConvUtil {
*bias_dims = {static_cast<int>(bias.dim_size(0))};
}
- // Function to calculate output and padding size for 2D convolution.
+ // Function to calculate output and padding size for 2D/3D convolution.
//
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function also calculates
- // left, right, top and bottom pads. Function does not return any status -
- // status is returned via context status.
+ // MKL-DNN uses NCHW(Conv2D) or NCDHW(Conv3D) for output order.
+ // But TensorFlow output will be in NHWC||NCHW(Conv2D) or
+ // NDHWC||NCDHW(Conv3D) format depending on data format.
+ // Function also calculates left, right, top and bottom pads.
+ // Function does not return any status which is set with context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void GetOutputAndPadSizeInMklOrder(
const TensorShape& input_shape, const TensorShape& filter_shape,
const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order,
- memory::dims* output_dims_mkl_order, memory::dims* pad_l,
- memory::dims* pad_r) {
+ memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
+ memory::dims* pad_l, memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
- int input_rows = GetTensorDim(input_shape, data_format_, 'H');
- int input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ bool isConv2D = (strides_.size() == 4);
+ int input_planes, input_rows, input_cols;
+ if (isConv2D) {
+ input_rows = GetTensorDim(input_shape, data_format_, 'H');
+ input_cols = GetTensorDim(input_shape, data_format_, 'W');
+ } else {
+ input_planes = GetTensorDim(input_shape, data_format_, '0');
+ input_rows = GetTensorDim(input_shape, data_format_, '1');
+ input_cols = GetTensorDim(input_shape, data_format_, '2');
+ }
- // The first dimension for filter is rows/height.
- int filter_rows = filter_shape.dim_size(0);
- // The second dimension for filter is cols/width.
- int filter_cols = filter_shape.dim_size(1);
+ // Filter dimension
+ // Conv2D:
+ // First dimension: rows/height.
+ // Second dimension: cols/width.
+ // Conv3D:
+ // First dimension: planes/depth.
+ // Second dimension: rows/height.
+ // Third dimension: cols/width.
+
+ int filter_planes, filter_rows, filter_cols;
+ if (isConv2D) {
+ filter_rows = filter_shape.dim_size(0);
+ filter_cols = filter_shape.dim_size(1);
+ } else {
+ filter_planes = filter_shape.dim_size(0);
+ filter_rows = filter_shape.dim_size(1);
+ filter_cols = filter_shape.dim_size(2);
+ }
- // Stride is vector of 2 elements: {s_r, s_c}
- int stride_rows = strides[0];
- int stride_cols = strides[1];
- int dilation_rows = dilations[0];
- int dilation_cols = dilations[1];
+ int stride_planes, stride_rows, stride_cols;
+ int dilation_planes, dilation_rows, dilation_cols;
+ if (isConv2D) {
+ // Conv2D stride is a vector of 2 elements: {s_r, s_c}
+ stride_rows = strides[0];
+ stride_cols = strides[1];
+ dilation_rows = dilations[0];
+ dilation_cols = dilations[1];
+ } else {
+ // Conv3D stride is a vector of 3 elements: {s_d, s_r, s_c}
+ stride_planes = strides[0];
+ stride_rows = strides[1];
+ stride_cols = strides[2];
+ dilation_planes = dilations[0];
+ dilation_rows = dilations[1];
+ dilation_cols = dilations[2];
+ }
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
+
// Output depth is same as last dimension for filter.
- int out_depth = filter_shape.dim_size(3);
+ int out_depth = filter_shape.dim_size(isConv2D ? 3 : 4);
- int64 out_rows = 0, out_cols = 0;
+ int64 out_rows = 0, out_cols = 0, out_planes = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
+ int64 pad_D1, pad_D2;
+
+ if (isConv2D) {
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_rows, filter_rows, dilation_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_,
+ GetWindowedOutputSizeVerboseV2(
+ input_cols, filter_cols, dilation_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ } else {
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_planes, filter_planes, stride_planes,
+ padding_, &out_planes, &pad_D1, &pad_D2));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_rows, filter_rows, stride_rows,
+ padding_, &out_rows, &pad_top, &pad_bottom));
+ OP_REQUIRES_OK(context_, GetWindowedOutputSizeVerbose(
+ input_cols, filter_cols, stride_cols,
+ padding_, &out_cols, &pad_left, &pad_right));
+ }
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_rows, filter_rows,
- dilation_rows, stride_rows, padding_,
- &out_rows, &pad_top, &pad_bottom));
- OP_REQUIRES_OK(context_,
- GetWindowedOutputSizeVerboseV2(input_cols, filter_cols,
- dilation_cols, stride_cols, padding_,
- &out_cols, &pad_left, &pad_right));
-
- // Tensorflow output is in data_format order. (NHWC or NCHW)
+ // Tensorflow output is in data_format order.
+ // Conv2D: NHWC or NCHW
+ // Conv3D: NDHWC or NCDHW
+ // MKL-DNN uses asymetric padding.
TensorShape out_shape =
- ShapeFromFormat(data_format_, out_batch, out_rows, out_cols, out_depth);
+ isConv2D
+ ? ShapeFromFormat(data_format_, out_batch, out_rows, out_cols,
+ out_depth)
+ : ShapeFromFormat(data_format_, out_batch,
+ {{out_planes, out_rows, out_cols}}, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
- // MKL-DNN always needs output in NCHW format.
- std::vector<int> mkldnn_sizes(4, -1);
- mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
- mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
- mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
- mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
- *output_dims_mkl_order = mkldnn_sizes;
-
- // Now handle padding. MKL-DNN uses asymetric padding.
- *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
- *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ if (isConv2D) {
+ // For Conv2D, MKL-DNN always needs output in NCHW format.
+ std::vector<int> mkldnn_sizes(4, -1);
+ mkldnn_sizes[MklDnnDims::Dim_N] = out_batch;
+ mkldnn_sizes[MklDnnDims::Dim_C] = out_depth;
+ mkldnn_sizes[MklDnnDims::Dim_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims::Dim_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
+ } else {
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = out_batch;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = out_depth;
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = static_cast<int>(out_planes);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = static_cast<int>(out_rows);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = static_cast<int>(out_cols);
+ *output_dims_mkl_order = mkldnn_sizes;
+
+ *pad_l = {static_cast<int>(pad_D1), static_cast<int>(pad_top),
+ static_cast<int>(pad_left)};
+ *pad_r = {static_cast<int>(pad_D2), static_cast<int>(pad_bottom),
+ static_cast<int>(pad_right)};
+ }
}
// Calculate output and pad size of forward Convolution operator.
@@ -292,10 +424,10 @@ class MklDnnConvUtil {
//
// Function does not return anything, but sets error in context status.
inline void GetOutputAndPadSizeInMklOrder(
- size_t src_index, size_t filter_index,
- const memory::dims& strides, const memory::dims& dilations,
- memory::dims* output_dims_tf_order, memory::dims* output_dims_mkl_order,
- memory::dims* pad_l, memory::dims* pad_r) {
+ size_t src_index, size_t filter_index, const memory::dims& strides,
+ const memory::dims& dilations, memory::dims* output_dims_tf_order,
+ memory::dims* output_dims_mkl_order, memory::dims* pad_l,
+ memory::dims* pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
@@ -304,9 +436,17 @@ class MklDnnConvUtil {
auto input_tf_shape = GetTfShape(context_, src_index);
auto filter_tf_shape = GetTfShape(context_, filter_index);
- OP_REQUIRES(context_, input_tf_shape.dims() == 4,
- errors::InvalidArgument("input must be 4-dimensional",
- input_tf_shape.DebugString()));
+ if (strides_.size() == 4) {
+ // Conv2D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 4,
+ errors::InvalidArgument("input must be 4-dimensional",
+ input_tf_shape.DebugString()));
+ } else {
+ // Conv3D
+ OP_REQUIRES(context_, input_tf_shape.dims() == 5,
+ errors::InvalidArgument("input must be 5-dimensional",
+ input_tf_shape.DebugString()));
+ }
GetOutputAndPadSizeInMklOrder(input_tf_shape, filter_tf_shape,
strides, dilations, output_dims_tf_order,
@@ -314,9 +454,11 @@ class MklDnnConvUtil {
}
// Wrapper function to calculate input, filter, and output sizes of
- // 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
- // Function also calculates output shape in Tensorflow order. Additionally, it
- // also calculates strides and paddings for 2D Convolution.
+ // Conv2D/Conv3D in MKL order:
+ // Conv2D: NCHW for input and output; OIHW for filter.
+ // Conv3D: NCDHW for input and output; OIDHW for filter.
+ // Function also calculates output shape in Tensorflow order.
+ // Additionally, it also calculates strides and paddings.
//
// Function does not return anything, but sets error in context status.
inline void GetConvFwdSizesInMklOrder(
@@ -349,16 +491,15 @@ class MklDnnConvUtil {
}
};
-
/////////////////////////////////////////////////////////////////////
-/// Common class that implements Conv2DBackpropFilter and Input
+/// Common class that implements ConvBackpropFilter and Input
/////////////////////////////////////////////////////////////////////
template <typename Device, class T>
-class MklConv2DBackpropCommonOp : public OpKernel {
+class MklConvBackpropCommonOp : public OpKernel {
public:
- ~MklConv2DBackpropCommonOp() {}
- explicit MklConv2DBackpropCommonOp(OpKernelConstruction* context)
+ ~MklConvBackpropCommonOp() {}
+ explicit MklConvBackpropCommonOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
@@ -372,20 +513,25 @@ class MklConv2DBackpropCommonOp : public OpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
- OP_REQUIRES(context, dilations_.size() == 4,
- errors::InvalidArgument("Sliding window dilations field must "
- "specify 4 dimensions"));
- int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
- int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
- int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
- int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
- OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
- errors::InvalidArgument(
- "Current implementation does not yet support "
- "dilations in the batch and depth dimensions."));
- OP_REQUIRES(
- context, dilation_h > 0 && dilation_w > 0,
- errors::InvalidArgument("Dilated rates should be larger than 0."));
+
+ if (strides_.size() == 4) {
+ // Check Conv2D dilations
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
+ }
+
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
diff --git a/tensorflow/core/kernels/mkl_conv_ops_test.cc b/tensorflow/core/kernels/mkl_conv_ops_test.cc
new file mode 100644
index 0000000000..a055351337
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_conv_ops_test.cc
@@ -0,0 +1,407 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/const_op.h"
+#include "tensorflow/cc/ops/nn_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/public/session.h"
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#include "third_party/intel_mkl_dnn/include/mkldnn.h"
+#include "tensorflow/core/util/mkl_util.h"
+#endif
+
+// TODO(ezhulenev): Add numerical tests that will compare results of default
+// (aka Eigen) convolutions with MKL convolutions.
+
+// -------------------------------------------------------------------------- //
+// Performance Benchmarks. //
+// -------------------------------------------------------------------------- //
+
+// Compare performance of default Tensorflow convolution kernels (Eigen) with
+// MKL kernels on CPU.
+
+// Before running these benchmarks configure OpenMP environment variables:
+// export KMP_BLOCKTIME=0
+// export OMP_NUM_THREADS=${num_threads}
+
+namespace tensorflow {
+
+struct Conv2DDimensions {
+ Conv2DDimensions(int n, int h, int w, int c, int fc, int fh, int fw)
+ : input_batches(n),
+ input_height(h),
+ input_width(w),
+ input_depth(c),
+ filter_count(fc),
+ filter_height(fh),
+ filter_width(fw) {}
+
+ int input_batches;
+ int input_height;
+ int input_width;
+ int input_depth;
+ int filter_count;
+ int filter_height;
+ int filter_width;
+};
+
+static Tensor GetRandomTensor(const TensorShape& shape) {
+ Tensor tensor(DT_FLOAT, TensorShape(shape));
+ tensor.flat<float>() = tensor.flat<float>().setRandom();
+ return tensor;
+}
+
+// Get a random Tensor for the Conv2D input.
+static Tensor GetRandomInputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a random Tensor for the Conv2D filter.
+static Tensor GetRandomFilterTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+// Get a random Tensor for the Conv2D output (assuming SAME padding).
+static Tensor GetRandomOutputTensor(const Conv2DDimensions& dims) {
+ return GetRandomTensor({dims.input_batches, dims.input_height,
+ dims.input_width, dims.filter_count});
+}
+
+// Get a Tensor encoding Conv2D input shape.
+static Tensor GetInputSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.input_batches, dims.input_height,
+ dims.input_width, dims.input_depth});
+}
+
+// Get a Tensor encoding Conv2D filter shape.
+static Tensor GetFilterSizesTensor(const Conv2DDimensions& dims) {
+ return test::AsTensor<int32>({dims.filter_height, dims.filter_width,
+ dims.input_depth, dims.filter_count});
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Tensor NonMklTensor() {
+ MklDnnShape non_mkl_shape;
+ non_mkl_shape.SetMklTensor(false);
+
+ auto size = static_cast<int64>(non_mkl_shape.GetSerializeBufferSize());
+ Tensor tensor(DT_UINT8, {size});
+
+ non_mkl_shape.SerializeMklDnnShape(tensor.flat<uint8>().data(),
+ size * sizeof(uint8));
+ return tensor;
+}
+#endif
+
+static Graph* DefaultConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d"), "Conv2D")
+ .Input(input)
+ .Input(filter)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2D(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("mkl_conv_2d"), "_MklConv2D")
+ .Input(input)
+ .Input(filter)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_input"), "Conv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdInput(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_sizes_t = GetInputSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input_sizes =
+ test::graph::Constant(graph, input_sizes_t, "input_sizes");
+ Node* filter = test::graph::Constant(graph, filter_t, "filter");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_input;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_input"),
+ "_MklConv2DBackpropInput")
+ .Input(input_sizes)
+ .Input(filter)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_input));
+
+ return graph;
+}
+#endif
+
+static Graph* DefaultConv2DBwdFilter(const Conv2DDimensions& dims) {
+ auto* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(
+ NodeBuilder(graph->NewName("conv_2d_bwd_filter"), "Conv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+
+#if defined(INTEL_MKL_DNN_ONLY)
+static Graph* MklConv2DBwdFilter(const Conv2DDimensions& dims) {
+ Graph* graph = new Graph(OpRegistry::Global());
+
+ Tensor input_t = GetRandomInputTensor(dims);
+ Tensor filter_sizes_t = GetFilterSizesTensor(dims);
+ Tensor filter_t = GetRandomFilterTensor(dims);
+ Tensor out_backprop_t = GetRandomOutputTensor(dims); // assuming SAME padding
+
+ Node* input = test::graph::Constant(graph, input_t, "input");
+ Node* filter_sizes =
+ test::graph::Constant(graph, filter_sizes_t, "filter_sizes");
+ Node* out_backprop =
+ test::graph::Constant(graph, out_backprop_t, "out_backprop");
+
+ Node* not_mkl_shape = test::graph::Constant(graph, NonMklTensor(), "not_mkl");
+
+ Node* conv2d_bwd_filter;
+ TF_CHECK_OK(NodeBuilder(graph->NewName("conv_2d_bwd_filter"),
+ "_MklConv2DBackpropFilter")
+ .Input(input)
+ .Input(filter_sizes)
+ .Input(out_backprop)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Input(not_mkl_shape)
+ .Attr("T", DT_FLOAT)
+ .Attr("strides", {1, 1, 1, 1})
+ .Attr("padding", "SAME")
+ .Attr("_kernel", "MklOp")
+ .Finalize(graph, &conv2d_bwd_filter));
+
+ return graph;
+}
+#endif
+
+// Macro arguments names: --------------------------------------------------- //
+// N: batch size
+// H: height
+// W: width
+// C: channels
+// FC: filter count
+// FH: filter height
+// FW: filter width
+
+#define BM_CONCAT(a, b) a##b
+
+#define BM_NAME(p, type, N, H, W, C, FC, FH, FW) \
+ BM_CONCAT(BM_##p##_##type##_in_##N##_##H##_##W##_##C, _f_##FC##_##FH##_##FW)
+
+// Flops computation in these benchmarks are the same as in
+// eigen_benchmark_cpu_test.cc.
+
+#define BM_Conv2DT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2D)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2D_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2D(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdInputT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (N) * (H) * (W) * (C); \
+ int64 flops_per_iter = num_computed_elements * ((C) * (FH) * (FW)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdInput)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdInput_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdInputT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdInput(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdInputT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+#define BM_Conv2DBwdFilterT(kind, N, H, W, C, FC, FH, FW, type, LABEL) \
+ static void BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, \
+ FW)(int iters) { \
+ testing::SetLabel(LABEL); \
+ \
+ int64 num_computed_elements = (FH) * (FW) * (C) * (FC); \
+ int64 flops_per_iter = num_computed_elements * ((N) * (H) * (W)); \
+ testing::ItemsProcessed(static_cast<int64>(iters) * flops_per_iter); \
+ \
+ Conv2DDimensions dims(N, H, W, C, FC, FW, FH); \
+ test::Benchmark(#type, BM_CONCAT(kind, Conv2DBwdFilter)(dims)).Run(iters); \
+ } \
+ BENCHMARK(BM_NAME(Conv2DBwdFilter_##kind, type, N, H, W, C, FC, FH, FW))
+
+#if defined(INTEL_MKL_DNN_ONLY)
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL); \
+ BM_Conv2DBwdFilterT(Mkl, N, H, W, C, FC, FH, FW, type, LABEL);
+#else
+#define BM_Conv2DBwdFilter(N, H, W, C, FC, FH, FW, type, LABEL) \
+ BM_Conv2DBwdFilterT(Default, N, H, W, C, FC, FH, FW, type, LABEL);
+#endif
+
+// ImageNet Convolutions ---------------------------------------------------- //
+
+BM_Conv2D(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2D(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2D(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2D(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2D(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2D(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2D(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdInput(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdInput(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdInput(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdInput(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdInput(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+BM_Conv2DBwdFilter(32, 28, 28, 96, 128, 3, 3, cpu, "conv3a_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 16, 32, 5, 5, cpu, "conv3a_00_5x5");
+BM_Conv2DBwdFilter(32, 28, 28, 128, 192, 3, 3, cpu, "conv3_00_3x3");
+BM_Conv2DBwdFilter(32, 28, 28, 32, 96, 5, 5, cpu, "conv3_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 96, 204, 3, 3, cpu, "conv4a_00_3x3");
+BM_Conv2DBwdFilter(32, 14, 14, 16, 48, 5, 5, cpu, "conv4a_00_5x5");
+BM_Conv2DBwdFilter(32, 14, 14, 112, 224, 3, 3, cpu, "conv4b_00_3x3");
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 06ce820ae9..84ee241b8e 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -296,7 +296,9 @@ class MklInputConversionOp : public OpKernel {
// implementation.
TensorShape tf_shape0 = input_shape_0.GetTfShape();
TensorShape tf_shape1 = input_shape_1.GetTfShape();
- if (tf_shape0 == tf_shape1) {
+ TensorShape tensor_shape0 = input_tensor_0.shape();
+ TensorShape tensor_shape1 = input_tensor_1.shape();
+ if (tf_shape0 == tf_shape1 && tensor_shape0 == tensor_shape1) {
auto input0_md = input_shape_0.GetMklLayout();
auto input1_md = input_shape_1.GetMklLayout();
@@ -350,7 +352,8 @@ class MklInputConversionOp : public OpKernel {
}
// Sanity check
- bool mkl_shapes_are_same = input_shape_0 == input_shape_1;
+ bool mkl_shapes_are_same = ((input_shape_0 == input_shape_1) &&
+ (tensor_shape0 == tensor_shape1));
if (mkl_shapes_are_same) {
CHECK(false) << "MklInputConversionOp: Unexpected: TF shapes are "
"different but MKL shapes are same";
@@ -403,7 +406,8 @@ class MklInputConversionOp : public OpKernel {
}
// Broadcast is needed if the shapes are not the same
- if (mkl_shape->GetTfShape().num_elements() == tf_tensor->shape().num_elements() ) {
+ if (mkl_shape->GetTfShape().num_elements() ==
+ tf_tensor->shape().num_elements()) {
// Both shapes are same, convert the TF input to MKL
VLOG(1) << "MklInputConversionOp: No broadcast needed.";
VLOG(1) << "MklInputConversionOp: Converting input " << tf_tensor_index
@@ -437,16 +441,17 @@ class MklInputConversionOp : public OpKernel {
bool reordered = tf_input.CheckReorderToOpMem(
memory::primitive_desc(output_mkl_md, cpu_engine),
tensor_out, &net);
- if(!reordered) {
+
+ if (!reordered) {
// This is the case that the TF tensor has the same shape and format of
// mkl tensor. However, tf_tensor can not be simply forwarded to the
// output tensor since mkl data tensor is always one dimensional tensor.
// Tensor::CopyFrom shares the buffer of the other tensor while set its
// shape to the other tensor.
CHECK(tensor_out->CopyFrom(*tf_tensor, tensor_out->shape()));
- }
- else
+ } else {
stream(stream::kind::eager).submit(net).wait();
+ }
// -- The tensor in MKL format passes through --
ForwardMklTensorInToOut(context, mkl_tensor_index, mkl_tensor_index);
diff --git a/tensorflow/core/kernels/mkl_matmul_op.cc b/tensorflow/core/kernels/mkl_matmul_op.cc
index 077d62ce32..f4788f4851 100644
--- a/tensorflow/core/kernels/mkl_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_matmul_op.cc
@@ -217,7 +217,7 @@ class MklMatMulOp : public OpKernel {
reinterpret_cast<const MKL_Complex16*>(b), ldb, &beta,
reinterpret_cast<MKL_Complex16*>(c), ldc);
}
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
};
#define REGISTER_CPU(T) \
@@ -225,6 +225,7 @@ class MklMatMulOp : public OpKernel {
Name("MatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
MklMatMulOp<CPUDevice, T, false /* cublas, ignored for CPU */>);
+#ifdef ENABLE_MKL
// TODO(inteltf) Consider template specialization when adding/removing
// additional types
TF_CALL_float(REGISTER_CPU);
@@ -233,7 +234,8 @@ TF_CALL_float(REGISTER_CPU);
TF_CALL_double(REGISTER_CPU);
TF_CALL_complex64(REGISTER_CPU);
TF_CALL_complex128(REGISTER_CPU);
-#endif
+#endif // !INTEL_MKL_DNN_ONLY
+#endif // ENABLE_MKL
} // namespace tensorflow
#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_maxpooling_op.cc b/tensorflow/core/kernels/mkl_maxpooling_op.cc
index e149f003e5..256d48f4d5 100644
--- a/tensorflow/core/kernels/mkl_maxpooling_op.cc
+++ b/tensorflow/core/kernels/mkl_maxpooling_op.cc
@@ -524,6 +524,8 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
// initialize variables for the pooling op
MklPoolParameters pool_params;
+ // check whether pooling is 2D or 3D
+ bool is_pool2d = (this->ksize_.size() == 4);
// Get the input tensor and initialize the pooling parameters
TensorShape input_tensor_shape = input_tensor.shape();
this->InitMklPoolParameters(context, &pool_params, dnn_shape_input,
@@ -547,20 +549,26 @@ class MklMaxPoolingOp : public MklPoolingForwardOpBase<T> {
memory::desc input_md =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
- this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
+ : is_pool2d ? memory::desc(
+ TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_)
+ : memory::desc(
+ TFShapeToMklDnnDimsInNCDHW(
+ input_tensor_shape, this->data_format_tf_),
+ MklDnnType<T>(), this->data_format_mkldnn_);
// Get src/filter/stride/padding information
memory::dims src_dims =
dnn_shape_input.IsMklTensor()
? dnn_shape_input.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
- this->data_format_tf_);
-
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(input_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor.shape(),
+ this->data_format_tf_);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
// Get a pooling op from the cached pool
MklPoolingFwdPrimitive<T>* pooling_fwd = nullptr;
@@ -663,23 +671,30 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
MklPoolParameters pool_params;
TensorShape orig_input_shape = orig_input_tensor.shape();
+
+ bool is_pool2d = (this->ksize_.size() == 4);
this->InitMklPoolParameters(context, &pool_params, orig_input_mkl_shape,
orig_input_shape);
memory::dims filter_dims, strides, padding_left, padding_right;
this->PoolParamsToDims(&pool_params, &filter_dims, &strides,
- &padding_left, &padding_right);
+ &padding_left, &padding_right, is_pool2d);
- memory::dims diff_dst_dims =
- grad_mkl_shape.IsMklTensor()
- ? grad_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
- this->data_format_tf_);
memory::dims orig_input_dims_mkl_order =
orig_input_mkl_shape.IsMklTensor()
? orig_input_mkl_shape.GetSizesAsMklDnnDims()
- : TFShapeToMklDnnDimsInNCHW(orig_input_shape,
- this->data_format_tf_);
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(orig_input_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(orig_input_shape,
+ this->data_format_tf_);
+
+ memory::dims diff_dst_dims =
+ grad_mkl_shape.IsMklTensor()
+ ? grad_mkl_shape.GetSizesAsMklDnnDims()
+ : is_pool2d ? TFShapeToMklDnnDimsInNCHW(grad_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(grad_tensor.shape(),
+ this->data_format_tf_);
memory::dims output_dims_mkl_order;
this->GetOutputDims(pool_params, &output_dims_mkl_order);
@@ -715,7 +730,7 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
void* ws_data = static_cast<void*>(
const_cast<uint8*>(workspace_tensor.flat<uint8>().data()));
- ;
+
auto ws_md =
pooling_bwd->GetPoolingFwdPd()->workspace_primitive_desc().desc();
if (ws_md.data.format != pooling_bwd->GetWorkspaceFormat()) {
@@ -817,6 +832,18 @@ class MklMaxPoolingGradOp : public MklPoolingBackwardOpBase<T> {
}
}; // MklMaxPoolingGradOp
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3D")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklMaxPoolingOp<CPUDevice, float>);
+
+REGISTER_KERNEL_BUILDER(Name("_MklMaxPool3DGrad")
+ .Device(DEVICE_CPU)
+ .TypeConstraint<float>("T")
+ .Label(mkl_op_registry::kMklOpLabel),
+ MklMaxPoolingGradOp<CPUDevice, float>);
+
#endif // INTEL_MKL_ML_ONLY
REGISTER_KERNEL_BUILDER(Name("_MklMaxPool")
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.cc b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
index d7ad3f9dcd..5398e6113f 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.cc
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.cc
@@ -24,7 +24,7 @@ limitations under the License.
namespace tensorflow {
-#ifndef INTEL_MKL_ML
+#ifndef INTEL_MKL_ML_ONLY
using mkldnn::pooling_avg;
using mkldnn::pooling_avg_exclude_padding;
@@ -34,11 +34,11 @@ using mkldnn::prop_kind;
template <typename T>
void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
- if (fwdParams.alg_kind != pooling_max && fwdParams.alg_kind != pooling_avg &&
- fwdParams.alg_kind != pooling_avg_include_padding &&
- fwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(fwdParams.alg_kind == pooling_max ||
+ fwdParams.alg_kind == pooling_avg ||
+ fwdParams.alg_kind == pooling_avg_include_padding ||
+ fwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = fwdParams.alg_kind;
// create memory desc
@@ -46,9 +46,10 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
// so src format is currently hard-coded.
// A utility function is used to do this,
// which may be broken with future CPU architectures
+ bool is_2d = (fwdParams.src_dims.size() == 4);
context_.src_md.reset(
new memory::desc({fwdParams.src_dims}, MklDnnType<T>(),
- get_desired_format(fwdParams.src_dims[1])));
+ get_desired_format(fwdParams.src_dims[1], is_2d)));
context_.dst_md.reset(new memory::desc({fwdParams.dst_dims}, MklDnnType<T>(),
memory::format::any));
@@ -61,7 +62,7 @@ void MklPoolingFwdPrimitive<T>::Setup(const MklPoolingParams& fwdParams) {
new pooling_forward::primitive_desc(*context_.fwd_desc, cpu_engine_));
// store expected primitive format
- context_.src_fmt = get_desired_format(fwdParams.src_dims[1]);
+ context_.src_fmt = get_desired_format(fwdParams.src_dims[1], is_2d);
context_.dst_fmt = static_cast<mkldnn::memory::format>(
context_.fwd_pd.get()->dst_primitive_desc().desc().data.format);
@@ -101,7 +102,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
static_cast<void*>(const_cast<T*>(src_data)));
context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(ws_data);
}
context_.fwd_stream->submit(context_.fwd_primitives);
@@ -110,7 +111,7 @@ void MklPoolingFwdPrimitive<T>::Execute(const T* src_data, T* dst_data,
context_.src_mem->set_data_handle(DummyData);
context_.dst_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) { // max pooling must have ws
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -119,19 +120,21 @@ template class MklPoolingFwdPrimitive<float>;
template <typename T>
void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
- if (bwdParams.alg_kind != pooling_max && bwdParams.alg_kind != pooling_avg &&
- bwdParams.alg_kind != pooling_avg_include_padding &&
- bwdParams.alg_kind != pooling_avg_exclude_padding) {
- assert("Pooling algorithm kind is not supported\n");
- }
+ DCHECK(bwdParams.alg_kind == pooling_max ||
+ bwdParams.alg_kind == pooling_avg ||
+ bwdParams.alg_kind == pooling_avg_include_padding ||
+ bwdParams.alg_kind == pooling_avg_exclude_padding)
+ << "Pooling algorithm kind is not supported";
context_.alg_kind = bwdParams.alg_kind;
+ // check whether it is 2d or 3d
+ bool is_2d = (bwdParams.dst_dims.size() == 4);
// Create memory desc
context_.diff_src_md.reset(new memory::desc(
{bwdParams.src_dims}, MklDnnType<T>(), memory::format::any));
context_.diff_dst_md.reset(
new memory::desc({bwdParams.dst_dims}, MklDnnType<T>(),
- get_desired_format(bwdParams.dst_dims[1])));
+ get_desired_format(bwdParams.dst_dims[1], is_2d)));
context_.bwd_desc.reset(new pooling_backward::desc(
bwdParams.alg_kind, *context_.diff_src_md, *context_.diff_dst_md,
bwdParams.strides, bwdParams.filter_dims, bwdParams.padding_left,
@@ -151,7 +154,7 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
// store expected primitive format
context_.diff_src_fmt = static_cast<mkldnn::memory::format>(
context_.bwd_pd.get()->diff_src_primitive_desc().desc().data.format);
- context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1]);
+ context_.diff_dst_fmt = get_desired_format(bwdParams.dst_dims[1], is_2d);
// create MKL-DNN internal memory object with dummy data
context_.diff_src_mem.reset(
@@ -165,7 +168,7 @@ void MklPoolingBwdPrimitive<T>::Setup(const MklPoolingParams& bwdParams) {
if (bwdParams.alg_kind == pooling_max) {
auto ws_pd = context_.fwd_pd.get()->workspace_primitive_desc().desc().data;
context_.ws_dims.assign(ws_pd.dims, ws_pd.dims + ws_pd.ndims);
- context_.ws_fmt = get_desired_format(context_.ws_dims[1]);
+ context_.ws_fmt = get_desired_format(context_.ws_dims[1], is_2d);
context_.ws_dt = static_cast<mkldnn::memory::data_type>(ws_pd.data_type);
context_.ws_mem.reset(new memory(
{{{context_.ws_dims}, context_.ws_dt, context_.ws_fmt}, cpu_engine},
@@ -187,7 +190,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
static_cast<void*>(const_cast<T*>(diff_dst_data)));
context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(const_cast<void*>(ws_data));
}
@@ -196,7 +199,7 @@ void MklPoolingBwdPrimitive<T>::Execute(const T* diff_dst_data,
context_.diff_dst_mem->set_data_handle(DummyData);
context_.diff_src_mem->set_data_handle(DummyData);
if (context_.alg_kind == pooling_max) {
- assert(ws_data != nullptr);
+ DCHECK(ws_data != nullptr);
context_.ws_mem->set_data_handle(DummyData);
}
}
@@ -211,13 +214,22 @@ void MklPoolParameters::Init(OpKernelContext* context,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format,
const TensorShape& tensor_in_shape) {
- // For maxpooling, tensor_in should have 4 dimensions.
- OP_REQUIRES(context, tensor_in_shape.dims() == 4,
- errors::InvalidArgument("tensor_in must be 4-dimensional"));
+ // For maxpooling, tensor_in should have 4 or 5 dimensions.
+ OP_REQUIRES(context,
+ tensor_in_shape.dims() == 4 || tensor_in_shape.dims() == 5,
+ errors::InvalidArgument("tensor_in must be 4 or 5-dimensional"));
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
- tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
- tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ if (tensor_in_shape.dims() == 4) {
+ // Pool2D
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
+ } else {
+ // Pool3D
+ tensor_in_planes = GetTensorDim(tensor_in_shape, data_format, '0');
+ tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, '1');
+ tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, '2');
+ }
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
Init(context, ksize, stride, padding, data_format);
@@ -246,10 +258,20 @@ void MklPoolParameters::Init(OpKernelContext* context,
TensorFormat data_format,
const MklDnnShape* mklInputShape) {
// Get the input sizes
- depth = mklInputShape->GetDimension('C');
- tensor_in_cols = mklInputShape->GetDimension('W');
- tensor_in_rows = mklInputShape->GetDimension('H');
- tensor_in_batch = mklInputShape->GetDimension('N');
+ if (ksize.size() == 4) {
+ // Pool2D
+ depth = mklInputShape->GetDimension('C');
+ tensor_in_cols = mklInputShape->GetDimension('W');
+ tensor_in_rows = mklInputShape->GetDimension('H');
+ tensor_in_batch = mklInputShape->GetDimension('N');
+ } else {
+ // Pool3D
+ depth = mklInputShape->GetDimension3D('C');
+ tensor_in_cols = mklInputShape->GetDimension3D('W');
+ tensor_in_rows = mklInputShape->GetDimension3D('H');
+ tensor_in_planes = mklInputShape->GetDimension3D('D');
+ tensor_in_batch = mklInputShape->GetDimension3D('N');
+ }
Init(context, ksize, stride, padding, data_format);
}
@@ -262,25 +284,58 @@ void MklPoolParameters::Init(OpKernelContext* context,
// Get the data format
this->data_format = data_format;
- // Get the output sizes
- window_rows = GetTensorDim(ksize, data_format, 'H');
- window_cols = GetTensorDim(ksize, data_format, 'W');
- depth_window = GetTensorDim(ksize, data_format, 'C');
-
- // Get the strides
- row_stride = GetTensorDim(stride, data_format, 'H');
- col_stride = GetTensorDim(stride, data_format, 'W');
- depth_stride = GetTensorDim(stride, data_format, 'C');
+ bool is_pool2d = (ksize.size() == 4);
+ if (is_pool2d) {
+ // Pool2D
+ // Get the output sizes
+ window_rows = GetTensorDim(ksize, data_format, 'H');
+ window_cols = GetTensorDim(ksize, data_format, 'W');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ row_stride = GetTensorDim(stride, data_format, 'H');
+ col_stride = GetTensorDim(stride, data_format, 'W');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 2D pooling across width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
+ errors::Unimplemented(
+ "MaxPooling supports exactly one of pooling across depth "
+ "or pooling across width/height."));
+ } else {
+ // Pool3D
+ // Get the output sizes
+ window_planes = GetTensorDim(ksize, data_format, '0');
+ window_rows = GetTensorDim(ksize, data_format, '1');
+ window_cols = GetTensorDim(ksize, data_format, '2');
+ depth_window = GetTensorDim(ksize, data_format, 'C');
+
+ // Get the strides
+ planes_stride = GetTensorDim(stride, data_format, '0');
+ row_stride = GetTensorDim(stride, data_format, '1');
+ col_stride = GetTensorDim(stride, data_format, '2');
+ depth_stride = GetTensorDim(stride, data_format, 'C');
+
+ // We only support 3D pooling across depth/width/height and depthwise
+ // pooling, not a combination.
+ OP_REQUIRES(context,
+ (depth_window == 1 ||
+ (window_rows == 1 && window_cols == 1 && window_planes == 1)),
+ errors::Unimplemented(
+ "AvgPooling3D supports exactly one of pooling across depth "
+ "or pooling across depth/width/height."));
+ }
- // We only support 2D pooling across width/height and depthwise
- // pooling, not a combination.
- OP_REQUIRES(context,
- (depth_window == 1 || (window_rows == 1 && window_cols == 1)),
- errors::Unimplemented(
- "MaxPooling supports exactly one of pooling across depth "
- "or pooling across width/height."));
+ if (depth_window == 1) { // we are pooling in the D (Pool3D only), H and W
+ if (!is_pool2d) {
+ OP_REQUIRES_OK(
+ context, GetWindowedOutputSizeVerbose(tensor_in_planes, window_planes,
+ planes_stride, padding,
+ &out_planes, &pad_P1, &pad_P2));
+ }
- if (depth_window == 1) { // we are pooling in the H and W
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
tensor_in_rows, window_rows, row_stride,
padding, &out_height, &pad_top, &pad_bottom));
@@ -290,7 +345,14 @@ void MklPoolParameters::Init(OpKernelContext* context,
padding, &out_width, &pad_left, &pad_right));
#ifndef INTEL_MKL_ML_ONLY
// TF can work with int64, but mkldnn only supports int32
- // Fail if the height or width are greater than MAX_INT
+ // Fail if the depth, height or width are greater than MAX_INT
+ // We check depth only for 3D pooling case
+
+ if (!is_pool2d) {
+ OP_REQUIRES(context,
+ FastBoundsCheck(out_planes, std::numeric_limits<int>::max()),
+ errors::InvalidArgument("output depth/planes is too large"));
+ }
OP_REQUIRES(context,
FastBoundsCheck(out_height, std::numeric_limits<int>::max()),
@@ -299,7 +361,6 @@ void MklPoolParameters::Init(OpKernelContext* context,
OP_REQUIRES(context,
FastBoundsCheck(out_width, std::numeric_limits<int>::max()),
errors::InvalidArgument("output width is too large"));
-
#endif
out_depth = depth; // output will have the same depth as the input
} else { // we are pooling in the depth dimension
diff --git a/tensorflow/core/kernels/mkl_pooling_ops_common.h b/tensorflow/core/kernels/mkl_pooling_ops_common.h
index ec7af5092d..49f799d7ba 100644
--- a/tensorflow/core/kernels/mkl_pooling_ops_common.h
+++ b/tensorflow/core/kernels/mkl_pooling_ops_common.h
@@ -19,6 +19,7 @@ limitations under the License.
#ifdef INTEL_MKL
#include <memory>
#include <vector>
+#include <string>
#include "tensorflow/core/util/mkl_util.h"
#include "tensorflow/core/util/padding.h"
@@ -32,7 +33,7 @@ using mkldnn::stream;
namespace tensorflow {
-#ifndef INTEL_MKL_ML
+#ifndef INTEL_MKL_ML_ONLY
using mkldnn::memory;
using mkldnn::pooling_avg;
@@ -357,22 +358,28 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklPoolParameters {
int depth;
+ int tensor_in_planes; // Pool3D
int tensor_in_cols;
int tensor_in_rows;
int tensor_in_batch;
+ int window_planes; // Pool3D
int window_rows;
int window_cols;
int depth_window;
+ int planes_stride; // Pool3D
int row_stride;
int col_stride;
int depth_stride;
+ int64 out_planes; // Pool3D
int64 out_height;
int64 out_width;
int out_depth;
+ int64 pad_P1; // Pool3D
+ int64 pad_P2; // Pool3D
int64 pad_left;
int64 pad_right;
int64 pad_top;
@@ -382,18 +389,24 @@ struct MklPoolParameters {
TensorFormat data_format;
MklPoolParameters()
: depth(0),
+ tensor_in_planes(0),
tensor_in_cols(0),
tensor_in_rows(0),
tensor_in_batch(0),
+ window_planes(0),
window_rows(0),
window_cols(0),
depth_window(0),
+ planes_stride(0),
row_stride(0),
col_stride(0),
depth_stride(0),
+ out_planes(0),
out_height(0),
out_width(0),
out_depth(0),
+ pad_P1(0),
+ pad_P2(0),
pad_left(0),
pad_right(0),
pad_top(0),
@@ -433,20 +446,22 @@ class MklPoolingOpBase : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &this->data_format_tf_),
errors::InvalidArgument("Invalid data format"));
- this->data_format_mkldnn_ =
- TFDataFormatToMklDnnDataFormat(this->data_format_tf_);
OP_REQUIRES_OK(context, context->GetAttr("ksize", &this->ksize_));
- OP_REQUIRES(context, this->ksize_.size() == 4,
+ OP_REQUIRES(context, this->ksize_.size() == 4 || this->ksize_.size() == 5,
errors::InvalidArgument("Sliding window ksize field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &this->stride_));
- OP_REQUIRES(context, this->stride_.size() == 4,
+ OP_REQUIRES(context, this->stride_.size() == 4 || this->stride_.size() == 5,
errors::InvalidArgument("Sliding window strides field must "
- "specify 4 dimensions"));
+ "specify 4 or 5 dimensions"));
OP_REQUIRES_OK(context, context->GetAttr("padding", &this->padding_));
OP_REQUIRES(context, this->ksize_[0] == 1 && this->stride_[0] == 1,
errors::Unimplemented("Pooling is not yet supported on the "
"batch dimension."));
+ bool is_pool2d = (this->ksize_.size() == 4);
+ this->data_format_mkldnn_ =
+ is_pool2d ? TFDataFormatToMklDnnDataFormat(this->data_format_tf_)
+ : TFDataFormatToMklDnn3DDataFormat(this->data_format_tf_);
// We may not get this attribute for this node if it does not go through
// graph rewrite pass. So we do not check for error while retrieving this
@@ -457,17 +472,26 @@ class MklPoolingOpBase : public OpKernel {
protected:
// Calculate output shape of pooling op in MKL-DNN and TensorFlow order.
- // MKL-DNN uses NCHW for output order. But TensorFlow output will be in
- // NHWC or NCHW format depending on data format. Function expects
- // output height and output width to have already been int32
- // bounds-checked
+ // MKL-DNN uses NCHW(Pool2D) or NCDHW(Pool3D) for output order.
+ // But TensorFlow output will be in NHWC/NCHW(Pool2D) or
+ // NDHWC/NCDHW(Pool3D) format depending on data format. Function expects
+ // output height and width to have already been int32 bounds-checked.
void GetOutputDims(const MklPoolParameters& mkl_pool_params,
memory::dims* output_dims_mkl_order) {
- // MKL-DNN always needs output in NCHW format.
- *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
- mkl_pool_params.out_depth,
- static_cast<int>(mkl_pool_params.out_height),
- static_cast<int>(mkl_pool_params.out_width)};
+ if (this->ksize_.size() == 4) {
+ // Pooling2D: MKL-DNN always needs output in NCHW format.
+ *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
+ mkl_pool_params.out_depth,
+ static_cast<int>(mkl_pool_params.out_height),
+ static_cast<int>(mkl_pool_params.out_width)};
+ } else {
+ // Pooling3D: MKL-DNN always needs output in NCDHW format.
+ *output_dims_mkl_order = {mkl_pool_params.tensor_in_batch,
+ mkl_pool_params.out_depth,
+ static_cast<int>(mkl_pool_params.out_planes),
+ static_cast<int>(mkl_pool_params.out_height),
+ static_cast<int>(mkl_pool_params.out_width)};
+ }
}
void InitMklPoolParameters(OpKernelContext* context,
@@ -485,14 +509,34 @@ class MklPoolingOpBase : public OpKernel {
void PoolParamsToDims(const MklPoolParameters* pool_params,
memory::dims* filter_dims, memory::dims* strides,
- memory::dims* padding_left,
- memory::dims* padding_right) {
- *filter_dims = {pool_params->window_rows, pool_params->window_cols};
- *strides = {pool_params->row_stride, pool_params->col_stride};
- *padding_left = {static_cast<int>(pool_params->pad_top),
- static_cast<int>(pool_params->pad_left)};
- *padding_right = {static_cast<int>(pool_params->pad_bottom),
- static_cast<int>(pool_params->pad_right)};
+ memory::dims* padding_left, memory::dims* padding_right,
+ bool is_pool2d) {
+ if (is_pool2d) {
+ // Pool2D
+ *filter_dims =
+ memory::dims({pool_params->window_rows, pool_params->window_cols});
+ *strides =
+ memory::dims({pool_params->row_stride, pool_params->col_stride});
+ *padding_left = memory::dims({static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)});
+ *padding_right = memory::dims({static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)});
+ } else {
+ // Pool3D
+ *filter_dims =
+ memory::dims({pool_params->window_planes, pool_params->window_rows,
+ pool_params->window_cols});
+ *strides =
+ memory::dims({pool_params->planes_stride, pool_params->row_stride,
+ pool_params->col_stride});
+
+ *padding_left = memory::dims({static_cast<int>(pool_params->pad_P1),
+ static_cast<int>(pool_params->pad_top),
+ static_cast<int>(pool_params->pad_left)});
+ *padding_right = memory::dims({static_cast<int>(pool_params->pad_P2),
+ static_cast<int>(pool_params->pad_bottom),
+ static_cast<int>(pool_params->pad_right)});
+ }
}
void AllocateEmptyOutputTensor(OpKernelContext* context,
@@ -556,12 +600,27 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
TensorShape input_tensor_shape = input_tensor.shape();
if (input_tensor.NumElements() != 0) {
memory::desc input_md =
- input_mkl_shape.IsMklTensor()
- ? input_mkl_shape.GetMklLayout()
- : memory::desc(TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetMklLayout()
+ : memory::desc(
+ (this->ksize_.size() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(input_tensor_shape,
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(input_tensor_shape,
this->data_format_tf_),
- MklDnnType<T>(), this->data_format_mkldnn_);
+ MklDnnType<T>(), this->data_format_mkldnn_);
dnn_data_input->SetUsrMem(input_md, &input_tensor);
+
+ if (this->ksize_.size() == 5) {
+ // Pool3D
+ std::vector<int> mkldnn_sizes(5, -1);
+ mkldnn_sizes[MklDnnDims3D::Dim3d_N] = input_md.data.dims[0];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_C] = input_md.data.dims[1];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_D] = input_md.data.dims[2];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_H] = input_md.data.dims[3];
+ mkldnn_sizes[MklDnnDims3D::Dim3d_W] = input_md.data.dims[4];
+ dnn_data_input->SetOpMemDesc(mkldnn_sizes, this->data_format_mkldnn_);
+ }
}
this->InitMklPoolParameters(context, pool_params, input_mkl_shape,
input_tensor_shape);
@@ -593,12 +652,13 @@ class MklPoolingForwardOpBase : public MklPoolingOpBase<T> {
void SanityCheckInput(OpKernelContext* context, const Tensor& input_tensor,
const MklDnnShape& input_mkl_shape) {
if (!input_mkl_shape.IsMklTensor()) {
- OP_REQUIRES(context, input_tensor.dims() == 4,
- errors::InvalidArgument("Input must be 4-dimensional"));
+ OP_REQUIRES(context, input_tensor.dims() == 4 || input_tensor.dims() == 5,
+ errors::InvalidArgument("Input must be 4 or 5-dimensional"));
} else {
- OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4,
+ OP_REQUIRES(context, input_mkl_shape.GetDimension() == 4 ||
+ input_mkl_shape.GetDimension() == 5,
errors::InvalidArgument("Input shape must be "
- "4-dimensional"));
+ "4 or 5-dimensional"));
}
}
// .Input("value: T")
@@ -649,8 +709,12 @@ class MklPoolingBackwardOpBase : public MklPoolingOpBase<T> {
input_gradient_mkl_shape.IsMklTensor()
? input_gradient_mkl_shape.GetMklLayout()
: memory::desc(
- TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
- this->data_format_tf_),
+ (this->ksize_.size() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(input_gradient_tensor.shape(),
+ this->data_format_tf_)
+ : TFShapeToMklDnnDimsInNCDHW(
+ input_gradient_tensor.shape(),
+ this->data_format_tf_),
MklDnnType<T>(), this->data_format_mkldnn_);
input_gradient_dnn_data->SetUsrMem(original_input_grad_md,
diff --git a/tensorflow/core/kernels/mkl_relu_op.cc b/tensorflow/core/kernels/mkl_relu_op.cc
index 05034894e5..84385356e1 100644
--- a/tensorflow/core/kernels/mkl_relu_op.cc
+++ b/tensorflow/core/kernels/mkl_relu_op.cc
@@ -30,10 +30,12 @@ using mkldnn::algorithm;
using mkldnn::eltwise_elu;
using mkldnn::eltwise_relu;
using mkldnn::eltwise_tanh;
+using mkldnn::memory;
using mkldnn::prop_kind;
using mkldnn::relu_backward;
using mkldnn::relu_forward;
using mkldnn::stream;
+using mkldnn::memory;
#else
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
@@ -42,6 +44,415 @@ using mkldnn::stream;
namespace tensorflow {
+#ifndef INTEL_MKL_ML_ONLY
+
+template <typename T>
+class MklEltwiseFwdParams {
+ public:
+ memory::dims src_dims; // check if this is needed
+ memory::desc src_md;
+ algorithm alg_kind;
+ T alpha;
+ T beta;
+
+ MklEltwiseFwdParams(memory::dims src_dims, memory::desc src_md,
+ algorithm alg_kind, T alpha, T beta)
+ : src_dims(src_dims),
+ src_md(src_md),
+ alg_kind(alg_kind),
+ alpha(alpha),
+ beta(beta) {}
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklEltwiseFwdPrimitive(const MklEltwiseFwdParams<T>& fwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ // store expected format
+ context_.src_fmt =
+ static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
+ context_.fwd_stream.reset(new stream(stream::kind::eager));
+
+ // create eltwise primitive
+ if (context_.eltwise_fwd == nullptr) {
+ Setup(fwdParams);
+ }
+ }
+
+ ~MklEltwiseFwdPrimitive() {}
+
+ // Eltwise forward execute
+ // src_data: input data buffer of src
+ // dst_data: output data buffer of dst
+ void Execute(const T* src_data, T* dst_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.dst_mem->set_data_handle(static_cast<void*>(dst_data));
+ context_.fwd_stream->submit(context_.fwd_primitives);
+
+ // after execution, set data handle back
+ context_.src_mem->set_data_handle(DummyData);
+ context_.dst_mem->set_data_handle(DummyData);
+ }
+
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> GetEltwiseFwdPd() {
+ return context_.fwd_pd;
+ }
+
+ memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
+
+ private:
+ // Primitive reuse context for eltwise Fwd ops: Relu, Elu, Tanh
+ struct EltwiseFwdContext {
+ // expected memory format for this primitive instance
+ mkldnn::memory::format src_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<memory> src_mem;
+ std::shared_ptr<memory> dst_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
+
+ // memory desc
+ std::shared_ptr<memory::desc> src_md;
+ std::shared_ptr<memory::desc> dst_md;
+
+ // memory primitive desc
+ std::shared_ptr<memory::primitive_desc> src_mpd;
+
+ // Eltwise primitive
+ std::shared_ptr<mkldnn::primitive> eltwise_fwd;
+
+ std::shared_ptr<stream> fwd_stream;
+ std::vector<mkldnn::primitive> fwd_primitives;
+
+ EltwiseFwdContext()
+ : src_fmt(memory::format::any),
+ src_mem(nullptr),
+ dst_mem(nullptr),
+ fwd_desc(nullptr),
+ fwd_pd(nullptr),
+ src_md(nullptr),
+ dst_md(nullptr),
+ src_mpd(nullptr),
+ eltwise_fwd(nullptr),
+ fwd_stream(nullptr) {}
+ };
+
+ // Eltwise forward primitive setup
+ void Setup(const MklEltwiseFwdParams<T>& fwdParams) {
+ // create memory descriptors for eltwise data with specified format
+ context_.src_md.reset(new memory::desc(fwdParams.src_md.data));
+ context_.src_mpd.reset(
+ new memory::primitive_desc(*context_.src_md, cpu_engine_));
+
+ // create a eltwise
+ context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
+ prop_kind::forward, fwdParams.alg_kind, *context_.src_md,
+ fwdParams.alpha, fwdParams.beta));
+ context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+
+ // create memory primitive based on dummy data
+ context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
+ context_.dst_mem.reset(
+ new memory(context_.fwd_pd.get()->dst_primitive_desc(), DummyData));
+
+ // create eltwise primitive and add it to net
+ context_.eltwise_fwd.reset(new mkldnn::eltwise_forward(
+ *context_.fwd_pd, *context_.src_mem, *context_.dst_mem));
+
+ context_.fwd_primitives.push_back(*context_.eltwise_fwd);
+ }
+
+ struct EltwiseFwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklEltwiseFwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ public:
+ static MklEltwiseFwdPrimitive<T>* Get(
+ const MklEltwiseFwdParams<T>& fwdParams) {
+ MklEltwiseFwdPrimitive<T>* eltwise_forward = nullptr;
+
+ auto src_fmt =
+ static_cast<mkldnn::memory::format>(fwdParams.src_md.data.format);
+
+ // Get a eltwise fwd primitive from the cached pool
+ eltwise_forward = static_cast<MklEltwiseFwdPrimitive<T>*>(
+ MklEltwiseFwdPrimitiveFactory<T>::GetInstance().GetEltwiseFwd(fwdParams,
+ src_fmt));
+ if (eltwise_forward == nullptr) {
+ eltwise_forward = new MklEltwiseFwdPrimitive<T>(fwdParams);
+ MklEltwiseFwdPrimitiveFactory<T>::GetInstance().SetEltwiseFwd(
+ fwdParams, src_fmt, eltwise_forward);
+ }
+ return eltwise_forward;
+ }
+
+ static MklEltwiseFwdPrimitiveFactory& GetInstance() {
+ static MklEltwiseFwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ MklEltwiseFwdPrimitiveFactory() {}
+ ~MklEltwiseFwdPrimitiveFactory() {}
+
+ static string CreateKey(const MklEltwiseFwdParams<T>& fwdParams,
+ memory::format src_fmt) {
+ string prefix = "eltwise_fwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(fwdParams.src_dims);
+ key_creator.AddAsKey<int>(static_cast<int>(fwdParams.alg_kind));
+ key_creator.AddAsKey<float>(static_cast<float>(fwdParams.alpha));
+ key_creator.AddAsKey<float>(static_cast<float>(fwdParams.beta));
+ key_creator.AddAsKey<int>(static_cast<int>(src_fmt));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
+ memory::format src_fmt) {
+ string key = CreateKey(fwdParams, src_fmt);
+ return this->GetOp(key);
+ }
+
+ void SetEltwiseFwd(const MklEltwiseFwdParams<T>& fwdParams,
+ memory::format src_fmt, MklPrimitive* op) {
+ string key = CreateKey(fwdParams, src_fmt);
+ this->SetOp(key, op);
+ }
+};
+
+template <typename T>
+class MklEltwiseBwdParams {
+ public:
+ memory::dims src_dims;
+ memory::desc common_md;
+ algorithm alg_kind;
+ T alpha;
+ T beta;
+
+ MklEltwiseBwdParams(const memory::dims& src_dims,
+ const memory::desc& common_md, algorithm alg_kind,
+ T alpha, T beta)
+ : src_dims(src_dims),
+ common_md(common_md),
+ alg_kind(alg_kind),
+ alpha(alpha),
+ beta(beta) {}
+};
+
+template <typename T>
+class MklEltwiseBwdPrimitive : public MklPrimitive {
+ public:
+ explicit MklEltwiseBwdPrimitive(const MklEltwiseBwdParams<T>& bwdParams)
+ : cpu_engine_(engine::cpu, 0) {
+ context_.src_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+ context_.diff_dst_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+ context_.bwd_stream.reset(new stream(stream::kind::eager));
+ // create eltwise primitive
+ if (context_.eltwise_bwd == nullptr) {
+ Setup(bwdParams);
+ }
+ }
+
+ ~MklEltwiseBwdPrimitive() {}
+
+ // Eltwise backward execute
+ // src_data: input data buffer of src
+ // diff_dst_data: input data buffer of diff_dst
+ // diff_src_data: output data buffer of diff_src
+ void Execute(const T* src_data, const T* diff_dst_data, T* diff_src_data) {
+ context_.src_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(src_data)));
+ context_.diff_dst_mem->set_data_handle(
+ static_cast<void*>(const_cast<T*>(diff_dst_data)));
+ context_.diff_src_mem->set_data_handle(static_cast<void*>(diff_src_data));
+ context_.bwd_stream->submit(context_.bwd_primitives);
+
+ // after execution, set data handle back
+ context_.src_mem->set_data_handle(DummyData);
+ context_.diff_dst_mem->set_data_handle(DummyData);
+ context_.diff_src_mem->set_data_handle(DummyData);
+ }
+
+ std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> GetEltwiseBwdPd() {
+ return context_.bwd_pd;
+ }
+
+ memory::format GetSrcMemoryFormat() { return context_.src_fmt; }
+
+ memory::format GetDiffDstMemoryFormat() { return context_.diff_dst_fmt; }
+
+ private:
+ // Primitive reuse context for eltwise Bwd ops: Relu, Elu, Tanh
+ struct EltwiseBwdContext {
+ // expected memory format for this primitive instance
+ memory::format src_fmt;
+ memory::format diff_dst_fmt;
+
+ // MKLDNN memory
+ std::shared_ptr<memory> src_mem;
+ std::shared_ptr<memory> diff_dst_mem;
+ std::shared_ptr<memory> diff_src_mem;
+
+ // desc & prmitive desc
+ std::shared_ptr<mkldnn::eltwise_backward::desc> bwd_desc;
+
+ // memory desc
+ std::shared_ptr<memory::desc> src_md;
+ std::shared_ptr<memory::desc> diff_dst_md;
+ std::shared_ptr<memory::desc> common_md;
+
+ // memory primitive desc
+ std::shared_ptr<memory::primitive_desc> src_mpd;
+ std::shared_ptr<memory::primitive_desc> diff_dst_mpd;
+
+ // fwd primitive desc
+ std::shared_ptr<mkldnn::eltwise_forward::desc> fwd_desc;
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> fwd_pd;
+ std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> bwd_pd;
+
+ // Eltwise primitive
+ std::shared_ptr<mkldnn::primitive> eltwise_bwd;
+
+ std::shared_ptr<stream> bwd_stream;
+ std::vector<mkldnn::primitive> bwd_primitives;
+
+ EltwiseBwdContext()
+ : src_fmt(memory::format::any),
+ diff_dst_fmt(memory::format::any),
+ src_mem(nullptr),
+ diff_dst_mem(nullptr),
+ diff_src_mem(nullptr),
+ src_md(nullptr),
+ diff_dst_md(nullptr),
+ common_md(nullptr),
+ src_mpd(nullptr),
+ diff_dst_mpd(nullptr),
+ fwd_desc(nullptr),
+ fwd_pd(nullptr),
+ bwd_pd(nullptr),
+ eltwise_bwd(nullptr),
+ bwd_stream(nullptr) {}
+ };
+
+ // Eltwise backward primitive setup
+ void Setup(const MklEltwiseBwdParams<T>& bwdParams) {
+ // create memory descriptors for eltwise data w/ no specified format
+ context_.src_md.reset(new memory::desc(bwdParams.common_md.data));
+ context_.diff_dst_md.reset(new memory::desc(bwdParams.common_md.data));
+
+ context_.src_mpd.reset(
+ new memory::primitive_desc(*context_.src_md, cpu_engine_));
+ context_.diff_dst_mpd.reset(
+ new memory::primitive_desc(*context_.diff_dst_md, cpu_engine_));
+
+ // create forward eltwise primitive
+ context_.fwd_desc.reset(new mkldnn::eltwise_forward::desc(
+ prop_kind::forward_training, bwdParams.alg_kind, *context_.src_md,
+ bwdParams.alpha, bwdParams.beta));
+ context_.fwd_pd.reset(new mkldnn::eltwise_forward::primitive_desc(
+ *context_.fwd_desc, cpu_engine_));
+ context_.bwd_desc.reset(new mkldnn::eltwise_backward::desc(
+ bwdParams.alg_kind, *context_.diff_dst_md, *context_.src_md,
+ bwdParams.alpha, bwdParams.beta));
+ context_.bwd_pd.reset(new mkldnn::eltwise_backward::primitive_desc(
+ *context_.bwd_desc, cpu_engine_, *context_.fwd_pd));
+
+ // create memory primitive based on dummy data
+ context_.src_mem.reset(new memory(*context_.src_mpd, DummyData));
+ context_.diff_dst_mem.reset(new memory(*context_.diff_dst_mpd, DummyData));
+ context_.diff_src_mem.reset(new memory(
+ context_.bwd_pd.get()->diff_src_primitive_desc(), DummyData));
+
+ // create eltwise primitive and add it to net
+ context_.eltwise_bwd.reset(new mkldnn::eltwise_backward(
+ *context_.bwd_pd, *context_.src_mem, *context_.diff_dst_mem,
+ *context_.diff_src_mem));
+
+ context_.bwd_primitives.push_back(*context_.eltwise_bwd);
+ }
+
+ struct EltwiseBwdContext context_;
+ engine cpu_engine_;
+};
+
+template <typename T>
+class MklEltwiseBwdPrimitiveFactory : public MklPrimitiveFactory<T> {
+ private:
+ MklEltwiseBwdPrimitiveFactory() {}
+ ~MklEltwiseBwdPrimitiveFactory() {}
+
+ public:
+ static MklEltwiseBwdPrimitive<T>* Get(
+ const MklEltwiseBwdParams<T>& bwdParams) {
+ MklEltwiseBwdPrimitive<T>* eltwise_backward = nullptr;
+
+ auto src_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+ auto diff_dst_fmt =
+ static_cast<mkldnn::memory::format>(bwdParams.common_md.data.format);
+
+ // try to find a suitable one in pool
+ eltwise_backward = static_cast<MklEltwiseBwdPrimitive<T>*>(
+ MklEltwiseBwdPrimitiveFactory<T>::GetInstance().GetEltwiseBwd(
+ bwdParams, src_fmt, diff_dst_fmt));
+
+ if (eltwise_backward == nullptr) {
+ eltwise_backward = new MklEltwiseBwdPrimitive<T>(bwdParams);
+ MklEltwiseBwdPrimitiveFactory<T>::GetInstance().SetEltwiseBwd(
+ bwdParams, src_fmt, diff_dst_fmt, eltwise_backward);
+ }
+ return eltwise_backward;
+ }
+
+ static MklEltwiseBwdPrimitiveFactory& GetInstance() {
+ static MklEltwiseBwdPrimitiveFactory instance_;
+ return instance_;
+ }
+
+ private:
+ static string CreateKey(const MklEltwiseBwdParams<T>& bwdParams,
+ const memory::format& src_fmt,
+ const memory::format& diff_dst_fmt) {
+ string prefix = "eltwise_bwd";
+ FactoryKeyCreator key_creator;
+ key_creator.AddAsKey(prefix);
+ key_creator.AddAsKey(bwdParams.src_dims);
+ key_creator.AddAsKey(static_cast<int>(bwdParams.alg_kind));
+ key_creator.AddAsKey(static_cast<float>(bwdParams.alpha));
+ key_creator.AddAsKey(static_cast<float>(bwdParams.beta));
+ key_creator.AddAsKey(static_cast<int>(src_fmt));
+ key_creator.AddAsKey(static_cast<int>(diff_dst_fmt));
+ return key_creator.GetKey();
+ }
+
+ MklPrimitive* GetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
+ const memory::format& src_fmt,
+ const memory::format& diff_dst_fmt) {
+ string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
+ return this->GetOp(key);
+ }
+
+ void SetEltwiseBwd(const MklEltwiseBwdParams<T>& bwdParams,
+ const memory::format& src_fmt,
+ const memory::format& diff_dst_fmt, MklPrimitive* op) {
+ string key = CreateKey(bwdParams, src_fmt, diff_dst_fmt);
+ this->SetOp(key, op);
+ }
+};
+
+#endif
+
typedef Eigen::ThreadPoolDevice CPUDevice;
struct MklReluHelpers {
@@ -375,55 +786,63 @@ class MklReluOpBase : public OpKernel {
~MklReluOpBase() {}
explicit MklReluOpBase(OpKernelConstruction* context) : OpKernel(context) {}
-
virtual void Compute_Scalar(OpKernelContext* context) = 0;
void Compute(OpKernelContext* context) override {
try {
- auto cpu_engine = engine(engine::cpu, 0);
const size_t src_index = 0; // index of src input tensor
const size_t dst_index = 0; // index of dst output tensor
const Tensor& src_tensor = MklGetInput(context, src_index);
MklDnnShape dnn_shape_src;
GetMklShape(context, src_index, &dnn_shape_src);
- Tensor* dst_tensor = nullptr;
if (src_tensor.dims() == 0) {
- Compute_Scalar(context); // scalar case doesn't use in-place operation
+ Compute_Scalar(context);
return;
}
- // Create relu primitive.
- MklDnnData<T> src(&cpu_engine);
- MklDnnData<T> dst(&cpu_engine);
-
// Set DNN primitive - src
+ MklDnnData<T> src(&cpu_engine);
+ memory::dims src_dims;
memory::desc src_md({}, memory::data_undef, memory::format_undef);
if (dnn_shape_src.IsMklTensor()) {
src_md = dnn_shape_src.GetMklLayout();
+ src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
} else {
- auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
+ src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims);
// Create blocked memory descriptor
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
}
- src.SetUsrMem(src_md, &src_tensor);
T alpha = 0, beta = 0;
- std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
- auto relu_fwd_desc = relu_forward::desc(
- prop_kind::forward_training,
- // Operator memory descriptor is same as user memory descriptor.
- alg_kind, src.GetUsrMemDesc(), alpha, beta);
- relu_fwd_pd.reset(
- new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine));
-
- // allocate dst tensor
+
+ // get a eltwise fwd from primitive pool
+ MklEltwiseFwdParams<T> fwdParams(src_dims, src_md, alg_kind, alpha, beta);
+ MklEltwiseFwdPrimitive<T>* eltwise_fwd =
+ MklEltwiseFwdPrimitiveFactory<T>::Get(fwdParams);
+
+ // prepare for execuation
+ const T* src_data = src_tensor.flat<T>().data();
+ // check wehther src need to reorder
+ if (src_md.data.format != eltwise_fwd->GetSrcMemoryFormat()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ auto src_target_pd = memory::primitive_desc(
+ {{src_dims}, MklDnnType<T>(), eltwise_fwd->GetSrcMemoryFormat()},
+ cpu_engine);
+ src.CheckReorderToOpMem(src_target_pd);
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
+ }
+
+ // allocate dst tensor, always set it as MKL-DNN layout
+ std::shared_ptr<mkldnn::eltwise_forward::primitive_desc> eltwise_fwd_pd =
+ eltwise_fwd->GetEltwiseFwdPd();
MklDnnShape dnn_shape_dst;
TensorShape tf_shape_dst;
if (dnn_shape_src.IsMklTensor()) {
dnn_shape_dst.SetMklTensor(true);
- auto dst_pd = relu_fwd_pd->dst_primitive_desc();
+ auto dst_pd = eltwise_fwd_pd->dst_primitive_desc();
dnn_shape_dst.SetMklLayout(&dst_pd);
dnn_shape_dst.SetElemType(MklDnnType<T>());
dnn_shape_dst.SetTfLayout(dnn_shape_src.GetDimension(),
@@ -434,34 +853,32 @@ class MklReluOpBase : public OpKernel {
dnn_shape_dst.SetMklTensor(false);
tf_shape_dst = src_tensor.shape();
}
-
- // Allocate output and MklDnnShape tensors separately for possible
- // in-place operation
+
+ Tensor* dst_tensor = nullptr;
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
{static_cast<const int>(src_index)},
static_cast<const int>(dst_index),
tf_shape_dst, &dst_tensor));
AllocateOutputSetMklShape(context, dst_index, dnn_shape_dst);
- // Destination memory descriptor is same as source memory descriptor.
- auto &dst_md = src_md;
- dst.SetUsrMem(dst_md, dst_tensor);
+ T* dst_data = dst_tensor->flat<T>().data();
- // execute net
- std::vector<primitive> net;
- auto relu_fwd =
- relu_forward(*relu_fwd_pd, src.GetOpMem(), dst.GetOpMem());
- net.push_back(relu_fwd);
- stream(stream::kind::eager).submit(net).wait();
+ // execute eltwise
+ eltwise_fwd->Execute(src_data, dst_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
- ", message: " + string(e.message) + ", in file " +
- string(__FILE__) + ":" + std::to_string(__LINE__);
- OP_REQUIRES_OK(
- context,
- errors::Aborted("Operation received an exception:", error_msg));
+ ", message: " + string(e.message) +
+ ", in file " + string(__FILE__) + ":" +
+ std::to_string(__LINE__);
+ OP_REQUIRES_OK(context,
+ errors::Aborted("Operation received an exception:",
+ error_msg));
}
}
+
+ private:
+ engine cpu_engine = engine(engine::cpu, 0);
+ std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
};
template <typename Device, typename T, algorithm alg_kind>
@@ -470,16 +887,15 @@ class MklReluGradOpBase : public OpKernel {
~MklReluGradOpBase() {}
explicit MklReluGradOpBase(OpKernelConstruction* context)
- : OpKernel(context) {}
+ : OpKernel(context) {
+ }
virtual void Compute_Scalar(OpKernelContext* context) = 0;
void Compute(OpKernelContext* context) {
try {
- auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> diff_dst(&cpu_engine);
- MklDnnData<T> diff_src(&cpu_engine);
const size_t diff_dst_index = 0; // index of diff_dst input tensor
const size_t src_index = 1; // index of src input tensor
@@ -495,37 +911,23 @@ class MklReluGradOpBase : public OpKernel {
int src_dims_size = src_tensor.dims();
if (src_dims_size == 0) {
- Compute_Scalar(context); // scalar case doesn't use in-place operation
+ Compute_Scalar(context);
return;
}
- // Set DNN primitives for src & diff_dst
+ // get a eltwise bwd from primitive pool
+ memory::dims src_dims = {};
memory::desc src_md({}, memory::data_undef, memory::format_undef);
memory::desc diff_dst_md({}, memory::data_undef, memory::format_undef);
-
- // For creating Sum primitive, we need to ensure that all inputs are in
- // same format. What that means is if we have a mixed input case - where
- // one input is in Tensorflow format and one input is in MKL format -,
- // then we need to ensure that all inputs are in same format for
- // primitive construction. For performance reason, we say that all inputs
- // are in MKL format in such case, and insert reorder for input that is
- // in Tensorflow format into MKL format. On the other hand, if both the
- // inputs are in MKL format or both are in Tensorflow format, then we
- // dont need reorder.
if (!dnn_shape_src.IsMklTensor() && !dnn_shape_diff_dst.IsMklTensor()) {
- // If both the inputs are in Tensorflow format, we create blocked memory
- // descriptor.
- auto src_dims = TFShapeToMklDnnDims(src_tensor.shape());
+ src_dims = TFShapeToMklDnnDims(src_tensor.shape());
auto src_strides = CalculateTFStrides(src_dims);
src_md = MklDnnData<T>::CreateBlockedMemDesc(src_dims, src_strides);
diff_dst_md = src_md;
} else if (dnn_shape_src.IsMklTensor() &&
!dnn_shape_diff_dst.IsMklTensor()) {
- // If one input is in MKL format and other is in Tensorflow, then
- // create respective descriptors describing the actual case. For input
- // in Mkl format, we just get Mkl layout from MklDnnShape. For input in
- // Tensorflow format, we create memory descriptor using data format.
src_md = dnn_shape_src.GetMklLayout();
+ src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
memory::format src_mkl_data_format = dnn_shape_src.GetTfDataFormat();
auto src_tf_data_format =
@@ -536,26 +938,27 @@ class MklReluGradOpBase : public OpKernel {
memory::desc(diff_dst_dims, MklDnnType<T>(), src_mkl_data_format);
} else if (!dnn_shape_src.IsMklTensor() &&
dnn_shape_diff_dst.IsMklTensor()) {
- // Same comment as above.
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
memory::format diff_dst_mkl_data_format =
dnn_shape_diff_dst.GetTfDataFormat();
auto diff_dst_tf_data_format =
MklDnnDataFormatToTFDataFormat(diff_dst_mkl_data_format);
- auto src_dims = TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
- diff_dst_tf_data_format);
+
+ src_dims = (src_tensor.dims() == 4)
+ ? TFShapeToMklDnnDimsInNCHW(src_tensor.shape(),
+ diff_dst_tf_data_format)
+ : TFShapeToMklDnnDimsInNCDHW(src_tensor.shape(),
+ diff_dst_tf_data_format);
src_md =
memory::desc(src_dims, MklDnnType<T>(), diff_dst_mkl_data_format);
} else {
- // If both the inputs are in MKL format, we use Mkl layout of the input
- // tensors.
src_md = dnn_shape_src.GetMklLayout();
diff_dst_md = dnn_shape_diff_dst.GetMklLayout();
+ src_dims = dnn_shape_src.GetSizesAsMklDnnDims();
}
- src.SetUsrMem(src_md, &src_tensor);
- diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ T alpha = 0, beta = 0;
// As per comment above, we tell MKLDNN that both the inputs are in same
// format. So we set common memory descriptor in MKL format, if any of the
@@ -570,24 +973,38 @@ class MklReluGradOpBase : public OpKernel {
common_md = src_md;
}
- T alpha = 0, beta = 0;
- std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
- auto relu_fwd_desc = relu_forward::desc(prop_kind::forward_training,
- alg_kind, src_md, alpha, beta);
- relu_fwd_pd.reset(
- new relu_forward::primitive_desc(relu_fwd_desc, cpu_engine));
- auto relu_bwd_desc =
- relu_backward::desc(alg_kind, common_md, common_md, alpha, beta);
- auto relu_bwd_pd = relu_backward::primitive_desc(
- relu_bwd_desc, cpu_engine, *relu_fwd_pd);
+ MklEltwiseBwdParams<T> bwdParams(src_dims, common_md, alg_kind, alpha,
+ beta);
+ MklEltwiseBwdPrimitive<T>* eltwise_bwd =
+ MklEltwiseBwdPrimitiveFactory<T>::Get(bwdParams);
+ auto eltwise_bwd_pd = eltwise_bwd->GetEltwiseBwdPd();
+
+ // check whether need reorder for src / diff_dst
+ const T* src_data = src_tensor.flat<T>().data();
+ if (src_md.data.format != eltwise_bwd->GetSrcMemoryFormat()) {
+ src.SetUsrMem(src_md, &src_tensor);
+ src.CheckReorderToOpMem(
+ eltwise_bwd_pd.get()->diff_src_primitive_desc());
+ src_data = const_cast<T*>(
+ reinterpret_cast<T*>(src.GetOpMem().get_data_handle()));
+ }
+
+ const T* diff_dst_data = diff_dst_tensor.flat<T>().data();
+ if (diff_dst_md.data.format != eltwise_bwd->GetDiffDstMemoryFormat()) {
+ diff_dst.SetUsrMem(diff_dst_md, &diff_dst_tensor);
+ diff_dst.CheckReorderToOpMem(
+ eltwise_bwd_pd.get()->diff_src_primitive_desc());
+ diff_dst_data = const_cast<T*>(
+ reinterpret_cast<T*>(diff_dst.GetOpMem().get_data_handle()));
+ }
// allocate diff_src tensor
MklDnnShape dnn_shape_diff_src;
TensorShape tf_shape_diff_src;
if (dnn_shape_src.IsMklTensor() ||
dnn_shape_diff_dst.IsMklTensor()) {
+ auto diff_src_pd = eltwise_bwd_pd->diff_src_primitive_desc();
dnn_shape_diff_src.SetMklTensor(true);
- auto diff_src_pd = relu_bwd_pd.diff_src_primitive_desc();
dnn_shape_diff_src.SetMklLayout(&diff_src_pd);
dnn_shape_diff_src.SetElemType(MklDnnType<T>());
if (dnn_shape_src.IsMklTensor()) {
@@ -602,25 +1019,18 @@ class MklReluGradOpBase : public OpKernel {
tf_shape_diff_src.AddDim(diff_src_pd.get_size() / sizeof(T));
} else {
dnn_shape_diff_src.SetMklTensor(false);
- // both src and diff_dst are TensorFlow layout,
- // so it is ok to get TensorFlow shape.
tf_shape_diff_src = src_tensor.shape();
}
- // Allocate diff_src and MklDnnShape tensors separately for possible
- // in-place operation
OP_REQUIRES_OK(context, context->forward_input_or_allocate_output(
- {static_cast<const int>(diff_dst_index)},
- static_cast<const int>(diff_src_index),
- tf_shape_diff_src,
- &diff_src_tensor));
+ {diff_dst_index}, diff_src_index,
+ tf_shape_diff_src, &diff_src_tensor));
AllocateOutputSetMklShape(context, diff_src_index, dnn_shape_diff_src);
- // diff_src memory descriptor is same as memory descriptor for both
- // inputs.
- diff_src.SetUsrMem(common_md, diff_src_tensor);
+ T* diff_src_data = diff_src_tensor->flat<T>().data();
- PrepareAndExecuteNet(relu_bwd_pd, &src, &diff_src, &diff_dst);
+ // execute eltwise bwd
+ eltwise_bwd->Execute(src_data, diff_dst_data, diff_src_data);
} catch (mkldnn::error& e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) + ", in file " +
@@ -631,22 +1041,9 @@ class MklReluGradOpBase : public OpKernel {
}
}
- void PrepareAndExecuteNet(const relu_backward::primitive_desc& relu_prim_desc,
- MklDnnData<T>* src, MklDnnData<T>* diff_src,
- MklDnnData<T>* diff_dst) {
- std::vector<primitive> net;
-
- // Check if we need to reorder original input tensors into common_md layout
- // that we set for primitive creation. diff_src_primitive_desc is same as
- // common_md.
- src->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(), &net);
- diff_dst->CheckReorderToOpMem(relu_prim_desc.diff_src_primitive_desc(),
- &net);
-
- net.push_back(relu_backward(relu_prim_desc, src->GetOpMem(),
- diff_dst->GetOpMem(), diff_src->GetOpMem()));
- stream(stream::kind::eager).submit(net).wait();
- }
+ private:
+ engine cpu_engine = engine(engine::cpu, 0);
+ std::shared_ptr<relu_forward::primitive_desc> relu_fwd_pd;
};
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/mkl_slice_op.cc b/tensorflow/core/kernels/mkl_slice_op.cc
new file mode 100644
index 0000000000..d63e14adf6
--- /dev/null
+++ b/tensorflow/core/kernels/mkl_slice_op.cc
@@ -0,0 +1,358 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// See docs in ../ops/array_ops.cc.
+
+#ifdef INTEL_MKL
+#ifndef INTEL_MKL_ML_ONLY
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/platform/prefetch.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+
+#include "mkldnn.hpp"
+#include "tensorflow/core/util/mkl_util.h"
+
+using mkldnn::stream;
+using mkldnn::view;
+
+namespace tensorflow {
+
+namespace {
+
+gtl::InlinedVector<int64, 4> IntTensorToInt64Vec(const Tensor& tensor) {
+ gtl::InlinedVector<int64, 4> out;
+ if (tensor.dtype() == DT_INT32) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int32>()(i));
+ }
+ } else if (tensor.dtype() == DT_INT64) {
+ for (int64 i = 0; i < tensor.NumElements(); ++i) {
+ out.push_back(tensor.flat<int64>()(i));
+ }
+ } else {
+ // tensor must be either int32 or int64
+ DCHECK(false);
+ }
+ return out;
+}
+
+} // namespace
+
+typedef Eigen::ThreadPoolDevice CPUDevice;
+
+// A version of SharedValidation (slice_op.h) written for input that is in
+// either Mkl layout or Tensorflow layout.
+// A shared code to validate input shapes and check for identity, which is not dependent on the type of T.
+// We do this to reduce code size by not duplicating all this for all T (float, double, int32, etc.)
+static void ValidateMklInputs(OpKernelContext* context, bool* is_identity,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size) {
+ const int kInputTensorIndex = 0;
+ const int kInputBeginIndex = 1;
+ const int kInputSizeIndex = 2;
+ const Tensor& input = MklGetInput(context, kInputTensorIndex);
+ const Tensor& begin_tensor = MklGetInput(context, kInputBeginIndex);
+ const Tensor& size_tensor = MklGetInput(context, kInputSizeIndex);
+
+ MklDnnShape input_mkl_shape, begin_mkl_shape, size_mkl_shape;
+ GetMklShape(context, kInputTensorIndex, &input_mkl_shape);
+ GetMklShape(context, kInputBeginIndex, &begin_mkl_shape);
+ GetMklShape(context, kInputSizeIndex, &size_mkl_shape);
+
+ // Begin and size tensors cannot be in MklDnn layout.
+ DCHECK_EQ(begin_mkl_shape.IsMklTensor(), false);
+ DCHECK_EQ(size_mkl_shape.IsMklTensor(), false);
+
+ TensorShape input_tf_shape = input_mkl_shape.IsMklTensor()
+ ? input_mkl_shape.GetTfShape()
+ : input.shape();
+ const int input_dims = input_tf_shape.dims();
+
+ OP_REQUIRES(
+ context, context->op_kernel().IsLegacyVector(begin_tensor.shape()) &&
+ context->op_kernel().IsLegacyVector(size_tensor.shape()) &&
+ begin_tensor.NumElements() == input_dims &&
+ size_tensor.NumElements() == input_dims,
+ errors::InvalidArgument(
+ "Expected begin and size arguments to be 1-D tensors of size ",
+ input_dims, ", but got shapes ", begin_tensor.shape().DebugString(),
+ " and ", size_tensor.shape().DebugString(), " instead."));
+
+ *begin = IntTensorToInt64Vec(begin_tensor);
+ *size = IntTensorToInt64Vec(size_tensor);
+ for (int i = 0; i < input_dims; ++i) {
+ if ((*size)[i] == -1) {
+ // A size[i] of -1 means "all elements from begin[i] to dim_size(i)".
+ (*size)[i] = input_tf_shape.dim_size(i) - (*begin)[i];
+ }
+ }
+
+ *is_identity = true;
+ for (int i = 0; i < input_dims; ++i) {
+ int64 b = (*begin)[i];
+ int64 s = (*size)[i];
+ if (input_tf_shape.dim_size(i) == 0) {
+ OP_REQUIRES(
+ context, b == 0 && s == 0,
+ errors::InvalidArgument("Expected begin[", i, "] == 0 (got ", b,
+ ") and size[", i, "] == 0 ", "(got ", s,
+ ") when ", "input.dim_size(", i, ") == 0"));
+ } else {
+ OP_REQUIRES(context, 0 <= b && b <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected begin[", i, "] in [0, ",
+ input_tf_shape.dim_size(i),
+ "], but got ", b));
+ OP_REQUIRES(context, 0 <= s && b + s <= input_tf_shape.dim_size(i),
+ errors::InvalidArgument("Expected size[", i, "] in [0, ",
+ input_tf_shape.dim_size(i) - b,
+ "], but ", "got ", s));
+ }
+ const bool take_all = (b == 0) && (s == input_tf_shape.dim_size(i));
+ (*is_identity) &= take_all;
+ }
+}
+
+// A version of SharedSliceCommonCases function written for input tensor
+// that may be in MklDnn layout or in Tensorflow layout.
+template <typename T>
+static void CheckCommonCasesForMklInputs(OpKernelContext* context,
+ gtl::InlinedVector<int64, 4>* begin,
+ gtl::InlinedVector<int64, 4>* size,
+ bool* done) {
+ bool is_identity = true;
+ *done = false;
+
+ ValidateMklInputs(context, &is_identity, begin, size);
+ if (!context->status().ok()) return;
+
+ const Tensor& input = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (is_identity) {
+ VLOG(1) << "Slice identity";
+ context->set_output(0, input);
+ // Mkl metadata tensor in this case can just be forwarded from input to
+ // output.
+ AllocateOutputSetMklShape(context, 0, input_mkl_shape);
+ *done = true;
+ }
+}
+
+// MKL-DNN implementation of Slice
+template <typename Device, typename T>
+class MklDnnSliceOp : public OpKernel {
+ public:
+ explicit MklDnnSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ ~MklDnnSliceOp() {}
+
+ void Compute(OpKernelContext* context) override {
+ gtl::InlinedVector<int64, 4> begin;
+ gtl::InlinedVector<int64, 4> size;
+ bool done = false;
+
+ CheckCommonCasesForMklInputs<T>(context, &begin, &size, &done);
+ if (!context->status().ok() || done == true) return;
+
+ // Though MKL-DNN supports more than 8 dimension and
+ // less than 12 dimension tensor.
+ // But we are mimicking functionality of Eigen Slice op for CPU.
+ if (begin.size() >= 8) {
+ OP_REQUIRES(
+ context, false,
+ errors::Unimplemented("MklDnnSliceOp : Unhandled input dimensions"));
+ }
+
+ ComputeMklDnnSlice(context, begin, size);
+ }
+
+ private:
+ // Slice op implemented using MKL-DNN APIs.
+ void ComputeMklDnnSlice(OpKernelContext* context,
+ const gtl::InlinedVector<int64, 4>& begin,
+ const gtl::InlinedVector<int64, 4>& size) {
+ try {
+ // MKL-DNN API usage below is guided by description at:
+ // https://github.com/01org/mkl-dnn/issues/69
+ //
+ // Relevant part of the description is copied below:
+ //
+ // Let's say you want to copy a part of memory into another buffer (and
+ // probably change the format). Then your steps are:
+ //
+ // 1. create memory primitive descriptor in_mem_pd and memory primitive
+ // in_mem_p for the entire source data.
+ // 2. create view primitive descriptor in_submem_pd based on in_mem_pd,
+ // initial offsets, and sub-sizes
+ // 3. create memory primitive descriptor out_mem_pd and memory primitive
+ // out_mem_p for the output (the logical sizes should match sub-sizes
+ // used in step 2, but the format might be arbitrary)
+ // 4. create reorder primitive descriptor reorder_pd based on in_submem_pd
+ // and out_mem_pd
+ // 5. create reorder primitive itself based on reorder_pd, in_mem_p, and
+ // out_mem_p.
+ //
+ // Please notice that there is no view primitive. There is only view
+ // primitive descriptor. And the reorder uses source memory as input but
+ // traverses it according to a view in_submem_pd.
+
+ auto cpu_engine = engine(engine::cpu, 0);
+ MklDnnData<T> src(&cpu_engine);
+ MklDnnData<T> output(&cpu_engine);
+
+ // Populate offsets and sizes in memory::dims format based on vector.
+ memory::dims begin_dims = {};
+ begin_dims.resize(begin.size());
+ for (size_t i = 0; i < begin.size(); ++i) begin_dims[i] = begin[i];
+ memory::dims size_dims = {};
+ bool empty = false;
+ size_dims.resize(size.size());
+ for (size_t i = 0; i < size.size(); ++i) {
+ size_dims[i] = size[i];
+ if (size_dims[i] == 0) empty = true;
+ }
+
+ Tensor* output_tensor = nullptr;
+ MklDnnShape output_mkl_shape;
+
+ // If no dimension is selected in slice, the result should be empty.
+ // Just return an empty output tensor, and a dummy Mkl-shape tensor.
+ if (empty) { // for empty dims
+ auto shape_to = MklDnnDimsToTFShape(size_dims);
+ AllocateOutputSetMklShape(context, 0, &output_tensor, shape_to,
+ output_mkl_shape);
+ return;
+ }
+
+ // Step 1 (as per above description) - Create memory for user data.
+ // We use blocked format here to describe input tensor.
+ const Tensor& input_tensor = MklGetInput(context, 0);
+ MklDnnShape input_mkl_shape;
+ GetMklShape(context, 0, &input_mkl_shape);
+
+ if (input_mkl_shape.IsMklTensor()) {
+ auto input_mkl_format = input_mkl_shape.GetTfDataFormat();
+ auto input_tf_format = MklDnnDataFormatToTFDataFormat(input_mkl_format);
+ begin_dims = MklDnnDimsInNCHW(begin_dims, input_tf_format);
+ size_dims = MklDnnDimsInNCHW(size_dims, input_tf_format);
+ auto input_md = input_mkl_shape.GetMklLayout();
+ src.SetUsrMem(input_md, &input_tensor);
+ } else {
+ // Initialize input dimensions and strides to be used when input is not
+ // in MklDnn layout.
+ memory::dims input_dims, input_strides;
+ input_dims = TFShapeToMklDnnDims(input_tensor.shape());
+ input_strides = CalculateTFStrides(input_dims);
+ // Create input memory descriptor.
+ auto input_md =
+ MklDnnData<T>::CreateBlockedMemDesc(input_dims, input_strides);
+ src.SetUsrMem(input_md, &input_tensor);
+ }
+
+ // Step 2 - create view primitive descriptor
+ auto view_pd =
+ view::primitive_desc(src.GetUsrMemPrimDesc(), size_dims, begin_dims)
+ .dst_primitive_desc();
+ auto output_strides = CalculateTFStrides(size_dims);
+ auto output_md =
+ MklDnnData<T>::CreateBlockedMemDesc(size_dims, output_strides);
+ auto output_pd = memory::primitive_desc(output_md, cpu_engine);
+
+ // Step 3 - Create memory for output. If input is in MklDnn layout, then
+ // output is also in MklDnn layout. Otherwise, output is in Tensorflow
+ // layout.
+ AllocateOutputTensor(context, input_mkl_shape, &output_pd, size_dims,
+ &output_tensor, &output_mkl_shape);
+ DCHECK(output_tensor);
+ DCHECK_EQ(input_mkl_shape.IsMklTensor(), output_mkl_shape.IsMklTensor());
+ output.SetUsrMem(output_md, output_tensor);
+
+ std::vector<primitive> net;
+ // Step 4 - create reorder primitive desc between view_pd and output_pd.
+ auto reorder_pd =
+ reorder::primitive_desc(view_pd, output.GetUsrMemPrimDesc());
+ // Step 5 - create reorder primitive itself.
+ net.push_back(reorder(reorder_pd, *src.GetUsrMem(), *output.GetUsrMem()));
+ // Execute the reorder primitive.
+ stream(stream::kind::eager).submit(net).wait();
+ } catch (mkldnn::error& e) {
+ string error_msg = "Status: " + std::to_string(e.status) + ", message: " +
+ string(e.message) + ", in file " + string(__FILE__) +
+ ":" + std::to_string(__LINE__);
+ OP_REQUIRES_OK(
+ context,
+ errors::Aborted("Operation received an exception:", error_msg));
+ }
+ }
+
+ private:
+ void AllocateOutputTensor(OpKernelContext* context,
+ const MklDnnShape& input_mkl_shape,
+ memory::primitive_desc* output_pd,
+ const memory::dims& output_dims,
+ Tensor** output_tensor,
+ MklDnnShape* output_mkl_shape) {
+ DCHECK(output_tensor);
+ DCHECK(output_mkl_shape);
+
+ TensorShape output_tf_shape;
+
+ if (input_mkl_shape.IsMklTensor()) {
+ // Since input tensor is in Mkl layout, output tensor will be in Mkl
+ // layout.
+
+ // Allocate shape of Mkl tensor.
+ output_mkl_shape->SetMklTensor(true);
+ output_mkl_shape->SetMklLayout(output_pd);
+ output_mkl_shape->SetElemType(MklDnnType<T>());
+ output_mkl_shape->SetTfLayout(input_mkl_shape.GetDimension(), output_dims,
+ input_mkl_shape.GetTfDataFormat());
+
+ output_tf_shape.AddDim((output_pd->get_size() / sizeof(T)) + 1);
+ } else {
+ // If input is not in Mkl layout, then output won't be in Mkl layout.
+ output_mkl_shape->SetMklTensor(false);
+ output_tf_shape = MklDnnDimsToTFShape(output_dims);
+ }
+
+ AllocateOutputSetMklShape(context, 0, output_tensor, output_tf_shape,
+ *output_mkl_shape);
+ }
+};
+
+// MKL-DNN Slice registration
+#define REGISTER_MKL_SLICE(type) \
+ REGISTER_KERNEL_BUILDER(Name("_MklSlice") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("begin") \
+ .HostMemory("size") \
+ .Label(mkl_op_registry::kMklOpLabel), \
+ MklDnnSliceOp<CPUDevice, type>);
+
+TF_CALL_float(REGISTER_MKL_SLICE);
+#undef REGISTER_MKL_SLICE
+
+} // namespace tensorflow
+
+#endif // INTEL_MKL_DNN
+#endif // INTEL_MKL
diff --git a/tensorflow/core/kernels/mkl_softmax_op.cc b/tensorflow/core/kernels/mkl_softmax_op.cc
index 8bde966be9..cfab529662 100644
--- a/tensorflow/core/kernels/mkl_softmax_op.cc
+++ b/tensorflow/core/kernels/mkl_softmax_op.cc
@@ -50,6 +50,7 @@ class MklSoftmaxOp : public OpKernel {
// src_tensor now points to the 0-th input of global data struct "context"
size_t src_idx = 0;
const Tensor& src_tensor = MklGetInput(context, src_idx);
+ const int input_dims = src_tensor.dims();
// Add: get MklShape
MklDnnShape src_mkl_shape;
@@ -62,7 +63,33 @@ class MklSoftmaxOp : public OpKernel {
: src_tensor.shape();
auto src_dims = TFShapeToMklDnnDims(src_tf_shape);
auto output_dims = src_dims;
-
+ memory::format layout_type;
+ // In MKL, data format passed to mkl softmax op depends on dimension of the input tensor.
+ // Here "x" data format in MKL is used for 1 dim tensor, "nc" for 2 dim tensor,
+ // "tnc" for 3 dim tensor, "nchw" for 4 dim tensor, and "ncdhw" for 5 dim tensor.
+ // Each of the simbols has the following meaning:
+ // n = batch, c = channels, t = sequence lenght, h = height,
+ // w = width, d = depth
+ switch (input_dims) {
+ case 1:
+ layout_type = memory::format::x;
+ break;
+ case 2:
+ layout_type = memory::format::nc;
+ break;
+ case 3:
+ layout_type = memory::format::tnc;
+ break;
+ case 4:
+ layout_type = memory::format::nchw;
+ break;
+ case 5:
+ layout_type = memory::format::ncdhw;
+ break;
+ default:
+ OP_REQUIRES_OK(context, errors::Aborted("Input dims must be <= 5 and >=1"));
+ return;
+ }
// Create softmax memory for src, dst: both are defined in mkl_util.h,
// they are wrapper
MklDnnData<T> src(&cpu_engine);
@@ -75,7 +102,7 @@ class MklSoftmaxOp : public OpKernel {
auto src_md =
src_mkl_shape.IsMklTensor()
? src_mkl_shape.GetMklLayout()
- : memory::desc(src_dims, MklDnnType<T>(), memory::format::nc);
+ : memory::desc(src_dims, MklDnnType<T>(), layout_type);
// src: setting memory descriptor and op memory descriptor
// Basically following two functions maps the TF "src_tensor" to mkl
@@ -84,10 +111,11 @@ class MklSoftmaxOp : public OpKernel {
// data format is "nc" for src and dst; since the src and dst buffer is
// always in 2D shape
src.SetUsrMem(src_md, &src_tensor);
- src.SetOpMemDesc(src_dims, memory::format::nc);
+ src.SetOpMemDesc(src_dims, layout_type);
// creating a memory descriptor
- int axis = 1; // axis to which softmax will be applied
+ // passing outermost dim as default axis, where the softmax is applied
+ int axis = input_dims - 1;
auto softmax_fwd_desc = softmax_forward::desc(prop_kind::forward_scoring,
src.GetOpMemDesc(), axis);
auto softmax_fwd_pd =
@@ -107,7 +135,7 @@ class MklSoftmaxOp : public OpKernel {
output_mkl_shape.SetMklLayout(&dst_pd);
output_mkl_shape.SetElemType(MklDnnType<T>());
output_mkl_shape.SetTfLayout(output_dims.size(), output_dims,
- memory::format::nc);
+ layout_type);
output_tf_shape.AddDim((dst_pd.get_size() / sizeof(T)));
} else { // then output is also TF shape
output_mkl_shape.SetMklTensor(false);
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 7a64788448..82dfece4a2 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -75,7 +75,7 @@ struct MultinomialFunctor<CPUDevice, T, OutputType> {
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
random::PhiloxRandom gen_copy = gen;
- // Skip takes units of 128 bytes. +3 is so rounding doesn't lead to
+ // Skip takes units of 128 bits. +3 is so rounding doesn't lead to
// us using the same state in different batches.
gen_copy.Skip(start_row * (num_samples + 3) / 4);
random::SimplePhilox simple_philox(&gen_copy);
diff --git a/tensorflow/core/kernels/multinomial_op.h b/tensorflow/core/kernels/multinomial_op.h
index 6e41060aa4..34e2123613 100644
--- a/tensorflow/core/kernels/multinomial_op.h
+++ b/tensorflow/core/kernels/multinomial_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_MULTINOMIAL_OP_H_
-#define TENSORFLOW_KERNELS_MULTINOMIAL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_
namespace tensorflow {
@@ -27,4 +27,4 @@ struct MultinomialFunctor;
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MULTINOMIAL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_MULTINOMIAL_OP_H_
diff --git a/tensorflow/core/kernels/neon/depthwiseconv_float.h b/tensorflow/core/kernels/neon/depthwiseconv_float.h
index 11f5be7c03..0d5a42bf10 100644
--- a/tensorflow/core/kernels/neon/depthwiseconv_float.h
+++ b/tensorflow/core/kernels/neon/depthwiseconv_float.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
-#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
+#define TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
#include "public/gemmlowp.h"
#include "tensorflow/core/kernels/neon/types.h"
@@ -722,4 +722,4 @@ void DepthwiseConv(const float* input_data, const Dims<4>& input_dims,
} // end namespace neon
} // end namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_H_
+#endif // TENSORFLOW_CORE_KERNELS_NEON_DEPTHWISECONV_FLOAT_H_
diff --git a/tensorflow/core/kernels/no_op.h b/tensorflow/core/kernels/no_op.h
index 29ea46aed6..9e16d06978 100644
--- a/tensorflow/core/kernels/no_op.h
+++ b/tensorflow/core/kernels/no_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_NO_OP_H_
-#define TENSORFLOW_KERNELS_NO_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NO_OP_H_
+#define TENSORFLOW_CORE_KERNELS_NO_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
@@ -29,4 +29,4 @@ class NoOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_NO_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_NO_OP_H_
diff --git a/tensorflow/core/kernels/non_max_suppression_op.cc b/tensorflow/core/kernels/non_max_suppression_op.cc
index c7d0d4de0d..81ce6d6e95 100644
--- a/tensorflow/core/kernels/non_max_suppression_op.cc
+++ b/tensorflow/core/kernels/non_max_suppression_op.cc
@@ -75,28 +75,28 @@ static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
}
// Return intersection-over-union overlap between boxes i and j
-static inline float IOUGreaterThanThreshold(
- typename TTypes<float, 2>::ConstTensor boxes, int i, int j,
- float iou_threshold) {
- const float ymin_i = std::min<float>(boxes(i, 0), boxes(i, 2));
- const float xmin_i = std::min<float>(boxes(i, 1), boxes(i, 3));
- const float ymax_i = std::max<float>(boxes(i, 0), boxes(i, 2));
- const float xmax_i = std::max<float>(boxes(i, 1), boxes(i, 3));
- const float ymin_j = std::min<float>(boxes(j, 0), boxes(j, 2));
- const float xmin_j = std::min<float>(boxes(j, 1), boxes(j, 3));
- const float ymax_j = std::max<float>(boxes(j, 0), boxes(j, 2));
- const float xmax_j = std::max<float>(boxes(j, 1), boxes(j, 3));
- const float area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
- const float area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
- if (area_i <= 0 || area_j <= 0) return 0.0;
- const float intersection_ymin = std::max<float>(ymin_i, ymin_j);
- const float intersection_xmin = std::max<float>(xmin_i, xmin_j);
- const float intersection_ymax = std::min<float>(ymax_i, ymax_j);
- const float intersection_xmax = std::min<float>(xmax_i, xmax_j);
- const float intersection_area =
- std::max<float>(intersection_ymax - intersection_ymin, 0.0) *
- std::max<float>(intersection_xmax - intersection_xmin, 0.0);
- const float iou = intersection_area / (area_i + area_j - intersection_area);
+template <typename T>
+static inline bool IOUGreaterThanThreshold(
+ typename TTypes<T, 2>::ConstTensor boxes, int i, int j, T iou_threshold) {
+ const T ymin_i = std::min<T>(boxes(i, 0), boxes(i, 2));
+ const T xmin_i = std::min<T>(boxes(i, 1), boxes(i, 3));
+ const T ymax_i = std::max<T>(boxes(i, 0), boxes(i, 2));
+ const T xmax_i = std::max<T>(boxes(i, 1), boxes(i, 3));
+ const T ymin_j = std::min<T>(boxes(j, 0), boxes(j, 2));
+ const T xmin_j = std::min<T>(boxes(j, 1), boxes(j, 3));
+ const T ymax_j = std::max<T>(boxes(j, 0), boxes(j, 2));
+ const T xmax_j = std::max<T>(boxes(j, 1), boxes(j, 3));
+ const T area_i = (ymax_i - ymin_i) * (xmax_i - xmin_i);
+ const T area_j = (ymax_j - ymin_j) * (xmax_j - xmin_j);
+ if (area_i <= static_cast<T>(0) || area_j <= static_cast<T>(0)) return 0;
+ const T intersection_ymin = std::max<T>(ymin_i, ymin_j);
+ const T intersection_xmin = std::max<T>(xmin_i, xmin_j);
+ const T intersection_ymax = std::min<T>(ymax_i, ymax_j);
+ const T intersection_xmax = std::min<T>(xmax_i, xmax_j);
+ const T intersection_area =
+ std::max<T>(intersection_ymax - intersection_ymin, static_cast<T>(0.0)) *
+ std::max<T>(intersection_xmax - intersection_xmin, static_cast<T>(0.0));
+ const T iou = intersection_area / (area_i + area_j - intersection_area);
return iou > iou_threshold;
}
@@ -106,11 +106,13 @@ static inline bool OverlapsGreaterThanThreshold(
return overlaps(i, j) > overlap_threshold;
}
+template <typename T>
static inline std::function<bool(int, int)> CreateIOUSuppressCheckFn(
const Tensor& boxes, float threshold) {
- typename TTypes<float, 2>::ConstTensor boxes_data = boxes.tensor<float, 2>();
- return std::bind(&IOUGreaterThanThreshold, boxes_data, std::placeholders::_1,
- std::placeholders::_2, threshold);
+ typename TTypes<T, 2>::ConstTensor boxes_data = boxes.tensor<T, 2>();
+ return std::bind(&IOUGreaterThanThreshold<T>, boxes_data,
+ std::placeholders::_1, std::placeholders::_2,
+ static_cast<T>(threshold));
}
static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
@@ -121,20 +123,21 @@ static inline std::function<bool(int, int)> CreateOverlapsSuppressCheckFn(
std::placeholders::_1, std::placeholders::_2, threshold);
}
+template <typename T>
void DoNonMaxSuppressionOp(
OpKernelContext* context, const Tensor& scores, int num_boxes,
const Tensor& max_output_size, const float score_threshold,
const std::function<bool(int, int)>& suppress_check_fn,
bool pad_to_max_output_size = false, int* ptr_num_valid_outputs = nullptr) {
- const int output_size = std::min(max_output_size.scalar<int>()(), num_boxes);
+ const int output_size = max_output_size.scalar<int>()();
- std::vector<float> scores_data(num_boxes);
- std::copy_n(scores.flat<float>().data(), num_boxes, scores_data.begin());
+ std::vector<T> scores_data(num_boxes);
+ std::copy_n(scores.flat<T>().data(), num_boxes, scores_data.begin());
// Data structure for selection candidate in NMS.
struct Candidate {
int box_index;
- float score;
+ T score;
};
auto cmp = [](const Candidate bs_i, const Candidate bs_j) {
@@ -143,13 +146,13 @@ void DoNonMaxSuppressionOp(
std::priority_queue<Candidate, std::deque<Candidate>, decltype(cmp)>
candidate_priority_queue(cmp);
for (int i = 0; i < scores_data.size(); ++i) {
- if (scores_data[i] > score_threshold) {
+ if (static_cast<float>(scores_data[i]) > score_threshold) {
candidate_priority_queue.emplace(Candidate({i, scores_data[i]}));
}
}
std::vector<int> selected;
- std::vector<float> selected_scores;
+ std::vector<T> selected_scores;
Candidate next_candidate;
while (selected.size() < output_size && !candidate_priority_queue.empty()) {
@@ -176,7 +179,7 @@ void DoNonMaxSuppressionOp(
int num_valid_outputs = selected.size();
if (pad_to_max_output_size) {
selected.resize(output_size, 0);
- selected_scores.resize(output_size, 0);
+ selected_scores.resize(output_size, static_cast<T>(0));
}
if (ptr_num_valid_outputs) {
*ptr_num_valid_outputs = num_valid_outputs;
@@ -221,18 +224,19 @@ class NonMaxSuppressionOp : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<float>(boxes, iou_threshold_);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
private:
float iou_threshold_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV2Op : public OpKernel {
public:
explicit NonMaxSuppressionV2Op(OpKernelConstruction* context)
@@ -264,11 +268,12 @@ class NonMaxSuppressionV2Op : public OpKernel {
if (!context->status().ok()) {
return;
}
- auto suppress_check_fn = CreateIOUSuppressCheckFn(boxes, iou_threshold_val);
+ auto suppress_check_fn =
+ CreateIOUSuppressCheckFn<T>(boxes, iou_threshold_val);
const float score_threshold_val = std::numeric_limits<float>::lowest();
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
@@ -325,7 +330,7 @@ class NonMaxSuppressionV3V4Base : public OpKernel {
float score_threshold_val_;
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV3Op(OpKernelConstruction* context)
@@ -334,14 +339,14 @@ class NonMaxSuppressionV3Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn);
}
};
-template <typename Device>
+template <typename Device, typename T>
class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
public:
explicit NonMaxSuppressionV4Op(OpKernelConstruction* context)
@@ -353,12 +358,12 @@ class NonMaxSuppressionV4Op : public NonMaxSuppressionV3V4Base {
protected:
void DoComputeAndPostProcess(OpKernelContext* context) override {
auto suppress_check_fn =
- CreateIOUSuppressCheckFn(boxes_, iou_threshold_val_);
+ CreateIOUSuppressCheckFn<T>(boxes_, iou_threshold_val_);
int num_valid_outputs;
- DoNonMaxSuppressionOp(context, scores_, num_boxes_, max_output_size_,
- score_threshold_val_, suppress_check_fn,
- pad_to_max_output_size_, &num_valid_outputs);
+ DoNonMaxSuppressionOp<T>(context, scores_, num_boxes_, max_output_size_,
+ score_threshold_val_, suppress_check_fn,
+ pad_to_max_output_size_, &num_valid_outputs);
// Allocate scalar output tensor for number of indices computed.
Tensor* num_outputs_t = nullptr;
@@ -413,22 +418,37 @@ class NonMaxSuppressionWithOverlapsOp : public OpKernel {
auto suppress_check_fn =
CreateOverlapsSuppressCheckFn(overlaps, overlap_threshold_val);
- DoNonMaxSuppressionOp(context, scores, num_boxes, max_output_size,
- score_threshold_val, suppress_check_fn);
+ DoNonMaxSuppressionOp<float>(context, scores, num_boxes, max_output_size,
+ score_threshold_val, suppress_check_fn);
}
};
REGISTER_KERNEL_BUILDER(Name("NonMaxSuppression").Device(DEVICE_CPU),
NonMaxSuppressionOp<CPUDevice>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2").Device(DEVICE_CPU),
- NonMaxSuppressionV2Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV2").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV2")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV2Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3").Device(DEVICE_CPU),
- NonMaxSuppressionV3Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV3").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV3")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV3Op<CPUDevice, Eigen::half>);
-REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4").Device(DEVICE_CPU),
- NonMaxSuppressionV4Op<CPUDevice>);
+REGISTER_KERNEL_BUILDER(
+ Name("NonMaxSuppressionV4").TypeConstraint<float>("T").Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("NonMaxSuppressionV4")
+ .TypeConstraint<Eigen::half>("T")
+ .Device(DEVICE_CPU),
+ NonMaxSuppressionV4Op<CPUDevice, Eigen::half>);
REGISTER_KERNEL_BUILDER(
Name("NonMaxSuppressionWithOverlaps").Device(DEVICE_CPU),
diff --git a/tensorflow/core/kernels/nth_element_op.h b/tensorflow/core/kernels/nth_element_op.h
index e7d25daecc..7a5ec3d0b5 100644
--- a/tensorflow/core/kernels/nth_element_op.h
+++ b/tensorflow/core/kernels/nth_element_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_NTH_ELEMENT_OP_H_
-#define TENSORFLOW_NTH_ELEMENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -34,4 +34,4 @@ struct NthElementFunctor {
} // namespace tensorflow
-#endif // TENSORFLOW_NTH_ELEMENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_NTH_ELEMENT_OP_H_
diff --git a/tensorflow/core/kernels/one_hot_op.h b/tensorflow/core/kernels/one_hot_op.h
index db59f0f0d4..879df2b59b 100644
--- a/tensorflow/core/kernels/one_hot_op.h
+++ b/tensorflow/core/kernels/one_hot_op.h
@@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/array_ops.cc
-#ifndef TENSORFLOW_KERNELS_ONE_HOT_OP_H_
-#define TENSORFLOW_KERNELS_ONE_HOT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_
// Generator definition for OneHotOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -69,4 +69,4 @@ struct OneHot {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_ONE_HOT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_ONE_HOT_OP_H_
diff --git a/tensorflow/core/kernels/ops_testutil.h b/tensorflow/core/kernels/ops_testutil.h
index 2c195beb7f..5d607b9044 100644
--- a/tensorflow/core/kernels/ops_testutil.h
+++ b/tensorflow/core/kernels/ops_testutil.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
-#define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
+#define TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
#include <memory>
#include <vector>
@@ -252,4 +252,4 @@ class OpsTestBase : public ::testing::Test {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
+#endif // TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
diff --git a/tensorflow/core/kernels/ops_util.h b/tensorflow/core/kernels/ops_util.h
index 93ef512778..a496487d1b 100644
--- a/tensorflow/core/kernels/ops_util.h
+++ b/tensorflow/core/kernels/ops_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_OPS_UTIL_H_
-#define TENSORFLOW_KERNELS_OPS_UTIL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_
// This file contains utilities for various operations.
@@ -113,4 +113,4 @@ gtl::InlinedVector<T, 8> ComputeEigenStrides(const EigenDimensions& shape) {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_OPS_UTIL_H_
+#endif // TENSORFLOW_CORE_KERNELS_OPS_UTIL_H_
diff --git a/tensorflow/core/kernels/pad_op.h b/tensorflow/core/kernels/pad_op.h
index ee9e0f0330..ae79f515d9 100644
--- a/tensorflow/core/kernels/pad_op.h
+++ b/tensorflow/core/kernels/pad_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PAD_OP_H_
-#define TENSORFLOW_KERNELS_PAD_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PAD_OP_H_
+#define TENSORFLOW_CORE_KERNELS_PAD_OP_H_
// Functor definition for PadOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -54,4 +54,4 @@ struct Pad<Device, T, Tpadding, 0> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PAD_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_PAD_OP_H_
diff --git a/tensorflow/core/kernels/padding_fifo_queue.h b/tensorflow/core/kernels/padding_fifo_queue.h
index 9d7c935068..b86b03c8f0 100644
--- a/tensorflow/core/kernels/padding_fifo_queue.h
+++ b/tensorflow/core/kernels/padding_fifo_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
-#define TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
#include <deque>
#include <vector>
@@ -86,4 +86,4 @@ class PaddingFIFOQueue : public FIFOQueue {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PADDING_FIFO_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_PADDING_FIFO_QUEUE_H_
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
index 0ab9ff9f65..aa70ee06f5 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.cc
@@ -47,7 +47,7 @@ using random::PhiloxRandom;
template <typename T>
struct TruncatedNormalFunctor<CPUDevice, T> {
- static const int kMaxIterations = 100;
+ static const int kMaxIterations = 1000;
void operator()(OpKernelContext* ctx, const CPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
@@ -124,6 +124,7 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
(normMin * (normMin - sqrtFactor)) / T(4)) /
(normMin + sqrtFactor);
const T diff = normMax - normMin;
+
if (diff < cutoff) {
// Sample from a uniform distribution on [normMin, normMax].
@@ -143,15 +144,20 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const auto u = dist(&gen_copy);
for (int i = 0; i < size; i++) {
- if (u[i] <= Eigen::numext::exp(g[i]) ||
- numIterations + 1 >= kMaxIterations) {
+ auto accept = u[i] <= Eigen::numext::exp(g[i]);
+ if (accept || numIterations + 1 >= kMaxIterations) {
// Accept the sample z.
// If we run out of iterations, just use the current uniform
- // sample. Emperically, the probability of accepting each sample
- // is at least 50% for typical inputs, so we will always accept
- // by 100 iterations.
- // This introduces a slight inaccuracy when at least one bound
- // is large, minval is negative and maxval is positive.
+ // sample, but emit a warning.
+ // TODO(jjhunt) For small entropies (relative to the bounds),
+ // this sampler is poor and may take many iterations since
+ // the proposal distribution is the uniform distribution
+ // U(lower_bound, upper_bound).
+ if (!accept) {
+ LOG(WARNING) << "TruncatedNormal uniform rejection sampler "
+ << "exceeded max iterations. Sample may contain "
+ << "outliers.";
+ }
output(sample) = z[i] * stddev + mean;
sample++;
if (sample >= limit_sample) {
@@ -181,8 +187,13 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
const T g = Eigen::numext::exp(-x * x / T(2.0));
const T u = rand[i];
i++;
- if ((u <= g && z < normMax) ||
- numIterations + 1 >= kMaxIterations) {
+ auto accept = (u <= g && z < normMax);
+ if (accept || numIterations + 1 >= kMaxIterations) {
+ if (!accept) {
+ LOG(WARNING) << "TruncatedNormal exponential distribution "
+ << "rejection sampler exceeds max iterations. "
+ << "Sample may contain outliers.";
+ }
output(sample) = z * stddev + mean;
sample++;
if (sample >= limit_sample) {
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op.h b/tensorflow/core/kernels/parameterized_truncated_normal_op.h
index cc801eb810..2e54db31fe 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op.h
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
-#define TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/lib/random/random_distributions.h"
@@ -49,4 +49,4 @@ struct TruncatedNormalFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_PARAMETERIZED_TRUNCATED_NORMAL_OP_H_
diff --git a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
index 661d47d925..5b80a962bc 100644
--- a/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/parameterized_truncated_normal_op_gpu.cu.cc
@@ -190,7 +190,7 @@ __global__ void __launch_bounds__(1024)
// Partial specialization for GPU
template <typename T>
struct TruncatedNormalFunctor<GPUDevice, T> {
- static const int kMaxIterations = 100;
+ static const int kMaxIterations = 1000;
void operator()(OpKernelContext* ctx, const GPUDevice& d, int64 num_batches,
int64 samples_per_batch, int64 num_elements,
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
index 8db78f9784..3979e4b53a 100644
--- a/tensorflow/core/kernels/partitioned_function_ops.cc
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/optimization_registry.h"
#include "tensorflow/core/common_runtime/placer.h"
#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
#include "tensorflow/core/framework/function.h"
@@ -96,22 +97,27 @@ class PartitionedCallOp : public AsyncOpKernel {
OP_REQUIRES_ASYNC(ctx, fbody != nullptr,
errors::Internal("Could not find handle ", handle),
done);
+ OP_REQUIRES_ASYNC(
+ ctx, args.size() == fbody->arg_nodes.size(),
+ errors::InvalidArgument(
+ "Wrong number of arguments to the op; function expects ",
+ fbody->arg_nodes.size(), " but PartitionedCall received ",
+ args.size()),
+ done);
+ // We need to pass global op_registry as default_registry when creating
+ // graph. So that graph optimization passes can lookup all possible ops
+ // by name.
auto graph = tensorflow::MakeUnique<Graph>(fbody->graph->flib_def());
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ graph.get()->AddFunctionLibrary(global_flib.ToProto()));
CopyGraph(*fbody->graph, graph.get());
- OP_REQUIRES_OK_ASYNC(ctx, PropagateInheritedDevices(graph.get(), args),
- done);
+ OP_REQUIRES_OK_ASYNC(ctx, PinResourceArgs(graph.get(), args), done);
DeviceSet device_set;
for (auto d : lib->device_mgr()->ListDevices()) {
device_set.AddDevice(d);
}
- Placer placer(graph.get(), &device_set);
- OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
-
- std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
- OP_REQUIRES_OK_ASYNC(
- ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
- done);
// The FunctionLibraryRuntime's library cannot be mutated from within
// an OpKernel, so functions are instantiated in an overlay library.
@@ -125,6 +131,47 @@ class PartitionedCallOp : public AsyncOpKernel {
new FunctionLibraryDefinition(*lib->GetFunctionLibraryDefinition());
overlay_libs_.emplace(lib, overlay_lib);
+ GraphOptimizationPassOptions optimization_options;
+ // TODO(akshayka): Thread SessionOptions (if any) into this kernel, or
+ // make it possible to specify the relevant options via attributes.
+ SessionOptions session_options;
+ session_options.env = ctx->env();
+ optimization_options.session_options = &session_options;
+ optimization_options.graph = &graph;
+ optimization_options.flib_def = overlay_lib;
+ optimization_options.device_set = &device_set;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::PRE_PLACEMENT, optimization_options),
+ done);
+ Placer placer(graph.get(), &device_set);
+ OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PLACEMENT, optimization_options),
+ done);
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_REWRITE_FOR_EXEC,
+ optimization_options),
+ done);
+
+ std::unordered_map<string, std::unique_ptr<Graph>> subgraphs;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, PartitionHelper(device_set, std::move(graph), &subgraphs),
+ done);
+ optimization_options.graph = nullptr;
+ optimization_options.device_set = nullptr;
+ optimization_options.partition_graphs = &subgraphs;
+ OP_REQUIRES_OK_ASYNC(ctx,
+ OptimizationPassRegistry::Global()->RunGrouping(
+ OptimizationPassRegistry::POST_PARTITIONING,
+ optimization_options),
+ done);
+
auto handles = tensorflow::MakeUnique<gtl::FlatMap<string, FHandle>>();
for (const auto& pair : subgraphs) {
// TODO(akshayka): Fail gracefully if the set of devices corresponds
@@ -163,15 +210,10 @@ class PartitionedCallOp : public AsyncOpKernel {
std::vector<AllocatorAttributes>>
ArgAndRetAllocAttrs;
- // Propagates device annotations from the outer graph to the function body.
- //
// Pins each arg that emits a `DT_RESOURCE` tensor to the device on which the
// corresponding resource lives. This ensures that the Placer assigns ops that
- // access these resources to the appropriate devices. Additionally, places
- // nodes that are unadorned with device annotations onto PartitiondCallOp's
- // device. This lets call-site device annotations influence the execution
- // of the function.
- Status PropagateInheritedDevices(Graph* graph, const OpInputList& args) {
+ // access these resources to the appropriate devices.
+ Status PinResourceArgs(Graph* graph, const OpInputList& args) {
for (Node* node : graph->op_nodes()) {
string node_type = node->type_string();
if (node_type == FunctionLibraryDefinition::kArgOp) {
@@ -181,21 +223,9 @@ class PartitionedCallOp : public AsyncOpKernel {
TF_RETURN_IF_ERROR(node->attrs().Find("T", &attr_value));
DataType dtype = attr_value->type();
if (dtype == DT_RESOURCE) {
- ResourceHandle handle = args[index].flat<ResourceHandle>()(0);
+ const ResourceHandle& handle = args[index].flat<ResourceHandle>()(0);
node->set_assigned_device_name(handle.device());
}
- } else if (node_type != FunctionLibraryDefinition::kRetOp) {
- // All non-RetVal nodes that weren't explicitly placed by the user
- // inherit PartitionedCallOp's device. RetVal placement is inferred by
- // the placer, to avoid forcing the function's outputs through a single
- // device.
- //
- // TODO(b/112166045): Plumb the original requested device into this
- // OpKernel (this->requested_device() isn't reliable), and merge it
- // with node->requested_device() if possible.
- if (node->requested_device().empty()) {
- node->set_requested_device(local_device_name_);
- }
}
}
return Status::OK();
@@ -233,9 +263,11 @@ class PartitionedCallOp : public AsyncOpKernel {
VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
<< partitions.size() << " shards.";
- const FunctionLibraryDefinition* flib_def = &graph->flib_def();
for (const auto& partition : partitions) {
- std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ std::unique_ptr<Graph> subgraph(new Graph(graph->flib_def()));
+ FunctionLibraryDefinition global_flib(OpRegistry::Global(), {});
+ TF_CHECK_OK(
+ subgraph.get()->AddFunctionLibrary(global_flib.ToProto()));
GraphConstructorOptions opts;
opts.allow_internal_ops = true;
opts.expect_device_spec = true;
diff --git a/tensorflow/core/kernels/poisson-loss.h b/tensorflow/core/kernels/poisson-loss.h
new file mode 100644
index 0000000000..f91244454e
--- /dev/null
+++ b/tensorflow/core/kernels/poisson-loss.h
@@ -0,0 +1,109 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_POISSON_LOSS_H_
+
+#include <cmath>
+
+#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+class PoissonLossUpdater : public DualLossUpdater {
+ public:
+ // Update is found by a Newton algorithm (see readme.md).
+ double ComputeUpdatedDual(const int num_loss_partitions, const double label,
+ const double example_weight,
+ const double current_dual, const double wx,
+ const double weighted_example_norm) const final {
+ // Newton algorithm converges quadratically so 10 steps will be largely
+ // enough to achieve a very good precision
+ static const int newton_total_steps = 10;
+ // Initialize the Newton optimization at x such that
+ // exp(x) = label - current_dual
+ const double y_minus_a = label - current_dual;
+ double x = (y_minus_a > 0) ? log(y_minus_a) : 0;
+ for (int i = 0; i < newton_total_steps; ++i) {
+ x = NewtonStep(x, num_loss_partitions, label, wx, example_weight,
+ weighted_example_norm, current_dual);
+ }
+ return label - exp(x);
+ }
+
+ // Dual of poisson loss function.
+ // https://en.wikipedia.org/wiki/Convex_conjugate
+ double ComputeDualLoss(const double current_dual, const double example_label,
+ const double example_weight) const final {
+ // Dual of the poisson loss function is
+ // (y-a)*(log(y-a)-1), where a is the dual variable.
+ // It is defined only for a<y.
+ const double y_minus_a = example_label - current_dual;
+ if (y_minus_a == 0.0) {
+ // (y-a)*(log(y-a)-1) approaches 0 as y-a approaches 0.
+ return 0.0;
+ }
+ if (y_minus_a < 0.0) {
+ return std::numeric_limits<double>::max();
+ }
+ return y_minus_a * (log(y_minus_a) - 1) * example_weight;
+ }
+
+ double ComputePrimalLoss(const double wx, const double example_label,
+ const double example_weight) const final {
+ return (exp(wx) - wx * example_label) * example_weight;
+ }
+
+ double PrimalLossDerivative(const double wx, const double label,
+ const double example_weight) const final {
+ return (exp(wx) - label) * example_weight;
+ }
+
+ // TODO(chapelle): We need to introduce a maximum_prediction parameter,
+ // expose that parameter to the user and have this method return
+ // 1.0/maximum_prediction.
+ // Setting this at 1 for now, it only impacts the adaptive sampling.
+ double SmoothnessConstant() const final { return 1; }
+
+ Status ConvertLabel(float* const example_label) const final {
+ if (*example_label < 0.0) {
+ return errors::InvalidArgument(
+ "Only non-negative labels can be used with the Poisson log loss. "
+ "Found example with label: ", *example_label);
+ }
+ return Status::OK();
+ }
+
+ private:
+ // One Newton step (see readme.md).
+ double NewtonStep(const double x, const int num_loss_partitions,
+ const double label, const double wx,
+ const double example_weight,
+ const double weighted_example_norm,
+ const double current_dual) const {
+ const double expx = exp(x);
+ const double numerator =
+ x - wx - num_loss_partitions * weighted_example_norm *
+ example_weight * (label - current_dual - expx);
+ const double denominator =
+ 1 + num_loss_partitions * weighted_example_norm * example_weight * expx;
+ return x - numerator / denominator;
+ }
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_LOGISTIC_LOSS_H_
diff --git a/tensorflow/core/kernels/pooling_ops_3d.h b/tensorflow/core/kernels/pooling_ops_3d.h
index d1be3ba407..319b17397e 100644
--- a/tensorflow/core/kernels/pooling_ops_3d.h
+++ b/tensorflow/core/kernels/pooling_ops_3d.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_POOLING_OPS_3D_H_
-#define TENSORFLOW_KERNELS_POOLING_OPS_3D_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/util/padding.h"
@@ -77,4 +77,4 @@ struct Pool3dParameters {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_POOLING_OPS_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_H_
diff --git a/tensorflow/core/kernels/pooling_ops_3d_gpu.h b/tensorflow/core/kernels/pooling_ops_3d_gpu.h
index 350b1b6732..2c3681455e 100644
--- a/tensorflow/core/kernels/pooling_ops_3d_gpu.h
+++ b/tensorflow/core/kernels/pooling_ops_3d_gpu.h
@@ -17,8 +17,8 @@ limitations under the License.
#error This file must only be included when building with Cuda support
#endif
-#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_GPU_H_
-#define TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_GPU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
#define EIGEN_USE_GPU
@@ -45,4 +45,4 @@ struct MaxPool3dGradBackward {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_KERNELS_POOLING_OP_3D_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_3D_GPU_H_
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index e9265551e3..dda2c80c49 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_
#include <vector>
@@ -605,4 +605,4 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_POOLING_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_POOLING_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/priority_queue.h b/tensorflow/core/kernels/priority_queue.h
index ff168df449..8e69b5b699 100644
--- a/tensorflow/core/kernels/priority_queue.h
+++ b/tensorflow/core/kernels/priority_queue.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_PRIORITY_QUEUE_H_
-#define TENSORFLOW_KERNELS_PRIORITY_QUEUE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_
+#define TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_
#include <deque>
#include <queue>
@@ -90,4 +90,4 @@ class PriorityQueue
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_PRIORITY_QUEUE_H_
+#endif // TENSORFLOW_CORE_KERNELS_PRIORITY_QUEUE_H_
diff --git a/tensorflow/core/kernels/qr_op_complex128.cc b/tensorflow/core/kernels/qr_op_complex128.cc
index c5b73139bb..8a3e3dc0a9 100644
--- a/tensorflow/core/kernels/qr_op_complex128.cc
+++ b/tensorflow/core/kernels/qr_op_complex128.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex128>), complex128);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex128>), complex128);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<complex128>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<complex128>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_complex64.cc b/tensorflow/core/kernels/qr_op_complex64.cc
index 4e14f2639c..467fa6c2d6 100644
--- a/tensorflow/core/kernels/qr_op_complex64.cc
+++ b/tensorflow/core/kernels/qr_op_complex64.cc
@@ -20,7 +20,11 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<complex64>), complex64);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+// REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<complex64>), complex64);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_double.cc b/tensorflow/core/kernels/qr_op_double.cc
index 51885eb355..05537a0eaa 100644
--- a/tensorflow/core/kernels/qr_op_double.cc
+++ b/tensorflow/core/kernels/qr_op_double.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<double>), double);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<double>), double);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<double>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<double>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_float.cc b/tensorflow/core/kernels/qr_op_float.cc
index d0a1dd4204..6aebd98186 100644
--- a/tensorflow/core/kernels/qr_op_float.cc
+++ b/tensorflow/core/kernels/qr_op_float.cc
@@ -20,7 +20,17 @@ namespace tensorflow {
REGISTER_LINALG_OP("Qr", (QrOp<float>), float);
#if GOOGLE_CUDA
-REGISTER_LINALG_OP_GPU("Qr", (QrOpGpu<float>), float);
+// We temporarily disable QR on GPU due to a bug in the QR implementation in
+// cuSolver affecting older hardware. The cuSolver team is tracking the issue
+// (https://partners.nvidia.com/bug/viewbug/2171459) and we will re-enable
+// this feature when a fix is available.
+REGISTER_KERNEL_BUILDER(Name("Qr")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .HostMemory("input")
+ .HostMemory("q")
+ .HostMemory("r"),
+ QrOp<float>);
#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/qr_op_impl.h b/tensorflow/core/kernels/qr_op_impl.h
index 0552c034d2..535df9d160 100644
--- a/tensorflow/core/kernels/qr_op_impl.h
+++ b/tensorflow/core/kernels/qr_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual qr_*op*.cc files for registering
@@ -292,6 +295,8 @@ class QrOpGpu : public AsyncOpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(QrOpGpu);
};
-#endif
+#endif // GOOGLE_CUDA
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_QR_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/queue_base.h b/tensorflow/core/kernels/queue_base.h
index 5fb1c92f94..272aa3b4f5 100644
--- a/tensorflow/core/kernels/queue_base.h
+++ b/tensorflow/core/kernels/queue_base.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <deque>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/queue_interface.h"
#include "tensorflow/core/framework/tensor.h"
@@ -82,6 +83,9 @@ class QueueBase : public QueueInterface {
// NOTE(mrry): This method is deprecated. Use
// `tensorflow::batch_util::CopySliceToElement()` defined in
// "./batch_util.h" instead.
+ ABSL_DEPRECATED(
+ "Use `tensorflow::batch_util::CopySliceToElement()` defined in "
+ "\"./batch_util.h\" instead.")
static Status CopyElementToSlice(const Tensor& element, Tensor* parent,
int64 index);
diff --git a/tensorflow/core/kernels/queue_ops.cc b/tensorflow/core/kernels/queue_ops.cc
index c4d404259b..97ddc852f7 100644
--- a/tensorflow/core/kernels/queue_ops.cc
+++ b/tensorflow/core/kernels/queue_ops.cc
@@ -65,7 +65,7 @@ class FakeQueueOp : public OpKernel {
}
void Compute(OpKernelContext* context) override {
- ResourceHandle ref = context->input(0).flat<ResourceHandle>()(0);
+ const ResourceHandle& ref = context->input(0).flat<ResourceHandle>()(0);
handle_.AccessTensor(context)->flat<string>()(0) = ref.container();
handle_.AccessTensor(context)->flat<string>()(1) = ref.name();
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
diff --git a/tensorflow/core/kernels/random_op.cc b/tensorflow/core/kernels/random_op.cc
index e37232539f..04a53697c0 100644
--- a/tensorflow/core/kernels/random_op.cc
+++ b/tensorflow/core/kernels/random_op.cc
@@ -231,7 +231,13 @@ class RandomUniformIntOp : public OpKernel {
errors::InvalidArgument("maxval must be 0-D, got shape ",
maxval.shape().DebugString()));
- // Verify that minval < maxval
+ // Allocate output, and exit early if possible
+ Tensor* output;
+ OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
+ if (output->NumElements() == 0) return;
+
+ // Verify that minval < maxval. This check intentionally happens after the
+ // early exit for empty output. Zero impossible things are fine.
IntType lo = minval.scalar<IntType>()();
IntType hi = maxval.scalar<IntType>()();
OP_REQUIRES(
@@ -243,8 +249,6 @@ class RandomUniformIntOp : public OpKernel {
Distribution;
Distribution dist(lo, hi);
- Tensor* output;
- OP_REQUIRES_OK(ctx, AllocateOutputWithShape(ctx, shape, 0, &output));
auto output_flat = output->flat<IntType>();
functor::FillPhiloxRandom<Device, Distribution>()(
ctx, ctx->eigen_device<Device>(),
diff --git a/tensorflow/core/kernels/random_op.h b/tensorflow/core/kernels/random_op.h
index 97bcaf1a49..d313a021dd 100644
--- a/tensorflow/core/kernels/random_op.h
+++ b/tensorflow/core/kernels/random_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RANDOM_OP_H_
-#define TENSORFLOW_KERNELS_RANDOM_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/lib/random/random_distributions.h"
@@ -69,4 +69,4 @@ struct FillPhiloxRandom<SYCLDevice, Distribution> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RANDOM_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RANDOM_OP_H_
diff --git a/tensorflow/core/kernels/random_poisson_op.h b/tensorflow/core/kernels/random_poisson_op.h
index 4e9fd62520..62ae01c16c 100644
--- a/tensorflow/core/kernels/random_poisson_op.h
+++ b/tensorflow/core/kernels/random_poisson_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
-#define TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
namespace tensorflow {
@@ -28,4 +28,4 @@ struct PoissonFunctor;
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RANDOM_POISSON_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RANDOM_POISSON_OP_H_
diff --git a/tensorflow/core/kernels/range_sampler.h b/tensorflow/core/kernels/range_sampler.h
index 3010666598..ed160adfb4 100644
--- a/tensorflow/core/kernels/range_sampler.h
+++ b/tensorflow/core/kernels/range_sampler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
-#define TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
+#define TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
#include <vector>
@@ -249,4 +249,4 @@ class FixedUnigramSampler : public RangeSampler {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RANGE_SAMPLER_H_
+#endif // TENSORFLOW_CORE_KERNELS_RANGE_SAMPLER_H_
diff --git a/tensorflow/core/kernels/range_sampler_test.cc b/tensorflow/core/kernels/range_sampler_test.cc
index 9020121169..3d49af7cb1 100644
--- a/tensorflow/core/kernels/range_sampler_test.cc
+++ b/tensorflow/core/kernels/range_sampler_test.cc
@@ -45,7 +45,7 @@ class RangeSamplerTest : public ::testing::Test {
// Using a fixed random seed to make the test deterministic.
random::PhiloxRandom philox(123, 17);
random::SimplePhilox rnd(&philox);
- sampler_->SampleBatch(&rnd, false, &a);
+ sampler_->SampleBatch(&rnd, false, absl::MakeSpan(a));
for (int i = 0; i < num_samples; i++) {
int64 val = a[i];
ASSERT_GE(val, 0);
@@ -251,8 +251,9 @@ TEST_F(RangeSamplerTest, All) {
extras[0] = 0;
extras[1] = batch_size - 1;
sampler_->SampleBatchGetExpectedCount(nullptr, // no random numbers needed
- false, &batch, &batch_expected, extras,
- &extras_expected);
+ false, absl::MakeSpan(batch),
+ absl::MakeSpan(batch_expected), extras,
+ absl::MakeSpan(extras_expected));
for (int i = 0; i < batch_size; i++) {
EXPECT_EQ(i, batch[i]);
EXPECT_EQ(1, batch_expected[i]);
@@ -281,17 +282,18 @@ TEST_F(RangeSamplerTest, Unique) {
std::vector<float> expected(range);
// Sample one batch and get the expected counts of all values
- sampler_->SampleBatchGetExpectedCount(
- &rnd, true, &batch, MutableArraySlice<float>(), all_values, &expected);
+ sampler_->SampleBatchGetExpectedCount(&rnd, true, absl::MakeSpan(batch),
+ MutableArraySlice<float>(), all_values,
+ absl::MakeSpan(expected));
// Check that all elements are unique
std::set<int64> s(batch.begin(), batch.end());
CHECK_EQ(batch_size, s.size());
for (int trial = 0; trial < num_batches; trial++) {
std::vector<float> trial_expected(range);
- sampler_->SampleBatchGetExpectedCount(&rnd, true, &batch,
- MutableArraySlice<float>(),
- all_values, &trial_expected);
+ sampler_->SampleBatchGetExpectedCount(
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ all_values, absl::MakeSpan(trial_expected));
for (int i = 0; i < range; i++) {
EXPECT_NEAR(expected[i], trial_expected[i], expected[i] * 0.5);
}
@@ -318,8 +320,8 @@ TEST_F(RangeSamplerTest, Avoid) {
// We expect to pick all elements of [0, 100) except the avoided two.
sampler_->SampleBatchGetExpectedCountAvoid(
- &rnd, true, &batch, MutableArraySlice<float>(), ArraySlice<int64>(),
- MutableArraySlice<float>(), avoided);
+ &rnd, true, absl::MakeSpan(batch), MutableArraySlice<float>(),
+ ArraySlice<int64>(), MutableArraySlice<float>(), avoided);
int sum = 0;
for (auto val : batch) {
diff --git a/tensorflow/core/kernels/record_yielder.h b/tensorflow/core/kernels/record_yielder.h
index 34817ad51b..159b43b4cd 100644
--- a/tensorflow/core/kernels/record_yielder.h
+++ b/tensorflow/core/kernels/record_yielder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RECORD_YIELDER_H_
-#define TENSORFLOW_KERNELS_RECORD_YIELDER_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
+#define TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
#include <atomic>
#include <random>
@@ -157,4 +157,4 @@ class RecordYielder {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RECORD_YIELDER_H_
+#endif // TENSORFLOW_CORE_KERNELS_RECORD_YIELDER_H_
diff --git a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
index 9af4cc23b6..bb8254eaac 100644
--- a/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
+++ b/tensorflow/core/kernels/reduction_gpu_kernels.cu.h
@@ -13,16 +13,19 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_segmented_reduce.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
-#include "external/cub_archive/cub/warp/warp_reduce.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_segmented_reduce.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/warp/warp_reduce.cuh"
#include "cuda/include/cuComplex.h"
#include "tensorflow/core/kernels/reduction_ops.h"
#include "tensorflow/core/lib/core/bits.h"
@@ -1058,4 +1061,6 @@ struct ReduceFunctor<GPUDevice, Eigen::internal::OrReducer> {
} // namespace functor
} // namespace tensorflow
-#endif
+#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_GPU_KERNELS_CU_H_
diff --git a/tensorflow/core/kernels/reduction_ops.h b/tensorflow/core/kernels/reduction_ops.h
index e43d2828f3..eb264e0e5a 100644
--- a/tensorflow/core/kernels/reduction_ops.h
+++ b/tensorflow/core/kernels/reduction_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_H_
-#define TENSORFLOW_KERNELS_REDUCTION_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
// Functor definitions for Reduction ops, must be compilable by nvcc.
@@ -79,4 +79,4 @@ struct ReduceFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_H_
diff --git a/tensorflow/core/kernels/reduction_ops_common.h b/tensorflow/core/kernels/reduction_ops_common.h
index 03d6e82e01..d83e1c7d15 100644
--- a/tensorflow/core/kernels/reduction_ops_common.h
+++ b/tensorflow/core/kernels/reduction_ops_common.h
@@ -18,8 +18,8 @@ limitations under the License.
// is a header file because we split the various reduction ops into their
// own compilation units to get more parallelism in compilation.
-#ifndef TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
-#define TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
+#define TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
#define EIGEN_USE_THREADS
@@ -277,4 +277,4 @@ struct ReduceFunctor<SYCLDevice, Reducer>
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_REDUCTION_OPS_COMMON_H_
+#endif // TENSORFLOW_CORE_KERNELS_REDUCTION_OPS_COMMON_H_
diff --git a/tensorflow/core/kernels/reduction_ops_max.cc b/tensorflow/core/kernels/reduction_ops_max.cc
index 9cf953f4bf..8bfa44b2d0 100644
--- a/tensorflow/core/kernels/reduction_ops_max.cc
+++ b/tensorflow/core/kernels/reduction_ops_max.cc
@@ -50,6 +50,8 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MaxReducer<type>>);
+
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
REGISTER_GPU_KERNELS(int64);
diff --git a/tensorflow/core/kernels/regex_full_match_op.cc b/tensorflow/core/kernels/regex_full_match_op.cc
index 5863a2c8e4..7edaaad8f7 100644
--- a/tensorflow/core/kernels/regex_full_match_op.cc
+++ b/tensorflow/core/kernels/regex_full_match_op.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -56,4 +57,36 @@ class RegexFullMatchOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexFullMatch").Device(DEVICE_CPU),
RegexFullMatchOp);
+class StaticRegexFullMatchOp : public OpKernel {
+ public:
+ explicit StaticRegexFullMatchOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<string>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<bool>();
+ for (size_t i = 0; i < input_flat.size(); ++i) {
+ output_flat(i) = RE2::FullMatch(input_flat(i), *re_);
+ }
+ }
+
+ private:
+ std::unique_ptr<RE2> re_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexFullMatch").Device(DEVICE_CPU),
+ StaticRegexFullMatchOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/regex_replace_op.cc b/tensorflow/core/kernels/regex_replace_op.cc
index 59ec854a79..a1b948891d 100644
--- a/tensorflow/core/kernels/regex_replace_op.cc
+++ b/tensorflow/core/kernels/regex_replace_op.cc
@@ -20,8 +20,43 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
+namespace {
+
+// Execute the specified regex using the given context.
+// Context requirements:
+// - "input" string Tensor at input_index=0
+// - "output" string Tensor at output_index=0
+Status InternalCompute(const RE2& match, const string& rewrite,
+ const bool replace_global, OpKernelContext* ctx) {
+ const Tensor* input_tensor;
+ TF_RETURN_IF_ERROR(ctx->input("input", &input_tensor));
+ Tensor* output_tensor;
+ std::unique_ptr<Tensor> maybe_forwarded =
+ ctx->forward_input(0 /*input_index*/, 0 /*output_index*/,
+ tensorflow::DT_STRING, input_tensor->shape(),
+ ctx->input_memory_type(0), ctx->input_alloc_attr(0));
+ if (maybe_forwarded) {
+ output_tensor = maybe_forwarded.get();
+ TF_RETURN_IF_ERROR(ctx->set_output("output", *output_tensor));
+ } else {
+ TF_RETURN_IF_ERROR(
+ ctx->allocate_output("output", input_tensor->shape(), &output_tensor));
+ output_tensor->flat<string>() = input_tensor->flat<string>();
+ }
+ auto output_flat = output_tensor->flat<string>();
+ for (size_t i = 0; i < output_flat.size(); ++i) {
+ if (replace_global) {
+ RE2::GlobalReplace(&output_flat(i), match, rewrite);
+ } else {
+ RE2::Replace(&output_flat(i), match, rewrite);
+ }
+ }
+ return Status::OK();
+}
+} // namespace
class RegexReplaceOp : public OpKernel {
public:
@@ -30,10 +65,6 @@ class RegexReplaceOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- const Tensor* input_tensor;
- OP_REQUIRES_OK(ctx, ctx->input("input", &input_tensor));
- const auto& input_flat = input_tensor->flat<string>();
-
const Tensor* pattern_tensor;
OP_REQUIRES_OK(ctx, ctx->input("pattern", &pattern_tensor));
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(pattern_tensor->shape()),
@@ -51,19 +82,7 @@ class RegexReplaceOp : public OpKernel {
errors::InvalidArgument("Rewrite must be scalar, but received ",
rewrite_tensor->shape().DebugString()));
const string rewrite = rewrite_tensor->flat<string>()(0);
-
- Tensor* output_tensor = nullptr;
- OP_REQUIRES_OK(ctx, ctx->allocate_output("output", input_tensor->shape(),
- &output_tensor));
- auto output_flat = output_tensor->flat<string>();
- for (size_t i = 0; i < input_flat.size(); ++i) {
- output_flat(i) = input_flat(i);
- if (replace_global_) {
- RE2::GlobalReplace(&output_flat(i), match, rewrite);
- } else {
- RE2::Replace(&output_flat(i), match, rewrite);
- }
- }
+ OP_REQUIRES_OK(ctx, InternalCompute(match, rewrite, replace_global_, ctx));
}
private:
@@ -73,4 +92,31 @@ class RegexReplaceOp : public OpKernel {
REGISTER_KERNEL_BUILDER(Name("RegexReplace").Device(DEVICE_CPU),
RegexReplaceOp);
+class StaticRegexReplaceOp : public OpKernel {
+ public:
+ explicit StaticRegexReplaceOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string pattern;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("pattern", &pattern));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("rewrite", &rewrite_str_));
+ re_ = MakeUnique<RE2>(pattern);
+ OP_REQUIRES(ctx, re_->ok(),
+ errors::InvalidArgument("Invalid pattern: ", pattern,
+ ", error: ", re_->error()));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("replace_global", &replace_global_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ OP_REQUIRES_OK(ctx,
+ InternalCompute(*re_, rewrite_str_, replace_global_, ctx));
+ }
+
+ private:
+ string rewrite_str_;
+ std::unique_ptr<RE2> re_;
+ bool replace_global_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StaticRegexReplace").Device(DEVICE_CPU),
+ StaticRegexReplaceOp);
+
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/regex_replace_op_test.cc b/tensorflow/core/kernels/regex_replace_op_test.cc
new file mode 100644
index 0000000000..9691d4a89f
--- /dev/null
+++ b/tensorflow/core/kernels/regex_replace_op_test.cc
@@ -0,0 +1,137 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+const char kRegExPattern[] = "\\p{P}";
+const char kRewrite[] = " ";
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupRegexReplaceGraph(const Tensor& input, const string& input_pattern,
+ const string& input_rewrite) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor pattern(DT_STRING, TensorShape({}));
+ pattern.flat<string>().setConstant(input_pattern);
+ Tensor rewrite(DT_STRING, TensorShape({}));
+ rewrite.flat<string>().setConstant(input_rewrite);
+
+ TF_CHECK_OK(NodeBuilder("regex_replace_op", "RegexReplace")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, pattern))
+ .Input(test::graph::Constant(g, rewrite))
+ .Attr("replace_global", true)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_RegexReplace(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupRegexReplaceGraph(input, kRegExPattern, kRewrite);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_RegexReplace)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+Graph* SetupStaticGraph(const Tensor& input, const string& input_pattern,
+ const string& rewrite) {
+ Graph* g = new Graph(OpRegistry::Global());
+
+ TF_CHECK_OK(NodeBuilder("static_regex_replace_op", "StaticRegexReplace")
+ .Attr("pattern", input_pattern)
+ .Attr("rewrite", rewrite)
+ .Input(test::graph::Constant(g, input))
+ .Attr("replace_global", true)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+void BM_StaticRegexReplace(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStaticGraph(input, kRegExPattern, kRewrite);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StaticRegexReplace)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/relu_op.cc b/tensorflow/core/kernels/relu_op.cc
index d52358737f..173fea37ed 100644
--- a/tensorflow/core/kernels/relu_op.cc
+++ b/tensorflow/core/kernels/relu_op.cc
@@ -124,6 +124,12 @@ namespace functor {
typename TTypes<T>::Tensor backprops); \
extern template struct SeluGrad<GPUDevice, T>;
+template <>
+void Relu<GPUDevice, qint8>::operator()(
+ const GPUDevice& d, typename TTypes<qint8>::ConstTensor features,
+ typename TTypes<qint8>::Tensor activations);
+extern template struct Relu<GPUDevice, qint8>;
+
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
} // namespace functor
@@ -157,6 +163,27 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+template <typename Device>
+class ReluOp<Device, qint8>
+ : public UnaryElementWiseOp<qint8, ReluOp<Device, qint8>> {
+ public:
+ using UnaryElementWiseOp<qint8, ReluOp<Device, qint8>>::UnaryElementWiseOp;
+
+ void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
+ auto flat_input = input.flat<qint8>();
+ OP_REQUIRES(context, (flat_input.size() % 4) == 0,
+ errors::InvalidArgument(
+ "Tensor size must be a multiple of 4 for Relu<qint8>. Got ",
+ flat_input.size()));
+ functor::Relu<Device, qint8> func;
+ func(context->eigen_device<Device>(), flat_input, output->flat<qint8>());
+ }
+};
+
+REGISTER_KERNEL_BUILDER(
+ Name("Relu").Device(DEVICE_GPU).TypeConstraint<qint8>("T"),
+ ReluOp<GPUDevice, qint8>);
+
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
diff --git a/tensorflow/core/kernels/relu_op.h b/tensorflow/core/kernels/relu_op.h
index e712b02bd7..4775deeb61 100644
--- a/tensorflow/core/kernels/relu_op.h
+++ b/tensorflow/core/kernels/relu_op.h
@@ -15,8 +15,8 @@ limitations under the License.
// See docs in ../ops/nn_ops.cc.
-#ifndef TENSORFLOW_KERNELS_RELU_OP_H_
-#define TENSORFLOW_KERNELS_RELU_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RELU_OP_H_
#define EIGEN_USE_THREADS
@@ -219,4 +219,4 @@ void SeluGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
#undef EIGEN_USE_THREADS
-#endif // TENSORFLOW_KERNELS_RELU_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_H_
diff --git a/tensorflow/core/kernels/relu_op_functor.h b/tensorflow/core/kernels/relu_op_functor.h
index 3bc5ba8a50..e564da335a 100644
--- a/tensorflow/core/kernels/relu_op_functor.h
+++ b/tensorflow/core/kernels/relu_op_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
// Functor definition for ReluOp and ReluGradOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -168,4 +168,4 @@ struct SeluGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RELU_OP_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_RELU_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/relu_op_gpu.cu.cc b/tensorflow/core/kernels/relu_op_gpu.cu.cc
index 089ca8ed27..b9391517c1 100644
--- a/tensorflow/core/kernels/relu_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/relu_op_gpu.cu.cc
@@ -103,7 +103,7 @@ struct ReluGrad<Device, Eigen::half> {
int32 count = gradient.size();
if (count == 0) return;
int32 half2_count = Eigen::divup(count, 2);
- const int32 kThreadInBlock = 512;
+ constexpr int32 kThreadInBlock = 512;
CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
half2_count, d, ReluGradHalfKernel, 0, kThreadInBlock);
ReluGradHalfKernel<<<config.block_count, config.thread_per_block, 0,
@@ -111,6 +111,37 @@ struct ReluGrad<Device, Eigen::half> {
backprop.data(), count);
}
};
+
+__global__ void Relu_int8x4_kernel(int vect_count, const int32* input,
+ int32* output) {
+ CUDA_1D_KERNEL_LOOP(index, vect_count) {
+ output[index] = __vmaxs4(input[index], 0);
+ }
+}
+
+// Functor used by ReluOp to do the computations.
+template <typename Device>
+struct Relu<Device, qint8> {
+ // Computes Relu activation of 'input' containing int8 elements, whose buffer
+ // size should be a multiple of 4, and aligned to an int32* boundary.
+ // (Alignment should be guaranteed by the GPU tensor allocator).
+ // 'output' should have the same size as 'input'.
+ void operator()(const Device& d, typename TTypes<qint8>::ConstTensor input,
+ typename TTypes<qint8>::Tensor output) {
+ int32 count = input.size();
+ if (count == 0) return;
+
+ int32 vect_count = Eigen::divup(count, 4);
+ constexpr int32 kThreadInBlock = 512;
+ CudaLaunchConfig config = GetCudaLaunchConfigFixedBlockSize(
+ vect_count, d, Relu_int8x4_kernel, 0, kThreadInBlock);
+ Relu_int8x4_kernel<<<config.block_count, config.thread_per_block, 0,
+ d.stream()>>>(
+ vect_count, reinterpret_cast<const int32*>(input.data()),
+ reinterpret_cast<int32*>(output.data()));
+ }
+};
+
} // namespace functor
// Definition of the GPU implementations declared in relu_op.cc.
@@ -126,6 +157,8 @@ struct ReluGrad<Device, Eigen::half> {
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
+template struct functor::Relu<GPUDevice, qint8>;
+
} // end namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index 194a711d98..26f107f940 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -47,7 +47,7 @@ std::unordered_set<string> BuildNodeSetFromNodeNamesAndPorts(
std::unordered_set<string> retval;
for (const string& node_name_and_port : node_names_and_ports) {
const TensorId tid = ParseTensorName(node_name_and_port);
- retval.emplace(std::string(tid.first));
+ retval.emplace(tid.first);
}
return retval;
}
@@ -64,7 +64,7 @@ Node* FindMutableNodeByName(const string& name, Graph* graph) {
const NodeDef* FindNodeDefByName(const string& input,
const GraphDef& graph_def) {
const TensorId tid = ParseTensorName(input);
- const string name = std::string(tid.first);
+ const string name = string(tid.first);
for (const NodeDef& node_def : graph_def.node()) {
if (node_def.name() == name) {
return &node_def;
@@ -423,7 +423,7 @@ RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap(
std::vector<DataType> data_types;
std::vector<TensorShape> shapes;
const TensorId tid = ParseTensorName(name_and_port);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
const int port = tid.second;
const NodeDef* node_def = FindNodeDefByName(node_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -522,8 +522,7 @@ RemoteFusedGraphExecuteUtils::GetTensorShapeType(
const TensorShapeMap& tensor_shape_map, const string& node_name) {
if (node_name.find(':') != string::npos) {
const TensorId tid = ParseTensorName(node_name);
- return GetTensorShapeType(tensor_shape_map, std::string(tid.first),
- tid.second);
+ return GetTensorShapeType(tensor_shape_map, string(tid.first), tid.second);
} else {
return GetTensorShapeType(tensor_shape_map, node_name, 0);
}
@@ -570,7 +569,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
const TensorId tid = ParseTensorName(name);
CHECK_EQ(tensor_shape_map->count(name), 0);
tensor_shape_map->emplace(
- std::string(tid.first),
+ string(tid.first),
std::make_pair(tid.second,
std::make_pair(tensor.dtype(), tensor.shape())));
}
@@ -692,7 +691,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::vector<NodeBuilder::NodeOut> node_out_list;
for (const string& input : inputs) {
const TensorId tid = ParseTensorName(input);
- Node* node = FindMutableNodeByName(std::string(tid.first), graph);
+ Node* node = FindMutableNodeByName(string(tid.first), graph);
CHECK_NOTNULL(node);
node_out_list.emplace_back(node, tid.second);
}
@@ -848,7 +847,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (const string& subgraph_input : std::get<1>(cluster)) {
const TensorId tid = ParseTensorName(subgraph_input);
- const string subgraph_input_name = std::string(tid.first);
+ const string subgraph_input_name(tid.first);
const int subgraph_input_port = tid.second;
const NodeDef* node_def = FindNodeDefByName(subgraph_input_name, graph_def);
CHECK_NOTNULL(node_def);
@@ -895,7 +894,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
std::deque<const Node*> queue;
for (const string& output : border_outputs) {
const TensorId tid = ParseTensorName(output);
- const string& output_node_name = std::string(tid.first);
+ const string output_node_name(tid.first);
for (const Node* node : graph.nodes()) {
if (output_node_name == node->name()) {
queue.push_back(node);
@@ -975,7 +974,7 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (int j = 0; j < border_outputs.size(); ++j) {
const string& output = border_outputs.at(j);
const TensorId tid = ParseTensorName(output);
- const string output_name = std::string(tid.first);
+ const string output_name(tid.first);
Node* src_node = edge->src();
if (src_node != nullptr && src_node->name() == output_name &&
edge->src_output() == tid.second) {
@@ -995,12 +994,11 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
// RemoteFusedGraphExecuteOpNode
for (const string& output : outputs) {
const TensorId output_tid = ParseTensorName(output);
- const string output_name = std::string(output_tid.first);
+ const string output_name(output_tid.first);
for (size_t i = 0; i < border_outputs.size(); ++i) {
const TensorId subgraph_output_tid =
ParseTensorName(border_outputs.at(i));
- const string& subgraph_output_name =
- std::string(subgraph_output_tid.first);
+ const string subgraph_output_name(subgraph_output_tid.first);
if (output_name == subgraph_output_name) {
LOG(INFO) << "As graph output and subgraph output are same, "
<< "the graph output node is replaced by identity node";
@@ -1435,7 +1433,7 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
GraphDef* graph_def) {
const TensorId tid = ParseTensorName(input);
CHECK_EQ(0, tid.second);
- const string node_name = std::string(tid.first);
+ const string node_name(tid.first);
for (NodeDef& node : *graph_def->mutable_node()) {
if (node.name() != node_name) {
continue;
diff --git a/tensorflow/core/kernels/reshape_op.h b/tensorflow/core/kernels/reshape_op.h
index 5db2d148b9..7458ac75ca 100644
--- a/tensorflow/core/kernels/reshape_op.h
+++ b/tensorflow/core/kernels/reshape_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_RESHAPE_OP_H_
-#define TENSORFLOW_KERNELS_RESHAPE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
@@ -121,4 +121,4 @@ class ReshapeOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_RESHAPE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_RESHAPE_OP_H_
diff --git a/tensorflow/core/kernels/resize_bilinear_op.cc b/tensorflow/core/kernels/resize_bilinear_op.cc
index dde59e8e74..f10c9a19a7 100644
--- a/tensorflow/core/kernels/resize_bilinear_op.cc
+++ b/tensorflow/core/kernels/resize_bilinear_op.cc
@@ -277,13 +277,13 @@ struct ResizeBilinearGrad<CPUDevice, T> {
typename TTypes<float, 4>::ConstTensor input_grad,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output_grad) {
- const int batch = output_grad.dimension(0);
- const int64 original_height = output_grad.dimension(1);
- const int64 original_width = output_grad.dimension(2);
- const int channels = output_grad.dimension(3);
+ const Eigen::Index batch = output_grad.dimension(0);
+ const Eigen::Index original_height = output_grad.dimension(1);
+ const Eigen::Index original_width = output_grad.dimension(2);
+ const Eigen::Index channels = output_grad.dimension(3);
- const int64 resized_height = input_grad.dimension(1);
- const int64 resized_width = input_grad.dimension(2);
+ const Eigen::Index resized_height = input_grad.dimension(1);
+ const Eigen::Index resized_width = input_grad.dimension(2);
output_grad.setZero();
@@ -294,22 +294,24 @@ struct ResizeBilinearGrad<CPUDevice, T> {
// + top_right * (1 - y) * x
// + bottom_left * y * (1 - x)
// + bottom_right * y * x
- for (int64 b = 0; b < batch; ++b) {
- for (int64 y = 0; y < resized_height; ++y) {
+ for (Eigen::Index b = 0; b < batch; ++b) {
+ for (Eigen::Index y = 0; y < resized_height; ++y) {
const float in_y = y * height_scale;
- const int64 top_y_index = static_cast<int64>(floorf(in_y));
- const int64 bottom_y_index =
- std::min(static_cast<int64>(ceilf(in_y)), original_height - 1);
+ const Eigen::Index top_y_index =
+ static_cast<Eigen::Index>(floorf(in_y));
+ const Eigen::Index bottom_y_index = std::min(
+ static_cast<Eigen::Index>(ceilf(in_y)), original_height - 1);
const float y_lerp = in_y - top_y_index;
const float inverse_y_lerp = (1.0f - y_lerp);
- for (int64 x = 0; x < resized_width; ++x) {
+ for (Eigen::Index x = 0; x < resized_width; ++x) {
const float in_x = x * width_scale;
- const int64 left_x_index = static_cast<int64>(floorf(in_x));
- const int64 right_x_index =
- std::min(static_cast<int64>(ceilf(in_x)), original_width - 1);
+ const Eigen::Index left_x_index =
+ static_cast<Eigen::Index>(floorf(in_x));
+ const Eigen::Index right_x_index = std::min(
+ static_cast<Eigen::Index>(ceilf(in_x)), original_width - 1);
const float x_lerp = in_x - left_x_index;
const float inverse_x_lerp = (1.0f - x_lerp);
- for (int64 c = 0; c < channels; ++c) {
+ for (Eigen::Index c = 0; c < channels; ++c) {
output_grad(b, top_y_index, left_x_index, c) +=
T(input_grad(b, y, x, c) * inverse_y_lerp * inverse_x_lerp);
output_grad(b, top_y_index, right_x_index, c) +=
diff --git a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
index 8ec526c2b2..e985d3e5a5 100644
--- a/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
+++ b/tensorflow/core/kernels/resize_nearest_neighbor_op.cc
@@ -88,25 +88,27 @@ struct ResizeNearestNeighbor<CPUDevice, T, align_corners> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output) {
- const int batch_size = input.dimension(0);
- const int64 in_height = input.dimension(1);
- const int64 in_width = input.dimension(2);
- const int channels = input.dimension(3);
-
- const int64 out_height = output.dimension(1);
- const int64 out_width = output.dimension(2);
-
- for (int b = 0; b < batch_size; ++b) {
- for (int y = 0; y < out_height; ++y) {
- const int64 in_y = std::min(
- (align_corners) ? static_cast<int64>(roundf(y * height_scale))
- : static_cast<int64>(floorf(y * height_scale)),
- in_height - 1);
- for (int x = 0; x < out_width; ++x) {
- const int64 in_x = std::min(
- (align_corners) ? static_cast<int64>(roundf(x * width_scale))
- : static_cast<int64>(floorf(x * width_scale)),
- in_width - 1);
+ const Eigen::Index batch_size = input.dimension(0);
+ const Eigen::Index in_height = input.dimension(1);
+ const Eigen::Index in_width = input.dimension(2);
+ const Eigen::Index channels = input.dimension(3);
+
+ const Eigen::Index out_height = output.dimension(1);
+ const Eigen::Index out_width = output.dimension(2);
+
+ for (Eigen::Index b = 0; b < batch_size; ++b) {
+ for (Eigen::Index y = 0; y < out_height; ++y) {
+ const Eigen::Index in_y =
+ std::min((align_corners)
+ ? static_cast<Eigen::Index>(roundf(y * height_scale))
+ : static_cast<Eigen::Index>(floorf(y * height_scale)),
+ in_height - 1);
+ for (Eigen::Index x = 0; x < out_width; ++x) {
+ const Eigen::Index in_x =
+ std::min((align_corners)
+ ? static_cast<Eigen::Index>(roundf(x * width_scale))
+ : static_cast<Eigen::Index>(floorf(x * width_scale)),
+ in_width - 1);
std::copy_n(&input(b, in_y, in_x, 0), channels, &output(b, y, x, 0));
}
}
@@ -199,28 +201,29 @@ struct ResizeNearestNeighborGrad<CPUDevice, T, align_corners> {
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor input,
const float height_scale, const float width_scale,
typename TTypes<T, 4>::Tensor output) {
- const int batch_size = input.dimension(0);
- const int64 in_height = input.dimension(1);
- const int64 in_width = input.dimension(2);
- const int channels = input.dimension(3);
+ const Eigen::Index batch_size = input.dimension(0);
+ const Eigen::Index in_height = input.dimension(1);
+ const Eigen::Index in_width = input.dimension(2);
+ const Eigen::Index channels = input.dimension(3);
- const int64 out_height = output.dimension(1);
- const int64 out_width = output.dimension(2);
+ const Eigen::Index out_height = output.dimension(1);
+ const Eigen::Index out_width = output.dimension(2);
output.setZero();
- for (int y = 0; y < in_height; ++y) {
- const int64 out_y = std::min(
- (align_corners) ? static_cast<int64>(roundf(y * height_scale))
- : static_cast<int64>(floorf(y * height_scale)),
+ for (Eigen::Index y = 0; y < in_height; ++y) {
+ const Eigen::Index out_y = std::min(
+ (align_corners) ? static_cast<Eigen::Index>(roundf(y * height_scale))
+ : static_cast<Eigen::Index>(floorf(y * height_scale)),
out_height - 1);
- for (int x = 0; x < in_width; ++x) {
- const int64 out_x = std::min(
- (align_corners) ? static_cast<int64>(roundf(x * width_scale))
- : static_cast<int64>(floorf(x * width_scale)),
- out_width - 1);
- for (int b = 0; b < batch_size; ++b) {
- for (int c = 0; c < channels; ++c) {
+ for (Eigen::Index x = 0; x < in_width; ++x) {
+ const Eigen::Index out_x =
+ std::min((align_corners)
+ ? static_cast<Eigen::Index>(roundf(x * width_scale))
+ : static_cast<Eigen::Index>(floorf(x * width_scale)),
+ out_width - 1);
+ for (Eigen::Index b = 0; b < batch_size; ++b) {
+ for (Eigen::Index c = 0; c < channels; ++c) {
output(b, out_y, out_x, c) += input(b, y, x, c);
}
}
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 115a8eb251..678d675c4a 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -51,7 +51,9 @@ limitations under the License.
#define EIGEN_USE_GPU
#endif
-#include "tensorflow/core/kernels/resource_variable_ops.h"
+#include <memory>
+#include <vector>
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
@@ -60,10 +62,12 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/gather_functor.h"
+#include "tensorflow/core/kernels/resource_variable_ops.h"
#include "tensorflow/core/kernels/scatter_functor.h"
#include "tensorflow/core/kernels/training_op_helpers.h"
#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -72,6 +76,8 @@ limitations under the License.
namespace tensorflow {
REGISTER_RESOURCE_HANDLE_KERNEL(Var);
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp").Device(DEVICE_CPU),
+ ResourceHandlesOp<Var>);
ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_));
@@ -79,7 +85,7 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {
void ReadVariableOp::Compute(OpKernelContext* ctx) {
Var* variable = nullptr;
- ResourceHandle handle = HandleFromInput(ctx, 0);
+ const ResourceHandle& handle = HandleFromInput(ctx, 0);
const auto status = LookupResource(ctx, handle, &variable);
OP_REQUIRES(ctx, status.ok(),
errors::FailedPrecondition(
@@ -101,13 +107,58 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) {
ctx->set_output(0, t);
}
+ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) {
+ int n;
+ OP_REQUIRES_OK(c, c->GetAttr("N", &n));
+ OP_REQUIRES_OK(c, c->GetAttr("dtypes", &dtypes_));
+ OP_REQUIRES(c, n == dtypes_.size(),
+ errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp (", n,
+ " vs. ", dtypes_.size(), ")"));
+}
+
+void ReadVariablesOp::Compute(OpKernelContext* ctx) {
+ std::vector<std::unique_ptr<Var, core::RefCountDeleter>> variables(
+ dtypes_.size());
+ std::vector<const ResourceHandle*> handles(dtypes_.size());
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ handles[i] = &HandleFromInput(ctx, i);
+ }
+ const auto status = LookupResources(ctx, handles, &variables);
+ OP_REQUIRES(ctx, status.ok(),
+ errors::FailedPrecondition(
+ "Error while reading resource variable. This could mean that "
+ "the variable was uninitialized. ",
+ status.ToString()));
+
+ for (size_t i = 0; i < dtypes_.size(); ++i) {
+ // We're acquiring a reference to the underlying buffer while
+ // holding a shared lock to guarantee ordering of reads and
+ // writes.
+ tf_shared_lock ml(*variables[i]->mu());
+ const Tensor& t = *variables[i]->tensor();
+ OP_REQUIRES(ctx, dtypes_[i] == t.dtype(),
+ errors::InvalidArgument(
+ "Trying to read variable ", handles[i]->name(),
+ " from Container: ", handles[i]->container(),
+ " with wrong dtype. Expected ", DataTypeString(dtypes_[i]),
+ " got ", DataTypeString(t.dtype())));
+ ctx->set_output(i, t);
+ }
+}
+
REGISTER_KERNEL_BUILDER(Name("ReadVariableOp").Device(DEVICE_CPU),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(Name("_ReadVariablesOp").Device(DEVICE_CPU),
+ ReadVariablesOp);
#if GOOGLE_CUDA
REGISTER_KERNEL_BUILDER(
Name("ReadVariableOp").Device(DEVICE_GPU).HostMemory("resource"),
ReadVariableOp);
+REGISTER_KERNEL_BUILDER(
+ Name("_ReadVariablesOp").Device(DEVICE_GPU).HostMemory("resources"),
+ ReadVariablesOp);
#define REGISTER_GPU_KERNELS(type) \
namespace functor { \
@@ -122,11 +173,20 @@ REGISTER_KERNEL_BUILDER(
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
ResourceHandleOp<Var>)
-
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_variant(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
+ .Device(DEVICE_GPU)
+ .HostMemory("resources")
+ .TypeConstraint("dtypes",
+ {DT_INT64, DT_COMPLEX64,
+ DT_COMPLEX128, DT_HALF, DT_FLOAT,
+ DT_DOUBLE, DT_BOOL, DT_VARIANT}),
+ ResourceHandlesOp<Var>);
+
#endif // GOOGLE_CUDA
template <typename T>
@@ -232,12 +292,12 @@ class AssignVariableOp : public OpKernel {
return Status::OK();
}));
core::ScopedUnref s(variable);
+ mutex_lock ml(*variable->mu());
OP_REQUIRES(context, variable->tensor()->dtype() == dtype_,
errors::InvalidArgument(
"Trying to assign variable with wrong dtype. Expected ",
DataTypeString(variable->tensor()->dtype()), " got ",
DataTypeString(dtype_)));
- mutex_lock ml(*variable->mu());
variable->is_initialized = true;
*variable->tensor() = value;
}
@@ -268,11 +328,6 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
return Status::OK();
}));
core::ScopedUnref s(variable);
- OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
- errors::InvalidArgument(
- "Trying to assign variable with wrong dtype. Expected ",
- DataTypeString(variable->tensor()->dtype()), " got ",
- DataTypeString(DT_VARIANT)));
// For purposes of forwarding DT_VARIANT, we want the least
// restrictive attr; we already know the input is on host.
@@ -293,6 +348,11 @@ class AssignVariableOp<Device, Variant> : public OpKernel {
attr);
mutex_lock ml(*variable->mu());
+ OP_REQUIRES(context, variable->tensor()->dtype() == DT_VARIANT,
+ errors::InvalidArgument(
+ "Trying to assign variable with wrong dtype. Expected ",
+ DataTypeString(variable->tensor()->dtype()), " got ",
+ DataTypeString(DT_VARIANT)));
variable->is_initialized = true;
*variable->tensor() = Tensor(DT_VARIANT, value.shape());
@@ -366,6 +426,12 @@ class AssignUpdateVariableOp : public OpKernel {
// ADD if value's refcount was 1.
mutex_lock ml(*variable->mu());
Tensor* var_tensor = variable->tensor();
+ OP_REQUIRES(context, var_tensor->shape().IsSameSize(value.shape()),
+ errors::InvalidArgument("Cannot update variable with shape ",
+ var_tensor->shape().DebugString(),
+ " using a Tensor with shape ",
+ value.shape().DebugString(),
+ ", shapes must be equal."));
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, var_tensor));
functor::DenseUpdate<Device, T, Op> update_functor;
diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h
index 9b60106f13..cffb732c38 100644
--- a/tensorflow/core/kernels/resource_variable_ops.h
+++ b/tensorflow/core/kernels/resource_variable_ops.h
@@ -28,6 +28,16 @@ class ReadVariableOp : public OpKernel {
DataType dtype_;
};
+class ReadVariablesOp : public OpKernel {
+ public:
+ explicit ReadVariablesOp(OpKernelConstruction* c);
+ void Compute(OpKernelContext* ctx) override;
+ bool IsExpensive() override { return false; }
+
+ private:
+ DataTypeVector dtypes_;
+};
+
class DestroyResourceOp : public OpKernel {
public:
explicit DestroyResourceOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/kernels/reverse_op.h b/tensorflow/core/kernels/reverse_op.h
index 934f0277a9..44e7967c5d 100644
--- a/tensorflow/core/kernels/reverse_op.h
+++ b/tensorflow/core/kernels/reverse_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_REVERSE_OP_H_
-#define TENSORFLOW_KERNELS_REVERSE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -45,4 +45,4 @@ struct Reverse<Device, T, 0> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_MIRROR_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_REVERSE_OP_H_
diff --git a/tensorflow/core/kernels/reverse_sequence_op.cc b/tensorflow/core/kernels/reverse_sequence_op.cc
index 15a707a9c6..cded417986 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.cc
+++ b/tensorflow/core/kernels/reverse_sequence_op.cc
@@ -64,7 +64,7 @@ void CheckErrors(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
for (size_t d = 0; d < seq_lens_vec.size(); ++d) {
OP_REQUIRES(context, seq_lens_vec[d] >= 0,
@@ -91,7 +91,7 @@ void CheckErrorsGPU(OpKernelContext* context, int batch_dim, int seq_dim) {
OP_REQUIRES(context, seq_lens.NumElements() == input.dim_size(batch_dim),
errors::InvalidArgument("len(seq_lens) != input.dims(", batch_dim,
"), ", "(", seq_lens.NumElements(),
- " vs. ", input.dim_size(batch_dim)));
+ " vs. ", input.dim_size(batch_dim), ")"));
}
template <>
@@ -127,6 +127,7 @@ class ReverseSequenceOp : public OpKernel {
auto seq_lens_t = seq_lens.vec<Tlen>();
CheckErrors<Device, Tlen>(context, batch_dim_, seq_dim_);
+ if (!context->status().ok()) return;
const int input_dims = input.dims();
diff --git a/tensorflow/core/kernels/reverse_sequence_op.h b/tensorflow/core/kernels/reverse_sequence_op.h
index 8ccd32ea16..d6ba2781a9 100644
--- a/tensorflow/core/kernels/reverse_sequence_op.h
+++ b/tensorflow/core/kernels/reverse_sequence_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
-#define TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_
// Generator definition for ReverseSequenceOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -75,4 +75,4 @@ struct ReverseSequence {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_REVERSE_SEQUENCE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_REVERSE_SEQUENCE_OP_H_
diff --git a/tensorflow/core/kernels/save_restore_tensor.cc b/tensorflow/core/kernels/save_restore_tensor.cc
index e335e38bdc..82546d581a 100644
--- a/tensorflow/core/kernels/save_restore_tensor.cc
+++ b/tensorflow/core/kernels/save_restore_tensor.cc
@@ -161,9 +161,12 @@ void RestoreTensor(OpKernelContext* context,
// If we cannot find a cached reader we will allocate our own.
std::unique_ptr<checkpoint::TensorSliceReader> allocated_reader;
- const checkpoint::TensorSliceReader* reader =
- context->slice_reader_cache()->GetReader(file_pattern, open_func,
- preferred_shard);
+ const checkpoint::TensorSliceReader* reader = nullptr;
+
+ if (context->slice_reader_cache()) {
+ reader = context->slice_reader_cache()->GetReader(file_pattern, open_func,
+ preferred_shard);
+ }
if (!reader) {
allocated_reader.reset(new checkpoint::TensorSliceReader(
file_pattern, open_func, preferred_shard));
diff --git a/tensorflow/core/kernels/save_restore_tensor.h b/tensorflow/core/kernels/save_restore_tensor.h
index 5b74b586e8..be7f4b889e 100644
--- a/tensorflow/core/kernels/save_restore_tensor.h
+++ b/tensorflow/core/kernels/save_restore_tensor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
-#define TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
+#define TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
#include "tensorflow/core/util/tensor_slice_reader.h"
#include "tensorflow/core/util/tensor_slice_writer.h"
@@ -70,4 +70,4 @@ Status RestoreTensorsV2(OpKernelContext* context, const Tensor& prefix,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SAVE_RESTORE_TENSOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SAVE_RESTORE_TENSOR_H_
diff --git a/tensorflow/core/kernels/save_restore_v2_ops.cc b/tensorflow/core/kernels/save_restore_v2_ops.cc
index ab4de6c815..180eb3ca34 100644
--- a/tensorflow/core/kernels/save_restore_v2_ops.cc
+++ b/tensorflow/core/kernels/save_restore_v2_ops.cc
@@ -220,9 +220,9 @@ class MergeV2Checkpoints : public OpKernel {
context, tensorflow::MergeBundles(env, input_prefixes, merged_prefix));
if (delete_old_dirs_) {
- const string& merged_dir = std::string(io::Dirname(merged_prefix));
+ const string merged_dir(io::Dirname(merged_prefix));
for (const string& input_prefix : input_prefixes) {
- const string& dirname = std::string(io::Dirname(input_prefix));
+ const string dirname(io::Dirname(input_prefix));
if (dirname == merged_dir) continue;
Status status = env->DeleteDir(dirname);
// For sharded save, only the first delete will go through and all
diff --git a/tensorflow/core/kernels/scan_ops.h b/tensorflow/core/kernels/scan_ops.h
index 1a1f71d722..13831bb377 100644
--- a/tensorflow/core/kernels/scan_ops.h
+++ b/tensorflow/core/kernels/scan_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SCAN_OPS_H_
-#define TENSORFLOW_KERNELS_SCAN_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -43,4 +43,4 @@ struct Scan {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SCAN_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCAN_OPS_H_
diff --git a/tensorflow/core/kernels/scatter_functor.h b/tensorflow/core/kernels/scatter_functor.h
index ebaa2bd9c6..2d43bde23f 100644
--- a/tensorflow/core/kernels/scatter_functor.h
+++ b/tensorflow/core/kernels/scatter_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
#include <type_traits>
@@ -488,4 +488,4 @@ struct ScatterScalarFunctorSYCL {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SCATTER_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/scatter_functor_gpu.cu.h b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
index 70809e4dcf..057755a05c 100644
--- a/tensorflow/core/kernels/scatter_functor_gpu.cu.h
+++ b/tensorflow/core/kernels/scatter_functor_gpu.cu.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
-#define TENSORFLOW_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
#if GOOGLE_CUDA
@@ -161,4 +161,4 @@ struct ScatterScalarFunctor<GPUDevice, T, Index, op> {
#endif // GOOGLE_CUDA
-#endif // TENSORFLOW_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
+#endif // TENSORFLOW_CORE_KERNELS_SCATTER_FUNCTOR_GPU_CU_H_
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e0194605ce..2f8aede427 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -145,6 +145,7 @@ class ScatterNdUpdateOp : public OpKernel {
if (dtype_ == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock m(*v->mu());
DoCompute(c);
} else if (use_exclusive_lock_) {
diff --git a/tensorflow/core/kernels/sdca_internal.cc b/tensorflow/core/kernels/sdca_internal.cc
index 1c071d3d41..a8e9b3261c 100644
--- a/tensorflow/core/kernels/sdca_internal.cc
+++ b/tensorflow/core/kernels/sdca_internal.cc
@@ -251,7 +251,7 @@ Status Examples::SampleAdaptiveProbabilities(
num_weight_vectors);
const double kappa = example_state_data(example_id, 0) +
loss_updater->PrimalLossDerivative(
- example_statistics.wx[0], label, example_weight);
+ example_statistics.wx[0], label, 1.0);
probabilities_[example_id] = example_weight *
sqrt(examples_[example_id].squared_norm_ +
regularization.symmetric_l2() *
diff --git a/tensorflow/core/kernels/sdca_ops.cc b/tensorflow/core/kernels/sdca_ops.cc
index 05c835ebc4..3bd4168dc7 100644
--- a/tensorflow/core/kernels/sdca_ops.cc
+++ b/tensorflow/core/kernels/sdca_ops.cc
@@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hinge-loss.h"
#include "tensorflow/core/kernels/logistic-loss.h"
#include "tensorflow/core/kernels/loss.h"
+#include "tensorflow/core/kernels/poisson-loss.h"
#include "tensorflow/core/kernels/sdca_internal.h"
#include "tensorflow/core/kernels/smooth-hinge-loss.h"
#include "tensorflow/core/kernels/squared-loss.h"
@@ -75,6 +76,8 @@ struct ComputeOptions {
loss_updater.reset(new HingeLossUpdater);
} else if (loss_type == "smooth_hinge_loss") {
loss_updater.reset(new SmoothHingeLossUpdater);
+ } else if (loss_type == "poisson_loss") {
+ loss_updater.reset(new PoissonLossUpdater);
} else {
OP_REQUIRES(
context, false,
diff --git a/tensorflow/core/kernels/searchsorted_op.cc b/tensorflow/core/kernels/searchsorted_op.cc
new file mode 100644
index 0000000000..dc627ac77a
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.cc
@@ -0,0 +1,249 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+typedef Eigen::ThreadPoolDevice CPUDevice;
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::upper_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<CPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ // TODO(eriche): If anyone ever needs this to be faster, we can multithread.
+ for (int b = 0; b < batch_size; ++b) {
+ const T* sorted_inputs_ptr = sorted_inputs.data() + b * num_inputs;
+ OutType* output_ptr = output->data() + b * num_values;
+ for (int i = 0; i < num_values; ++i) {
+ output_ptr[i] =
+ std::lower_bound(sorted_inputs_ptr, sorted_inputs_ptr + num_inputs,
+ values(i + b * num_values)) -
+ sorted_inputs_ptr;
+ }
+ }
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+template <typename Device, typename T, typename OutType>
+class UpperBoundOp : public OpKernel {
+ public:
+ explicit UpperBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::UpperBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+template <typename Device, typename T, typename OutType>
+class LowerBoundOp : public OpKernel {
+ public:
+ explicit LowerBoundOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
+
+ void Compute(OpKernelContext* ctx) override {
+ const Tensor& sorted_inputs_t = ctx->input(0);
+ const Tensor& values_t = ctx->input(1);
+
+ // must have same batch dim_size for both
+ OP_REQUIRES(ctx, sorted_inputs_t.dim_size(0) == values_t.dim_size(0),
+ Status(error::INVALID_ARGUMENT,
+ "Leading dim_size of both tensors must match."));
+
+ // this is required because we do indexing in int32 on the GPU
+ OP_REQUIRES(ctx, values_t.NumElements() < std::numeric_limits<int>::max(),
+ Status(error::INVALID_ARGUMENT,
+ "values tensor size must less than INT_MAX"));
+
+ Tensor* output_t;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, values_t.shape(), &output_t));
+
+ if (output_t->dtype() == DT_INT32) {
+ OP_REQUIRES(ctx,
+ FastBoundsCheck(sorted_inputs_t.dim_size(1),
+ std::numeric_limits<int>::max()),
+ errors::InvalidArgument("trailing dim_size must less than "
+ "INT_MAX for int32 output type, was ",
+ sorted_inputs_t.dim_size(1)));
+ }
+
+ auto output = output_t->template flat<OutType>();
+ const auto sorted_inputs = sorted_inputs_t.template flat<T>();
+ const auto values = values_t.template flat<T>();
+ OP_REQUIRES_OK(
+ ctx, functor::LowerBoundFunctor<Device, T, OutType>::Compute(
+ ctx, sorted_inputs, values, sorted_inputs_t.dim_size(0),
+ sorted_inputs_t.dim_size(1), values_t.dim_size(1), &output));
+ }
+};
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("UpperBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ UpperBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<CPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#if GOOGLE_CUDA
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int32>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#define REGISTER_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("LowerBound") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int64>("out_type"), \
+ LowerBoundOp<GPUDevice, type, int64>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+#undef REGISTER_KERNELS
+
+#endif // GOOGLE_CUDA
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/searchsorted_op.h b/tensorflow/core/kernels/searchsorted_op.h
new file mode 100644
index 0000000000..f075bf0fa2
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op.h
@@ -0,0 +1,52 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+namespace functor {
+
+template <typename Device, typename T, typename OutType>
+struct UpperBoundFunctor {
+ // Searches for values in sorted_inputs and returns the greatest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+
+template <typename Device, typename T, typename OutType>
+struct LowerBoundFunctor {
+ // Searches for values in sorted_inputs and returns the lowest possible
+ // index where they maintain sorted order.
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output);
+};
+} // namespace functor
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_KERNELS_SEARCHSORTED_OP_H_
diff --git a/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
new file mode 100644
index 0000000000..263b5bf298
--- /dev/null
+++ b/tensorflow/core/kernels/searchsorted_op_gpu.cu.cc
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#if GOOGLE_CUDA
+
+#define EIGEN_USE_GPU
+
+#include "tensorflow/core/kernels/searchsorted_op.h"
+
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/cuda_kernel_helper.h"
+
+namespace tensorflow {
+typedef Eigen::GpuDevice GPUDevice;
+
+namespace {
+template <typename T, typename OutType>
+__global__ void UpperBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::upper_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+
+template <typename T, typename OutType>
+__global__ void LowerBoundKernel(const T* sorted_inputs, int batch_size,
+ int sorted_inputs_size, int values_size,
+ const T* values, OutType* outputs) {
+ CUDA_1D_KERNEL_LOOP(work_unit_id, values_size * batch_size) {
+ int bid = work_unit_id / values_size;
+ T value = values[work_unit_id];
+ outputs[work_unit_id] = cuda_helper::lower_bound<T, OutType>(
+ sorted_inputs + bid * sorted_inputs_size, sorted_inputs_size, value);
+ }
+}
+} // namespace
+
+namespace functor {
+template <typename T, typename OutType>
+struct UpperBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ UpperBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+
+template <typename T, typename OutType>
+struct LowerBoundFunctor<GPUDevice, T, OutType> {
+ static Status Compute(OpKernelContext* context,
+ const typename TTypes<T, 1>::ConstTensor& sorted_inputs,
+ const typename TTypes<T, 1>::ConstTensor& values,
+ int batch_size, int num_inputs, int num_values,
+ typename TTypes<OutType, 1>::Tensor* output) {
+ const cudaStream_t& stream = GetCudaStream(context);
+ CudaLaunchConfig config =
+ GetCudaLaunchConfig(values.size(), context->eigen_gpu_device());
+
+ LowerBoundKernel<T>
+ <<<config.block_count, config.thread_per_block, 0, stream>>>(
+ sorted_inputs.data(), batch_size, num_inputs, num_values,
+ values.data(), output->data());
+
+ return Status::OK();
+ }
+};
+} // namespace functor
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::UpperBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int32>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+
+#define REGISTER_GPU_SPEC(type) \
+ template struct functor::LowerBoundFunctor<GPUDevice, type, int64>;
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_GPU_SPEC);
+#undef REGISTER_GPU_SPEC
+} // namespace tensorflow
+
+#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h b/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
index 271dd2c485..b5274f8788 100644
--- a/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
+++ b/tensorflow/core/kernels/self_adjoint_eig_v2_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
#include "third_party/eigen3/Eigen/Core"
@@ -85,3 +88,5 @@ class SelfAdjointEigV2Op : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SELF_ADJOINT_EIG_V2_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/sendrecv_ops.h b/tensorflow/core/kernels/sendrecv_ops.h
index 1ff8eff13f..223854de13 100644
--- a/tensorflow/core/kernels/sendrecv_ops.h
+++ b/tensorflow/core/kernels/sendrecv_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SENDRECV_OPS_H_
-#define TENSORFLOW_KERNELS_SENDRECV_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/macros.h"
@@ -49,4 +49,4 @@ class RecvOp : public AsyncOpKernel {
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SENDRECV_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SENDRECV_OPS_H_
diff --git a/tensorflow/core/kernels/set_kernels.cc b/tensorflow/core/kernels/set_kernels.cc
index f893d4e945..0428909145 100644
--- a/tensorflow/core/kernels/set_kernels.cc
+++ b/tensorflow/core/kernels/set_kernels.cc
@@ -269,7 +269,7 @@ void SetSizeOp<T>::Compute(OpKernelContext* ctx) {
// Group by all but last dimension, create a set of group values, and add set
// size to output.
- VarDimArray group_ix(set_st.order(), 0, set_st.order().size() - 1);
+ VarDimArray group_ix = set_st.order().subspan(0, set_st.order().size() - 1);
std::set<T> group_set;
for (const auto& group : set_st.group(group_ix)) {
PopulateFromSparseGroup<T>(ctx, group, set_st.shape(), &group_set);
@@ -500,8 +500,8 @@ void SetOperationOp<T>::ComputeDenseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
std::vector<int64> group_indices;
int64 num_elements;
@@ -621,11 +621,11 @@ void SetOperationOp<T>::ComputeSparseToSparse(OpKernelContext* ctx) const {
std::set<T> set1_group_set;
std::set<T> set2_group_set;
- auto set1_grouper = set1_st.group(
- VarDimArray(set1_st.order(), 0, set1_st.order().size() - 1));
+ auto set1_grouper =
+ set1_st.group(set1_st.order().subspan(0, set1_st.order().size() - 1));
auto set1_group_it = set1_grouper.begin();
- auto set2_grouper = set2_st.group(
- VarDimArray(set2_st.order(), 0, set2_st.order().size() - 1));
+ auto set2_grouper =
+ set2_st.group(set2_st.order().subspan(0, set2_st.order().size() - 1));
auto set2_group_it = set2_grouper.begin();
// Group by rows, and iterate over rows of both sets in parallel, creating a
diff --git a/tensorflow/core/kernels/shape_op_test.cc b/tensorflow/core/kernels/shape_op_test.cc
index 9cd590ae61..30cb1e0a7f 100644
--- a/tensorflow/core/kernels/shape_op_test.cc
+++ b/tensorflow/core/kernels/shape_op_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/abi.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -60,8 +61,7 @@ Status GetShapeFromKnownVecSize(const KnownVecSize& ks, TensorShape* s) {
REGISTER_UNARY_VARIANT_DECODE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE");
-REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, "KNOWN VECTOR SIZE TYPE",
- GetShapeFromKnownVecSize);
+REGISTER_UNARY_VARIANT_SHAPE_FUNCTION(KnownVecSize, GetShapeFromKnownVecSize);
static void ExpectHasError(const Status& s, StringPiece substr) {
EXPECT_TRUE(str_util::StrContains(s.ToString(), substr))
@@ -94,9 +94,9 @@ TEST_F(ShapeOpTest, Simple) {
Status s = session.Run({{input, variant_tensor}}, {shape_output}, &outputs);
EXPECT_FALSE(s.ok());
ExpectHasError(
- s,
- "No unary variant shape function found for Variant type_name: "
- "NO KNOWN SHAPE");
+ s, strings::StrCat(
+ "No unary variant shape function found for Variant type_index: ",
+ port::MaybeAbiDemangle(MakeTypeIndex<NoKnownShape>().name())));
}
{
diff --git a/tensorflow/core/kernels/shape_ops.cc b/tensorflow/core/kernels/shape_ops.cc
index 28a39bae3f..ab1ce0f9c8 100644
--- a/tensorflow/core/kernels/shape_ops.cc
+++ b/tensorflow/core/kernels/shape_ops.cc
@@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#include "tensorflow/core/kernels/shape_ops.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/register_types.h"
namespace tensorflow {
@@ -460,4 +461,96 @@ REGISTER_KERNEL_BUILDER(Name("Squeeze")
SqueezeOp);
#endif // TENSORFLOW_USE_SYCL
+class EnsureShapeOp : public OpKernel {
+ public:
+ explicit EnsureShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &expected_shape_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx,
+ shape_op_helpers::GetRegularOrVariantShape(ctx, 0, &shape));
+
+ if (!expected_shape_.IsCompatibleWith(shape)) {
+ ctx->SetStatus(errors::InvalidArgument(
+ "Shape of tensor ", this->def().input(0), " ", shape.DebugString(),
+ " is not compatible with expected shape ",
+ expected_shape_.DebugString(), "."));
+ }
+
+ // If shape matches, outputs the tensor.
+ if (IsRefType(ctx->input_dtype(0))) {
+ ctx->forward_ref_input_to_ref_output(0, 0);
+ } else {
+ ctx->set_output(0, ctx->input(0));
+ }
+ }
+
+ bool IsExpensive() override { return false; }
+
+ private:
+ PartialTensorShape expected_shape_;
+};
+
+// NOTE(rachelim): The kernel registrations for EnsureShapeOp are identical to
+// those of the identity op, since the ops have the same device type
+// constraints.
+REGISTER_KERNEL_BUILDER(Name("EnsureShape").Device(DEVICE_CPU), EnsureShapeOp);
+
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EnsureShape").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL);
+
+#undef REGISTER_SYCL_KERNEL
+
+#define REGISTER_SYCL_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("EnsureShape") \
+ .Device(DEVICE_SYCL) \
+ .HostMemory("input") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+REGISTER_SYCL_HOST_KERNEL(int32);
+REGISTER_SYCL_HOST_KERNEL(bool);
+
+#undef REGISTER_SYCL_HOST_KERNEL
+
+#endif // TENSORFLOW_USE_SYCL
+
+#define REGISTER_GPU_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("EnsureShape").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_GPU_KERNEL);
+REGISTER_GPU_KERNEL(Variant);
+
+#undef REGISTER_GPU_KERNEL
+
+#if GOOGLE_CUDA
+// A special GPU kernel for int32 and bool.
+// TODO(b/25387198): Also enable int32 in device memory. This kernel
+// registration requires all int32 inputs and outputs to be in host memory.
+#define REGISTER_GPU_HOST_KERNEL(type) \
+ REGISTER_KERNEL_BUILDER(Name("EnsureShape") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("input") \
+ .HostMemory("output") \
+ .TypeConstraint<type>("T"), \
+ EnsureShapeOp)
+
+REGISTER_GPU_HOST_KERNEL(int32);
+REGISTER_GPU_HOST_KERNEL(bool);
+REGISTER_GPU_HOST_KERNEL(string);
+REGISTER_GPU_HOST_KERNEL(ResourceHandle);
+
+#undef REGISTER_GPU_HOST_KERNEL
+
+#endif
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/shape_ops.h b/tensorflow/core/kernels/shape_ops.h
index 55be308901..7a50f158af 100644
--- a/tensorflow/core/kernels/shape_ops.h
+++ b/tensorflow/core/kernels/shape_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SHAPE_OPS_H_
-#define TENSORFLOW_KERNELS_SHAPE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
#include <limits>
#include <unordered_set>
@@ -154,6 +154,9 @@ class ExpandDimsOp : public OpKernel {
OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
errors::InvalidArgument("ExpandDims on Variant not supported"));
+ OP_REQUIRES(
+ ctx, (ctx->input(1).NumElements() == 1),
+ errors::InvalidArgument("'dim' must be a tensor with a single value"));
Tdim dim = ctx->input(1).flat<Tdim>()(0);
OP_REQUIRES(
ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
@@ -236,9 +239,8 @@ class SqueezeOp : public OpKernel {
if (wrapped_squeeze_dims.count(i) > 0) {
OP_REQUIRES(ctx, existing_dim == 1,
errors::InvalidArgument(
- "Tried to explicitly squeeze "
- "dimension ",
- i, " but dimension was not 1: ", existing_dim));
+ "Can not squeeze dim[", i,
+ "], expected a dimension of 1, got ", existing_dim));
} else {
// This dimension is not being squeezed.
new_shape.push_back(existing_dim);
@@ -272,4 +274,4 @@ class SqueezeOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SHAPE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 77594479cb..a006c69297 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -228,191 +228,6 @@ class SliceOp : public OpKernel {
}
};
-#ifdef INTEL_MKL
-template <typename Device, typename T>
-class MklSliceOp : public OpKernel {
- public:
- explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- TensorShape output_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> size;
- Tensor* result = nullptr;
- bool done = false;
- SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
- &done);
- if (!context->status().ok() || done == true) return;
-
- const Tensor& input = context->input(0);
- const int input_dims = input.dims();
-
- if (output_shape.num_elements() > 0) {
- if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
- DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- auto input = context->input(0).tensor<T, 2>();
- auto output = result->tensor<T, 2>();
- // TODO(agarwal): Consider multi-threading this loop for cases where
- // size[0] is very large.
- for (int i = 0; i < size[0]; ++i) {
- const int64 row = begin[0] + i;
- if (i + 1 < size[0]) {
- port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
- }
- memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
- }
- return;
- }
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
- }
-
- HANDLE_DIM(1);
- HANDLE_DIM(2);
- HANDLE_DIM(3);
- HANDLE_DIM(4);
- HANDLE_DIM(5);
- HANDLE_DIM(6);
- HANDLE_DIM(7);
-
-#undef HANDLE_DIM
-
- OP_REQUIRES(
- context, false,
- errors::Unimplemented("SliceOp : Unhandled input dimensions"));
- }
- }
-
- private:
- // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
- // criteria matches for slice_dim: if indices for slice are 0 in all dims
- // except slice_dim and if sizes of all the dimensions of the slice are same
- // as the sizes of all the dimensions of the input except slice_dim, then
- // returns True. Otherwise, returns False.
- bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (dim != slice_dim &&
- (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
- return false;
- }
- }
- return true;
- }
-
- // Is 'input' tensor being sliced over a single dimension out of 4?
- //
- // This check is applicable in the context of Slice of a 4-D tensor in
- // NHWC or NCHW format over channel dimension.
- //
- // If indices for slice are 0 in all dims except one dimension and if sizes of
- // all dimensions of slice are same as sizes of all dimensions of inputs
- // except that dimension, then we are slicing over a single dimension.
- //
- // Returns True if Slicing over a single dimension, and sets slice_dim
- // to the number of the dimension that satisfies criteria.
- bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int* slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
- *slice_dim = dim;
- return true;
- }
- }
- return false;
- }
-
- template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size, Tensor* result) {
- int slice_dim = -1;
- TensorShape in_shape = context->input(0).shape();
- // Special case for handling 4-D tensor slice when shape of the slice
- // differs from the input tensor in only 1 out of 4 dimensions.
- // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
- // format over channel dimension.
- if (NDIM == 4 &&
- DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
- size_t in_strides[4] = {
- (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) *
- in_shape.dim_size(3),
- (size_t)in_shape.dim_size(2) * in_shape.dim_size(3),
- (size_t)in_shape.dim_size(3), (size_t)1};
-
- size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3],
- (size_t)size[2] * size[3], (size_t)size[3],
- (size_t)1};
-
- T* in_buf = const_cast<T*>(
- const_cast<const T*>(context->input(0).flat<T>().data()));
- T* op_buf = result->flat<T>().data();
-
- if (slice_dim == 1) {
- /* data format = NCHW */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
- // For NCHW, H and W will be contiguous. So we can copy
- // both with one memcpy.
- memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
- sizeof(T) * in_strides[1]);
- }
- }
- return;
- } else if (slice_dim == 3) {
- /* data_format = NHWC */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
-#pragma omp parallel for
- for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
- T* ip2 = ip1 + (d2 * in_strides[2]);
- T* ip3 = ip2 + begin[3];
- T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
- T* op3 = op2;
- memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
- sizeof(T) * size[3]);
- }
- }
- }
- return;
- }
- // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
- }
-
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
- }
-
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
- }
-};
-#endif
-
// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
@@ -440,7 +255,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#ifndef INTEL_MKL
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -452,19 +266,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
#undef REGISTER_SLICE
-#else
-#define REGISTER_SLICE(type) \
- REGISTER_KERNEL_BUILDER(Name("Slice") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .HostMemory("begin") \
- .HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
-
-TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
-TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-#undef REGISTER_SLICE
-#endif // INTEL_MKL
#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
diff --git a/tensorflow/core/kernels/slice_op.h b/tensorflow/core/kernels/slice_op.h
index db7eded745..1d662f6362 100644
--- a/tensorflow/core/kernels/slice_op.h
+++ b/tensorflow/core/kernels/slice_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SLICE_OP_H_
-#define TENSORFLOW_KERNELS_SLICE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SLICE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SLICE_OP_H_
// Functor definition for SliceOp, must be compilable by nvcc.
@@ -51,4 +51,4 @@ struct Slice {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SLICE_OP_H_
diff --git a/tensorflow/core/kernels/smooth-hinge-loss.h b/tensorflow/core/kernels/smooth-hinge-loss.h
index 5074ad0795..d51f5c130e 100644
--- a/tensorflow/core/kernels/smooth-hinge-loss.h
+++ b/tensorflow/core/kernels/smooth-hinge-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
-#define TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
#include <limits>
@@ -110,5 +110,5 @@ class SmoothHingeLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif
+#endif // TENSORFLOW_CORE_KERNELS_SMOOTH_HINGE_LOSS_H_
// TENSORFLOW_KERNELS_SMOOTH_HINGE_LOSS_H_
diff --git a/tensorflow/core/kernels/snapshot_op.h b/tensorflow/core/kernels/snapshot_op.h
index a18065d42b..02d492988e 100644
--- a/tensorflow/core/kernels/snapshot_op.h
+++ b/tensorflow/core/kernels/snapshot_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
-#define TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
@@ -41,4 +41,4 @@ struct Snapshot {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SNAPSHOT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SNAPSHOT_OP_H_
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
index d3a267ed87..c8bc1ad3bb 100644
--- a/tensorflow/core/kernels/softmax_op_functor.h
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_
// Functor definition for SoftmaxOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -98,4 +98,4 @@ struct SoftmaxEigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SOFTMAX_OP_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SOFTMAX_OP_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/softplus_op.cc b/tensorflow/core/kernels/softplus_op.cc
index 494a83ed14..d3fc0e1461 100644
--- a/tensorflow/core/kernels/softplus_op.cc
+++ b/tensorflow/core/kernels/softplus_op.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -35,9 +34,7 @@ template <typename Device, typename T>
class SoftplusOp : public UnaryElementWiseOp<T, SoftplusOp<Device, T>> {
public:
explicit SoftplusOp(OpKernelConstruction* context)
- : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : UnaryElementWiseOp<T, SoftplusOp<Device, T>>(context) {}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softplus<Device, T> functor;
@@ -51,9 +48,7 @@ class SoftplusGradOp
: public BinaryElementWiseOp<T, SoftplusGradOp<Device, T>> {
public:
explicit SoftplusGradOp(OpKernelConstruction* context)
- : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : BinaryElementWiseOp<T, SoftplusGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
@@ -89,7 +84,7 @@ void SoftplusGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
Name("SoftplusGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftplusGradOp<CPUDevice, type>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+TF_CALL_FLOAT_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/softplus_op.h b/tensorflow/core/kernels/softplus_op.h
index e17e175d41..8c083ba158 100644
--- a/tensorflow/core/kernels/softplus_op.h
+++ b/tensorflow/core/kernels/softplus_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
-#define TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_
// Functor definition for SoftplusOp and SoftplusGradOp, must be compilable by
// nvcc.
@@ -73,4 +73,4 @@ struct SoftplusGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SOFTPLUS_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SOFTPLUS_OP_H_
diff --git a/tensorflow/core/kernels/softsign_op.cc b/tensorflow/core/kernels/softsign_op.cc
index 00ee649b17..d691f15651 100644
--- a/tensorflow/core/kernels/softsign_op.cc
+++ b/tensorflow/core/kernels/softsign_op.cc
@@ -23,7 +23,6 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/kernels/warn_about_ints.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -35,9 +34,7 @@ template <typename Device, typename T>
class SoftsignOp : public UnaryElementWiseOp<T, SoftsignOp<Device, T>> {
public:
explicit SoftsignOp(OpKernelConstruction* context)
- : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : UnaryElementWiseOp<T, SoftsignOp<Device, T>>(context) {}
void Operate(OpKernelContext* context, const Tensor& input, Tensor* output) {
functor::Softsign<Device, T> functor;
@@ -51,9 +48,7 @@ class SoftsignGradOp
: public BinaryElementWiseOp<T, SoftsignGradOp<Device, T>> {
public:
explicit SoftsignGradOp(OpKernelConstruction* context)
- : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {
- WarnAboutInts(context);
- }
+ : BinaryElementWiseOp<T, SoftsignGradOp<Device, T>>(context) {}
void OperateNoTemplate(OpKernelContext* context, const Tensor& g,
const Tensor& a, Tensor* output);
@@ -90,7 +85,7 @@ void SoftsignGradOp<Device, T>::OperateNoTemplate(OpKernelContext* context,
Name("SoftsignGrad").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
SoftsignGradOp<CPUDevice, type>);
-TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNELS);
+TF_CALL_FLOAT_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/softsign_op.h b/tensorflow/core/kernels/softsign_op.h
index c2ababf697..61ff6eeede 100644
--- a/tensorflow/core/kernels/softsign_op.h
+++ b/tensorflow/core/kernels/softsign_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
-#define TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_
// Functor definition for SoftsignOp and SoftsignGradOp, must be compilable by
// nvcc.
@@ -57,4 +57,4 @@ struct SoftsignGrad {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SOFTSIGN_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SOFTSIGN_OP_H_
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator.h b/tensorflow/core/kernels/sparse_conditional_accumulator.h
index 2c1bffbee4..a4453bd7ab 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator.h
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
-#define TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
#include "tensorflow/core/kernels/typed_conditional_accumulator_base.h"
@@ -50,10 +50,10 @@ class SparseConditionalAccumulator
public:
SparseConditionalAccumulator(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
+ const string& name, const string& reduction_type)
: TypedConditionalAccumulatorBase<
std::tuple<const Tensor*, const Tensor*, const Tensor*>>(
- dtype, shape, name) {
+ dtype, shape, name, reduction_type) {
accum_idx_vec_ = nullptr;
count_element_ = nullptr;
accum_val_ = nullptr;
@@ -459,4 +459,4 @@ class SparseConditionalAccumulator
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_CONDITIONAL_ACCUMULATOR_H_
diff --git a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
index 80bc1f1934..1e542a26a7 100644
--- a/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
+++ b/tensorflow/core/kernels/sparse_conditional_accumulator_op.cc
@@ -34,8 +34,8 @@ class SparseConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
Creator GetCreator() const override {
return [this](ConditionalAccumulatorBase** ret) {
SparseConditionalAccumulator<Device, T>* accumulator =
- new SparseConditionalAccumulator<Device, T>(dtype_, shape_,
- cinfo_.name());
+ new SparseConditionalAccumulator<Device, T>(
+ dtype_, shape_, cinfo_.name(), reduction_type_);
*ret = accumulator;
return Status::OK();
};
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index e89280724e..6b9db8f471 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
-#define TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/platform/byte_order.h"
@@ -465,4 +465,4 @@ EIGEN_DEVICE_FUNC inline Packet16f pexpand_bf16_u(const Packet16f& from) {
#endif
} // namespace internal
} // namespace Eigen
-#endif
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/sparse_softmax_op.cc b/tensorflow/core/kernels/sparse_softmax_op.cc
index dc3119bba4..37664fe8df 100644
--- a/tensorflow/core/kernels/sparse_softmax_op.cc
+++ b/tensorflow/core/kernels/sparse_softmax_op.cc
@@ -90,7 +90,7 @@ class SparseSoftmaxOp : public OpKernel {
// { 0, ..., rank-1 }.
const ArraySlice<int64> kReorderDims(dims);
// All but the last dim -- the class dimension to be max-reduced along.
- const ArraySlice<int64> kGroupByDims(kReorderDims, 0, rank - 1);
+ const ArraySlice<int64> kGroupByDims = kReorderDims.subspan(0, rank - 1);
st.Reorder<T>(kReorderDims);
int count = 0;
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_add_op.h b/tensorflow/core/kernels/sparse_tensor_dense_add_op.h
index 353cf0e519..c26ed5e874 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_add_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_add_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
-#define TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -39,4 +39,4 @@ struct ScatterNdFunctor {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_ADD_OP_H_
diff --git a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
index da13190494..d6dd2deca5 100644
--- a/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
-#define TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -71,4 +71,4 @@ class MaybeAdjoint<MATRIX, true> {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
diff --git a/tensorflow/core/kernels/sparse_xent_op.h b/tensorflow/core/kernels/sparse_xent_op.h
index b5587aa9d7..6ba7931ab5 100644
--- a/tensorflow/core/kernels/sparse_xent_op.h
+++ b/tensorflow/core/kernels/sparse_xent_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
-#define TENSORFLOW_KERNELS_XENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
// Functor definition for SparseXentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -224,4 +224,4 @@ struct SparseXentEigenImpl {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_XENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPARSE_XENT_OP_H_
diff --git a/tensorflow/core/kernels/split_lib.h b/tensorflow/core/kernels/split_lib.h
index bc1fa28f8f..9d43a00822 100644
--- a/tensorflow/core/kernels/split_lib.h
+++ b/tensorflow/core/kernels/split_lib.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SPLIT_LIB_H_
-#define TENSORFLOW_KERNELS_SPLIT_LIB_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_
+#define TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_
// Functor definition for SplitOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -62,4 +62,4 @@ struct Split<Eigen::SyclDevice, T> {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SPLIT_LIB_H_
+#endif // TENSORFLOW_CORE_KERNELS_SPLIT_LIB_H_
diff --git a/tensorflow/core/kernels/split_lib_gpu.cu.cc b/tensorflow/core/kernels/split_lib_gpu.cu.cc
index 393818730b..a4a59dbcbc 100644
--- a/tensorflow/core/kernels/split_lib_gpu.cu.cc
+++ b/tensorflow/core/kernels/split_lib_gpu.cu.cc
@@ -54,6 +54,7 @@ void SplitCustom<Device, T>::operator()(
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
TF_CALL_bfloat16(DEFINE_GPU_KERNELS);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/split_op.cc b/tensorflow/core/kernels/split_op.cc
index 7cc3c532c9..11db72bfa3 100644
--- a/tensorflow/core/kernels/split_op.cc
+++ b/tensorflow/core/kernels/split_op.cc
@@ -49,7 +49,12 @@ class SplitOpBase : public OpKernel {
void ComputeEasyCases(OpKernelContext* context, bool* done) {
const Tensor& input = context->input(1);
const TensorShape& input_shape = input.shape();
- const int32 split_dim_orig = context->input(0).flat<int32>()(0);
+ const Tensor& split_dim_tensor = context->input(0);
+ OP_REQUIRES(
+ context, split_dim_tensor.shape().dims() == 0,
+ errors::InvalidArgument("split_dim must be a scalar but has rank ",
+ split_dim_tensor.shape().dims()));
+ const int32 split_dim_orig = split_dim_tensor.flat<int32>()(0);
const int32 split_dim =
split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig;
const int32 num_split = num_outputs();
diff --git a/tensorflow/core/kernels/squared-loss.h b/tensorflow/core/kernels/squared-loss.h
index 49e6db406e..d256a69350 100644
--- a/tensorflow/core/kernels/squared-loss.h
+++ b/tensorflow/core/kernels/squared-loss.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_SQUARED_LOSS_H_
-#define TENSORFLOW_KERNELS_SQUARED_LOSS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_
+#define TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_
#include "tensorflow/core/kernels/loss.h"
@@ -70,4 +70,4 @@ class SquaredLossUpdater : public DualLossUpdater {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SQUARED_LOSS_H_
+#endif // TENSORFLOW_CORE_KERNELS_SQUARED_LOSS_H_
diff --git a/tensorflow/core/kernels/stack_ops.cc b/tensorflow/core/kernels/stack_ops.cc
index 65296f61fd..add4afafc9 100644
--- a/tensorflow/core/kernels/stack_ops.cc
+++ b/tensorflow/core/kernels/stack_ops.cc
@@ -131,10 +131,8 @@ class Stack : public ResourceBase {
};
Status GetStack(OpKernelContext* ctx, Stack** stack) {
- string key;
if (ctx->input_dtype(0) == DT_RESOURCE) {
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
- key = resource.name();
+ return LookupResource(ctx, HandleFromInput(ctx, 0), stack);
} else {
Tensor Tstack_handle = ctx->mutable_input(0, false);
if (Tstack_handle.NumElements() != 2) {
@@ -144,18 +142,18 @@ Status GetStack(OpKernelContext* ctx, Stack** stack) {
}
const string& container = Tstack_handle.flat<string>()(0);
const string& stack_name = Tstack_handle.flat<string>()(1);
- key = strings::StrCat(container, stack_name);
- }
- ResourceMgr* rm = ctx->resource_manager();
- if (rm == nullptr) {
- return errors::Internal("No resource manager.");
- }
- auto* step_container = ctx->step_container();
- if (step_container == nullptr) {
- return errors::Internal("No step container.");
+ string key = strings::StrCat(container, stack_name);
+ ResourceMgr* rm = ctx->resource_manager();
+ if (rm == nullptr) {
+ return errors::Internal("No resource manager.");
+ }
+ auto* step_container = ctx->step_container();
+ if (step_container == nullptr) {
+ return errors::Internal("No step container.");
+ }
+ TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
+ return Status::OK();
}
- TF_RETURN_IF_ERROR(rm->Lookup(step_container->name(), key, stack));
- return Status::OK();
}
std::atomic<int64> Stack::stack_counter{0};
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 59fdc2262a..3e8a4c5b72 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -149,7 +149,7 @@ class StridedSliceOp : public OpKernel {
// NDIM and T
if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
input_dims == 2 && processing_shape.dims() == 2 &&
- final_shape.dims() == 2) {
+ final_shape.dims() == 2 && new_axis_mask == 0) {
MemCpyFunctor<T> functor;
if (functor.Copy(input, begin, end, result)) {
return;
@@ -300,37 +300,40 @@ class StridedSliceAssignOp : public OpKernel {
gtl::InlinedVector<int64, 4> end;
gtl::InlinedVector<int64, 4> strides;
- Tensor old_lhs;
+ Tensor* old_lhs = nullptr;
+ Tensor tmp;
if (context->input_dtype(0) == DT_RESOURCE) {
Var* v;
OP_REQUIRES_OK(context,
LookupResource(context, HandleFromInput(context, 0), &v));
+ core::ScopedUnref scoped_unref(v);
mutex_lock ml(*v->mu());
OP_REQUIRES_OK(context,
PrepareToUpdateVariable<Device, T>(context, v->tensor()));
- old_lhs = *v->tensor();
- OP_REQUIRES(context, old_lhs.dtype() == DataTypeToEnum<T>::value,
+ old_lhs = v->tensor();
+ OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum<T>::value,
errors::InvalidArgument(
- "l-value dtype ", DataTypeString(old_lhs.dtype()),
+ "l-value dtype ", DataTypeString(old_lhs->dtype()),
" does not match r-value dtype ",
DataTypeString(DataTypeToEnum<T>::value)));
} else {
context->forward_ref_input_to_ref_output(0, 0);
- old_lhs = context->mutable_input(0, true);
+ tmp = context->mutable_input(0, true);
+ old_lhs = &tmp;
}
OP_REQUIRES_OK(
- context,
- ValidateStridedSliceOp(
- &context->input(1), &context->input(2), context->input(3),
- old_lhs.shape(), begin_mask, end_mask, ellipsis_mask, new_axis_mask,
- shrink_axis_mask, &processing_shape, &final_shape, &is_identity,
- &is_simple_slice, &slice_dim0, &begin, &end, &strides));
+ context, ValidateStridedSliceOp(
+ &context->input(1), &context->input(2), context->input(3),
+ old_lhs->shape(), begin_mask, end_mask, ellipsis_mask,
+ new_axis_mask, shrink_axis_mask, &processing_shape,
+ &final_shape, &is_identity, &is_simple_slice, &slice_dim0,
+ &begin, &end, &strides));
if (processing_shape.num_elements()) {
const Tensor& input = context->input(4);
TensorShape input_shape = input.shape();
- TensorShape original_shape = old_lhs.shape();
+ TensorShape original_shape = old_lhs->shape();
// TODO(aselle): This check is too strong, we only should need
// input_shape to be broadcastable to final_shape
OP_REQUIRES(
@@ -345,12 +348,12 @@ class StridedSliceAssignOp : public OpKernel {
// scalar shape
// Handle general dimensions
-#define HANDLE_DIM(NDIM) \
- if (processing_dims == NDIM) { \
- HandleStridedSliceAssignCase<Device, T, NDIM>()( \
- context, begin, end, strides, processing_shape, is_simple_slice, \
- &old_lhs); \
- return; \
+#define HANDLE_DIM(NDIM) \
+ if (processing_dims == NDIM) { \
+ HandleStridedSliceAssignCase<Device, T, NDIM>()(context, begin, end, \
+ strides, processing_shape, \
+ is_simple_slice, old_lhs); \
+ return; \
}
HANDLE_DIM(0);
HANDLE_DIM(1);
diff --git a/tensorflow/core/kernels/strided_slice_op.h b/tensorflow/core/kernels/strided_slice_op.h
index 2b58632298..86d105391d 100644
--- a/tensorflow/core/kernels/strided_slice_op.h
+++ b/tensorflow/core/kernels/strided_slice_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_STRIDED_SLICE_OP_H_
-#define TENSORFLOW_KERNELS_STRIDED_SLICE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
// Functor definition for StridedSliceOp, must be compilable by nvcc.
@@ -137,4 +137,4 @@ struct StridedSliceAssignScalar {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_H_
diff --git a/tensorflow/core/kernels/strided_slice_op_impl.h b/tensorflow/core/kernels/strided_slice_op_impl.h
index 1c4472bb1a..099083b2ff 100644
--- a/tensorflow/core/kernels/strided_slice_op_impl.h
+++ b/tensorflow/core/kernels/strided_slice_op_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_
-#define TENSORFLOW_KERNELS_STRIDED_SLICE_OP_IMPL_H_
+#ifndef TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
// Functor definition for StridedSliceOp, must be compilable by nvcc.
@@ -313,4 +313,4 @@ DECLARE_FOR_N_SYCL(int64);
} // end namespace tensorflow
#endif // END STRIDED_SLICE_INSTANTIATE_DIM
-#endif // TENSORFLOW_KERNELS_SLICE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_STRIDED_SLICE_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/string_format_op.cc b/tensorflow/core/kernels/string_format_op.cc
new file mode 100644
index 0000000000..e4a1887f8d
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op.cc
@@ -0,0 +1,65 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iostream>
+#include "absl/strings/str_split.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+class StringFormatOp : public OpKernel {
+ public:
+ explicit StringFormatOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string template_;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("template", &template_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("placeholder", &placeholder_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("summarize", &summarize_));
+
+ split_template_ = absl::StrSplit(template_, placeholder_);
+ int64 num_placeholders = split_template_.size() - 1;
+ OP_REQUIRES(ctx, ctx->num_inputs() == num_placeholders,
+ errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", ctx->num_inputs())));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Tensor* formatted_string = nullptr;
+ OP_REQUIRES_OK(ctx,
+ ctx->allocate_output(0, TensorShape({}), &formatted_string));
+
+ string msg;
+ strings::StrAppend(&msg, split_template_[0].c_str());
+ for (int i = 0; i < ctx->num_inputs(); ++i) {
+ strings::StrAppend(&msg, ctx->input(i).SummarizeValue(summarize_, true));
+ strings::StrAppend(&msg, split_template_[i + 1].c_str());
+ }
+
+ formatted_string->scalar<string>()() = msg;
+ }
+
+ private:
+ int32 summarize_ = 0;
+ string placeholder_;
+ std::vector<std::string> split_template_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringFormat").Device(DEVICE_CPU),
+ StringFormatOp);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_format_op_test.cc b/tensorflow/core/kernels/string_format_op_test.cc
new file mode 100644
index 0000000000..13130a5797
--- /dev/null
+++ b/tensorflow/core/kernels/string_format_op_test.cc
@@ -0,0 +1,66 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace tensorflow {
+namespace {
+
+class StringFormatGraphTest : public OpsTestBase {
+ protected:
+ Status Init(int num_inputs, DataType input_type,
+ const string& template_ = "%s", const string& placeholder = "%s",
+ int summarize = 3) {
+ TF_CHECK_OK(NodeDefBuilder("op", "StringFormat")
+ .Input(FakeInput(num_inputs, input_type))
+ .Attr("template", template_)
+ .Attr("placeholder", placeholder)
+ .Attr("summarize", summarize)
+ .Finalize(node_def()));
+ return InitOp();
+ }
+};
+
+TEST_F(StringFormatGraphTest, Int32Success_7) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s"));
+
+ AddInputFromArray<int32>(TensorShape({7}), {1, 2, 3, 4, 5, 6, 7});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [1 2 3 ... 5 6 7]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+TEST_F(StringFormatGraphTest, Int32Success_3_3) {
+ TF_ASSERT_OK(Init(1, DT_INT32, "First tensor: %s", "%s", 1));
+
+ AddInputFromArray<int32>(TensorShape({3, 3}), {1, 2, 3, 4, 5, 6, 7, 8, 9});
+ TF_ASSERT_OK(RunOpKernel());
+ Tensor expected(allocator(), DT_STRING, TensorShape({}));
+ test::FillValues<string>(&expected, {"First tensor: [[1 ... 3]\n ..."
+ "\n [7 ... 9]]"});
+ test::ExpectTensorEqual<string>(expected, *GetOutput(0));
+}
+
+} // end namespace
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_length_op.cc b/tensorflow/core/kernels/string_length_op.cc
new file mode 100644
index 0000000000..435a7abdca
--- /dev/null
+++ b/tensorflow/core/kernels/string_length_op.cc
@@ -0,0 +1,62 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/string_util.h"
+
+namespace tensorflow {
+namespace {
+
+class StringLengthOp : public OpKernel {
+ public:
+ explicit StringLengthOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor& input = context->input(0);
+
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, input.shape(), &output));
+
+ auto src = input.flat<string>();
+ auto dst = output->flat<int32>();
+
+ switch (unit_) {
+ case CharUnit::BYTE:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = src(n).size();
+ }
+ break;
+ case CharUnit::UTF8_CHAR:
+ for (int n = 0; n < src.size(); ++n) {
+ dst(n) = UTF8StrLen(src(n));
+ }
+ break;
+ }
+ }
+
+ private:
+ CharUnit unit_ = CharUnit::BYTE;
+};
+
+REGISTER_KERNEL_BUILDER(Name("StringLength").Device(DEVICE_CPU),
+ StringLengthOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 26ab72f12e..3884370a6c 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -26,25 +26,81 @@ limitations under the License.
#include "tensorflow/core/lib/strings/str_util.h"
namespace tensorflow {
-
namespace {
+// Split input string `str` based on a character delimiter.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+// Note: The single character delimiter is a common case and is implemented as
+// a series of finds in the input string, making it much more effcient than
+// SplitOnCharSet.
+template <typename Predicate>
+std::vector<StringPiece> SplitOnChar(const string& str, const char delim,
+ Predicate p) {
+ std::vector<StringPiece> result;
+ StringPiece text(str);
+ auto f = text.find(delim);
+ while (f != StringPiece::npos) {
+ StringPiece token = text.substr(0, f);
+ if (p(token)) {
+ result.emplace_back(token);
+ }
+ text.remove_prefix(f + 1);
+ f = text.find(delim);
+ }
+ if (p(text)) {
+ result.push_back(text);
+ }
+ return result;
+}
-std::vector<string> Split(const string& str, const string& delimiter,
- const bool skipEmpty) {
- if (!delimiter.empty()) {
- if (skipEmpty) {
- return str_util::Split(str, delimiter, str_util::SkipEmpty());
+// Split input string `str` based on a set of character delimiters.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+// Based on str_util::Split.
+template <typename Predicate>
+std::vector<StringPiece> SplitOnCharSet(const string& str,
+ const string& delim_set, Predicate p) {
+ std::vector<StringPiece> result;
+ StringPiece text(str);
+ StringPiece delims(delim_set);
+ size_t token_start = 0;
+ for (size_t i = 0; i < text.size() + 1; i++) {
+ if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
+ StringPiece token(text.data() + token_start, i - token_start);
+ if (p(token)) {
+ result.emplace_back(token);
+ }
+ token_start = i + 1;
}
- return str_util::Split(str, delimiter);
}
- std::vector<string> char_vector(str.size());
- for (size_t i = 0; i < str.size(); ++i) {
- char_vector[i] = str[i];
+ return result;
+}
+
+// Split input string `str` based on given delimiter.
+// Returns a vector of StringPieces which are valid as long as input `str`
+// is valid.
+template <typename Predicate>
+std::vector<StringPiece> Split(const string& str, const string& delimiter,
+ Predicate predicate) {
+ if (str.empty()) {
+ return std::vector<StringPiece>();
+ }
+ if (delimiter.empty()) {
+ std::vector<StringPiece> result;
+ result.resize(str.size());
+ for (size_t i = 0; i < str.size(); ++i) {
+ result[i] = StringPiece(str.data() + i, 1);
+ }
+ return result;
}
- return char_vector;
+ if (delimiter.size() == 1) {
+ return SplitOnChar(str, delimiter[0], predicate);
+ }
+ return SplitOnCharSet(str, delimiter, predicate);
}
-std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
+std::vector<StringPiece> SplitV2(const string& str, StringPiece sep,
+ int maxsplit) {
// This SplitV2 method matches the behavior of python's str.split:
// If sep is given, consecutive delimiters are not grouped together
// and are deemed to delimit empty strings (for example, '1,,2'.split(',')
@@ -59,11 +115,11 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
// splitting an empty string or a string consisting of just whitespace
// with a None separator returns [].
- std::vector<string> result;
+ std::vector<StringPiece> result;
StringPiece text(str);
if (maxsplit == 0) {
- result.emplace_back(std::string(text));
+ result.emplace_back(text);
return result;
}
@@ -73,11 +129,11 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
str_util::RemoveLeadingWhitespace(&text);
int split = 0;
while (str_util::ConsumeNonWhitespace(&text, &token)) {
- result.emplace_back(std::string(token));
+ result.push_back(token);
str_util::RemoveLeadingWhitespace(&text);
++split;
if (maxsplit > 0 && split == maxsplit) {
- result.emplace_back(std::string(text));
+ result.push_back(text);
return result;
}
}
@@ -87,17 +143,17 @@ std::vector<string> SplitV2(const string& str, StringPiece sep, int maxsplit) {
int split = 0;
while (p != text.end()) {
StringPiece token = text.substr(0, p - text.begin());
- result.emplace_back(std::string(token));
+ result.push_back(token);
text.remove_prefix(token.size());
text.remove_prefix(sep.size());
++split;
if (maxsplit > 0 && split == maxsplit) {
- result.emplace_back(std::string(text));
+ result.push_back(StringPiece(text));
return result;
}
p = std::search(text.begin(), text.end(), sep.begin(), sep.end());
}
- result.emplace_back(std::string(text));
+ result.push_back(text);
return result;
}
@@ -134,7 +190,7 @@ class StringSplitOp : public OpKernel {
const auto delimiter_vec = delimiter_tensor->flat<string>();
const string& delimiter = delimiter_vec(0);
// Empty delimiter means split the input character by character.
- std::vector<string> tokens;
+ std::vector<StringPiece> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
tokens.reserve(batch_size * kReserveSize);
@@ -143,12 +199,15 @@ class StringSplitOp : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = Split(input_vec(i), delimiter, skip_empty_);
+ std::vector<StringPiece> parts =
+ skip_empty_ ? Split(input_vec(i), delimiter, str_util::SkipEmpty())
+ : Split(input_vec(i), delimiter, str_util::AllowEmpty());
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
max_num_entries = std::max(max_num_entries, n_entries);
- tokens.insert(tokens.end(), parts.begin(), parts.end());
+ tokens.insert(tokens.end(), std::make_move_iterator(parts.begin()),
+ std::make_move_iterator(parts.end()));
}
Tensor* sp_indices_t;
@@ -170,7 +229,7 @@ class StringSplitOp : public OpKernel {
for (size_t j = 0; j < num_indices[i]; ++j) {
sp_indices(c, 0) = i;
sp_indices(c, 1) = j;
- sp_tokens(c) = tokens[c];
+ sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
++c;
}
}
@@ -204,7 +263,7 @@ class StringSplitV2Op : public OpKernel {
sep_tensor->shape().DebugString()));
const auto sep_vec = sep_tensor->flat<string>();
StringPiece sep(sep_vec(0));
- std::vector<string> tokens;
+ std::vector<StringPiece> tokens;
// Guess that we'll be unpacking a handful of tokens per example.
static constexpr int kReserveSize = 4;
tokens.reserve(batch_size * kReserveSize);
@@ -213,7 +272,7 @@ class StringSplitV2Op : public OpKernel {
int64 max_num_entries = 0;
std::vector<int64> num_indices(batch_size);
for (int64 i = 0; i < batch_size; ++i) {
- std::vector<string> parts = SplitV2(input_vec(i), sep, maxsplit_);
+ std::vector<StringPiece> parts = SplitV2(input_vec(i), sep, maxsplit_);
int64 n_entries = parts.size();
num_indices[i] = n_entries;
output_size += n_entries;
@@ -240,7 +299,7 @@ class StringSplitV2Op : public OpKernel {
for (size_t j = 0; j < num_indices[i]; ++j) {
sp_indices(c, 0) = i;
sp_indices(c, 1) = j;
- sp_tokens(c) = tokens[c];
+ sp_tokens(c).assign(tokens[c].data(), tokens[c].size());
++c;
}
}
diff --git a/tensorflow/core/kernels/string_split_op_test.cc b/tensorflow/core/kernels/string_split_op_test.cc
new file mode 100644
index 0000000000..58ad61adc8
--- /dev/null
+++ b/tensorflow/core/kernels/string_split_op_test.cc
@@ -0,0 +1,129 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupStringSplitGraph(const Tensor& input) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor delim(DT_STRING, TensorShape({}));
+ delim.flat<string>().setConstant(" ");
+
+ TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplit")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, delim))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_StringSplit(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStringSplitGraph(input);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StringSplit)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+Graph* SetupStringSplitV2Graph(const Tensor& input) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor sep(DT_STRING, TensorShape({}));
+ sep.flat<string>().setConstant(" ");
+
+ TF_CHECK_OK(NodeBuilder("string_split_op", "StringSplitV2")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, sep))
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_StringSplitV2(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupStringSplitV2Graph(input);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_StringSplitV2)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/string_strip_op.cc b/tensorflow/core/kernels/string_strip_op.cc
index 2aeafa28c4..544dca96ba 100644
--- a/tensorflow/core/kernels/string_strip_op.cc
+++ b/tensorflow/core/kernels/string_strip_op.cc
@@ -43,7 +43,7 @@ class StringStripOp : public OpKernel {
for (int64 i = 0; i < input.size(); ++i) {
StringPiece entry(input(i));
str_util::RemoveWhitespaceContext(&entry);
- output(i) = std::string(entry);
+ output(i) = string(entry);
}
}
};
diff --git a/tensorflow/core/kernels/string_util.cc b/tensorflow/core/kernels/string_util.cc
new file mode 100644
index 0000000000..92c73220d8
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.cc
@@ -0,0 +1,59 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/string_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+// Sets unit value based on str.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding) {
+ if (str == "UTF8") {
+ *encoding = UnicodeEncoding::UTF8;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid encoding \"", str, "\": Should be one of: BYTE"));
+ }
+ return Status::OK();
+}
+
+// Sets unit value based on str.
+Status ParseCharUnit(const string& str, CharUnit* unit) {
+ if (str == "BYTE") {
+ *unit = CharUnit::BYTE;
+ } else if (str == "UTF8_CHAR") {
+ *unit = CharUnit::UTF8_CHAR;
+ } else {
+ return errors::InvalidArgument(strings::StrCat(
+ "Invalid unit \"", str, "\": Should be one of: BYTE, UTF8_CHAR"));
+ }
+ return Status::OK();
+}
+
+// Return the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string) {
+ const int32 byte_size = string.size();
+ const char* const end = string.data() + byte_size;
+ const char* ptr = string.data();
+ int32 skipped_count = 0;
+ while (ptr < end) {
+ skipped_count += IsTrailByte(*ptr++) ? 1 : 0;
+ }
+ const int32 result = byte_size - skipped_count;
+ return result;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/string_util.h b/tensorflow/core/kernels/string_util.h
new file mode 100644
index 0000000000..d40e93ea33
--- /dev/null
+++ b/tensorflow/core/kernels/string_util.h
@@ -0,0 +1,89 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+#define TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
+
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// Enumeration for unicode encodings. Used by ops such as
+// tf.strings.unicode_encode and tf.strings.unicode_decode.
+// TODO(edloper): Add support for:
+// UTF16, UTF32, UTF16BE, UTF32BE, UTF16LE, UTF32LE
+enum class UnicodeEncoding { UTF8 };
+
+// Enumeration for character units. Used by string such as
+// tf.strings.length and tf.substr.
+// TODO(edloper): Add support for: UTF32_CHAR, etc.
+enum class CharUnit { BYTE, UTF8_CHAR };
+
+// Whether or not the given byte is the trailing byte of a UTF-8/16/32 char.
+inline bool IsTrailByte(char x) { return static_cast<signed char>(x) < -0x40; }
+
+// Sets `encoding` based on `str`.
+Status ParseUnicodeEncoding(const string& str, UnicodeEncoding* encoding);
+
+// Sets `unit` value based on `str`.
+Status ParseCharUnit(const string& str, CharUnit* unit);
+
+// Returns the number of Unicode characters in a UTF-8 string.
+// Result may be incorrect if the input string is not valid UTF-8.
+int32 UTF8StrLen(const string& string);
+
+// Get the next UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset, and
+// should never be `null`. The function return true if successful. However, if
+// the end of the string is reached before the requested characters, then the
+// position will point to the end of string and this function will return false.
+template <typename T>
+bool ForwardNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t size = in.size();
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && *pos < size) {
+ // move forward one utf-8 character
+ do {
+ ++*pos;
+ } while (IsTrailByte(in[*pos]) && *pos < size);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
+// Get the previous UTF8 character position starting at the given position and
+// skipping the given number of characters. Position is a byte offset with a
+// positive value, relative to the beginning of the string, and should never be
+// `null`. The function return true if successful. However, if the beginning of
+// the string is reached before the requested character, then the position will
+// point to the beginning of the string and this function will return false.
+template <typename T>
+bool BackNUTF8CharPositions(const StringPiece in,
+ const T num_utf8_chars_to_shift, T* pos) {
+ const size_t start = 0;
+ T utf8_chars_counted = 0;
+ while (utf8_chars_counted < num_utf8_chars_to_shift && (*pos > start)) {
+ // move back one utf-8 character
+ do {
+ --*pos;
+ } while (IsTrailByte(in[*pos]) && *pos > start);
+ ++utf8_chars_counted;
+ }
+ return utf8_chars_counted == num_utf8_chars_to_shift;
+}
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_STRING_UTIL_H_
diff --git a/tensorflow/core/kernels/substr_op.cc b/tensorflow/core/kernels/substr_op.cc
index 22e45918a0..93c427039d 100644
--- a/tensorflow/core/kernels/substr_op.cc
+++ b/tensorflow/core/kernels/substr_op.cc
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <cstddef>
+#include <cstdlib>
#include <string>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -24,7 +26,10 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/string_util.h"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/bcast.h"
namespace tensorflow {
@@ -33,7 +38,11 @@ namespace tensorflow {
template <typename T>
class SubstrOp : public OpKernel {
public:
- using OpKernel::OpKernel;
+ explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ string unit;
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
+ OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
+ }
void Compute(OpKernelContext* context) override {
// Get inputs
@@ -64,26 +73,52 @@ class SubstrOp : public OpKernel {
const T len =
tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
- OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ StringPiece in(input(i));
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
} else {
// Perform Op element-wise with tensor pos/len
auto pos_flat = pos_tensor.flat<T>();
auto len_flat = len_tensor.flat<T>();
for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
- string in = input(i);
+ StringPiece in(input(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
- OP_REQUIRES(
- context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
}
} else {
@@ -142,14 +177,28 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
- string in = input_bcast(i);
+ StringPiece in(input_bcast(i));
const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
- OP_REQUIRES(
- context, FastBoundsCheck(pos, input_bcast(i).size() + 1),
- errors::InvalidArgument("pos ", pos, " out of range for string",
- "b'", in, "' at index ", i));
- output(i) = in.substr(pos, len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context,
+ FastBoundsCheck(byte_pos, input_bcast(i).size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index ", i));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
+ output(i).assign(sub_in.data(), sub_in.size());
}
break;
}
@@ -192,16 +241,30 @@ class SubstrOp : public OpKernel {
// Iterate through broadcasted tensors and perform substr
for (int i = 0; i < output_shape.dim_size(0); ++i) {
for (int j = 0; j < output_shape.dim_size(1); ++j) {
- string in = input_bcast(i, j);
+ StringPiece in(input_bcast(i, j));
const T pos =
tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
const T len =
tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
- OP_REQUIRES(context, FastBoundsCheck(pos, in.size() + 1),
- errors::InvalidArgument(
- "pos ", pos, " out of range for ", "string b'",
- in, "' at index (", i, ", ", j, ")"));
- output(i, j) = in.substr(pos, len);
+ T byte_pos = pos;
+ T byte_len = len;
+ switch (unit_) {
+ case CharUnit::UTF8_CHAR:
+ OP_REQUIRES(
+ context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string at index ", i));
+ break;
+ case CharUnit::BYTE:
+ byte_pos = AdjustedPosIndex(byte_pos, in);
+ OP_REQUIRES(
+ context, FastBoundsCheck(byte_pos, in.size() + 1),
+ errors::InvalidArgument("pos ", pos, " out of range for ",
+ "string b'", in, "' at index (",
+ i, ", ", j, ")"));
+ }
+ StringPiece sub_in = in.substr(byte_pos, byte_len);
+ output(i, j).assign(sub_in.data(), sub_in.size());
}
}
break;
@@ -213,6 +276,77 @@ class SubstrOp : public OpKernel {
}
}
}
+
+ private:
+ // This adjusts the requested position. Note it does not perform any bound
+ // checks.
+ static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
+ if (pos_requested < 0) {
+ return s.size() + pos_requested;
+ }
+ return pos_requested;
+ }
+
+ // Return true if successful; otherwise, return false if the `pos` argument
+ // is out of range in the string.
+ static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
+ T* len) {
+ if (*pos >= 0) {
+ return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
+ } else {
+ return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
+ }
+ }
+
+ static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ *char_pos = 0;
+ // Determine byte position of the substring start.
+ if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
+ return false;
+ }
+ // Determine position of the end of the substring.
+ // The length will be capped at the end of the string, and we ignore whether
+ // the string had enough characters to handle it or not.
+ *char_len = *char_pos;
+ ForwardNUTF8CharPositions(in, len, char_len);
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ // This function expects a negative position relative to the end of the
+ // string, but will update the character position to a positive number
+ // relative to the beginning of the string.
+ static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
+ const T len, T* char_pos,
+ T* char_len) {
+ // Initially treat the length as position of the end of the substring.
+ *char_len = in.size();
+ // This is the number of character to skip from the end of the string to
+ // arrive at the position where the substring should end.
+ T utf8_chars_to_skip = -pos - len;
+ if (utf8_chars_to_skip < 0) {
+ utf8_chars_to_skip = 0;
+ }
+ // Find the byte position where the substring should end using the computed
+ // number of characters to skip.
+ if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
+ return false;
+ }
+ // Next, determine where the substring should begin. The number of chars to
+ // skip is the requested position minus the chars we've previously skipped.
+ *char_pos = *char_len;
+ if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
+ return false;
+ }
+ // The length in bytes is the position end of the substring less the start.
+ *char_len = *char_len - *char_pos;
+ return true;
+ }
+
+ CharUnit unit_ = CharUnit::BYTE;
};
#define REGISTER_SUBSTR(type) \
diff --git a/tensorflow/core/kernels/substr_op_test.cc b/tensorflow/core/kernels/substr_op_test.cc
new file mode 100644
index 0000000000..ea6b1ed500
--- /dev/null
+++ b/tensorflow/core/kernels/substr_op_test.cc
@@ -0,0 +1,189 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <string>
+
+#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/graph/testlib.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+
+// Test data from the TensorFlow README.md.
+const char* ascii_lines[] = {
+ "**TensorFlow** is an open source software library for numerical "
+ "computation using data flow graphs.",
+ "The graph nodes represent mathematical operations, while the graph edges "
+ "represent the multidimensional data arrays (tensors) that flow between "
+ "them.",
+ "This flexible architecture enables you to deploy computation to one or "
+ "more CPUs or GPUs in a desktop, server, or mobile device without "
+ "rewriting code.",
+ "TensorFlow also includes "
+ "[TensorBoard](https://www.tensorflow.org/guide/"
+ "summaries_and_tensorboard), a data visualization toolkit.",
+ "TensorFlow was originally developed by researchers and engineers working "
+ "on the Google Brain team within Google's Machine Intelligence Research "
+ "organization for the purposes of conducting machine learning and deep "
+ "neural networks research.",
+ "The system is general enough to be applicable in a wide variety of other "
+ "domains, as well.",
+ "TensorFlow provides stable Python API and C APIs as well as without API "
+ "backwards compatibility guarantee like C++, Go, Java, JavaScript and "
+ "Swift."};
+
+const char* unicode_lines[] = {
+ "TensorFlow\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe4\xbd\xbf\xe7\x94\xa8\xe6"
+ "\x95\xb0\xe6\x8d\xae\xe6\xb5\x81\xe5\x9b\xbe\xe8\xbf\x9b\xe8\xa1\x8c\xe6"
+ "\x95\xb0\xe5\x80\xbc\xe8\xae\xa1\xe7\xae\x97\xe7\x9a\x84\xe5\xbc\x80\xe6"
+ "\xba\x90\xe8\xbd\xaf\xe4\xbb\xb6\xe5\xba\x93\xe3\x80\x82",
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\x8a\x82\xe7\x82\xb9\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe6\x95\xb0\xe5\xad\xa6\xe8\xbf\x90\xe7\xae\x97\xef\xbc\x8c\xe8\x80\x8c"
+ "\xe5\x9b\xbe\xe5\xbd\xa2\xe8\xbe\xb9\xe7\xbc\x98\xe8\xa1\xa8\xe7\xa4\xba"
+ "\xe5\x9c\xa8\xe5\xae\x83\xe4\xbb\xac\xe4\xb9\x8b\xe9\x97\xb4\xe6\xb5\x81"
+ "\xe5\x8a\xa8\xe7\x9a\x84\xe5\xa4\x9a\xe7\xbb\xb4\xe6\x95\xb0\xe6\x8d\xae"
+ "\xe9\x98\xb5\xe5\x88\x97\xef\xbc\x88\xe5\xbc\xa0\xe9\x87\x8f\xef\xbc\x89"
+ "\xe3\x80\x82",
+ "\xe8\xbf\x99\xe7\xa7\x8d\xe7\x81\xb5\xe6\xb4\xbb\xe7\x9a\x84\xe4\xbd\x93"
+ "\xe7\xb3\xbb\xe7\xbb\x93\xe6\x9e\x84\xe4\xbd\xbf\xe6\x82\xa8\xe5\x8f\xaf"
+ "\xe4\xbb\xa5\xe5\xb0\x86\xe8\xae\xa1\xe7\xae\x97\xe9\x83\xa8\xe7\xbd\xb2"
+ "\xe5\x88\xb0\xe6\xa1\x8c\xe9\x9d\xa2\xef\xbc\x8c\xe6\x9c\x8d\xe5\x8a\xa1"
+ "\xe5\x99\xa8\xe6\x88\x96\xe7\xa7\xbb\xe5\x8a\xa8\xe8\xae\xbe\xe5\xa4\x87"
+ "\xe4\xb8\xad\xe7\x9a\x84\xe4\xb8\x80\xe4\xb8\xaa\xe6\x88\x96\xe5\xa4\x9a"
+ "\xe4\xb8\xaa CPU\xe6\x88\x96GPU\xef\xbc\x8c\xe8\x80\x8c\xe6\x97\xa0\xe9"
+ "\x9c\x80\xe9\x87\x8d\xe5\x86\x99\xe4\xbb\xa3\xe7\xa0\x81\xe3\x80\x82",
+ "TensorFlow\xe8\xbf\x98\xe5\x8c\x85\xe6\x8b\xac[TensorBoard]\xef\xbc\x88"
+ "https://www.tensorflow.org/guide/summaries_and_tensorboard\xef\xbc\x89\xef"
+ "\xbc\x8c\xe8\xbf\x99\xe6\x98\xaf\xe4\xb8\x80\xe4\xb8\xaa\xe6\x95\xb0\xe6"
+ "\x8d\xae\xe5\x8f\xaf\xe8\xa7\x86\xe5\x8c\x96\xe5\xb7\xa5\xe5\x85\xb7\xe5"
+ "\x8c\x85\xe3\x80\x82",
+ "TensorFlow\xe6\x9c\x80\xe5\x88\x9d\xe6\x98\xaf\xe7\x94\xb1\xe7\xa0\x94\xe7"
+ "\xa9\xb6\xe4\xba\xba\xe5\x91\x98\xe5\x92\x8c\xe5\xb7\xa5\xe7\xa8\x8b\xe5"
+ "\xb8\x88\xe5\x9c\xa8Google\xe6\x9c\xba\xe5\x99\xa8\xe6\x99\xba\xe8\x83\xbd"
+ "\xe7\xa0\x94\xe7\xa9\xb6\xe7\xbb\x84\xe7\xbb\x87\xe7\x9a\x84Google Brain"
+ "\xe5\x9b\xa2\xe9\x98\x9f\xe5\xbc\x80\xe5\x8f\x91\xe7\x9a\x84\xef\xbc\x8c"
+ "\xe7\x9b\xae\xe7\x9a\x84\xe6\x98\xaf\xe8\xbf\x9b\xe8\xa1\x8c\xe6\x9c\xba"
+ "\xe5\x99\xa8\xe5\xad\xa6\xe4\xb9\xa0\xe5\x92\x8c\xe6\xb7\xb1\xe5\xba\xa6"
+ "\xe7\xa5\x9e\xe7\xbb\x8f\xe7\xbd\x91\xe7\xbb\x9c\xe7\xa0\x94\xe7\xa9\xb6"
+ "\xe3\x80\x82",
+ "\xe8\xaf\xa5\xe7\xb3\xbb\xe7\xbb\x9f\xe8\xb6\xb3\xe4\xbb\xa5\xe9\x80\x82"
+ "\xe7\x94\xa8\xe4\xba\x8e\xe5\x90\x84\xe7\xa7\x8d\xe5\x85\xb6\xe4\xbb\x96"
+ "\xe9\xa2\x86\xe5\x9f\x9f\xe4\xb9\x9f\xe6\x98\xaf\xe5\xa6\x82\xe6\xad\xa4"
+ "\xe3\x80\x82",
+ "TensorFlow\xe6\x8f\x90\xe4\xbe\x9b\xe7\xa8\xb3\xe5\xae\x9a\xe7\x9a\x84"
+ "Python API\xe5\x92\x8c C API\xef\xbc\x8c\xe4\xbb\xa5\xe5\x8f\x8a\xe6\xb2"
+ "\xa1\xe6\x9c\x89 API\xe5\x90\x91\xe5\x90\x8e\xe5\x85\xbc\xe5\xae\xb9\xe6"
+ "\x80\xa7\xe4\xbf\x9d\xe8\xaf\x81\xef\xbc\x8c\xe5\xa6\x82 C ++\xef\xbc\x8c"
+ "Go\xef\xbc\x8cJava\xef\xbc\x8cJavaScript\xe5\x92\x8cSwift\xe3\x80\x82",
+};
+
+const char* const kByteUnit = "BYTE";
+const char* const kUTF8Unit = "UTF8_CHAR";
+
+Tensor GetTestTensor(int batch) {
+ const int sz = TF_ARRAYSIZE(ascii_lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = ascii_lines[i % sz];
+ }
+ return t;
+}
+
+Tensor GetTestUTF8Tensor(int batch) {
+ const int sz = TF_ARRAYSIZE(unicode_lines);
+ Tensor t(DT_STRING, {batch});
+ auto s = t.flat<string>();
+ for (int i = 0; i < batch; ++i) {
+ s(i) = unicode_lines[i % sz];
+ }
+ return t;
+}
+
+Graph* SetupSubstrGraph(const Tensor& input, const int32 pos, const int32 len,
+ const char* const unit) {
+ Graph* g = new Graph(OpRegistry::Global());
+ Tensor position(DT_INT32, TensorShape({}));
+ position.flat<int32>().setConstant(pos);
+ Tensor length(DT_INT32, TensorShape({}));
+ length.flat<int32>().setConstant(len);
+
+ TF_CHECK_OK(NodeBuilder("substr_op", "Substr")
+ .Input(test::graph::Constant(g, input))
+ .Input(test::graph::Constant(g, position))
+ .Input(test::graph::Constant(g, length))
+ .Attr("unit", unit)
+ .Finalize(g, nullptr /* node */));
+ return g;
+}
+
+void BM_SubstrByte(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestTensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kByteUnit);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+void BM_SubstrUTF8(int iters, int batch_size) {
+ testing::StopTiming();
+ testing::ItemsProcessed(static_cast<int64>(iters));
+ testing::UseRealTime();
+ Tensor input = GetTestUTF8Tensor(batch_size);
+ Graph* g = SetupSubstrGraph(input, 3, 30, kUTF8Unit);
+ testing::StartTiming();
+ test::Benchmark("cpu", g).Run(iters);
+}
+
+BENCHMARK(BM_SubstrByte)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+BENCHMARK(BM_SubstrUTF8)
+ ->Arg(1)
+ ->Arg(8)
+ ->Arg(16)
+ ->Arg(32)
+ ->Arg(64)
+ ->Arg(128)
+ ->Arg(256);
+
+} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/svd_op_impl.h b/tensorflow/core/kernels/svd_op_impl.h
index a996b67c62..2a67700c12 100644
--- a/tensorflow/core/kernels/svd_op_impl.h
+++ b/tensorflow/core/kernels/svd_op_impl.h
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
+
// See docs in ../ops/linalg_ops.cc.
//
// This header file is used by the individual svd_*op*.cc files for registering
@@ -101,3 +104,5 @@ class SvdOp : public LinearAlgebraOp<Scalar> {
};
} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_KERNELS_SVD_OP_IMPL_H_
diff --git a/tensorflow/core/kernels/tensor_array.cc b/tensorflow/core/kernels/tensor_array.cc
index 765467bc1e..0e6c0ddccc 100644
--- a/tensorflow/core/kernels/tensor_array.cc
+++ b/tensorflow/core/kernels/tensor_array.cc
@@ -62,7 +62,8 @@ TF_CALL_complex128(TENSOR_ARRAY_WRITE_OR_ADD_GPU);
}
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/tensor_array.h b/tensorflow/core/kernels/tensor_array.h
index 68fab85770..384a63e945 100644
--- a/tensorflow/core/kernels/tensor_array.h
+++ b/tensorflow/core/kernels/tensor_array.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
-#define TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
+#define TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
#include <limits.h>
#include <vector>
@@ -81,7 +81,8 @@ Status TensorSetZero(OpKernelContext* ctx, Tensor* value) {
Status TensorSetZero<Device, T>(OpKernelContext * ctx, Tensor * value);
#define TENSOR_ARRAY_SET_ZERO_CPU(T) TENSOR_ARRAY_SET_ZERO(CPUDevice, T)
-TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU)
+TF_CALL_NUMBER_TYPES(TENSOR_ARRAY_SET_ZERO_CPU);
+TF_CALL_bool(TENSOR_ARRAY_SET_ZERO_CPU);
#undef TENSOR_ARRAY_SET_ZERO_CPU
#if GOOGLE_CUDA
@@ -629,4 +630,4 @@ Status TensorArray::LockedRead(OpKernelContext* ctx, const int32 index,
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TENSOR_ARRAY_H_
+#endif // TENSORFLOW_CORE_KERNELS_TENSOR_ARRAY_H_
diff --git a/tensorflow/core/kernels/tensor_array_ops.cc b/tensorflow/core/kernels/tensor_array_ops.cc
index b368ffc875..a97a71b344 100644
--- a/tensorflow/core/kernels/tensor_array_ops.cc
+++ b/tensorflow/core/kernels/tensor_array_ops.cc
@@ -259,6 +259,7 @@ REGISTER_KERNEL_BUILDER(Name("TensorArrayV3").Device(DEVICE_CPU),
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -290,14 +291,14 @@ class TensorArrayGradOp : public TensorArrayCreationOp {
}
} else {
container = "_tensor_arrays";
- auto resource = ctx->input(0).flat<ResourceHandle>()(0);
+ const auto& resource = ctx->input(0).flat<ResourceHandle>()(0);
if (StringPiece(resource.name()).substr(0, container.size()) !=
container) {
return errors::InvalidArgument("Wrong input container. ",
resource.name());
}
tensor_array_name =
- std::string(StringPiece(resource.name()).substr(container.size()));
+ string(StringPiece(resource.name()).substr(container.size()));
}
auto output_handle = tensor_array_output_handle->flat<string>();
@@ -576,6 +577,7 @@ TF_CALL_ALL_TYPES(REGISTER_READ)
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
REGISTER_GPU(bfloat16);
#undef REGISTER_GPU
@@ -1119,8 +1121,8 @@ class TensorArrayUnpackOrScatterOp : public OpKernel {
{1, num_values, element_shape.num_elements()});
Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, 0, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, 1,
- element_shape.num_elements()};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ 1, 1, static_cast<Eigen::DenseIndex>(element_shape.num_elements())};
std::vector<PersistentTensor> write_values;
write_values.reserve(num_values);
@@ -1218,6 +1220,7 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_AND_UNPACK);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
#undef REGISTER_GPU
#endif // GOOGLE_CUDA
@@ -1315,9 +1318,11 @@ class TensorArraySplitOp : public OpKernel {
PersistentTensor persistent_tensor;
int64 previous_length = (i == 0) ? 0 : cumulative_lengths[i - 1];
- Eigen::DSizes<Eigen::DenseIndex, 3> indices{0, previous_length, 0};
- Eigen::DSizes<Eigen::DenseIndex, 3> sizes{1, tensor_lengths_t(i),
- elements_per_row};
+ Eigen::DSizes<Eigen::DenseIndex, 3> indices{
+ 0, static_cast<Eigen::DenseIndex>(previous_length), 0};
+ Eigen::DSizes<Eigen::DenseIndex, 3> sizes{
+ 1, static_cast<Eigen::DenseIndex>(tensor_lengths_t(i)),
+ static_cast<Eigen::DenseIndex>(elements_per_row)};
OP_REQUIRES_OK(ctx, ctx->allocate_persistent(
tensor_array->ElemType(), element_shapes[i],
diff --git a/tensorflow/core/kernels/tile_functor.h b/tensorflow/core/kernels/tile_functor.h
index 189be9239b..95986af8b7 100644
--- a/tensorflow/core/kernels/tile_functor.h
+++ b/tensorflow/core/kernels/tile_functor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
-#define TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_
+#define TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -106,4 +106,4 @@ struct Tile {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TILE_FUNCTOR_H_
+#endif // TENSORFLOW_CORE_KERNELS_TILE_FUNCTOR_H_
diff --git a/tensorflow/core/kernels/tile_ops_impl.h b/tensorflow/core/kernels/tile_ops_impl.h
index 9861717a0b..6a9de388c6 100644
--- a/tensorflow/core/kernels/tile_ops_impl.h
+++ b/tensorflow/core/kernels/tile_ops_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
-#define TENSORFLOW_KERNELS_TILE_IMPL_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_
+#define TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -68,4 +68,4 @@ struct ReduceAndReshape {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TILE_OPS_IMPL_H_
+#endif // TENSORFLOW_CORE_KERNELS_TILE_OPS_IMPL_H_
diff --git a/tensorflow/core/kernels/topk_op.h b/tensorflow/core/kernels/topk_op.h
index a53e3ec8d4..1fdbc5b15f 100644
--- a/tensorflow/core/kernels/topk_op.h
+++ b/tensorflow/core/kernels/topk_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_TOPK_OP_H_
-#define TENSORFLOW_TOPK_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TOPK_OP_H_
+#define TENSORFLOW_CORE_KERNELS_TOPK_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -39,4 +39,4 @@ struct TopKFunctor {
} // end namespace tensorflow
-#endif // TENSORFLOW_TOPK_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_TOPK_OP_H_
diff --git a/tensorflow/core/kernels/topk_op_gpu.cu.cc b/tensorflow/core/kernels/topk_op_gpu.cu.cc
index ca296d5aa0..2fbe1fe7cb 100644
--- a/tensorflow/core/kernels/topk_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/topk_op_gpu.cu.cc
@@ -20,9 +20,9 @@ limitations under the License.
#include <cmath>
#include <vector>
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_segmented_radix_sort.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_segmented_radix_sort.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc
index d3c4f62071..4262a5404b 100644
--- a/tensorflow/core/kernels/training_op_helpers.cc
+++ b/tensorflow/core/kernels/training_op_helpers.cc
@@ -15,13 +15,16 @@ limitations under the License.
#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/util/ptr_util.h"
+
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource) {
+ *maybe_resource = nullptr;
if (ctx->input_dtype(input) == DT_RESOURCE) {
- Var* var;
- if (LookupResource(ctx, HandleFromInput(ctx, input), &var).ok()) {
- return var->mu();
+ if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) {
+ return (*maybe_resource)->mu();
} else {
ctx->CtxFailureWithWarning(
errors::Internal("Invalid variable reference."));
@@ -32,12 +35,13 @@ mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input) {
}
// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes
-// in address order to mitigate deadlock. Returns a vector of acquired mutexes.
-// Safe to pass duplicates - will only lock each distinct mutex once. If
-// do_lock is false, returns immediately. Note that this silently doesn't lock
-// mutexes for invalid variable references; in all usages this is followed by
-// GetInputTensor which will signal a failure.
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// in address order to mitigate deadlock. Returns a structure that, when
+// deleted, will release the acquired mutexes. Safe to pass duplicates - will
+// only lock each distinct mutex once. If do_lock is false, returns
+// immediately. Note that this silently doesn't lock mutexes for invalid
+// variable references; in all usages this is followed by GetInputTensor which
+// will signal a failure.
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids) {
bool any_resource = false;
for (auto i : input_ids) {
@@ -46,14 +50,16 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
break;
}
}
- std::vector<mutex_lock> locks;
if (!do_lock && !any_resource) {
- return locks;
+ return VariableInputLockHolder({}, {});
}
+ std::vector<Var*> vars;
std::vector<mutex*> mutexes;
std::vector<int> acquire_order;
for (auto input : input_ids) {
- mutex* mutex = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mutex = GetTrainingVariableMutex(ctx, input, &var);
+ if (var) vars.push_back(var);
// Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3).
if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) {
acquire_order.push_back(mutexes.size());
@@ -63,13 +69,19 @@ std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
std::sort(acquire_order.begin(), acquire_order.end(),
[&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; });
+ std::unique_ptr<std::vector<mutex_lock>> locks =
+ MakeUnique<std::vector<mutex_lock>>();
+ locks->reserve(acquire_order.size());
+
for (auto input : acquire_order) {
- mutex* mu = GetTrainingVariableMutex(ctx, input);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, input, &var);
+ core::ScopedUnref scoped_unref(var);
if (mu != nullptr) {
- locks.emplace_back(*mu);
+ locks->emplace_back(*mu);
}
}
- return locks;
+ return VariableInputLockHolder(std::move(vars), std::move(locks));
}
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h
index 765335d3a0..9f173a80f7 100644
--- a/tensorflow/core/kernels/training_op_helpers.h
+++ b/tensorflow/core/kernels/training_op_helpers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
-#define TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
+#define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/variant_op_registry.h"
@@ -23,9 +23,42 @@ limitations under the License.
namespace tensorflow {
-mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input);
+// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`.
+//
+// If `input` corresponds to a `DT_RESOURCE`-type variable input,
+// `*maybe_resource` will be updated to contain the underlying resource, and the
+// caller will be responsible for calling `Unref()` on that resource.
+mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input,
+ Var** maybe_resource);
-std::vector<mutex_lock> MaybeLockVariableInputMutexesInOrder(
+// Utility structure that releases a sequence of borrowed mutexes when it is
+// deleted.
+struct VariableInputLockHolder {
+ public:
+ VariableInputLockHolder(std::vector<Var*> vars,
+ std::unique_ptr<std::vector<mutex_lock>> locks)
+ : vars_(std::move(vars)), locks_(std::move(locks)) {}
+
+ VariableInputLockHolder(VariableInputLockHolder&& other)
+ : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {}
+
+ ~VariableInputLockHolder() {
+ // Release the locks before unreffing the Vars, because each lock
+ // is potentially borrowed from a Var in vars_.
+ locks_.reset();
+ for (Var* var : vars_) {
+ var->Unref();
+ }
+ }
+
+ private:
+ std::vector<Var*> vars_;
+ // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly,
+ // because a `std::vector<mutex_lock>` is not movable on all platforms.
+ std::unique_ptr<std::vector<mutex_lock>> locks_;
+};
+
+VariableInputLockHolder MaybeLockVariableInputMutexesInOrder(
OpKernelContext* ctx, bool do_lock, const std::vector<int>& input_ids);
void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input,
@@ -90,4 +123,4 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input,
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TRAINING_OP_HELPERS_H_
+#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 271329599f..acf162deec 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#define EIGEN_USE_THREADS
-
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
#include <algorithm>
@@ -201,7 +200,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
typename TTypes<T>::ConstScalar l2_shrinkage,
typename TTypes<T>::ConstScalar lr_power) {
auto grad_with_shrinkage = grad + static_cast<T>(2) * l2_shrinkage() * var;
- auto new_accum = accum + grad_with_shrinkage.square();
+ auto new_accum = accum + grad * grad;
// special case for which lr_power=-0.5.
if (lr_power() == static_cast<T>(-0.5)) {
linear.device(d) +=
@@ -226,7 +225,7 @@ struct ApplyFtrlV2<CPUDevice, T> {
var.device(d) = (linear.abs() > linear.constant(l1()))
.select(pre_shrink, var.constant(static_cast<T>(0)));
}
- accum.device(d) += grad_with_shrinkage.square();
+ accum.device(d) += grad * grad;
}
};
@@ -562,7 +561,9 @@ class ApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* resource;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource);
+ core::ScopedUnref scoped_unref(resource);
if (use_exclusive_lock_ && mu != nullptr) {
mutex_lock l1(*mu);
// Don't try to acquire a lock on the second ref as they share the same
@@ -711,7 +712,9 @@ class SparseApplyAdadeltaOp : public OpKernel {
}
void Compute(OpKernelContext* ctx) override {
- mutex* mu = GetTrainingVariableMutex(ctx, 0);
+ Var* var;
+ mutex* mu = GetTrainingVariableMutex(ctx, 0, &var);
+ core::ScopedUnref scoped_unref(var);
// mu_accum is actually the same mutex as mu_var since currently we use a
// global mutex.
//
@@ -2167,15 +2170,15 @@ class SparseApplyFtrlOp : public OpKernel {
// Use a macro to implement the computation here due to the templating of the
// eigen tensor library.
-#define COMPUTE_FTRL(grad_to_use) \
- auto new_accum = accum + grad_to_use.square(); \
+#define COMPUTE_FTRL(grad, grad_maybe_with_shrinkage) \
+ auto new_accum = accum + grad.square(); \
if (lr_power_scalar == static_cast<T>(-0.5)) { \
- linear += \
- grad_to_use - (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
+ linear += grad_maybe_with_shrinkage - \
+ (new_accum.sqrt() - accum.sqrt()) / lr_scalar * var; \
} else { \
- linear += grad_to_use - (new_accum.pow(-lr_power_scalar) - \
- accum.pow(-lr_power_scalar)) / \
- lr_scalar * var; \
+ linear += grad_maybe_with_shrinkage - (new_accum.pow(-lr_power_scalar) - \
+ accum.pow(-lr_power_scalar)) / \
+ lr_scalar * var; \
} \
auto l1_reg_adjust = linear.cwiseMin(l1_scalar).cwiseMax(-l1_scalar); \
auto x = l1_reg_adjust - linear; \
@@ -2188,14 +2191,14 @@ class SparseApplyFtrlOp : public OpKernel {
linear.constant(static_cast<T>(2) * l2_scalar); \
var = x / y; \
} \
- accum += grad_to_use.square();
+ accum += grad.square();
if (has_l2_shrinkage) {
auto grad_with_shrinkage =
grad + static_cast<T>(2) * l2_shrinkage_scalar * var;
- COMPUTE_FTRL(grad_with_shrinkage);
+ COMPUTE_FTRL(grad, grad_with_shrinkage);
} else {
- COMPUTE_FTRL(grad);
+ COMPUTE_FTRL(grad, grad);
}
}
#undef COMPUTE_FTRL
@@ -2228,12 +2231,12 @@ class SparseApplyFtrlOp : public OpKernel {
T g;
if (has_l2_shrinkage) {
g = grad_flat(i) +
- (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(i));
+ (static_cast<T>(2) * l2_shrinkage_scalar * var_flat(index));
} else {
g = grad_flat(i);
}
- T updated_a = a + g * g;
+ T updated_a = a + grad_flat(i) * grad_flat(i);
using Eigen::numext::pow;
T sigma = pow(updated_a, -lr_power_scalar) - pow(a, -lr_power_scalar);
sigma /= lr_scalar;
@@ -2856,9 +2859,8 @@ class ApplyAdaMaxOp : public OpKernel {
const Device& device = ctx->template eigen_device<Device>();
functor::ApplyAdaMax<Device, T>()(
device, var.flat<T>(), m.flat<T>(), v.flat<T>(),
- beta1_power.scalar<T>(), lr.scalar<T>(),
- beta1.scalar<T>(), beta2.scalar<T>(), epsilon.scalar<T>(),
- grad.flat<T>());
+ beta1_power.scalar<T>(), lr.scalar<T>(), beta1.scalar<T>(),
+ beta2.scalar<T>(), epsilon.scalar<T>(), grad.flat<T>());
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
@@ -2867,16 +2869,16 @@ class ApplyAdaMaxOp : public OpKernel {
bool use_exclusive_lock_;
};
-#define REGISTER_KERNELS(D, T) \
- REGISTER_KERNEL_BUILDER( \
+#define REGISTER_KERNELS(D, T) \
+ REGISTER_KERNEL_BUILDER( \
Name("ApplyAdaMax").Device(DEVICE_##D).TypeConstraint<T>("T"), \
ApplyAdaMaxOp<D##Device, T>); \
REGISTER_KERNEL_BUILDER(Name("ResourceApplyAdaMax") \
- .HostMemory("var") \
- .HostMemory("m") \
- .HostMemory("v") \
- .Device(DEVICE_##D) \
- .TypeConstraint<T>("T"), \
+ .HostMemory("var") \
+ .HostMemory("m") \
+ .HostMemory("v") \
+ .Device(DEVICE_##D) \
+ .TypeConstraint<T>("T"), \
ApplyAdaMaxOp<D##Device, T>);
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T);
@@ -2889,7 +2891,7 @@ TF_CALL_double(REGISTER_CPU_KERNELS);
namespace functor {
#define DECLARE_GPU_SPEC(T) \
template <> \
- void ApplyAdaMax<GPUDevice, T>::operator()( \
+ void ApplyAdaMax<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat var, \
typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, \
typename TTypes<T>::ConstScalar beta1_power, \
@@ -2897,7 +2899,7 @@ namespace functor {
typename TTypes<T>::ConstScalar beta1, \
typename TTypes<T>::ConstScalar beta2, \
typename TTypes<T>::ConstScalar epsilon, \
- typename TTypes<T>::ConstFlat grad); \
+ typename TTypes<T>::ConstFlat grad); \
extern template struct ApplyAdaMax<GPUDevice, T>;
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index 495a94f1a1..e10a4cb125 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TRAINING_OPS_H_
-#define TENSORFLOW_KERNELS_TRAINING_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor_types.h"
@@ -199,4 +199,4 @@ struct ApplyPowerSign {
} // end namespace functor
} // end namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TRAINING_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_
diff --git a/tensorflow/core/kernels/transpose_op.cc b/tensorflow/core/kernels/transpose_op.cc
index 0f0f65c5a3..48e392c070 100644
--- a/tensorflow/core/kernels/transpose_op.cc
+++ b/tensorflow/core/kernels/transpose_op.cc
@@ -218,7 +218,7 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
perm, out);
}
-#if defined(INTEL_MKL)
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -230,11 +230,8 @@ Status ConjugateTransposeCpuOp::DoTranspose(OpKernelContext* ctx,
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
MklConjugateTransposeCpuOp);
-TF_CALL_ALL_TYPES(REGISTER);
-#undef REGISTER
-
-#else // INTEL_MKL
+#else // INTEL_MKL && ENABLE_MKL
#define REGISTER(T) \
REGISTER_KERNEL_BUILDER(Name("Transpose") \
.Device(DEVICE_CPU) \
@@ -246,9 +243,10 @@ TF_CALL_ALL_TYPES(REGISTER);
.TypeConstraint<T>("T") \
.HostMemory("perm"), \
ConjugateTransposeCpuOp);
+#endif // INTEL_MKL && ENABLE_MKL
+
TF_CALL_ALL_TYPES(REGISTER)
#undef REGISTER
-#endif // INTEL_MKL
#if GOOGLE_CUDA
Status TransposeGpuOp::DoTranspose(OpKernelContext* ctx, const Tensor& in,
diff --git a/tensorflow/core/kernels/typed_conditional_accumulator_base.h b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
index 1980f758fc..ca341e511e 100644
--- a/tensorflow/core/kernels/typed_conditional_accumulator_base.h
+++ b/tensorflow/core/kernels/typed_conditional_accumulator_base.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
-#define TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
+#ifndef TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
+#define TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
#include "tensorflow/core/kernels/conditional_accumulator_base.h"
@@ -35,8 +35,9 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
public:
TypedConditionalAccumulatorBase(const DataType& dtype,
const PartialTensorShape& shape,
- const string& name)
- : ConditionalAccumulatorBase(dtype, shape, name) {}
+ const string& name,
+ const string& reduction_type)
+ : ConditionalAccumulatorBase(dtype, shape, name, reduction_type) {}
/**
* Attempts to add a gradient to the accumulator. An ApplyGrad attempt is
@@ -91,4 +92,4 @@ class TypedConditionalAccumulatorBase : public ConditionalAccumulatorBase {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
+#endif // TENSORFLOW_CORE_KERNELS_TYPED_CONDITIONAL_ACCUMULATOR_BASE_H_
diff --git a/tensorflow/core/kernels/unicode_script_op.cc b/tensorflow/core/kernels/unicode_script_op.cc
new file mode 100644
index 0000000000..085e397eba
--- /dev/null
+++ b/tensorflow/core/kernels/unicode_script_op.cc
@@ -0,0 +1,53 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "unicode/errorcode.h" // TF:icu
+#include "unicode/uscript.h" // TF:icu
+#include "tensorflow/core/framework/op_kernel.h"
+
+namespace tensorflow {
+
+class UnicodeScriptOp : public OpKernel {
+ public:
+ explicit UnicodeScriptOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ const Tensor* input_tensor;
+ OP_REQUIRES_OK(context, context->input("input", &input_tensor));
+ const auto& input_flat = input_tensor->flat<int32>();
+
+ Tensor* output_tensor = nullptr;
+ OP_REQUIRES_OK(context,
+ context->allocate_output("output", input_tensor->shape(),
+ &output_tensor));
+ auto output_flat = output_tensor->flat<int32>();
+
+ icu::ErrorCode status;
+ for (int i = 0; i < input_flat.size(); i++) {
+ UScriptCode script_code = uscript_getScript(input_flat(i), status);
+ if (status.isSuccess()) {
+ output_flat(i) = script_code;
+ } else {
+ output_flat(i) = -1;
+ status.reset();
+ }
+ }
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnicodeScript").Device(DEVICE_CPU),
+ UnicodeScriptOp);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/unravel_index_op.cc b/tensorflow/core/kernels/unravel_index_op.cc
index 62e814ff77..8d839ba85a 100644
--- a/tensorflow/core/kernels/unravel_index_op.cc
+++ b/tensorflow/core/kernels/unravel_index_op.cc
@@ -97,10 +97,12 @@ class UnravelIndexOp : public OpKernel {
auto output = output_tensor->matrix<Tidx>();
- Eigen::array<int64, 2> reshape{{dims_tensor.NumElements(), 1}};
- Eigen::array<int64, 2> bcast({1, indices_tensor.NumElements()});
- Eigen::array<int64, 2> indices_reshape{{1, indices_tensor.NumElements()}};
- Eigen::array<int64, 2> indices_bcast({dims_tensor.NumElements(), 1});
+ Eigen::array<Eigen::Index, 2> reshape{{dims_tensor.NumElements(), 1}};
+ Eigen::array<Eigen::Index, 2> bcast({1, indices_tensor.NumElements()});
+ Eigen::array<Eigen::Index, 2> indices_reshape{
+ {1, indices_tensor.NumElements()}};
+ Eigen::array<Eigen::Index, 2> indices_bcast(
+ {dims_tensor.NumElements(), 1});
output = indices_tensor.vec<Tidx>()
.reshape(indices_reshape)
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index f27dab4ddd..4742e429ed 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_VARIABLE_OPS_H_
-#define TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+#ifndef TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_
+#define TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -46,4 +46,4 @@ class VariableOp : public OpKernel {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_VARIABLE_OPS_H_
+#endif // TENSORFLOW_CORE_KERNELS_VARIABLE_OPS_H_
diff --git a/tensorflow/core/kernels/where_op.h b/tensorflow/core/kernels/where_op.h
index d26849c8bd..e63b3ba8cd 100644
--- a/tensorflow/core/kernels/where_op.h
+++ b/tensorflow/core/kernels/where_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_WHERE_OP_H_
-#define TENSORFLOW_KERNELS_WHERE_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_H_
+#define TENSORFLOW_CORE_KERNELS_WHERE_OP_H_
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
@@ -63,4 +63,4 @@ struct Where {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_WHERE_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_H_
diff --git a/tensorflow/core/kernels/where_op_gpu.cu.h b/tensorflow/core/kernels/where_op_gpu.cu.h
index 57f51889de..2255597651 100644
--- a/tensorflow/core/kernels/where_op_gpu.cu.h
+++ b/tensorflow/core/kernels/where_op_gpu.cu.h
@@ -13,15 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
+#define TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
-#include "external/cub_archive/cub/device/device_reduce.cuh"
-#include "external/cub_archive/cub/device/device_select.cuh"
-#include "external/cub_archive/cub/iterator/counting_input_iterator.cuh"
-#include "external/cub_archive/cub/iterator/transform_input_iterator.cuh"
+#include "third_party/cub/device/device_reduce.cuh"
+#include "third_party/cub/device/device_select.cuh"
+#include "third_party/cub/iterator/counting_input_iterator.cuh"
+#include "third_party/cub/iterator/transform_input_iterator.cuh"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/kernels/bounds_check.h"
@@ -346,3 +349,5 @@ TF_CALL_WHERE_GPU_TYPES(DECLARE_GPU_SPEC);
} // namespace tensorflow
#endif // GOOGLE_CUDA
+
+#endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_GPU_CU_H_
diff --git a/tensorflow/core/kernels/whole_file_read_ops.cc b/tensorflow/core/kernels/whole_file_read_ops.cc
index ed2bf3e8e2..1bf46b5e46 100644
--- a/tensorflow/core/kernels/whole_file_read_ops.cc
+++ b/tensorflow/core/kernels/whole_file_read_ops.cc
@@ -134,7 +134,7 @@ class WriteFileOp : public OpKernel {
"Contents tensor must be scalar, but had shape: ",
contents_input->shape().DebugString()));
const string& filename = filename_input->scalar<string>()();
- const string dir = std::string(io::Dirname(filename));
+ const string dir(io::Dirname(filename));
if (!context->env()->FileExists(dir).ok()) {
OP_REQUIRES_OK(context, context->env()->RecursivelyCreateDir(dir));
}
diff --git a/tensorflow/core/kernels/xent_op.h b/tensorflow/core/kernels/xent_op.h
index 87be17fca9..23d3ad39a8 100644
--- a/tensorflow/core/kernels/xent_op.h
+++ b/tensorflow/core/kernels/xent_op.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_XENT_OP_H_
-#define TENSORFLOW_KERNELS_XENT_OP_H_
+#ifndef TENSORFLOW_CORE_KERNELS_XENT_OP_H_
+#define TENSORFLOW_CORE_KERNELS_XENT_OP_H_
// Functor definition for XentOp, must be compilable by nvcc.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
@@ -125,4 +125,4 @@ struct XentEigenImpl {
} // namespace functor
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_XENT_OP_H_
+#endif // TENSORFLOW_CORE_KERNELS_XENT_OP_H_
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index d6f3f26cd5..5c917e80c1 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -61,9 +61,7 @@ struct bfloat16 {
}
B16_DEVICE_FUNC explicit bfloat16(const float v) {
- // TODO(asabne) : change the below line to
- // value = round_to_bfloat16(v).value;
- value = truncate_to_bfloat16(v).value;
+ value = round_to_bfloat16(v).value;
}
B16_DEVICE_FUNC explicit bfloat16(const double val)
diff --git a/tensorflow/core/lib/core/arena.h b/tensorflow/core/lib/core/arena.h
index 5698303247..624ee77027 100644
--- a/tensorflow/core/lib/core/arena.h
+++ b/tensorflow/core/lib/core/arena.h
@@ -15,8 +15,8 @@ limitations under the License.
// TODO(vrv): Switch this to an open-sourced version of Arena.
-#ifndef TENSORFLOW_LIB_CORE_ARENA_H_
-#define TENSORFLOW_LIB_CORE_ARENA_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_ARENA_H_
+#define TENSORFLOW_CORE_LIB_CORE_ARENA_H_
#include <assert.h>
@@ -107,4 +107,4 @@ class Arena {
} // namespace core
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_ARENA_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_ARENA_H_
diff --git a/tensorflow/core/lib/core/bits.h b/tensorflow/core/lib/core/bits.h
index 1110ef5c2a..86e539a266 100644
--- a/tensorflow/core/lib/core/bits.h
+++ b/tensorflow/core/lib/core/bits.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_BITS_H_
-#define TENSORFLOW_LIB_CORE_BITS_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_BITS_H_
+#define TENSORFLOW_CORE_LIB_CORE_BITS_H_
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -106,4 +106,4 @@ inline uint64 NextPowerOfTwo64(uint64 value) {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_BITS_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_BITS_H_
diff --git a/tensorflow/core/lib/core/casts.h b/tensorflow/core/lib/core/casts.h
index 0f925c6051..7546d4edc5 100644
--- a/tensorflow/core/lib/core/casts.h
+++ b/tensorflow/core/lib/core/casts.h
@@ -20,8 +20,8 @@ limitations under the License.
// any changes here, make sure that you're not breaking any platforms.
//
-#ifndef TENSORFLOW_LIB_CORE_CASTS_H_
-#define TENSORFLOW_LIB_CORE_CASTS_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_CASTS_H_
+#define TENSORFLOW_CORE_LIB_CORE_CASTS_H_
#include <string.h> // for memcpy
@@ -97,4 +97,4 @@ inline Dest bit_cast(const Source& source) {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_CASTS_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_CASTS_H_
diff --git a/tensorflow/core/lib/core/coding.h b/tensorflow/core/lib/core/coding.h
index 8265aec870..4a70ffa619 100644
--- a/tensorflow/core/lib/core/coding.h
+++ b/tensorflow/core/lib/core/coding.h
@@ -18,8 +18,8 @@ limitations under the License.
// * In addition we support variable length "varint" encoding
// * Strings are encoded prefixed by their length in varint format
-#ifndef TENSORFLOW_LIB_CORE_CODING_H_
-#define TENSORFLOW_LIB_CORE_CODING_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_CODING_H_
+#define TENSORFLOW_CORE_LIB_CORE_CODING_H_
#include "tensorflow/core/lib/core/raw_coding.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -76,4 +76,4 @@ extern int VarintLength(uint64_t v);
} // namespace core
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_CODING_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_CODING_H_
diff --git a/tensorflow/core/lib/core/errors.h b/tensorflow/core/lib/core/errors.h
index a631d9815a..d5cbe6c616 100644
--- a/tensorflow/core/lib/core/errors.h
+++ b/tensorflow/core/lib/core/errors.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_ERRORS_H_
-#define TENSORFLOW_LIB_CORE_ERRORS_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
+#define TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
#include <sstream>
@@ -131,11 +131,23 @@ inline string FormatNodeNameForError(const string& name) {
// LINT.ThenChange(//tensorflow/python/client/session.py)
template <typename T>
string FormatNodeNamesForError(const T& names) {
- ::tensorflow::str_util::Formatter<string> f(
- [](string* output, const string& s) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
::tensorflow::strings::StrAppend(output, FormatNodeNameForError(s));
});
- return ::tensorflow::str_util::Join(names, ", ", f);
+}
+// LINT.IfChange
+inline string FormatColocationNodeForError(const string& name) {
+ return strings::StrCat("{{colocation_node ", name, "}}");
+}
+// LINT.ThenChange(//tensorflow/python/framework/error_interpolation.py)
+template <typename T>
+string FormatColocationNodeForError(const T& names) {
+ return ::tensorflow::str_util::Join(
+ names, ", ", [](string* output, const string& s) {
+ ::tensorflow::strings::StrAppend(output,
+ FormatColocationNodeForError(s));
+ });
}
// The CanonicalCode() for non-errors.
@@ -144,4 +156,4 @@ using ::tensorflow::error::OK;
} // namespace errors
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_ERRORS_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_ERRORS_H_
diff --git a/tensorflow/core/lib/core/notification.h b/tensorflow/core/lib/core/notification.h
index b3e515e28f..5def958e6b 100644
--- a/tensorflow/core/lib/core/notification.h
+++ b/tensorflow/core/lib/core/notification.h
@@ -13,11 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_NOTIFICATION_H_
-#define TENSORFLOW_UTIL_NOTIFICATION_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_
+#define TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_
// Notification implementation is platform-dependent, to support
// alternative synchronization primitives.
#include "tensorflow/core/platform/notification.h"
-#endif // TENSORFLOW_UTIL_NOTIFICATION_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_NOTIFICATION_H_
diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h
index 37201b755d..f49214939b 100644
--- a/tensorflow/core/lib/core/raw_coding.h
+++ b/tensorflow/core/lib/core/raw_coding.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_RAW_CODING_H_
-#define TENSORFLOW_LIB_CORE_RAW_CODING_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_
+#define TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_
#include <string.h>
#include "tensorflow/core/platform/byte_order.h"
@@ -68,4 +68,4 @@ inline uint64 DecodeFixed64(const char* ptr) {
} // namespace core
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_RAW_CODING_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_RAW_CODING_H_
diff --git a/tensorflow/core/lib/core/status.cc b/tensorflow/core/lib/core/status.cc
index 12dfcd284f..cb2a06e620 100644
--- a/tensorflow/core/lib/core/status.cc
+++ b/tensorflow/core/lib/core/status.cc
@@ -22,7 +22,7 @@ Status::Status(tensorflow::error::Code code, StringPiece msg) {
assert(code != tensorflow::error::OK);
state_ = std::unique_ptr<State>(new State);
state_->code = code;
- state_->msg = msg.ToString();
+ state_->msg = string(msg);
}
void Status::Update(const Status& new_status) {
diff --git a/tensorflow/core/lib/core/status.h b/tensorflow/core/lib/core/status.h
index 49f74ff47f..eb0ff555a5 100644
--- a/tensorflow/core/lib/core/status.h
+++ b/tensorflow/core/lib/core/status.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/core/status_test_util.h b/tensorflow/core/lib/core/status_test_util.h
index b35633c9da..c695caa8d1 100644
--- a/tensorflow/core/lib/core/status_test_util.h
+++ b/tensorflow/core/lib/core/status_test_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
-#define TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
+#define TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/test.h"
@@ -31,4 +31,4 @@ limitations under the License.
// If you want to check for particular errors, a better alternative is:
// EXPECT_EQ(..expected tensorflow::error::Code..., status.code());
-#endif // TENSORFLOW_LIB_CORE_STATUS_TEST_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_STATUS_TEST_UTIL_H_
diff --git a/tensorflow/core/lib/core/stringpiece.cc b/tensorflow/core/lib/core/stringpiece.cc
deleted file mode 100644
index 4c488066e4..0000000000
--- a/tensorflow/core/lib/core/stringpiece.cc
+++ /dev/null
@@ -1,54 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/core/stringpiece.h"
-
-#include <algorithm>
-#include <iostream>
-
-namespace tensorflow {
-
-std::ostream& operator<<(std::ostream& o, StringPiece piece) {
- o.write(piece.data(), piece.size());
- return o;
-}
-
-size_t StringPiece::find(char c, size_t pos) const {
- if (pos >= size_) {
- return npos;
- }
- const char* result =
- reinterpret_cast<const char*>(memchr(data_ + pos, c, size_ - pos));
- return result != nullptr ? result - data_ : npos;
-}
-
-// Search range is [0..pos] inclusive. If pos == npos, search everything.
-size_t StringPiece::rfind(char c, size_t pos) const {
- if (size_ == 0) return npos;
- for (const char* p = data_ + std::min(pos, size_ - 1); p >= data_; p--) {
- if (*p == c) {
- return p - data_;
- }
- }
- return npos;
-}
-
-StringPiece StringPiece::substr(size_t pos, size_t n) const {
- if (pos > size_) pos = size_;
- if (n > size_ - pos) n = size_ - pos;
- return StringPiece(data_ + pos, n);
-}
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/stringpiece.h b/tensorflow/core/lib/core/stringpiece.h
index d7ecc44e50..6edff139ae 100644
--- a/tensorflow/core/lib/core/stringpiece.h
+++ b/tensorflow/core/lib/core/stringpiece.h
@@ -23,129 +23,16 @@ limitations under the License.
// non-const method, all threads accessing the same StringPiece must use
// external synchronization.
-#ifndef TENSORFLOW_LIB_CORE_STRINGPIECE_H_
-#define TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
+#define TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
-#include <assert.h>
-#include <stddef.h>
-#include <string.h>
-#include <iosfwd>
-#include <string>
-#include "tensorflow/core/platform/types.h"
+#include "absl/strings/string_view.h"
namespace tensorflow {
-class StringPiece {
- public:
- typedef size_t size_type;
-
- // Create an empty slice.
- StringPiece() : data_(nullptr), size_(0) {}
-
- // Create a slice that refers to d[0,n-1].
- StringPiece(const char* d, size_t n) : data_(d), size_(n) {}
-
- // Create a slice that refers to the contents of "s"
- StringPiece(const string& s) : data_(s.data()), size_(s.size()) {}
-
- // Create a slice that refers to s[0,strlen(s)-1]
- StringPiece(const char* s) : data_(s), size_(strlen(s)) {}
-
- // Return a pointer to the beginning of the referenced data
- const char* data() const { return data_; }
-
- // Return the length (in bytes) of the referenced data
- size_t size() const { return size_; }
-
- // Return true iff the length of the referenced data is zero
- bool empty() const { return size_ == 0; }
-
- typedef const char* const_iterator;
- typedef const char* iterator;
- iterator begin() const { return data_; }
- iterator end() const { return data_ + size_; }
-
- static const size_t npos = size_type(-1);
-
- // Return the ith byte in the referenced data.
- // REQUIRES: n < size()
- char operator[](size_t n) const {
- assert(n < size());
- return data_[n];
- }
-
- // Drop the first "n" bytes from this slice.
- void remove_prefix(size_t n) {
- assert(n <= size());
- data_ += n;
- size_ -= n;
- }
-
- void remove_suffix(size_t n) {
- assert(size_ >= n);
- size_ -= n;
- }
-
- size_t find(char c, size_t pos = 0) const;
- size_t rfind(char c, size_t pos = npos) const;
-
- StringPiece substr(size_t pos, size_t n = npos) const;
-
- // Return a string that contains the copy of the referenced data.
- // DEPRECATED: use std::string(sv) instead.
- std::string ToString() const { return std::string(data_, size_); }
-
- // Three-way comparison. Returns value:
- // < 0 iff "*this" < "b",
- // == 0 iff "*this" == "b",
- // > 0 iff "*this" > "b"
- int compare(StringPiece b) const;
-
- // Converts to `std::basic_string`.
- template <typename A>
- explicit operator std::basic_string<char, std::char_traits<char>, A>() const {
- if (!data()) return {};
- return std::basic_string<char, std::char_traits<char>, A>(data(), size());
- }
-
- private:
- const char* data_;
- size_t size_;
-
- // Intentionally copyable
-};
-
-inline bool operator==(StringPiece x, StringPiece y) {
- return ((x.size() == y.size()) &&
- (memcmp(x.data(), y.data(), x.size()) == 0));
-}
-
-inline bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
-
-inline bool operator<(StringPiece x, StringPiece y) { return x.compare(y) < 0; }
-inline bool operator>(StringPiece x, StringPiece y) { return x.compare(y) > 0; }
-inline bool operator<=(StringPiece x, StringPiece y) {
- return x.compare(y) <= 0;
-}
-inline bool operator>=(StringPiece x, StringPiece y) {
- return x.compare(y) >= 0;
-}
-
-inline int StringPiece::compare(StringPiece b) const {
- const size_t min_len = (size_ < b.size_) ? size_ : b.size_;
- int r = memcmp(data_, b.data_, min_len);
- if (r == 0) {
- if (size_ < b.size_)
- r = -1;
- else if (size_ > b.size_)
- r = +1;
- }
- return r;
-}
-
-// allow StringPiece to be logged
-extern std::ostream& operator<<(std::ostream& o, tensorflow::StringPiece piece);
+// Deprecated: please use absl::string_view directly.
+using StringPiece = absl::string_view;
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_STRINGPIECE_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_STRINGPIECE_H_
diff --git a/tensorflow/core/lib/core/stringpiece_test.cc b/tensorflow/core/lib/core/stringpiece_test.cc
index 952b9eaaaa..e4b489fe17 100644
--- a/tensorflow/core/lib/core/stringpiece_test.cc
+++ b/tensorflow/core/lib/core/stringpiece_test.cc
@@ -56,8 +56,8 @@ TEST(StringPiece, Ctor) {
}
TEST(StringPiece, ConversionToString) {
- EXPECT_EQ("", std::string(StringPiece("")));
- EXPECT_EQ("foo", std::string(StringPiece("foo")));
+ EXPECT_EQ("", string(StringPiece("")));
+ EXPECT_EQ("foo", string(StringPiece("foo")));
}
} // namespace tensorflow
diff --git a/tensorflow/core/lib/core/threadpool.cc b/tensorflow/core/lib/core/threadpool.cc
index 99684ae47b..9ccd911b0e 100644
--- a/tensorflow/core/lib/core/threadpool.cc
+++ b/tensorflow/core/lib/core/threadpool.cc
@@ -17,6 +17,7 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/lib/core/blocking_counter.h"
#include "tensorflow/core/platform/context.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
@@ -120,6 +121,54 @@ void ThreadPool::Schedule(std::function<void()> fn) {
impl_->Schedule(std::move(fn));
}
+int ThreadPool::NumShardsUsedByTransformRangeConcurrently(
+ const int64 block_size, const int64 total) {
+ if (block_size <= 0 || total <= 1 || total <= block_size ||
+ NumThreads() == 1) {
+ return 1;
+ }
+ return (total + block_size - 1) / block_size;
+}
+
+// This functionality is similar to parallelFor, except that reasoning about
+// the number of shards used is significantly easier.
+void ThreadPool::TransformRangeConcurrently(
+ const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn) {
+ const int num_shards_used =
+ NumShardsUsedByTransformRangeConcurrently(block_size, total);
+ if (num_shards_used == 1) {
+ fn(0, total);
+ return;
+ }
+
+ // Adapted from Eigen's parallelFor implementation.
+ BlockingCounter counter(num_shards_used);
+ std::function<void(int64, int64)> handle_range =
+ [=, &handle_range, &counter, &fn](int64 first, int64 last) {
+ while (last - first > block_size) {
+ // Find something near the midpoint which is a multiple of block size.
+ const int64 mid = first + ((last - first) / 2 + block_size - 1) /
+ block_size * block_size;
+ Schedule([=, &handle_range]() { handle_range(mid, last); });
+ last = mid;
+ }
+ // Single block or less, execute directly.
+ fn(first, last);
+ counter.DecrementCount(); // The shard is done.
+ };
+ if (num_shards_used <= NumThreads()) {
+ // Avoid a thread hop by running the root of the tree and one block on the
+ // main thread.
+ handle_range(0, total);
+ } else {
+ // Execute the root in the thread pool to avoid running work on more than
+ // numThreads() threads.
+ Schedule([=, &handle_range]() { handle_range(0, total); });
+ }
+ counter.Wait();
+}
+
void ThreadPool::ParallelFor(int64 total, int64 cost_per_unit,
std::function<void(int64, int64)> fn) {
impl_->ParallelFor(total, cost_per_unit, std::move(fn));
diff --git a/tensorflow/core/lib/core/threadpool.h b/tensorflow/core/lib/core/threadpool.h
index b89b74b8de..e14ad7ac64 100644
--- a/tensorflow/core/lib/core/threadpool.h
+++ b/tensorflow/core/lib/core/threadpool.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_CORE_THREADPOOL_H_
-#define TENSORFLOW_LIB_CORE_THREADPOOL_H_
+#ifndef TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_
+#define TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_
#include <functional>
#include <memory>
@@ -59,6 +59,20 @@ class ThreadPool {
// Schedules fn() for execution in the pool of threads.
void Schedule(std::function<void()> fn);
+ // Requires 0 < block_size <= total.
+ // Spawns k threads and calls fn(i*block_size, (i+1)*block_size) from the
+ // ith thread (i>=0). When (i+1)*block_size > total, fn(i*block_size, total)
+ // is called instead. k = NumShardsUsedByTransformRangeConcurrently(...).
+ // Note that when there aren't enough threads in the pool to achieve full
+ // parallelism, function calls will be automatically queued.
+ void TransformRangeConcurrently(const int64 block_size, const int64 total,
+ const std::function<void(int64, int64)>& fn);
+
+ // Returns the number of threads spawned by calling TransformRangeConcurrently
+ // with these parameters.
+ int NumShardsUsedByTransformRangeConcurrently(const int64 block_size,
+ const int64 total);
+
// ParallelFor shards the "total" units of work assuming each unit of work
// having roughly "cost_per_unit" cost, in cycles. Each unit of work is
// indexed 0, 1, ..., total - 1. Each shard contains 1 or more units of work
@@ -108,4 +122,4 @@ class ThreadPool {
} // namespace thread
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_CORE_THREADPOOL_H_
+#endif // TENSORFLOW_CORE_LIB_CORE_THREADPOOL_H_
diff --git a/tensorflow/core/lib/core/threadpool_test.cc b/tensorflow/core/lib/core/threadpool_test.cc
index 320f3ebb83..db996b783f 100644
--- a/tensorflow/core/lib/core/threadpool_test.cc
+++ b/tensorflow/core/lib/core/threadpool_test.cc
@@ -61,6 +61,67 @@ TEST(ThreadPool, DoWork) {
}
}
+void RunSharding(int64 block_size, int64 total, ThreadPool* threads) {
+ mutex mu;
+ int64 num_shards = 0;
+ int64 num_done_work = 0;
+ std::vector<bool> work(total, false);
+ threads->TransformRangeConcurrently(
+ block_size, total,
+ [=, &mu, &num_shards, &num_done_work, &work](int64 start, int64 end) {
+ VLOG(1) << "Shard [" << start << "," << end << ")";
+ EXPECT_GE(start, 0);
+ EXPECT_LE(end, total);
+ mutex_lock l(mu);
+ ++num_shards;
+ for (; start < end; ++start) {
+ EXPECT_FALSE(work[start]); // No duplicate
+ ++num_done_work;
+ work[start] = true;
+ }
+ });
+ LOG(INFO) << block_size << " " << total;
+ const int64 num_workers = (total + block_size - 1) / block_size;
+ EXPECT_EQ(num_done_work, total);
+ if (num_workers < threads->NumThreads()) {
+ // If the intention is to limit the parallelism explicitly, we'd
+ // better honor it. Ideally, even if per_thread_max_parallelism >
+ // num_workers, we should expect that Shard() implementation do
+ // not over-shard. Unfortunately, ThreadPoolDevice::parallelFor
+ // tends to over-shard.
+ EXPECT_LE(num_shards, 1 + num_workers);
+ }
+}
+
+// Adapted from work_sharder_test.cc
+TEST(SparseUtilsTest, TransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ for (auto block_size : {1, 7, 10, 64, 100, 256, 1000, 9999}) {
+ for (auto diff : {0, 1, 11, 102, 1003, 10005, 1000007}) {
+ const int64 total = block_size + diff;
+ RunSharding(block_size, total, &threads);
+ }
+ }
+}
+
+TEST(SparseUtilsTest, NumShardsUsedByTransformRangeConcurrently) {
+ ThreadPool threads(Env::Default(), "test", 16);
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 3 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 4 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 5 /* total */));
+ EXPECT_EQ(2, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 6 /* total */));
+ EXPECT_EQ(3, threads.NumShardsUsedByTransformRangeConcurrently(
+ 3 /* block_size */, 7 /* total */));
+ EXPECT_EQ(7, threads.NumShardsUsedByTransformRangeConcurrently(
+ 1 /* block_size */, 7 /* total */));
+ EXPECT_EQ(1, threads.NumShardsUsedByTransformRangeConcurrently(
+ 0 /* block_size */, 7 /* total */));
+}
+
TEST(ThreadPool, ParallelFor) {
Context outer_context(ContextKind::kThread);
// Make ParallelFor use as many threads as possible.
diff --git a/tensorflow/core/lib/gtl/array_slice.h b/tensorflow/core/lib/gtl/array_slice.h
index 002d166c72..8f47faf89e 100644
--- a/tensorflow/core/lib/gtl/array_slice.h
+++ b/tensorflow/core/lib/gtl/array_slice.h
@@ -13,302 +13,24 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An ArraySlice<T> represents an immutable array of elements of type
-// T. It has a length "length", and a base pointer "ptr", and the
-// array it represents contains the elements "ptr[0] .. ptr[len-1]".
-// The backing store for the array is *not* owned by the ArraySlice
-// object, and clients must arrange for the backing store to remain
-// live while the ArraySlice object is in use.
-//
-// An ArraySlice<T> is somewhat analogous to a StringPiece, but for
-// array elements of type T.
-//
-// Implicit conversion operations are provided from types such as
-// std::vector<T> and util::gtl::InlinedVector<T, N>. Note that ArraySlice
-// objects constructed from types in this way may be invalidated by
-// any operations that mutate the underlying vector.
-//
-// One common use for ArraySlice is when passing arguments to a
-// routine where you want to be able to accept a variety of array
-// types (e.g. a vector, a util::gtl::InlinedVector, a C-style array,
-// etc.). The usual approach here is to have the client explicitly
-// pass in a pointer and a length, as in:
-//
-// void MyRoutine(const int* elems, int N) {
-// for (int i = 0; i < N; i++) { .. do something with elems[i] .. }
-// }
-//
-// Unfortunately, this leads to ugly and error-prone code at the call site:
-//
-// std::vector<int> my_vector;
-// MyRoutine(vector_as_array(&my_vector), my_vector.size());
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector.array(), my_inline_vector.size());
-//
-// int my_array[10];
-// MyRoutine(my_array, 10);
-//
-// Instead, you can use an ArraySlice as the argument to the routine:
-//
-// void MyRoutine(ArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. do something with a[i] .. }
-// }
-//
-// This makes the call sites cleaner, for the most part:
-//
-// std::vector<int> my_vector;
-// MyRoutine(my_vector);
-//
-// util::gtl::InlinedVector<int, 4> my_inline_vector;
-// MyRoutine(my_inline_vector);
-//
-// int my_array[10];
-// MyRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyRoutine(gtl::ArraySlice<int>(my_array, 10));
-//
-// MutableArraySlice<T> represents a mutable array of elements, and, like
-// ArraySlice, does not own the backing store. The implicit constructors it
-// provides allow functions not to worry about whether their mutable arguments
-// refer to vectors, arrays, proto2::RepeatedFields, etc.:
-//
-// void MyMutatingRoutine(MutableArraySlice<int> a) {
-// for (int i = 0; i < a.size(); i++) { .. mutate a[i] .. }
-// }
-//
-// std::vector<int> my_vector;
-// MyMutatingRoutine(&my_vector);
-//
-// int my_array[10];
-// MyMutatingRoutine(my_array);
-//
-// int* my_array = new int[10];
-// MyMutatingRoutine(gtl::MutableArraySlice<int>(my_array, 10));
-//
-// MyProto my_proto;
-// for (int i = 0; i < 10; ++i) { my_proto.add_value(i); }
-// MyMutatingRoutine(my_proto.mutable_value());
+#ifndef TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
+#define TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
-#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
-#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
-
-#include <initializer_list>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/array_slice_internal.h"
+#include "absl/types/span.h"
+// TODO(timshen): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
namespace tensorflow {
namespace gtl {
template <typename T>
-class ArraySlice {
- private:
- typedef array_slice_internal::ArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- ArraySlice() : impl_(nullptr, 0) {}
- ArraySlice(const_pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- ArraySlice(const std::vector<value_type>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- template <size_t N>
- ArraySlice(const value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- ArraySlice(const InlinedVector<value_type, N>& v) // NOLINT(runtime/explicit)
- : impl_(v.data(), v.size()) {}
-
- // The constructor for any class supplying 'data() const' that returns either
- // const T* or a less const-qualified version of it, and 'some_integral_type
- // size() const'. proto2::RepeatedField<T>, string and (since C++11)
- // std::vector<T,A> and std::array<T, N> are examples of this. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- ArraySlice(const V& v) // NOLINT(runtime/explicit)
- : impl_(v) {}
-
- // Implicitly constructs an ArraySlice from an initializer list. This makes it
- // possible to pass a brace-enclosed initializer list to a function expecting
- // an ArraySlice:
- // void Process(ArraySlice<int> x);
- // Process({1, 2, 3});
- // The data referenced by the initializer_list must outlive this
- // ArraySlice. For example, "ArraySlice<int> s={1,2};" and "return
- // ArraySlice<int>({3,4});" are errors, as the resulting ArraySlice may
- // reference data that is no longer valid.
- ArraySlice(std::initializer_list<value_type> v) // NOLINT(runtime/explicit)
- : impl_(v.begin(), v.size()) {}
+using ArraySlice = absl::Span<const T>;
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- ArraySlice(const ArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- const_pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- const_reference operator[](size_type i) const { return impl_[i]; }
- const_reference at(size_type i) const { return impl_.at(i); }
- const_reference front() const { return impl_.front(); }
- const_reference back() const { return impl_.back(); }
-
- const_iterator begin() const { return impl_.begin(); }
- const_iterator end() const { return impl_.end(); }
- const_reverse_iterator rbegin() const { return impl_.rbegin(); }
- const_reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
- void pop_back() { remove_suffix(1); }
- void pop_front() { remove_prefix(1); }
-
- // These relational operators have the same semantics as the
- // std::vector<T> relational operators: they do deep (element-wise)
- // comparisons. Array slices are equal iff their size is the same
- // and all their elements are equal.
- bool operator==(ArraySlice<T> other) const { return impl_ == other.impl_; }
- bool operator!=(ArraySlice<T> other) const { return impl_ != other.impl_; }
-
- private:
- Impl impl_;
-};
-
-// Mutable version of ArraySlice, which allows the clients to mutate the
-// underlying data. It is implicitly convertible to ArraySlice since it provides
-// the data() and size() methods with correct signatures. When a
-// MutableArraySlice is created from a pointer to a container (as opposed to raw
-// memory pointer), the pointer must not be null.
-//
-// A note on const-ness: "mutable" here refers to the mutability of the
-// underlying data, not of the slice itself. It is perfectly reasonable to have
-// a variable of type "const MutableArraySlice<T>"; this means that the bounds
-// of the view on the array cannot be changed, but the underlying data in the
-// array still may be modified. This is akin to a "T* const" pointer, as opposed
-// to a "const T*" pointer (corresponding to a non-const ArraySlice<T>).
-template <typename T>
-class MutableArraySlice {
- private:
- typedef array_slice_internal::MutableArraySliceImpl<T> Impl;
-
- public:
- typedef T value_type;
- typedef typename Impl::pointer pointer;
- typedef typename Impl::const_pointer const_pointer;
- typedef typename Impl::reference reference;
- typedef typename Impl::const_reference const_reference;
- typedef typename Impl::iterator iterator;
- typedef typename Impl::const_iterator const_iterator;
- typedef typename Impl::reverse_iterator reverse_iterator;
- typedef typename Impl::const_reverse_iterator const_reverse_iterator;
- typedef typename Impl::size_type size_type;
- typedef typename Impl::difference_type difference_type;
-
- static const size_type npos = Impl::npos;
-
- MutableArraySlice() : impl_(nullptr, 0) {}
- MutableArraySlice(pointer array, size_type length) : impl_(array, length) {}
-
- // Implicit conversion constructors
- MutableArraySlice(std::vector<value_type>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- template <size_t N>
- MutableArraySlice(value_type (&a)[N]) // NOLINT(runtime/explicit)
- : impl_(a, N) {}
-
- template <int N>
- MutableArraySlice(
- InlinedVector<value_type, N>* v) // NOLINT(runtime/explicit)
- : impl_(v->data(), v->size()) {}
-
- // The constructor for any class supplying 'T* data()' or 'T* mutable_data()'
- // (the former is called if both exist), and 'some_integral_type size()
- // const'. proto2::RepeatedField is an example of this. Also supports string
- // arguments, when T==char. The appropriate ctor is selected using SFINAE. See
- // array_slice_internal.h for details.
- template <typename V,
- typename = typename Impl::template EnableIfConvertibleFrom<V>>
- MutableArraySlice(V* v) // NOLINT(runtime/explicit)
- : impl_(v) {}
-
- // Substring of another MutableArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- // If len==npos, the substring continues till the end of x.
- MutableArraySlice(const MutableArraySlice& x, size_type pos, size_type len)
- : impl_(x.impl_, pos, len) {}
-
- // Accessors.
- pointer data() const { return impl_.data(); }
- size_type size() const { return impl_.size(); }
- size_type length() const { return size(); }
- bool empty() const { return size() == 0; }
-
- void clear() { impl_.clear(); }
-
- reference operator[](size_type i) const { return impl_[i]; }
- reference at(size_type i) const { return impl_.at(i); }
- reference front() const { return impl_.front(); }
- reference back() const { return impl_.back(); }
-
- iterator begin() const { return impl_.begin(); }
- iterator end() const { return impl_.end(); }
- reverse_iterator rbegin() const { return impl_.rbegin(); }
- reverse_iterator rend() const { return impl_.rend(); }
-
- void remove_prefix(size_type n) { impl_.remove_prefix(n); }
- void remove_suffix(size_type n) { impl_.remove_suffix(n); }
- void pop_back() { remove_suffix(1); }
- void pop_front() { remove_prefix(1); }
-
- bool operator==(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) == other;
- }
- bool operator!=(ArraySlice<T> other) const {
- return ArraySlice<T>(*this) != other;
- }
-
- // DEPRECATED(jacobsa): Please use data() instead.
- pointer mutable_data() const { return impl_.data(); }
-
- private:
- Impl impl_;
-};
-
-template <typename T>
-const typename ArraySlice<T>::size_type ArraySlice<T>::npos;
template <typename T>
-const typename MutableArraySlice<T>::size_type MutableArraySlice<T>::npos;
+using MutableArraySlice = absl::Span<T>;
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_ARRAY_SLICE_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_internal.h b/tensorflow/core/lib/gtl/array_slice_internal.h
deleted file mode 100644
index 689dd8a646..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_internal.h
+++ /dev/null
@@ -1,269 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-// NOT FOR INCLUSION BY CLIENT CODE. This file is only to be included by
-// array_slice.h.
-
-// Helper functions and templates for ArraySlice.
-
-#ifndef TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-#define TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
-
-#include <stddef.h>
-#include <algorithm>
-#include <iterator>
-#include <memory>
-#include <string>
-#include <type_traits>
-#include <utility>
-#include <vector>
-#include "tensorflow/core/platform/logging.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace array_slice_internal {
-
-// Template logic for generic constructors.
-
-// Wrappers whose Get() delegates to the appropriate method of a container, and
-// is defined when this method exists. Delegates to the const method if C is a
-// const type.
-struct Data {
- template <typename C>
- static decltype(std::declval<C>().data()) Get(C* v) {
- return v->data();
- }
-};
-
-struct MutableData {
- template <typename C>
- static decltype(std::declval<C>().mutable_data()) Get(C* v) {
- return v->mutable_data();
- }
-};
-
-struct Size {
- template <typename C>
- static decltype(std::declval<C>().size()) Get(C* v) {
- return v->size();
- }
-};
-
-struct MutableStringData {
- // Defined only for string.
- static char* Get(string* v) { return v->empty() ? nullptr : &*v->begin(); }
-};
-
-// Checks whether M::Get(C*) is defined and has a return type R such that
-// Checker::valid<R>()==true.
-template <typename M, typename Checker, typename C>
-struct HasGetHelper : public M {
- private:
- struct None {};
- // M::Get is selected when it is viable. Get(...) is selected otherwise.
- using M::Get;
- static None Get(...);
-
- public:
- static constexpr bool HasGet() {
- using Result = decltype(Get(std::declval<C*>()));
- return !std::is_same<Result, None>() && Checker::template valid<Result>();
- }
-};
-
-// Defines HasGet() for a particular method, container, and checker. If
-// HasGet()==true, provides Get() that delegates to the method.
-template <typename M, typename Checker, typename C,
- bool /*has_get*/ = HasGetHelper<M, Checker, C>::HasGet()>
-struct Wrapper {
- static constexpr bool HasGet() { return false; }
-};
-
-template <typename M, typename Checker, typename C>
-struct Wrapper<M, Checker, C, true> {
- static constexpr bool HasGet() { return true; }
- static decltype(M::Get(std::declval<C*>())) Get(C* v) { return M::Get(v); }
-};
-
-// Type checker for a method returning an integral value.
-struct SizeChecker {
- template <typename R>
- static constexpr bool valid() {
- return std::is_integral<R>::value;
- }
-};
-
-// Type checker for a method returning either a pointer to T or a less const
-// version of that.
-template <typename T>
-struct DataChecker {
- // We want to enable conversion from std::vector<T*> to ArraySlice<const T*>
- // but
- // disable conversion from std::vector<Derived> to ArraySlice<Base>. Here we
- // use
- // the fact that U** is convertible to Q* const* if and only if Q is the same
- // type or a more cv-qualified version of U.
- template <typename R>
- static constexpr bool valid() {
- return std::is_convertible<R*, T* const*>::value;
- }
-};
-
-// Aliases to A if A::HasGet()==true, or to B otherwise.
-template <typename A, typename B>
-using FirstWithGet = typename std::conditional<A::HasGet(), A, B>::type;
-
-// Wraps C::data() const, returning a pointer to const data.
-template <typename T, typename C>
-using ContainerData = Wrapper<Data, DataChecker<const T>, const C>;
-
-// Wraps a method returning a pointer to mutable data. Prefers data() over
-// mutable_data(), and handles strings when T==char. If data() returns a pointer
-// to mutable data, it is most likely overloaded, but may also be a single
-// method 'T* C::data() const' in a non-STL-compliant container.
-template <typename T, typename C>
-using ContainerMutableData =
- FirstWithGet<Wrapper<Data, DataChecker<T>, C>,
- FirstWithGet<Wrapper<MutableData, DataChecker<T>, C>,
- Wrapper<MutableStringData, DataChecker<T>, C>>>;
-
-// Wraps C::size() const.
-template <typename C>
-using ContainerSize = Wrapper<Size, SizeChecker, const C>;
-
-// Implementation class for ArraySlice and MutableArraySlice. In the case of
-// ArraySlice, T will be a const type; for MutableArraySlice, T will be a
-// mutable type.
-template <typename T>
-class ArraySliceImplBase {
- public:
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
- typedef std::reverse_iterator<iterator> reverse_iterator;
- typedef std::reverse_iterator<const_iterator> const_reverse_iterator;
- typedef size_t size_type;
- typedef ptrdiff_t difference_type;
-
- static const size_type npos = static_cast<size_type>(-1);
-
- ArraySliceImplBase(pointer array, size_type length)
- : ptr_(array), length_(length) {}
-
- // Substring of another ArraySlice.
- // pos must be non-negative and <= x.length().
- // len must be non-negative and will be pinned to at most x.length() - pos.
- ArraySliceImplBase(const ArraySliceImplBase& x, size_type pos, size_type len)
- : ptr_(x.ptr_ + pos), length_(std::min(x.length_ - pos, len)) {}
-
- // Some of the const methods below return pointers and references to mutable
- // data. This is only the case in this internal class; ArraySlice and
- // MutableArraySlice provide deep-constness.
-
- pointer data() const { return ptr_; }
- size_type size() const { return length_; }
-
- void clear() {
- ptr_ = nullptr;
- length_ = 0;
- }
-
- reference operator[](size_type i) const { return ptr_[i]; }
- reference at(size_type i) const {
- DCHECK_LT(i, length_);
- return ptr_[i];
- }
- reference front() const {
- DCHECK_GT(length_, 0);
- return ptr_[0];
- }
- reference back() const {
- DCHECK_GT(length_, 0);
- return ptr_[length_ - 1];
- }
-
- void remove_prefix(size_type n) {
- DCHECK_GE(length_, n);
- ptr_ += n;
- length_ -= n;
- }
- void remove_suffix(size_type n) {
- DCHECK_GE(length_, n);
- length_ -= n;
- }
-
- iterator begin() const { return ptr_; }
- iterator end() const { return ptr_ + length_; }
- reverse_iterator rbegin() const { return reverse_iterator(end()); }
- reverse_iterator rend() const { return reverse_iterator(begin()); }
-
- bool operator==(const ArraySliceImplBase& other) const {
- if (size() != other.size()) return false;
- if (data() == other.data()) return true;
- return std::equal(data(), data() + size(), other.data());
- }
- bool operator!=(const ArraySliceImplBase& other) const {
- return !(*this == other);
- }
-
- private:
- pointer ptr_;
- size_type length_;
-};
-
-template <typename T>
-class ArraySliceImpl : public ArraySliceImplBase<const T> {
- public:
- using ArraySliceImplBase<const T>::ArraySliceImplBase;
-
- // Defined iff the data and size accessors for the container C have been
- // defined.
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- // Constructs from a container when EnableIfConvertibleFrom is
- // defined. std::addressof handles types with overloaded operator&.
- template <typename C>
- explicit ArraySliceImpl(const C& v)
- : ArraySliceImplBase<const T>(ContainerData<T, C>::Get(std::addressof(v)),
- ContainerSize<C>::Get(std::addressof(v))) {}
-};
-
-template <typename T>
-class MutableArraySliceImpl : public ArraySliceImplBase<T> {
- public:
- using ArraySliceImplBase<T>::ArraySliceImplBase;
-
- template <typename C>
- using EnableIfConvertibleFrom =
- typename std::enable_if<ContainerMutableData<T, C>::HasGet() &&
- ContainerSize<C>::HasGet()>::type;
-
- template <typename C>
- explicit MutableArraySliceImpl(C* v)
- : ArraySliceImplBase<T>(ContainerMutableData<T, C>::Get(v),
- ContainerSize<C>::Get(v)) {}
-};
-
-} // namespace array_slice_internal
-} // namespace gtl
-} // namespace tensorflow
-
-#endif // TENSORFLOW_LIB_GTL_ARRAY_SLICE_INTERNAL_H_
diff --git a/tensorflow/core/lib/gtl/array_slice_test.cc b/tensorflow/core/lib/gtl/array_slice_test.cc
deleted file mode 100644
index 4d3da85b88..0000000000
--- a/tensorflow/core/lib/gtl/array_slice_test.cc
+++ /dev/null
@@ -1,666 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/array_slice.h"
-
-#include <algorithm>
-#include <array>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-#include "tensorflow/core/lib/gtl/stl_util.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace gtl {
-namespace {
-
-typedef ArraySlice<int> IntSlice;
-typedef ArraySlice<char> CharSlice;
-typedef MutableArraySlice<int> MutableIntSlice;
-typedef MutableArraySlice<char> MutableCharSlice;
-typedef std::vector<int> IntVec;
-
-// Append 0..len-1 to *v
-template <typename Vector>
-static void Fill(Vector* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static void TestHelper(const IntSlice& vorig, const IntVec& vec) {
- IntSlice other; // To test the assignment return value.
- IntSlice v = other = vorig;
- const int len = vec.size();
- EXPECT_EQ(v.size(), vec.size());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(v[i], vec[i]);
- EXPECT_EQ(v.at(i), vec[i]);
- }
- EXPECT_EQ(v.begin(), gtl::vector_as_array(&vec));
-
- int counter = 0;
- for (IntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(counter, *it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- if (len > 1) {
- v.pop_front();
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i + 1, v[i]);
- }
- }
- }
-}
-
-// The element access test that is applicable both when MutableArraySlice is
-// const and when it's not.
-template <class V>
-void MutableTestHelperTemplated(V v, int* ptr, const int len) {
- CHECK_EQ(v.size(), len);
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(ptr + i, &v[i]);
- EXPECT_EQ(ptr + i, &v.at(i));
- }
- EXPECT_EQ(ptr, v.begin());
- EXPECT_EQ(ptr + len, v.end());
- EXPECT_EQ(ptr, v.data());
-
- int counter = 0;
- for (MutableIntSlice::const_iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- EXPECT_EQ(len, std::distance(v.rbegin(), v.rend()));
-
- if (len > 0) {
- EXPECT_EQ(ptr, &v.front());
- EXPECT_EQ(ptr + len - 1, &v.back());
- EXPECT_EQ(ptr + len - 1, &*v.rbegin());
- EXPECT_EQ(ptr, &*(v.rend() - 1));
- }
-}
-
-static void MutableTestHelper(const MutableIntSlice& vorig, int* ptr,
- const int len) {
- // Test the data accessors both when the MutableArraySlice is declared const,
- // and when it is not.
- MutableTestHelperTemplated<const MutableIntSlice&>(vorig, ptr, len);
- MutableTestHelperTemplated<MutableIntSlice>(vorig, ptr, len);
-
- MutableIntSlice other; // To test the assignment return value.
- MutableIntSlice v = other = vorig;
- EXPECT_EQ(ptr, v.mutable_data());
-
- int counter = 0;
- for (MutableIntSlice::iterator it = v.begin(); it != v.end(); ++it) {
- EXPECT_EQ(ptr + counter, &*it);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- // Test that elements are assignable.
- v[0] = 1;
- v.front() = 2;
- v.back() = 5;
- *v.mutable_data() = 4;
- std::fill(v.begin(), v.end(), 5);
- std::fill(v.rbegin(), v.rend(), 6);
- // Test size-changing methods.
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i, &v[i]);
- }
- if (len > 1) {
- v.pop_front();
- EXPECT_EQ(len - 2, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(ptr + i + 1, &v[i]);
- }
- }
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const IntSlice& v, const Vector& vec) {
- EXPECT_EQ(v.size(), vec.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(v[i], vec[i]);
- }
-}
-
-template <typename Vector>
-static void TestImplicitConversion(const CharSlice& v, const Vector& vec) {
- TestImplicitConversion(IntVec(v.begin(), v.end()), vec);
-}
-
-static void TestImplicitConversion(const MutableIntSlice& v, const int* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-
-static void TestImplicitConversion(const MutableCharSlice& v, const char* data,
- int size) {
- EXPECT_EQ(size, v.size());
- for (size_t i = 0; i < v.size(); i++) {
- EXPECT_EQ(data + i, &v[i]);
- }
-}
-// A struct supplying the data(), mutable_data() and size() methods, just like
-// e.g. proto2::RepeatedField.
-struct RepeatedField {
- std::vector<int> storage;
- const int* data() const { return storage.data(); }
- int* mutable_data() { return storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying the data() (both mutable and const versions) and
-// size(). It also supplies mutable_data() but we test that data() is selected
-// instead.
-struct ContainerWithOverloads {
- std::vector<int> storage;
- std::vector<int> wrong_storage;
- const int* data() const { return storage.data(); }
- int* data() { return storage.data(); }
- // MutableArraySlice should not call mutable_data(), preferring data()
- // instead.
- int* mutable_data() { return wrong_storage.data(); }
- int size() const { return storage.size(); }
-};
-
-// A struct supplying data() and size() methods.
-struct ContainerWithShallowConstData {
- std::vector<int> storage;
- int* data() const { return const_cast<int*>(storage.data()); }
- int size() const { return storage.size(); }
-};
-
-TEST(IntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- TestHelper(IntSlice(vec), vec);
- TestHelper(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, WithPosAndLen) {
- IntVec vec;
- Fill(&vec, 20);
- for (size_t len = 0; len < vec.size(); len++) {
- IntVec subvec(vec.begin(), vec.begin() + len);
- TestImplicitConversion(IntSlice(vec, 0, len), subvec);
- TestImplicitConversion(IntSlice(IntSlice(vec), 0, len), subvec);
- }
- EXPECT_EQ(0, IntSlice(vec, 0, 0).size());
- EXPECT_EQ(0, IntSlice(IntSlice(vec), 0, 0).size());
- TestImplicitConversion(IntSlice(vec, 0, IntSlice::npos), vec);
-}
-
-TEST(IntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice v(vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec, bvec;
- Fill(&avec, l1);
- Fill(&bvec, l2, 100);
- IntSlice a(avec), b(bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(IntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec;
- Fill(&vec, len);
- IntSlice slice;
- slice = vec;
- TestImplicitConversion(vec, vec);
- TestImplicitConversion(slice, vec);
- TestImplicitConversion(IntSlice(vec.data(), vec.size()), vec);
- }
-}
-
-TEST(IntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- IntVec vec;
- Fill(&vec, len);
- IntSlice v = inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(inline_vec, vec);
- }
-}
-
-TEST(IntSlice, StaticArrayConversion) {
- int array[20];
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(array));
- std::copy(vec.begin(), vec.end(), array);
- IntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, vec);
-}
-
-TEST(IntSlice, StdArrayConversion) {
- std::array<int, 20> array;
- IntVec vec;
- Fill(&vec, array.size());
- std::copy(vec.begin(), vec.end(), array.begin());
-
- // Check assignment.
- {
- IntSlice v = array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- IntSlice v = {array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(array, vec);
-}
-
-// Values according to the Fill function.
-static const int test_const_array[] = {0, 1, 2};
-
-TEST(IntSlice, ConstStaticArrayConversion) {
- IntVec vec;
- Fill(&vec, TF_ARRAYSIZE(test_const_array));
- IntSlice v = test_const_array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(test_const_array, vec);
-}
-
-TEST(IntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- IntVec vec;
- Fill(&vec, 20);
- repeated_field.storage = vec;
- IntSlice v = repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(repeated_field, vec);
-}
-
-TEST(IntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- IntSlice v = container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(container, container.storage);
-}
-
-TEST(IntSlice, MutableIntSliceConversion) {
- IntVec vec(20);
- IntSlice slice = MutableIntSlice(&vec);
- EXPECT_EQ(vec.size(), slice.size());
- EXPECT_EQ(vec.data(), slice.data());
-}
-
-TEST(IntSlice, Equality) {
- IntVec vec1(20);
- IntVec vec2(20);
- // These two slices are from different vectors, but have the same
- // size and have the same elements (right now). They should
- // compare equal.
- const IntSlice from1(vec1);
- const IntSlice from2(vec2);
- EXPECT_EQ(from1, from1);
- EXPECT_EQ(from1, from2);
-
- // This verifies that MutableArraySlices can be compared freely with
- // ArraySlices.
- const MutableIntSlice mutable_from1(&vec1);
- const MutableIntSlice mutable_from2(&vec2);
- EXPECT_EQ(from1, mutable_from1);
- EXPECT_EQ(mutable_from1, from1);
- EXPECT_EQ(mutable_from1, mutable_from2);
- EXPECT_EQ(mutable_from2, mutable_from1);
-
- // With a different size, the array slices should not be equal.
- EXPECT_NE(from1, IntSlice(from1, 0, from1.size() - 1));
-
- // With different contents, the array slices should not be equal.
- ++vec2.back();
- EXPECT_NE(from1, from2);
-}
-
-// Compile-asserts that the argument has the expected type.
-template <typename Expected, typename T>
-void CheckType(const T& value) {
- ::testing::StaticAssertTypeEq<Expected, T>();
-}
-
-TEST(IntSlice, ExposesContainerTypesAndConsts) {
- IntSlice slice;
- const IntSlice const_slice;
- CheckType<IntSlice::iterator>(slice.begin());
- CheckType<IntSlice::const_iterator>(const_slice.end());
- CheckType<IntSlice::const_reverse_iterator>(const_slice.rbegin());
- CheckType<IntSlice::reverse_iterator>(slice.rend());
- ::testing::StaticAssertTypeEq<int, IntSlice::value_type>();
- ::testing::StaticAssertTypeEq<const int*, IntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int&, IntSlice::const_reference>();
- EXPECT_EQ(static_cast<IntSlice::size_type>(-1), IntSlice::npos);
-}
-
-void TestEmpty(IntSlice slice) { ASSERT_TRUE(slice.empty()); }
-
-void TestRange(IntSlice slice, int from, int to) {
- ASSERT_EQ(to - from + 1, slice.size());
- for (size_t i = 0; i < slice.size(); ++i) {
- EXPECT_EQ(from + i, slice[i]);
- }
-}
-
-TEST(IntSlice, InitializerListConversion) {
- TestEmpty({});
- TestRange({1}, 1, 1);
- TestRange({10, 11, 12, 13}, 10, 13);
-}
-
-TEST(CharSlice, StringConversion) {
- IntVec vec;
- Fill(&vec, 20);
- string str(vec.begin(), vec.end());
- CharSlice v = str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(str, vec);
-}
-
-TEST(IntPtrSlice, ConstConversion) {
- int one = 1;
- int two = 2;
- std::vector<int*> vec;
- vec.push_back(&one);
- vec.push_back(&two);
- ArraySlice<const int*> v = vec;
- ASSERT_EQ(2, v.size());
- EXPECT_EQ(&one, v[0]);
- EXPECT_EQ(&two, v[1]);
-}
-
-TEST(MutableIntSlice, Simple) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableTestHelper(MutableIntSlice(&vec), vec.data(), len);
- MutableTestHelper(MutableIntSlice(vec.data(), vec.size()), vec.data(), len);
- }
-}
-
-TEST(MutableIntSlice, WithPosAndLen) {
- IntVec vec(20);
- for (size_t len = 0; len < vec.size(); len++) {
- TestImplicitConversion(MutableIntSlice(&vec, 0, len), vec.data(), len);
- TestImplicitConversion(MutableIntSlice(MutableIntSlice(&vec), 0, len),
- vec.data(), len);
- }
- EXPECT_EQ(0, MutableIntSlice(&vec, 0, 0).size());
- EXPECT_EQ(0, MutableIntSlice(MutableIntSlice(&vec), 0, 0).size());
- TestImplicitConversion(MutableIntSlice(&vec, 0, MutableIntSlice::npos),
- vec.data(), vec.size());
-}
-
-TEST(MutableIntSlice, Clear) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice v(&vec);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(MutableIntSlice, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- IntVec avec(l1), bvec(l2);
- MutableIntSlice a(&avec), b(&bvec);
- using std::swap;
- swap(a, b);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(&avec[i], &b[i]);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(&bvec[i], &a[i]);
- }
- }
- }
-}
-
-TEST(MutableIntSlice, ImplicitConversion) {
- for (int len = 0; len < 20; len++) {
- IntVec vec(len);
- MutableIntSlice slice;
- slice = &vec;
- TestImplicitConversion(&vec, vec.data(), len);
- TestImplicitConversion(slice, vec.data(), len);
- TestImplicitConversion(MutableIntSlice(vec.data(), vec.size()), vec.data(),
- len);
- }
-}
-
-TEST(MutableIntSlice, InlinedVectorConversion) {
- for (int len = 0; len < 20; len++) {
- InlinedVector<int, 4> inline_vec;
- for (int i = 0; i < len; i++) {
- inline_vec.push_back(i);
- }
- MutableIntSlice v = &inline_vec; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&inline_vec, inline_vec.data(), inline_vec.size());
- }
-}
-
-TEST(MutableIntSlice, StaticArrayConversion) {
- int array[20];
- MutableIntSlice v = array; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(array, array, TF_ARRAYSIZE(array));
-}
-
-TEST(MutableIntSlice, StdArrayConversion) {
- std::array<int, 20> array;
-
- // Check assignment.
- {
- MutableIntSlice v = &array;
- static_cast<void>(v);
- }
-
- // Check sub-slice initialization.
- {
- MutableIntSlice v = {&array, 10, 15};
- static_cast<void>(v);
- }
-
- TestImplicitConversion(&array, &array[0], array.size());
-}
-
-TEST(MutableIntSlice, RepeatedFieldConversion) {
- RepeatedField repeated_field;
- Fill(&repeated_field.storage, 20);
- MutableIntSlice v = &repeated_field; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&repeated_field, repeated_field.storage.data(),
- repeated_field.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithOverloadsConversion) {
- ContainerWithOverloads container;
- Fill(&container.storage, 20);
- container.wrong_storage.resize(container.size());
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, ContainerWithShallowConstDataConversion) {
- ContainerWithShallowConstData container;
- Fill(&container.storage, 20);
- MutableIntSlice v = &container; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(&container, container.storage.data(),
- container.storage.size());
-}
-
-TEST(MutableIntSlice, TypedefsAndConstants) {
- ::testing::StaticAssertTypeEq<int, MutableIntSlice::value_type>();
- ::testing::StaticAssertTypeEq<int*, MutableIntSlice::pointer>();
- ::testing::StaticAssertTypeEq<const int*, MutableIntSlice::const_pointer>();
- ::testing::StaticAssertTypeEq<int&, MutableIntSlice::reference>();
- ::testing::StaticAssertTypeEq<const int&, MutableIntSlice::const_reference>();
-
- EXPECT_EQ(static_cast<MutableIntSlice::size_type>(-1), MutableIntSlice::npos);
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_pointer(s.mutable_data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-TEST(MutableIntSlice, IteratorsAndReferences_Const) {
- auto accept_pointer = [](int* x) {};
- auto accept_reference = [](int& x) {};
- auto accept_iterator = [](MutableIntSlice::iterator x) {};
- auto accept_reverse_iterator = [](MutableIntSlice::reverse_iterator x) {};
-
- int a[1];
- const MutableIntSlice s = a;
-
- accept_pointer(s.data());
- accept_pointer(s.mutable_data());
- accept_iterator(s.begin());
- accept_iterator(s.end());
- accept_reverse_iterator(s.rbegin());
- accept_reverse_iterator(s.rend());
-
- accept_reference(s[0]);
- accept_reference(s.at(0));
- accept_reference(s.front());
- accept_reference(s.back());
-}
-
-bool TestMutableOverload(MutableIntSlice slice) { return false; }
-
-bool TestMutableOverload(MutableCharSlice slice) { return true; }
-
-TEST(MutableCharSlice, StringConversion) {
- for (int len = 0; len < 20; len++) {
- string str(len, '\0');
- MutableCharSlice v = &str; // Test assignment
- static_cast<void>(v);
- TestImplicitConversion(v, str.data(), str.size());
- }
- // Verify that only the correct overload is feasible. Note that this would
- // fail if the string ctor was declared simply as MutableArraySlice(string*),
- // since in that case both overloads would be feasible.
- string str;
- EXPECT_TRUE(TestMutableOverload(&str));
-
- // Avoid warning "unused function 'TestMutableOverload'"
- int a[1];
- EXPECT_FALSE(TestMutableOverload(a));
-}
-
-} // namespace
-} // namespace gtl
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/cleanup.h b/tensorflow/core/lib/gtl/cleanup.h
index 6bd60ca482..8c73dc6aa9 100644
--- a/tensorflow/core/lib/gtl/cleanup.h
+++ b/tensorflow/core/lib/gtl/cleanup.h
@@ -39,8 +39,8 @@ limitations under the License.
//
// You can call 'release()' on a Cleanup object to cancel the cleanup.
-#ifndef TENSORFLOW_LIB_GTL_CLEANUP_H_
-#define TENSORFLOW_LIB_GTL_CLEANUP_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_
+#define TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_
#include <type_traits>
#include <utility>
@@ -110,4 +110,4 @@ TF_MUST_USE_RESULT Cleanup<DecayF> MakeCleanup(F&& f) {
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_CLEANUP_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_CLEANUP_H_
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index 2011f7d4a1..2d622dc229 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -13,676 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// An InlinedVector<T,N,A> is like a std::vector<T,A>, except that storage
-// for sequences of length <= N are provided inline without requiring
-// any heap allocation. Typically N is very small (e.g., 4) so that
-// sequences that are expected to be short do not require allocations.
-//
-// Only some of the std::vector<> operations are currently implemented.
-// Other operations may be added as needed to facilitate migrating
-// code that uses std::vector<> to InlinedVector<>.
-//
-// NOTE: If you want an inlined version to replace use of a
-// std::vector<bool>, consider using util::bitmap::InlinedBitVector<NBITS>
-// in util/bitmap/inlined_bitvector.h
-//
-// TODO(billydonahue): change size_t to size_type where appropriate.
+#ifndef TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
+#define TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
-#ifndef TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
-#define TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
-
-#include <stddef.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
-#include <algorithm>
-#include <cstddef>
-#include <iterator>
-#include <memory>
-#include <type_traits>
-#include <vector>
-
-#include "tensorflow/core/lib/gtl/manual_constructor.h"
-#include "tensorflow/core/platform/byte_order.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mem.h"
+#include "absl/container/inlined_vector.h"
+// TODO(kramerb): This is kept only because lots of targets transitively depend
+// on it. Remove all targets' dependencies.
+#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
-#include <initializer_list> // NOLINT(build/include_order)
-
namespace tensorflow {
namespace gtl {
-template <typename T, int N>
-class InlinedVector {
- public:
- typedef T value_type;
- typedef T* pointer;
- typedef const T* const_pointer;
- typedef T& reference;
- typedef const T& const_reference;
- typedef size_t size_type;
- typedef std::ptrdiff_t difference_type;
- typedef pointer iterator;
- typedef const_pointer const_iterator;
-
- // Create an empty vector
- InlinedVector();
-
- // Create a vector with n copies of value_type().
- explicit InlinedVector(size_t n);
-
- // Create a vector with n copies of elem
- InlinedVector(size_t n, const value_type& elem);
-
- // Create and initialize with the elements [range_start .. range_end).
- // The unused enable_if argument restricts this constructor so that it is
- // elided when value_type is an integral type. This prevents ambiguous
- // interpretation between a call to this constructor with two integral
- // arguments and a call to the preceding (n, elem) constructor.
- template <typename InputIterator>
- InlinedVector(
- InputIterator range_start, InputIterator range_end,
- typename std::enable_if<!std::is_integral<InputIterator>::value>::type* =
- NULL) {
- InitRep();
- AppendRange(range_start, range_end);
- }
-
- InlinedVector(std::initializer_list<value_type> init) {
- InitRep();
- AppendRange(init.begin(), init.end());
- }
-
- InlinedVector(const InlinedVector& v);
-
- ~InlinedVector() { clear(); }
-
- InlinedVector& operator=(const InlinedVector& v) {
- // Optimized to avoid reallocation.
- // Prefer reassignment to copy construction for elements.
- const size_t s = size();
- const size_t vs = v.size();
- if (s < vs) { // grow
- reserve(vs);
- if (s) std::copy(v.begin(), v.begin() + s, begin());
- std::copy(v.begin() + s, v.end(), std::back_inserter(*this));
- } else { // maybe shrink
- erase(begin() + vs, end());
- std::copy(v.begin(), v.end(), begin());
- }
- return *this;
- }
-
- size_t size() const { return size_internal(); }
-
- bool empty() const { return (size() == 0); }
-
- // Return number of elements that can be stored in vector
- // without requiring a reallocation of underlying memory
- size_t capacity() const {
- if (is_inline()) {
- return kFit;
- } else {
- return static_cast<size_t>(1) << u_.data[kSize - 2];
- }
- }
-
- // Return a pointer to the underlying array.
- // Only result[0,size()-1] are defined.
- pointer data() {
- if (is_inline()) {
- return reinterpret_cast<T*>(u_.data);
- } else {
- return outofline_pointer();
- }
- }
- const_pointer data() const {
- return const_cast<InlinedVector<T, N>*>(this)->data();
- }
-
- // Remove all elements
- void clear() {
- DiscardStorage();
- u_.data[kSize - 1] = 0;
- }
-
- // Return the ith element
- // REQUIRES: 0 <= i < size()
- const value_type& at(size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
- const value_type& operator[](size_t i) const {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- // Return a non-const reference to the ith element
- // REQUIRES: 0 <= i < size()
- value_type& at(size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
- value_type& operator[](size_t i) {
- DCHECK_LT(i, size());
- return data()[i];
- }
-
- value_type& back() {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- const value_type& back() const {
- DCHECK(!empty());
- return at(size() - 1);
- }
-
- value_type& front() {
- DCHECK(!empty());
- return at(0);
- }
-
- const value_type& front() const {
- DCHECK(!empty());
- return at(0);
- }
-
- // Append a T constructed with args to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- template <typename... Args>
- void emplace_back(Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
- if (s < capacity()) {
- new (data() + s) T(std::forward<Args>(args)...);
- set_size_internal(s + 1);
- } else {
- EmplaceBackSlow(std::forward<Args>(args)...);
- }
- }
-
- // Append t to the vector.
- // Increases size() by one.
- // Amortized complexity: O(1)
- // Worst-case complexity: O(size())
- void push_back(const value_type& t) { emplace_back(t); }
- void push_back(value_type&& t) { emplace_back(std::move(t)); }
-
- inline void pop_back() {
- DCHECK(!empty());
- const size_t s = size();
- Destroy(data() + s - 1, 1);
- set_size_internal(s - 1);
- }
-
- // Resizes the vector to contain "n" elements.
- // If "n" is smaller than the initial size, extra elements are destroyed.
- // If "n" is larger than the initial size, enough copies of "elem"
- // are appended to increase the size to "n". If "elem" is omitted,
- // new elements are value-initialized.
- void resize(size_t n) { Resize<ValueInit>(n, nullptr); }
- void resize(size_t n, const value_type& elem) { Resize<Fill>(n, &elem); }
-
- iterator begin() { return data(); }
- const_iterator begin() const { return data(); }
-
- iterator end() { return data() + size(); }
- const_iterator end() const { return data() + size(); }
-
- iterator insert(iterator pos, const value_type& v);
-
- iterator erase(iterator pos) {
- DCHECK_LT(pos, end());
- DCHECK_GE(pos, begin());
- std::copy(pos + 1, end(), pos);
- pop_back();
- return pos;
- }
-
- iterator erase(iterator first, iterator last);
-
- // Enlarges the underlying representation so it can hold at least
- // "n" elements without reallocation.
- // Does not change size() or the actual contents of the vector.
- void reserve(size_t n) {
- if (n > capacity()) {
- // Make room for new elements
- Grow<Move>(n);
- }
- }
-
- // Swap the contents of *this with other.
- // REQUIRES: value_type is swappable and copyable.
- void swap(InlinedVector& other);
-
- private:
- // Representation can either be inlined or out-of-line.
- // In either case, at least sizeof(void*) + 8 bytes are available.
- //
- // Inlined:
- // Last byte holds the length.
- // First (length*sizeof(T)) bytes stores the elements.
- // Outlined:
- // Last byte holds kSentinel.
- // Second-last byte holds lg(capacity)
- // Preceding 6 bytes hold size.
- // First sizeof(T*) bytes hold pointer.
-
- // Compute rep size.
- static const size_t kSizeUnaligned = N * sizeof(T) + 1; // Room for tag
- static const size_t kSize = ((kSizeUnaligned + 15) / 16) * 16; // Align
-
- // See how many fit T we can fit inside kSize, but no more than 254
- // since 255 is used as sentinel tag for out-of-line allocation.
- static const unsigned int kSentinel = 255;
- static const size_t kFit1 = (kSize - 1) / sizeof(T);
- static const size_t kFit = (kFit1 >= kSentinel) ? (kSentinel - 1) : kFit1;
-
- union {
- unsigned char data[kSize];
- // Force data to be aligned enough for a pointer.
- T* unused_aligner;
- } u_;
-
- inline void InitRep() { u_.data[kSize - 1] = 0; }
- inline bool is_inline() const { return u_.data[kSize - 1] != kSentinel; }
-
- inline T* outofline_pointer() const {
- T* ptr;
- memcpy(&ptr, &u_.data[0], sizeof(ptr));
- return ptr;
- }
-
- inline void set_outofline_pointer(T* p) {
- memcpy(&u_.data[0], &p, sizeof(p));
- }
-
- inline uint64_t outofline_word() const {
- uint64_t word;
- memcpy(&word, &u_.data[kSize - 8], sizeof(word));
- return word;
- }
-
- inline void set_outofline_word(uint64_t w) {
- memcpy(&u_.data[kSize - 8], &w, sizeof(w));
- }
-
- inline size_t size_internal() const {
- uint8_t s = static_cast<uint8_t>(u_.data[kSize - 1]);
- if (s != kSentinel) {
- return static_cast<size_t>(s);
- } else {
- const uint64_t word = outofline_word();
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- return static_cast<size_t>(word & 0xffffffffffffull);
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- return static_cast<size_t>(word >> 16);
- }
- }
- }
-
- void set_size_internal(size_t n) {
- if (is_inline()) {
- DCHECK_LT(n, kSentinel);
- u_.data[kSize - 1] = static_cast<unsigned char>(n);
- } else {
- uint64_t word;
- if (port::kLittleEndian) {
- // The sentinel and capacity bits are most-significant bits in word.
- word = (static_cast<uint64_t>(n) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 48) |
- (static_cast<uint64_t>(kSentinel) << 56));
- } else {
- // The sentinel and capacity bits are least-significant bits in word.
- word = ((static_cast<uint64_t>(n) << 16) |
- (static_cast<uint64_t>(u_.data[kSize - 2]) << 8) |
- (static_cast<uint64_t>(kSentinel)));
- }
- set_outofline_word(word);
- DCHECK_EQ(u_.data[kSize - 1], kSentinel) << n;
- }
- }
-
- void DiscardStorage() {
- T* base = data();
- size_t n = size();
- Destroy(base, n);
- if (!is_inline()) {
- port::Free(base);
- }
- }
-
- template <typename... Args>
- void EmplaceBackSlow(Args&&... args) {
- const size_t s = size();
- DCHECK_EQ(s, capacity());
- Grow<Move, Construct>(s + 1, std::forward<Args>(args)...);
- set_size_internal(s + 1);
- }
-
- // Movers for Grow
- // Does nothing.
- static void Nop(T* src, size_t n, T* dst) {}
-
- // Moves srcs[0,n-1] contents to dst[0,n-1].
- static void Move(T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(std::move(*(src + i)));
- }
- }
-
- // Initializers for Resize.
- // Initializes dst[0,n-1] with empty constructor.
- static void ValueInit(const T*, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T();
- }
- }
-
- // Initializes dst[0,n-1] with copies of *src.
- static void Fill(const T* src, size_t n, T* dst) {
- for (size_t i = 0; i < n; i++) {
- new (dst + i) T(*src);
- }
- }
-
- void Destroy(T* src, int n) {
- if (!std::is_trivially_destructible<T>::value) {
- for (int i = 0; i < n; i++) {
- (src + i)->~T();
- }
- }
- }
-
- // Initialization methods for Grow.
- // 1) Leave uninitialized memory.
- struct Uninitialized {
- void operator()(T*) const {}
- };
- // 2) Construct a T with args at not-yet-initialized memory pointed by dst.
- struct Construct {
- template <class... Args>
- void operator()(T* dst, Args&&... args) const {
- new (dst) T(std::forward<Args>(args)...);
- }
- };
-
- // Grow so that capacity >= n. Uses Mover to move existing elements
- // to new buffer, and possibly initialize the new element according
- // to InitType.
- // We pass the InitType and Mover as template arguments so that
- // this code compiles even if T does not support copying or default
- // construction.
- template <void(Mover)(T*, size_t, T*), class InitType = Uninitialized,
- class... Args>
- void Grow(size_t n, Args&&... args) {
- size_t s = size();
- DCHECK_LE(s, capacity());
-
- // Compute new capacity by repeatedly doubling current capacity
- size_t target = 1;
- size_t target_lg = 0;
- while (target < kFit || target < n) {
- // TODO(psrc): Check and avoid overflow?
- target_lg++;
- target <<= 1;
- }
-
- T* src = data();
- T* dst = static_cast<T*>(port::Malloc(target * sizeof(T)));
-
- // Need to copy elem before discarding src since it might alias src.
- InitType{}(dst + s, std::forward<Args>(args)...);
- Mover(src, s, dst);
- DiscardStorage();
-
- u_.data[kSize - 1] = kSentinel;
- u_.data[kSize - 2] = static_cast<unsigned char>(target_lg);
- set_size_internal(s);
- DCHECK_EQ(capacity(), target);
- set_outofline_pointer(dst);
- }
-
- // Resize to size n. Any new elements are initialized by passing
- // elem and the destination to Initializer. We pass the Initializer
- // as a template argument so that this code compiles even if T does
- // not support copying.
- template <void(Initializer)(const T*, size_t, T*)>
- void Resize(size_t n, const T* elem) {
- size_t s = size();
- if (n <= s) {
- Destroy(data() + n, s - n);
- set_size_internal(n);
- return;
- }
- reserve(n);
- DCHECK_GE(capacity(), n);
- set_size_internal(n);
- Initializer(elem, n - s, data() + s);
- }
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::input_iterator_tag);
-
- // Faster path for forward iterators.
- template <typename Iter>
- void AppendRange(Iter first, Iter last, std::forward_iterator_tag);
-
- template <typename Iter>
- void AppendRange(Iter first, Iter last);
-};
-
-// Provide linkage for constants.
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSizeUnaligned;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kSize;
-template <typename T, int N>
-const unsigned int InlinedVector<T, N>::kSentinel;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit1;
-template <typename T, int N>
-const size_t InlinedVector<T, N>::kFit;
-
-template <typename T, int N>
-inline void swap(InlinedVector<T, N>& a, InlinedVector<T, N>& b) {
- a.swap(b);
-}
-
-template <typename T, int N>
-inline bool operator==(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin());
-}
-
-template <typename T, int N>
-inline bool operator!=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a == b);
-}
-
-template <typename T, int N>
-inline bool operator<(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return std::lexicographical_compare(a.begin(), a.end(), b.begin(), b.end());
-}
-
-template <typename T, int N>
-inline bool operator>(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return b < a;
-}
-
-template <typename T, int N>
-inline bool operator<=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(b < a);
-}
-
-template <typename T, int N>
-inline bool operator>=(const InlinedVector<T, N>& a,
- const InlinedVector<T, N>& b) {
- return !(a < b);
-}
-
-// ========================================
-// Implementation
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector() {
- InitRep();
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Must use Nop in case T is not copyable
- }
- set_size_internal(n);
- ValueInit(nullptr, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(size_t n, const value_type& elem) {
- InitRep();
- if (n > capacity()) {
- Grow<Nop>(n); // Can use Nop since we know we have nothing to copy
- }
- set_size_internal(n);
- Fill(&elem, n, data());
-}
-
-template <typename T, int N>
-inline InlinedVector<T, N>::InlinedVector(const InlinedVector& v) {
- InitRep();
- *this = v;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::insert(
- iterator pos, const value_type& v) {
- DCHECK_GE(pos, begin());
- DCHECK_LE(pos, end());
- if (pos == end()) {
- push_back(v);
- return end() - 1;
- }
- size_t s = size();
- size_t idx = std::distance(begin(), pos);
- if (s == capacity()) {
- Grow<Move>(s + 1);
- }
- CHECK_LT(s, capacity());
- pos = begin() + idx; // Reset 'pos' into a post-enlarge iterator.
- Fill(data() + s - 1, 1, data() + s); // data[s] = data[s-1]
- std::copy_backward(pos, data() + s - 1, data() + s);
- *pos = v;
-
- set_size_internal(s + 1);
- return pos;
-}
-
-template <typename T, int N>
-typename InlinedVector<T, N>::iterator InlinedVector<T, N>::erase(
- iterator first, iterator last) {
- DCHECK_LE(begin(), first);
- DCHECK_LE(first, last);
- DCHECK_LE(last, end());
-
- size_t s = size();
- ptrdiff_t erase_gap = std::distance(first, last);
- std::copy(last, data() + s, first);
- Destroy(data() + s - erase_gap, erase_gap);
- set_size_internal(s - erase_gap);
- return first;
-}
-
-template <typename T, int N>
-void InlinedVector<T, N>::swap(InlinedVector& other) {
- using std::swap; // Augment ADL with std::swap.
- if (&other == this) {
- return;
- }
-
- InlinedVector* a = this;
- InlinedVector* b = &other;
-
- const bool a_inline = a->is_inline();
- const bool b_inline = b->is_inline();
-
- if (!a_inline && !b_inline) {
- // Just swap the top-level representations.
- T* aptr = a->outofline_pointer();
- T* bptr = b->outofline_pointer();
- a->set_outofline_pointer(bptr);
- b->set_outofline_pointer(aptr);
-
- uint64_t aword = a->outofline_word();
- uint64_t bword = b->outofline_word();
- a->set_outofline_word(bword);
- b->set_outofline_word(aword);
- return;
- }
-
- // Make a the larger of the two to reduce number of cases.
- size_t a_size = a->size();
- size_t b_size = b->size();
- if (a->size() < b->size()) {
- swap(a, b);
- swap(a_size, b_size);
- }
- DCHECK_GE(a_size, b_size);
-
- if (b->capacity() < a_size) {
- b->Grow<Move>(a_size);
- }
-
- // One is inline and one is not.
- // 'a' is larger. Swap the elements up to the smaller array size.
- std::swap_ranges(a->data(), a->data() + b_size, b->data());
- std::uninitialized_copy(a->data() + b_size, a->data() + a_size,
- b->data() + b_size);
- Destroy(a->data() + b_size, a_size - b_size);
- a->set_size_internal(b_size);
- b->set_size_internal(a_size);
- DCHECK_EQ(b->size(), a_size);
- DCHECK_EQ(a->size(), b_size);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::input_iterator_tag) {
- std::copy(first, last, std::back_inserter(*this));
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last,
- std::forward_iterator_tag) {
- typedef typename std::iterator_traits<Iter>::difference_type Length;
- Length length = std::distance(first, last);
- size_t s = size();
- reserve(s + length);
- std::uninitialized_copy_n(first, length, data() + s);
- set_size_internal(s + length);
-}
-
-template <typename T, int N>
-template <typename Iter>
-inline void InlinedVector<T, N>::AppendRange(Iter first, Iter last) {
- typedef typename std::iterator_traits<Iter>::iterator_category IterTag;
- AppendRange(first, last, IterTag());
-}
+using absl::InlinedVector;
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_INLINED_VECTOR_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_INLINED_VECTOR_H_
diff --git a/tensorflow/core/lib/gtl/inlined_vector_test.cc b/tensorflow/core/lib/gtl/inlined_vector_test.cc
deleted file mode 100644
index 2721885c4a..0000000000
--- a/tensorflow/core/lib/gtl/inlined_vector_test.cc
+++ /dev/null
@@ -1,898 +0,0 @@
-/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
-
-#include <list>
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/test_benchmark.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-
-typedef tensorflow::gtl::InlinedVector<int, 8> IntVec;
-
-// A type that counts number of live occurrences of the type
-static int64 instances = 0;
-class Instance {
- public:
- int value_;
- explicit Instance(int x) : value_(x) { instances++; }
- Instance(const Instance& x) : value_(x.value_) { instances++; }
- ~Instance() { instances--; }
-
- friend inline void swap(Instance& a, Instance& b) {
- using std::swap;
- swap(a.value_, b.value_);
- }
-
- friend std::ostream& operator<<(std::ostream& o, const Instance& v) {
- return o << "[value:" << v.value_ << "]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Instance, 8> InstanceVec;
-
-// A simple reference counted class to make sure that the proper elements are
-// destroyed in the erase(begin, end) test.
-class RefCounted {
- public:
- RefCounted(int value, int* count) : value_(value), count_(count) { Ref(); }
-
- RefCounted(const RefCounted& v) : value_(v.value_), count_(v.count_) {
- VLOG(5) << "[RefCounted: copy"
- << " from count @" << v.count_ << "]";
- Ref();
- }
-
- ~RefCounted() {
- Unref();
- count_ = nullptr;
- }
-
- friend void swap(RefCounted& a, RefCounted& b) {
- using std::swap;
- swap(a.value_, b.value_);
- swap(a.count_, b.count_);
- }
-
- RefCounted& operator=(RefCounted v) {
- using std::swap;
- swap(*this, v);
- return *this;
- }
-
- void Ref() const {
- CHECK(count_ != nullptr);
- ++(*count_);
- VLOG(5) << "[Ref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- void Unref() const {
- --(*count_);
- CHECK_GE(*count_, 0);
- VLOG(5) << "[Unref: refcount " << *count_ << " on count @" << count_ << "]";
- }
-
- int count() const { return *count_; }
-
- friend std::ostream& operator<<(std::ostream& o, const RefCounted& v) {
- return o << "[value:" << v.value_ << ", count:" << *v.count_ << "]";
- }
-
- int value_;
- int* count_;
-};
-
-typedef tensorflow::gtl::InlinedVector<RefCounted, 8> RefCountedVec;
-
-// A class with a vtable pointer
-class Dynamic {
- public:
- virtual ~Dynamic() {}
-
- friend std::ostream& operator<<(std::ostream& o, const Dynamic& v) {
- return o << "[Dynamic]";
- }
-};
-
-typedef tensorflow::gtl::InlinedVector<Dynamic, 8> DynamicVec;
-
-// Append 0..len-1 to *v
-static void Fill(IntVec* v, int len, int offset = 0) {
- for (int i = 0; i < len; i++) {
- v->push_back(i + offset);
- }
-}
-
-static IntVec Fill(int len, int offset = 0) {
- IntVec v;
- Fill(&v, len, offset);
- return v;
-}
-
-TEST(IntVec, SimpleOps) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- const IntVec& cv = v; // const alias
-
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- EXPECT_EQ(v.begin(), v.data());
- EXPECT_EQ(cv.begin(), cv.data());
-
- int counter = 0;
- for (IntVec::iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- counter = 0;
- for (IntVec::const_iterator iter = v.begin(); iter != v.end(); ++iter) {
- EXPECT_EQ(counter, *iter);
- counter++;
- }
- EXPECT_EQ(counter, len);
-
- if (len > 0) {
- EXPECT_EQ(0, v.front());
- EXPECT_EQ(len - 1, v.back());
- v.pop_back();
- EXPECT_EQ(len - 1, v.size());
- for (size_t i = 0; i < v.size(); ++i) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, Erase) {
- for (int len = 1; len < 20; len++) {
- for (int i = 0; i < len; ++i) {
- IntVec v;
- Fill(&v, len);
- v.erase(v.begin() + i);
- EXPECT_EQ(len - 1, v.size());
- for (int j = 0; j < i; ++j) {
- EXPECT_EQ(j, v[j]);
- }
- for (int j = i; j < len - 1; ++j) {
- EXPECT_EQ(j + 1, v[j]);
- }
- }
- }
-}
-
-// At the end of this test loop, the elements between [erase_begin, erase_end)
-// should have reference counts == 0, and all others elements should have
-// reference counts == 1.
-TEST(RefCountedVec, EraseBeginEnd) {
- for (int len = 1; len < 20; ++len) {
- for (int erase_begin = 0; erase_begin < len; ++erase_begin) {
- for (int erase_end = erase_begin; erase_end <= len; ++erase_end) {
- std::vector<int> counts(len, 0);
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- int erase_len = erase_end - erase_begin;
-
- v.erase(v.begin() + erase_begin, v.begin() + erase_end);
-
- EXPECT_EQ(len - erase_len, v.size());
-
- // Check the elements before the first element erased.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(i, v[i].value_);
- }
-
- // Check the elements after the first element erased.
- for (size_t i = erase_begin; i < v.size(); ++i) {
- EXPECT_EQ(i + erase_len, v[i].value_);
- }
-
- // Check that the elements at the beginning are preserved.
- for (int i = 0; i < erase_begin; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
-
- // Check that the erased elements are destroyed
- for (int i = erase_begin; i < erase_end; ++i) {
- EXPECT_EQ(0, counts[i]);
- }
-
- // Check that the elements at the end are preserved.
- for (int i = erase_end; i < len; ++i) {
- EXPECT_EQ(1, counts[i]);
- }
- }
- }
- }
-}
-
-struct NoDefaultCtor {
- explicit NoDefaultCtor(int) {}
-};
-struct NoCopy {
- NoCopy() {}
- NoCopy(const NoCopy&) = delete;
-};
-struct NoAssign {
- NoAssign() {}
- NoAssign& operator=(const NoAssign&) = delete;
-};
-struct MoveOnly {
- MoveOnly() {}
- MoveOnly(MoveOnly&&) = default;
- MoveOnly& operator=(MoveOnly&&) = default;
-};
-TEST(InlinedVectorTest, NoDefaultCtor) {
- tensorflow::gtl::InlinedVector<NoDefaultCtor, 1> v(10, NoDefaultCtor(2));
- (void)v;
-}
-TEST(InlinedVectorTest, NoCopy) {
- tensorflow::gtl::InlinedVector<NoCopy, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, NoAssign) {
- tensorflow::gtl::InlinedVector<NoAssign, 1> v(10);
- (void)v;
-}
-TEST(InlinedVectorTest, MoveOnly) {
- gtl::InlinedVector<MoveOnly, 2> v;
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
- v.push_back(MoveOnly{});
-}
-
-TEST(IntVec, Insert) {
- for (int len = 0; len < 20; len++) {
- for (int pos = 0; pos <= len; pos++) {
- IntVec v;
- Fill(&v, len);
- v.insert(v.begin() + pos, 9999);
- EXPECT_EQ(v.size(), len + 1);
- for (int i = 0; i < pos; i++) {
- EXPECT_EQ(v[i], i);
- }
- EXPECT_EQ(v[pos], 9999);
- for (size_t i = pos + 1; i < v.size(); i++) {
- EXPECT_EQ(v[i], i - 1);
- }
- }
- }
-}
-
-TEST(RefCountedVec, InsertConstructorDestructor) {
- // Make sure the proper construction/destruction happen during insert
- // operations.
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- for (int pos = 0; pos <= len; pos++) {
- SCOPED_TRACE(pos);
- std::vector<int> counts(len, 0);
- int inserted_count = 0;
- RefCountedVec v;
- for (int i = 0; i < len; ++i) {
- SCOPED_TRACE(i);
- v.push_back(RefCounted(i, &counts[i]));
- }
-
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
-
- RefCounted insert_element(9999, &inserted_count);
- EXPECT_EQ(1, inserted_count);
- v.insert(v.begin() + pos, insert_element);
- EXPECT_EQ(2, inserted_count);
- // Check that the elements at the end are preserved.
- for (auto elem : counts) {
- EXPECT_EQ(1, elem);
- }
- EXPECT_EQ(2, inserted_count);
- }
- }
-}
-
-TEST(IntVec, Resize) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- // Try resizing up and down by k elements
- static const int kResizeElem = 1000000;
- for (int k = 0; k < 10; k++) {
- // Enlarging resize
- v.resize(len + k, kResizeElem);
- EXPECT_EQ(len + k, v.size());
- EXPECT_LE(len + k, v.capacity());
- for (int i = 0; i < len + k; i++) {
- if (i < len) {
- EXPECT_EQ(i, v[i]);
- } else {
- EXPECT_EQ(kResizeElem, v[i]);
- }
- }
-
- // Shrinking resize
- v.resize(len, kResizeElem);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(i, v[i]);
- }
- }
- }
-}
-
-TEST(IntVec, InitWithLength) {
- for (int len = 0; len < 20; len++) {
- IntVec v(len, 7);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
- for (int i = 0; i < len; i++) {
- EXPECT_EQ(7, v[i]);
- }
- }
-}
-
-TEST(IntVec, CopyConstructorAndAssignment) {
- for (int len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
- EXPECT_EQ(len, v.size());
- EXPECT_LE(len, v.capacity());
-
- IntVec v2(v);
- EXPECT_EQ(v, v2);
-
- for (int start_len = 0; start_len < 20; start_len++) {
- IntVec v3;
- Fill(&v3, start_len, 99); // Add dummy elements that should go away
- v3 = v;
- EXPECT_EQ(v, v3);
- }
- }
-}
-
-TEST(OverheadTest, Storage) {
- // Check for size overhead.
- using tensorflow::gtl::InlinedVector;
- EXPECT_EQ(2 * sizeof(int*), sizeof(InlinedVector<int*, 1>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 2>));
- EXPECT_EQ(4 * sizeof(int*), sizeof(InlinedVector<int*, 3>));
- EXPECT_EQ(6 * sizeof(int*), sizeof(InlinedVector<int*, 4>));
-
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 1>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 2>));
- EXPECT_EQ(2 * sizeof(char*), sizeof(InlinedVector<char, 3>));
- EXPECT_EQ(2 * sizeof(char*),
- sizeof(InlinedVector<char, 2 * sizeof(char*) - 1>));
- EXPECT_EQ(4 * sizeof(char*), sizeof(InlinedVector<char, 2 * sizeof(char*)>));
-}
-
-TEST(IntVec, Clear) {
- for (int len = 0; len < 20; len++) {
- SCOPED_TRACE(len);
- IntVec v;
- Fill(&v, len);
- v.clear();
- EXPECT_EQ(0, v.size());
- EXPECT_EQ(v.begin(), v.end());
- }
-}
-
-TEST(IntVec, Reserve) {
- for (size_t len = 0; len < 20; len++) {
- IntVec v;
- Fill(&v, len);
-
- for (size_t newlen = 0; newlen < 100; newlen++) {
- const int* start_rep = v.data();
- v.reserve(newlen);
- const int* final_rep = v.data();
- if (newlen <= len) {
- EXPECT_EQ(start_rep, final_rep);
- }
- EXPECT_LE(newlen, v.capacity());
-
- // Filling up to newlen should not change rep
- while (v.size() < newlen) {
- v.push_back(0);
- }
- EXPECT_EQ(final_rep, v.data());
- }
- }
-}
-
-template <typename T>
-static std::vector<typename T::value_type> Vec(const T& src) {
- std::vector<typename T::value_type> result;
- for (const auto& elem : src) {
- result.push_back(elem);
- }
- return result;
-}
-
-TEST(IntVec, SelfRefPushBack) {
- std::vector<string> std_v;
- tensorflow::gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(std_v, Vec(v));
-
- v.push_back(v.back());
- std_v.push_back(std_v.back());
- }
- EXPECT_EQ(std_v, Vec(v));
-}
-
-TEST(IntVec, SelfRefPushBackWithMove) {
- std::vector<string> std_v;
- gtl::InlinedVector<string, 4> v;
- const string s = "A quite long string to ensure heap.";
- std_v.push_back(s);
- v.push_back(s);
- for (int i = 0; i < 20; ++i) {
- EXPECT_EQ(v.back(), std_v.back());
-
- v.push_back(std::move(v.back()));
- std_v.push_back(std::move(std_v.back()));
- }
- EXPECT_EQ(v.back(), std_v.back());
-}
-
-TEST(IntVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- SCOPED_TRACE(l1);
- for (int l2 = 0; l2 < 20; l2++) {
- SCOPED_TRACE(l2);
- IntVec a = Fill(l1, 0);
- IntVec b = Fill(l2, 100);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(i, b[i]);
- }
- for (int i = 0; i < l2; i++) {
- SCOPED_TRACE(i);
- EXPECT_EQ(100 + i, a[i]);
- }
- }
- }
-}
-
-TEST(InstanceVec, Swap) {
- for (int l1 = 0; l1 < 20; l1++) {
- for (int l2 = 0; l2 < 20; l2++) {
- InstanceVec a, b;
- for (int i = 0; i < l1; i++) a.push_back(Instance(i));
- for (int i = 0; i < l2; i++) b.push_back(Instance(100 + i));
- EXPECT_EQ(l1 + l2, instances);
- {
- using std::swap;
- swap(a, b);
- }
- EXPECT_EQ(l1 + l2, instances);
- EXPECT_EQ(l1, b.size());
- EXPECT_EQ(l2, a.size());
- for (int i = 0; i < l1; i++) {
- EXPECT_EQ(i, b[i].value_);
- }
- for (int i = 0; i < l2; i++) {
- EXPECT_EQ(100 + i, a[i].value_);
- }
- }
- }
-}
-
-TEST(IntVec, EqualAndNotEqual) {
- IntVec a, b;
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- a.push_back(3);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b.push_back(3);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b.push_back(7);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.push_back(6);
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- a.clear();
- b.clear();
- for (int i = 0; i < 100; i++) {
- a.push_back(i);
- b.push_back(i);
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
-
- b[i] = b[i] + 1;
- EXPECT_FALSE(a == b);
- EXPECT_TRUE(a != b);
-
- b[i] = b[i] - 1; // Back to before
- EXPECT_TRUE(a == b);
- EXPECT_FALSE(a != b);
- }
-}
-
-TEST(IntVec, RelationalOps) {
- IntVec a, b;
- EXPECT_FALSE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_FALSE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_TRUE(b <= a);
- EXPECT_TRUE(a >= b);
- EXPECT_TRUE(b >= a);
- b.push_back(3);
- EXPECT_TRUE(a < b);
- EXPECT_FALSE(b < a);
- EXPECT_FALSE(a > b);
- EXPECT_TRUE(b > a);
- EXPECT_TRUE(a <= b);
- EXPECT_FALSE(b <= a);
- EXPECT_FALSE(a >= b);
- EXPECT_TRUE(b >= a);
-}
-
-TEST(InstanceVec, CountConstructorsDestructors) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- InstanceVec v;
- for (int i = 0; i < len; i++) {
- v.push_back(Instance(i));
- }
- EXPECT_EQ(start + len, instances);
-
- { // Copy constructor should create 'len' more instances.
- InstanceVec v_copy(v);
- EXPECT_EQ(start + len + len, instances);
- }
- EXPECT_EQ(start + len, instances);
-
- // Enlarging resize() must construct some objects
- v.resize(len + 10, Instance(100));
- EXPECT_EQ(start + len + 10, instances);
-
- // Shrinking resize() must destroy some objects
- v.resize(len, Instance(100));
- EXPECT_EQ(start + len, instances);
-
- // reserve() must not increase the number of initialized objects
- v.reserve(len + 1000);
- EXPECT_EQ(start + len, instances);
-
- // pop_back() and erase() must destroy one object
- if (len > 0) {
- v.pop_back();
- EXPECT_EQ(start + len - 1, instances);
- if (!v.empty()) {
- v.erase(v.begin());
- EXPECT_EQ(start + len - 2, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(InstanceVec, CountConstructorsDestructorsOnAssignment) {
- const int start = instances;
- for (int len = 0; len < 20; len++) {
- for (int longorshort = 0; longorshort <= 1; ++longorshort) {
- InstanceVec longer, shorter;
- for (int i = 0; i < len; i++) {
- longer.push_back(Instance(i));
- shorter.push_back(Instance(i));
- }
- longer.push_back(Instance(len));
- EXPECT_EQ(start + len + len + 1, instances);
-
- if (longorshort) {
- shorter = longer;
- EXPECT_EQ(start + (len + 1) + (len + 1), instances);
- } else {
- longer = shorter;
- EXPECT_EQ(start + len + len, instances);
- }
- }
- }
- EXPECT_EQ(start, instances);
-}
-
-TEST(RangedConstructor, SimpleType) {
- std::vector<int> source_v = {4, 5, 6, 7};
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<int, 4> v(source_v.begin(), source_v.end());
- tensorflow::gtl::InlinedVector<int, 4> empty4;
- EXPECT_EQ(4, v.size());
- EXPECT_EQ(empty4.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(4, v[0]);
- EXPECT_EQ(5, v[1]);
- EXPECT_EQ(6, v[2]);
- EXPECT_EQ(7, v[3]);
-
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<int, 2> realloc_v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<int, 2> empty2;
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty2.capacity(), realloc_v.capacity());
- EXPECT_EQ(4, realloc_v[0]);
- EXPECT_EQ(5, realloc_v[1]);
- EXPECT_EQ(6, realloc_v[2]);
- EXPECT_EQ(7, realloc_v[3]);
-}
-
-TEST(RangedConstructor, ComplexType) {
- // We also use a list here to pass a different flavor of iterator (e.g. not
- // random-access).
- std::list<Instance> source_v = {Instance(0)};
-
- // First try to fit in inline backing
- tensorflow::gtl::InlinedVector<Instance, 1> v(source_v.begin(),
- source_v.end());
- tensorflow::gtl::InlinedVector<Instance, 1> empty1;
- EXPECT_EQ(1, v.size());
- EXPECT_EQ(empty1.capacity(), v.capacity()); // Must still be inline
- EXPECT_EQ(0, v[0].value_);
-
- std::list<Instance> source_v2 = {Instance(0), Instance(1), Instance(2),
- Instance(3)};
- // Now, force a re-allocate
- tensorflow::gtl::InlinedVector<Instance, 1> realloc_v(source_v2.begin(),
- source_v2.end());
- EXPECT_EQ(4, realloc_v.size());
- EXPECT_LT(empty1.capacity(), realloc_v.capacity());
- EXPECT_EQ(0, realloc_v[0].value_);
- EXPECT_EQ(1, realloc_v[1].value_);
- EXPECT_EQ(2, realloc_v[2].value_);
- EXPECT_EQ(3, realloc_v[3].value_);
-}
-
-TEST(RangedConstructor, ElementsAreConstructed) {
- std::vector<string> source_v = {"cat", "dog"};
-
- // Force expansion and re-allocation of v. Ensures that when the vector is
- // expanded that new elements are constructed.
- tensorflow::gtl::InlinedVector<string, 1> v(source_v.begin(), source_v.end());
- EXPECT_EQ("cat", v[0]);
- EXPECT_EQ("dog", v[1]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithInlineBacking) {
- auto vec = tensorflow::gtl::InlinedVector<int, 3>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_EQ(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, SimpleTypeWithReallocationRequired) {
- auto vec = tensorflow::gtl::InlinedVector<int, 2>{4, 5, 6};
- EXPECT_EQ(3, vec.size());
- EXPECT_LE(3, vec.capacity());
- EXPECT_EQ(4, vec[0]);
- EXPECT_EQ(5, vec[1]);
- EXPECT_EQ(6, vec[2]);
-}
-
-TEST(InitializerListConstructor, DisparateTypesInList) {
- EXPECT_EQ((std::vector<int>{-7, 8}),
- Vec(tensorflow::gtl::InlinedVector<int, 2>{-7, 8ULL}));
-
- EXPECT_EQ(
- (std::vector<string>{"foo", "bar"}),
- Vec(tensorflow::gtl::InlinedVector<string, 2>{"foo", string("bar")}));
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithInlineBacking) {
- tensorflow::gtl::InlinedVector<Instance, 1> empty;
- auto vec = tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0)};
- EXPECT_EQ(1, vec.size());
- EXPECT_EQ(empty.capacity(), vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
-}
-
-TEST(InitializerListConstructor, ComplexTypeWithReallocationRequired) {
- auto vec =
- tensorflow::gtl::InlinedVector<Instance, 1>{Instance(0), Instance(1)};
- EXPECT_EQ(2, vec.size());
- EXPECT_LE(2, vec.capacity());
- EXPECT_EQ(0, vec[0].value_);
- EXPECT_EQ(1, vec[1].value_);
-}
-
-TEST(DynamicVec, DynamicVecCompiles) {
- DynamicVec v;
- (void)v;
-}
-
-static void BM_InlinedVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- IntVec v;
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFill)->Range(0, 1024);
-
-static void BM_InlinedVectorFillRange(int iters, int len) {
- std::unique_ptr<int[]> ia(new int[len]);
- for (int j = 0; j < len; j++) {
- ia[j] = j;
- }
- for (int i = 0; i < iters; i++) {
- IntVec TF_ATTRIBUTE_UNUSED v(ia.get(), ia.get() + len);
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_InlinedVectorFillRange)->Range(0, 1024);
-
-static void BM_StdVectorFill(int iters, int len) {
- for (int i = 0; i < iters; i++) {
- std::vector<int> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(j);
- }
- }
- testing::BytesProcessed((int64{iters} * len) * sizeof(int));
-}
-BENCHMARK(BM_StdVectorFill)->Range(0, 1024);
-
-bool StringRepresentedInline(string s) {
- const char* chars = s.data();
- string s1 = std::move(s);
- return s1.data() != chars;
-}
-
-static void BM_InlinedVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- gtl::InlinedVector<string, 8> v;
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
-}
-BENCHMARK(BM_InlinedVectorFillString)->Range(0, 1024);
-
-static void BM_StdVectorFillString(int iters, int len) {
- string strings[4] = {"a quite long string", "another long string",
- "012345678901234567", "to cause allocation"};
- for (int i = 0; i < iters; i++) {
- std::vector<string> v;
- v.reserve(len);
- for (int j = 0; j < len; j++) {
- v.push_back(strings[j & 3]);
- }
- }
- testing::ItemsProcessed(int64{iters} * len);
- // The purpose of the benchmark is to verify that inlined vector is
- // efficient when moving is more efficient than copying. To do so, we
- // use strings that are larger than the small string optimization.
- CHECK(!StringRepresentedInline(strings[0]));
-}
-BENCHMARK(BM_StdVectorFillString)->Range(0, 1024);
-
-namespace {
-struct Buffer { // some arbitrary structure for benchmarking.
- char* base;
- int length;
- int capacity;
- void* user_data;
-};
-} // anonymous namespace
-
-static void BM_InlinedVectorTenAssignments(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<Buffer, 2> BufferVec;
-
- BufferVec src;
- src.resize(len);
-
- iters *= 10;
- BufferVec dst;
- for (int i = 0; i < iters; i++) {
- dst = src;
- }
-}
-BENCHMARK(BM_InlinedVectorTenAssignments)
- ->Arg(0)
- ->Arg(1)
- ->Arg(2)
- ->Arg(3)
- ->Arg(4)
- ->Arg(20);
-
-static void BM_CreateFromInitializerList(int iters) {
- for (; iters > 0; iters--) {
- tensorflow::gtl::InlinedVector<int, 4> x{1, 2, 3};
- (void)x[0];
- }
-}
-BENCHMARK(BM_CreateFromInitializerList);
-
-namespace {
-
-struct LargeSwappable {
- LargeSwappable() : d_(1024, 17) {}
- ~LargeSwappable() {}
- LargeSwappable(const LargeSwappable& o) : d_(o.d_) {}
-
- friend void swap(LargeSwappable& a, LargeSwappable& b) {
- using std::swap;
- swap(a.d_, b.d_);
- }
-
- LargeSwappable& operator=(LargeSwappable o) {
- using std::swap;
- swap(*this, o);
- return *this;
- }
-
- std::vector<int> d_;
-};
-
-} // namespace
-
-static void BM_LargeSwappableElements(int iters, int len) {
- typedef tensorflow::gtl::InlinedVector<LargeSwappable, 32> Vec;
- Vec a(len);
- Vec b;
- while (--iters >= 0) {
- using std::swap;
- swap(a, b);
- }
-}
-BENCHMARK(BM_LargeSwappableElements)->Range(0, 1024);
-
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/optional.h b/tensorflow/core/lib/gtl/optional.h
index 4ee3f88d18..238aa18e1e 100644
--- a/tensorflow/core/lib/gtl/optional.h
+++ b/tensorflow/core/lib/gtl/optional.h
@@ -13,864 +13,21 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_GTL_OPTIONAL_H_
-#define TENSORFLOW_LIB_GTL_OPTIONAL_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
+#define TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
-#include <assert.h>
-#include <functional>
-#include <initializer_list>
-#include <type_traits>
-#include <utility>
-
-#include "tensorflow/core/platform/logging.h"
+#include "absl/types/optional.h"
namespace tensorflow {
namespace gtl {
-// A value of type gtl::optional<T> holds either a value of T or an
-// "empty" value. When it holds a value of T, it stores it as a direct
-// subobject, so sizeof(optional<T>) is approximately sizeof(T)+1. The interface
-// is based on the upcoming std::optional<T>, and gtl::optional<T> is
-// designed to be cheaply drop-in replaceable by std::optional<T>, once it is
-// rolled out.
-//
-// This implementation is based on the specification in the latest draft as of
-// 2017-01-05, section 20.6.
-//
-// Differences between gtl::optional<T> and std::optional<T> include:
-// - constexpr not used for nonconst member functions.
-// (dependency on some differences between C++11 and C++14.)
-// - nullopt and in_place are not constexpr. We need the inline variable
-// support in C++17 for external linkage.
-// - CHECK instead of throwing std::bad_optional_access.
-// - optional::swap() and swap() relies on std::is_(nothrow_)swappable
-// which is introduced in C++17. So we assume is_swappable is always true
-// and is_nothrow_swappable is same as std::is_trivial.
-// - make_optional cannot be constexpr due to absence of guaranteed copy
-// elision.
-//
-// Synopsis:
-//
-// #include "tensorflow/core/lib/gtl/optional.h"
-//
-// tensorflow::gtl::optional<string> f() {
-// string result;
-// if (...) {
-// ...
-// result = ...;
-// return result;
-// } else {
-// ...
-// return tensorflow::gtl::nullopt;
-// }
-// }
-//
-// int main() {
-// tensorflow::gtl::optional<string> optstr = f();
-// if (optstr) {
-// // non-empty
-// print(optstr.value());
-// } else {
-// // empty
-// error();
-// }
-// }
-template <typename T>
-class optional;
-
-// The tag constant `in_place` is used as the first parameter of an optional<T>
-// constructor to indicate that the remaining arguments should be forwarded
-// to the underlying T constructor.
-struct in_place_t {};
-extern const in_place_t in_place;
-
-// The tag constant `nullopt` is used to indicate an empty optional<T> in
-// certain functions, such as construction or assignment.
-struct nullopt_t {
- struct init_t {};
- static init_t init;
- // It must not be default-constructible to avoid ambiguity for opt = {}.
- // Note the non-const reference, it is to eliminate ambiguity for code like:
- // struct S { int value; };
- //
- // void Test() {
- // optional<S> opt;
- // opt = {{}};
- // }
- explicit constexpr nullopt_t(init_t& /*unused*/) {} // NOLINT
-};
-extern const nullopt_t nullopt;
-
-namespace internal_optional {
-
-// define forward locally because std::forward is not constexpr until C++14
-template <typename T>
-constexpr T&& forward(typename std::remove_reference<T>::type&
- t) noexcept { // NOLINT(runtime/references)
- return static_cast<T&&>(t);
-}
-
-struct empty_struct {};
-// This class stores the data in optional<T>.
-// It is specialized based on whether T is trivially destructible.
-// This is the specialization for non trivially destructible type.
-template <typename T, bool = std::is_trivially_destructible<T>::value>
-class optional_data_dtor_base {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
-
- void destruct() noexcept {
- if (engaged_) {
- data_.~T();
- engaged_ = false;
- }
- }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() { destruct(); }
-};
-
-// Specialization for trivially destructible type.
-template <typename T>
-class optional_data_dtor_base<T, true> {
- protected:
- // Whether there is data or not.
- bool engaged_;
- // data storage
- union {
- empty_struct dummy_;
- T data_;
- };
- void destruct() noexcept { engaged_ = false; }
-
- // dummy_ must be initialized for constexpr constructor
- constexpr optional_data_dtor_base() noexcept : engaged_(false), dummy_{} {}
-
- template <typename... Args>
- constexpr explicit optional_data_dtor_base(in_place_t, Args&&... args)
- : engaged_(true), data_(internal_optional::forward<Args>(args)...) {}
-
- ~optional_data_dtor_base() = default;
-};
-
-template <typename T>
-class optional_data : public optional_data_dtor_base<T> {
- protected:
- using base = optional_data_dtor_base<T>;
- using base::base;
-
- T* pointer() { return &this->data_; }
-
- constexpr const T* pointer() const { return &this->data_; }
-
- template <typename... Args>
- void construct(Args&&... args) {
- new (pointer()) T(std::forward<Args>(args)...);
- this->engaged_ = true;
- }
-
- template <typename U>
- void assign(U&& u) {
- if (this->engaged_) {
- this->data_ = std::forward<U>(u);
- } else {
- construct(std::forward<U>(u));
- }
- }
-
- optional_data() = default;
-
- optional_data(const optional_data& rhs) {
- if (rhs.engaged_) {
- construct(rhs.data_);
- }
- }
-
- optional_data(optional_data&& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- construct(std::move(rhs.data_));
- }
- }
-
- optional_data& operator=(const optional_data& rhs) {
- if (rhs.engaged_) {
- assign(rhs.data_);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- optional_data& operator=(optional_data&& rhs) noexcept(
- std::is_nothrow_move_assignable<T>::value&&
- std::is_nothrow_move_constructible<T>::value) {
- if (rhs.engaged_) {
- assign(std::move(rhs.data_));
- } else {
- this->destruct();
- }
- return *this;
- }
-};
-
-// ordered by level of restriction, from low to high.
-// copyable implies movable.
-enum class copy_traits { copyable = 0, movable = 1, non_movable = 2 };
-
-// base class for enabling/disabling copy/move constructor.
-template <copy_traits>
-class optional_ctor_base;
-
-template <>
-class optional_ctor_base<copy_traits::copyable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = default;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = default;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-template <>
-class optional_ctor_base<copy_traits::non_movable> {
- public:
- constexpr optional_ctor_base() = default;
- optional_ctor_base(const optional_ctor_base&) = delete;
- optional_ctor_base(optional_ctor_base&&) = delete;
- optional_ctor_base& operator=(const optional_ctor_base&) = default;
- optional_ctor_base& operator=(optional_ctor_base&&) = default;
-};
-
-// base class for enabling/disabling copy/move assignment.
-template <copy_traits>
-class optional_assign_base;
-
-template <>
-class optional_assign_base<copy_traits::copyable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = default;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = default;
-};
-
-template <>
-class optional_assign_base<copy_traits::non_movable> {
- public:
- constexpr optional_assign_base() = default;
- optional_assign_base(const optional_assign_base&) = default;
- optional_assign_base(optional_assign_base&&) = default;
- optional_assign_base& operator=(const optional_assign_base&) = delete;
- optional_assign_base& operator=(optional_assign_base&&) = delete;
-};
-
+// Deprecated: please use absl::optional directly.
+using absl::make_optional;
+using absl::nullopt;
template <typename T>
-constexpr copy_traits get_ctor_copy_traits() {
- return std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_constructible<T>::value ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-template <typename T>
-constexpr copy_traits get_assign_copy_traits() {
- return std::is_copy_assignable<T>::value &&
- std::is_copy_constructible<T>::value
- ? copy_traits::copyable
- : std::is_move_assignable<T>::value &&
- std::is_move_constructible<T>::value
- ? copy_traits::movable
- : copy_traits::non_movable;
-}
-
-// Whether T is constructible or convertible from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_from_optional
- : std::integral_constant<
- bool, std::is_constructible<T, optional<U>&>::value ||
- std::is_constructible<T, optional<U>&&>::value ||
- std::is_constructible<T, const optional<U>&>::value ||
- std::is_constructible<T, const optional<U>&&>::value ||
- std::is_convertible<optional<U>&, T>::value ||
- std::is_convertible<optional<U>&&, T>::value ||
- std::is_convertible<const optional<U>&, T>::value ||
- std::is_convertible<const optional<U>&&, T>::value> {};
-
-// Whether T is constructible or convertible or assignable from optional<U>.
-template <typename T, typename U>
-struct is_constructible_convertible_assignable_from_optional
- : std::integral_constant<
- bool, is_constructible_convertible_from_optional<T, U>::value ||
- std::is_assignable<T&, optional<U>&>::value ||
- std::is_assignable<T&, optional<U>&&>::value ||
- std::is_assignable<T&, const optional<U>&>::value ||
- std::is_assignable<T&, const optional<U>&&>::value> {};
-
-} // namespace internal_optional
-
-template <typename T>
-class optional : private internal_optional::optional_data<T>,
- private internal_optional::optional_ctor_base<
- internal_optional::get_ctor_copy_traits<T>()>,
- private internal_optional::optional_assign_base<
- internal_optional::get_assign_copy_traits<T>()> {
- using data_base = internal_optional::optional_data<T>;
-
- public:
- typedef T value_type;
-
- // [optional.ctor], constructors
-
- // A default constructed optional holds the empty value, NOT a default
- // constructed T.
- constexpr optional() noexcept {}
-
- // An optional initialized with `nullopt` holds the empty value.
- constexpr optional(nullopt_t) noexcept {} // NOLINT(runtime/explicit)
-
- // Copy constructor, standard semantics.
- optional(const optional& src) = default;
-
- // Move constructor, standard semantics.
- optional(optional&& src) = default;
-
- // optional<T>(in_place, arg1, arg2, arg3) constructs a non-empty optional
- // with an in-place constructed value of T(arg1,arg2,arg3).
- // TODO(b/34201852): Add std::is_constructible<T, Args&&...> SFINAE.
- template <typename... Args>
- constexpr explicit optional(in_place_t, Args&&... args)
- : data_base(in_place_t(), internal_optional::forward<Args>(args)...) {}
-
- // optional<T>(in_place, {arg1, arg2, arg3}) constructs a non-empty optional
- // with an in-place list-initialized value of T({arg1, arg2, arg3}).
- template <typename U, typename... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- constexpr explicit optional(in_place_t, std::initializer_list<U> il,
- Args&&... args)
- : data_base(in_place_t(), il, internal_optional::forward<Args>(args)...) {
- }
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- constexpr optional(U&& v) // NOLINT
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- template <
- typename U = T,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !std::is_same<in_place_t, typename std::decay<U>::type>::value &&
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit constexpr optional(U&& v)
- : data_base(in_place_t(), internal_optional::forward<U>(v)) {}
-
- // Converting copy constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<const U&, T>::value,
- bool>::type = false>
- optional(const optional<U>& rhs) { // NOLINT
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting copy constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<const U&, T>::value,
- bool>::type = false>
- explicit optional(const optional<U>& rhs) {
- if (rhs) {
- this->construct(*rhs);
- }
- }
-
- // Converting move constructor (implicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- std::is_convertible<U&&, T>::value,
- bool>::type = false>
- optional(optional<U>&& rhs) { // NOLINT
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // Converting move constructor (explicit)
- template <
- typename U,
- typename std::enable_if<
- std::is_constructible<T, U&&>::value &&
- !internal_optional::is_constructible_convertible_from_optional<
- T, U>::value &&
- !std::is_convertible<U&&, T>::value,
- bool>::type = false>
- explicit optional(optional<U>&& rhs) {
- if (rhs) {
- this->construct(std::move(*rhs));
- }
- }
-
- // [optional.dtor], destructor, trivial if T is trivially destructible.
- ~optional() = default;
-
- // [optional.assign], assignment
-
- // Assignment from nullopt: opt = nullopt
- optional& operator=(nullopt_t) noexcept {
- this->destruct();
- return *this;
- }
-
- // Copy assignment, standard semantics.
- optional& operator=(const optional& src) = default;
-
- // Move assignment, standard semantics.
- optional& operator=(optional&& src) = default;
-
- // Value assignment
- template <
- typename U = T,
- typename = typename std::enable_if<
- !std::is_same<optional<T>, typename std::decay<U>::type>::value &&
- (!std::is_scalar<T>::value ||
- !std::is_same<T, typename std::decay<U>::type>::value) &&
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value>::type>
- optional& operator=(U&& v) {
- this->assign(std::forward<U>(v));
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, const U&>::value &&
- std::is_assignable<T&, const U&>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(const optional<U>& rhs) {
- if (rhs) {
- this->assign(*rhs);
- } else {
- this->destruct();
- }
- return *this;
- }
-
- template <typename U,
- typename = typename std::enable_if<
- std::is_constructible<T, U>::value &&
- std::is_assignable<T&, U>::value &&
- !internal_optional::
- is_constructible_convertible_assignable_from_optional<
- T, U>::value>::type>
- optional& operator=(optional<U>&& rhs) {
- if (rhs) {
- this->assign(std::move(*rhs));
- } else {
- this->destruct();
- }
- return *this;
- }
-
- // [optional.mod], modifiers
- // Destroys the inner T value if one is present.
- void reset() noexcept { this->destruct(); }
-
- // Emplace reconstruction. (Re)constructs the underlying T in-place with the
- // given arguments forwarded:
- //
- // optional<Foo> opt;
- // opt.emplace(arg1,arg2,arg3); (Constructs Foo(arg1,arg2,arg3))
- //
- // If the optional is non-empty, and the `args` refer to subobjects of the
- // current object, then behavior is undefined. This is because the current
- // object will be destructed before the new object is constructed with `args`.
- //
- template <typename... Args,
- typename = typename std::enable_if<
- std::is_constructible<T, Args&&...>::value>::type>
- void emplace(Args&&... args) {
- this->destruct();
- this->construct(std::forward<Args>(args)...);
- }
-
- // Emplace reconstruction with initializer-list. See immediately above.
- template <class U, class... Args,
- typename = typename std::enable_if<std::is_constructible<
- T, std::initializer_list<U>&, Args&&...>::value>::type>
- void emplace(std::initializer_list<U> il, Args&&... args) {
- this->destruct();
- this->construct(il, std::forward<Args>(args)...);
- }
-
- // [optional.swap], swap
- // Swap, standard semantics.
- void swap(optional& rhs) noexcept(
- std::is_nothrow_move_constructible<T>::value&&
- std::is_trivial<T>::value) {
- if (*this) {
- if (rhs) {
- using std::swap;
- swap(**this, *rhs);
- } else {
- rhs.construct(std::move(**this));
- this->destruct();
- }
- } else {
- if (rhs) {
- this->construct(std::move(*rhs));
- rhs.destruct();
- } else {
- // no effect (swap(disengaged, disengaged))
- }
- }
- }
-
- // [optional.observe], observers
- // You may use `*opt`, and `opt->m`, to access the underlying T value and T's
- // member `m`, respectively. If the optional is empty, behavior is
- // undefined.
- constexpr const T* operator->() const { return this->pointer(); }
- T* operator->() {
- assert(this->engaged_);
- return this->pointer();
- }
- constexpr const T& operator*() const& { return reference(); }
- T& operator*() & {
- assert(this->engaged_);
- return reference();
- }
- constexpr const T&& operator*() const&& { return std::move(reference()); }
- T&& operator*() && {
- assert(this->engaged_);
- return std::move(reference());
- }
-
- // In a bool context an optional<T> will return false if and only if it is
- // empty.
- //
- // if (opt) {
- // // do something with opt.value();
- // } else {
- // // opt is empty
- // }
- //
- constexpr explicit operator bool() const noexcept { return this->engaged_; }
-
- // Returns false if and only if *this is empty.
- constexpr bool has_value() const noexcept { return this->engaged_; }
-
- // Use `opt.value()` to get a reference to underlying value. The constness
- // and lvalue/rvalue-ness of `opt` is preserved to the view of the T
- // subobject.
- const T& value() const& {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T& value() & {
- CHECK(*this) << "Bad optional access";
- return reference();
- }
- T&& value() && { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
- const T&& value() const&& { // NOLINT(build/c++11)
- CHECK(*this) << "Bad optional access";
- return std::move(reference());
- }
-
- // Use `opt.value_or(val)` to get either the value of T or the given default
- // `val` in the empty case.
- template <class U>
- constexpr T value_or(U&& v) const& {
- return static_cast<bool>(*this) ? **this
- : static_cast<T>(std::forward<U>(v));
- }
- template <class U>
- T value_or(U&& v) && { // NOLINT(build/c++11)
- return static_cast<bool>(*this) ? std::move(**this)
- : static_cast<T>(std::forward<U>(v));
- }
-
- private:
- // Private accessors for internal storage viewed as reference to T.
- constexpr const T& reference() const { return *this->pointer(); }
- T& reference() { return *(this->pointer()); }
-
- // T constraint checks. You can't have an optional of nullopt_t, in_place_t
- // or a reference.
- static_assert(
- !std::is_same<nullopt_t, typename std::remove_cv<T>::type>::value,
- "optional<nullopt_t> is not allowed.");
- static_assert(
- !std::is_same<in_place_t, typename std::remove_cv<T>::type>::value,
- "optional<in_place_t> is not allowed.");
- static_assert(!std::is_reference<T>::value,
- "optional<reference> is not allowed.");
-};
-
-// [optional.specalg]
-// Swap, standard semantics.
-// This function shall not participate in overload resolution unless
-// is_move_constructible_v<T> is true and is_swappable_v<T> is true.
-// NOTE: we assume is_swappable is always true. There will be a compiling error
-// if T is actually not Swappable.
-template <typename T,
- typename std::enable_if<std::is_move_constructible<T>::value,
- bool>::type = false>
-void swap(optional<T>& a, optional<T>& b) noexcept(noexcept(a.swap(b))) {
- a.swap(b);
-}
-
-// NOTE: make_optional cannot be constexpr in C++11 because the copy/move
-// constructor is not constexpr and we don't have guaranteed copy elision
-// util C++17. But they are still declared constexpr for consistency with
-// the standard.
-
-// make_optional(v) creates a non-empty optional<T> where the type T is deduced
-// from v. Can also be explicitly instantiated as make_optional<T>(v).
-template <typename T>
-constexpr optional<typename std::decay<T>::type> make_optional(T&& v) {
- return optional<typename std::decay<T>::type>(std::forward<T>(v));
-}
-
-template <typename T, typename... Args>
-constexpr optional<T> make_optional(Args&&... args) {
- return optional<T>(in_place_t(), internal_optional::forward<Args>(args)...);
-}
-
-template <typename T, typename U, typename... Args>
-constexpr optional<T> make_optional(std::initializer_list<U> il,
- Args&&... args) {
- return optional<T>(in_place_t(), il,
- internal_optional::forward<Args>(args)...);
-}
-
-// Relational operators. Empty optionals are considered equal to each
-// other and less than non-empty optionals. Supports relations between
-// optional<T> and optional<T>, between optional<T> and T, and between
-// optional<T> and nullopt.
-// Note: We're careful to support T having non-bool relationals.
-
-// Relational operators [optional.relops]
-// The C++17 (N4606) "Returns:" statements are translated into code
-// in an obvious way here, and the original text retained as function docs.
-// Returns: If bool(x) != bool(y), false; otherwise if bool(x) == false, true;
-// otherwise *x == *y.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? false
- : static_cast<bool>(x) == false ? true : *x == *y;
-}
-// Returns: If bool(x) != bool(y), true; otherwise, if bool(x) == false, false;
-// otherwise *x != *y.
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const optional<T>& y) {
- return static_cast<bool>(x) != static_cast<bool>(y)
- ? true
- : static_cast<bool>(x) == false ? false : *x != *y;
-}
-// Returns: If !y, false; otherwise, if !x, true; otherwise *x < *y.
-template <class T>
-constexpr bool operator<(const optional<T>& x, const optional<T>& y) {
- return !y ? false : !x ? true : *x < *y;
-}
-// Returns: If !x, false; otherwise, if !y, true; otherwise *x > *y.
-template <class T>
-constexpr bool operator>(const optional<T>& x, const optional<T>& y) {
- return !x ? false : !y ? true : *x > *y;
-}
-// Returns: If !x, true; otherwise, if !y, false; otherwise *x <= *y.
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const optional<T>& y) {
- return !x ? true : !y ? false : *x <= *y;
-}
-// Returns: If !y, true; otherwise, if !x, false; otherwise *x >= *y.
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const optional<T>& y) {
- return !y ? true : !x ? false : *x >= *y;
-}
-
-// Comparison with nullopt [optional.nullops]
-// The C++17 (N4606) "Returns:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator==(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator!=(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, nullopt_t) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator<(nullopt_t, const optional<T>& x) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, nullopt_t) noexcept {
- return !x;
-}
-template <class T>
-constexpr bool operator<=(nullopt_t, const optional<T>& x) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, nullopt_t) noexcept {
- return static_cast<bool>(x);
-}
-template <class T>
-constexpr bool operator>(nullopt_t, const optional<T>& x) noexcept {
- return false;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, nullopt_t) noexcept {
- return true;
-}
-template <class T>
-constexpr bool operator>=(nullopt_t, const optional<T>& x) noexcept {
- return !x;
-}
-
-// Comparison with T [optional.comp_with_t]
-// The C++17 (N4606) "Equivalent to:" statements are used directly here.
-template <class T>
-constexpr bool operator==(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x == v : false;
-}
-template <class T>
-constexpr bool operator==(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v == *x : false;
-}
-template <class T>
-constexpr bool operator!=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x != v : true;
-}
-template <class T>
-constexpr bool operator!=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v != *x : true;
-}
-template <class T>
-constexpr bool operator<(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x < v : true;
-}
-template <class T>
-constexpr bool operator<(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v < *x : false;
-}
-template <class T>
-constexpr bool operator<=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x <= v : true;
-}
-template <class T>
-constexpr bool operator<=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v <= *x : false;
-}
-template <class T>
-constexpr bool operator>(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x > v : false;
-}
-template <class T>
-constexpr bool operator>(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v > *x : true;
-}
-template <class T>
-constexpr bool operator>=(const optional<T>& x, const T& v) {
- return static_cast<bool>(x) ? *x >= v : false;
-}
-template <class T>
-constexpr bool operator>=(const T& v, const optional<T>& x) {
- return static_cast<bool>(x) ? v >= *x : true;
-}
+using optional = absl::optional<T>;
} // namespace gtl
} // namespace tensorflow
-namespace std {
-
-// Normally std::hash specializations are not recommended in tensorflow code,
-// but we allow this as it is following a standard library component.
-template <class T>
-struct hash<::tensorflow::gtl::optional<T>> {
- size_t operator()(const ::tensorflow::gtl::optional<T>& opt) const {
- if (opt) {
- return hash<T>()(*opt);
- } else {
- return static_cast<size_t>(0x297814aaad196e6dULL);
- }
- }
-};
-
-} // namespace std
-
-#endif // TENSORFLOW_LIB_GTL_OPTIONAL_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_OPTIONAL_H_
diff --git a/tensorflow/core/lib/gtl/optional_test.cc b/tensorflow/core/lib/gtl/optional_test.cc
deleted file mode 100644
index 12b5bbc60b..0000000000
--- a/tensorflow/core/lib/gtl/optional_test.cc
+++ /dev/null
@@ -1,1098 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/core/lib/gtl/optional.h"
-
-#include <string>
-#include <utility>
-
-#include "tensorflow/core/platform/test.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace tensorflow {
-namespace {
-
-using tensorflow::gtl::in_place;
-using tensorflow::gtl::in_place_t;
-using tensorflow::gtl::make_optional;
-using tensorflow::gtl::nullopt;
-using tensorflow::gtl::nullopt_t;
-using tensorflow::gtl::optional;
-
-template <typename T>
-string TypeQuals(T&) {
- return "&";
-}
-template <typename T>
-string TypeQuals(T&&) {
- return "&&";
-}
-template <typename T>
-string TypeQuals(const T&) {
- return "c&";
-}
-template <typename T>
-string TypeQuals(const T&&) {
- return "c&&";
-}
-
-struct StructorListener {
- int construct0 = 0;
- int construct1 = 0;
- int construct2 = 0;
- int listinit = 0;
- int copy = 0;
- int move = 0;
- int copy_assign = 0;
- int move_assign = 0;
- int destruct = 0;
-};
-
-struct Listenable {
- static StructorListener* listener;
-
- Listenable() { ++listener->construct0; }
- Listenable(int /*unused*/) { ++listener->construct1; } // NOLINT
- Listenable(int /*unused*/, int /*unused*/) { ++listener->construct2; }
- Listenable(std::initializer_list<int> /*unused*/) { ++listener->listinit; }
- Listenable(const Listenable& /*unused*/) { ++listener->copy; }
- Listenable(Listenable&& /*unused*/) { ++listener->move; } // NOLINT
- Listenable& operator=(const Listenable& /*unused*/) {
- ++listener->copy_assign;
- return *this;
- }
- Listenable& operator=(Listenable&& /*unused*/) { // NOLINT
- ++listener->move_assign;
- return *this;
- }
- ~Listenable() { ++listener->destruct; }
-};
-
-StructorListener* Listenable::listener = nullptr;
-
-// clang on macos -- even the latest major version at time of writing (8.x) --
-// does not like much of our constexpr business. clang < 3.0 also has trouble.
-#if defined(__clang__) && defined(__APPLE__)
-#define SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
-#endif
-
-struct ConstexprType {
- constexpr ConstexprType() : x(0) {}
- constexpr explicit ConstexprType(int i) : x(i) {}
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr ConstexprType(std::initializer_list<int> il) : x(il.size()) {}
-#endif
- constexpr ConstexprType(const char* s) : x(-1) {} // NOLINT
- int x;
-};
-
-struct Copyable {
- Copyable() {}
- Copyable(const Copyable&) {}
- Copyable& operator=(const Copyable&) { return *this; }
-};
-
-struct MoveableThrow {
- MoveableThrow() {}
- MoveableThrow(MoveableThrow&&) {}
- MoveableThrow& operator=(MoveableThrow&&) { return *this; }
-};
-
-struct MoveableNoThrow {
- MoveableNoThrow() {}
- MoveableNoThrow(MoveableNoThrow&&) noexcept {}
- MoveableNoThrow& operator=(MoveableNoThrow&&) noexcept { return *this; }
-};
-
-struct NonMovable {
- NonMovable() {}
- NonMovable(const NonMovable&) = delete;
- NonMovable& operator=(const NonMovable&) = delete;
- NonMovable(NonMovable&&) = delete;
- NonMovable& operator=(NonMovable&&) = delete;
-};
-
-TEST(optionalTest, DefaultConstructor) {
- optional<int> empty;
- EXPECT_FALSE(!!empty);
- constexpr optional<int> cempty;
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE(std::is_nothrow_default_constructible<optional<int>>::value);
-}
-
-TEST(optionalTest, NullOptConstructor) {
- optional<int> empty(nullopt);
- EXPECT_FALSE(!!empty);
- // Creating a temporary nullopt_t object instead of using nullopt because
- // nullopt cannot be constexpr and have external linkage at the same time.
- constexpr optional<int> cempty{nullopt_t(nullopt_t::init)};
- static_assert(!cempty.has_value(), "");
- EXPECT_TRUE((std::is_nothrow_constructible<optional<int>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_copy(empty);
- EXPECT_FALSE(!!empty_copy);
- optional<int> opt42_copy(opt42);
- EXPECT_TRUE(!!opt42_copy);
- EXPECT_EQ(42, opt42_copy);
- // test copyablility
- EXPECT_TRUE(std::is_copy_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_copy_constructible<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_constructible<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveConstructor) {
- optional<int> empty, opt42 = 42;
- optional<int> empty_move(std::move(empty));
- EXPECT_FALSE(!!empty_move);
- optional<int> opt42_move(std::move(opt42));
- EXPECT_TRUE(!!opt42_move);
- EXPECT_EQ(42, opt42_move);
- // test movability
- EXPECT_TRUE(std::is_move_constructible<optional<int>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_constructible<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_constructible<optional<NonMovable>>::value);
- // test noexcept
- EXPECT_TRUE(std::is_nothrow_move_constructible<optional<int>>::value);
- EXPECT_FALSE(
- std::is_nothrow_move_constructible<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_constructible<optional<MoveableNoThrow>>::value);
-}
-
-TEST(optionalTest, Destructor) {
- struct Trivial {};
-
- struct NonTrivial {
- ~NonTrivial() {}
- };
-
- EXPECT_TRUE(std::is_trivially_destructible<optional<int>>::value);
- EXPECT_TRUE(std::is_trivially_destructible<optional<Trivial>>::value);
- EXPECT_FALSE(std::is_trivially_destructible<optional<NonTrivial>>::value);
-}
-
-TEST(optionalTest, InPlaceConstructor) {
- constexpr optional<ConstexprType> opt0{in_place_t()};
- static_assert(opt0, "");
- static_assert(opt0->x == 0, "");
- constexpr optional<ConstexprType> opt1{in_place_t(), 1};
- static_assert(opt1, "");
- static_assert(opt1->x == 1, "");
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<ConstexprType> opt2{in_place_t(), {1, 2}};
- static_assert(opt2, "");
- static_assert(opt2->x == 2, "");
-#endif
-
- // TODO(b/34201852): uncomment these when std::is_constructible<T, Args&&...>
- // SFINAE is added to optional::optional(in_place_t, Args&&...).
- // struct I {
- // I(in_place_t);
- // };
-
- // EXPECT_FALSE((std::is_constructible<optional<I>, in_place_t>::value));
- // EXPECT_FALSE((std::is_constructible<optional<I>, const
- // in_place_t&>::value));
-}
-
-// template<U=T> optional(U&&);
-TEST(optionalTest, ValueConstructor) {
- constexpr optional<int> opt0(0);
- static_assert(opt0, "");
- static_assert(*opt0 == 0, "");
- EXPECT_TRUE((std::is_convertible<int, optional<int>>::value));
- // Copy initialization ( = "abc") won't work due to optional(optional&&)
- // is not constexpr. Use list initialization instead. This invokes
- // optional<ConstexprType>::optional<U>(U&&), with U = const char (&) [4],
- // which direct-initializes the ConstexprType value held by the optional
- // via ConstexprType::ConstexprType(const char*).
- constexpr optional<ConstexprType> opt1 = {"abc"};
- static_assert(opt1, "");
- static_assert(-1 == opt1->x, "");
- EXPECT_TRUE(
- (std::is_convertible<const char*, optional<ConstexprType>>::value));
- // direct initialization
- constexpr optional<ConstexprType> opt2{2};
- static_assert(opt2, "");
- static_assert(2 == opt2->x, "");
- EXPECT_FALSE((std::is_convertible<int, optional<ConstexprType>>::value));
-
- // this invokes optional<int>::optional(int&&)
- // NOTE: this has different behavior than assignment, e.g.
- // "opt3 = {};" clears the optional rather than setting the value to 0
- constexpr optional<int> opt3({});
- static_assert(opt3, "");
- static_assert(*opt3 == 0, "");
-
- // this invokes the move constructor with a default constructed optional
- // because non-template function is a better match than template function.
- optional<ConstexprType> opt4({});
- EXPECT_FALSE(!!opt4);
-}
-
-struct Implicit {};
-
-struct Explicit {};
-
-struct Convert {
- Convert(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false) {}
- Convert(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true) {}
- explicit Convert(const Explicit&) : implicit(false), move(false) {}
- explicit Convert(Explicit&&) : implicit(false), move(true) {}
-
- bool implicit;
- bool move;
-};
-
-struct ConvertFromOptional {
- ConvertFromOptional(const Implicit&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(false) {}
- ConvertFromOptional(Implicit&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(false) {}
- ConvertFromOptional(const optional<Implicit>&) // NOLINT(runtime/explicit)
- : implicit(true), move(false), from_optional(true) {}
- ConvertFromOptional(optional<Implicit>&&) // NOLINT(runtime/explicit)
- : implicit(true), move(true), from_optional(true) {}
- explicit ConvertFromOptional(const Explicit&)
- : implicit(false), move(false), from_optional(false) {}
- explicit ConvertFromOptional(Explicit&&)
- : implicit(false), move(true), from_optional(false) {}
- explicit ConvertFromOptional(const optional<Explicit>&)
- : implicit(false), move(false), from_optional(true) {}
- explicit ConvertFromOptional(optional<Explicit>&&)
- : implicit(false), move(true), from_optional(true) {}
-
- bool implicit;
- bool move;
- bool from_optional;
-};
-
-TEST(optionalTest, ConvertingConstructor) {
- optional<Implicit> i_empty;
- optional<Implicit> i(in_place);
- optional<Explicit> e_empty;
- optional<Explicit> e(in_place);
- {
- // implicitly constructing optional<Convert> from optional<Implicit>
- optional<Convert> empty = i_empty;
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy = i;
- EXPECT_TRUE(!!opt_copy);
- EXPECT_TRUE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- optional<Convert> opt_move = optional<Implicit>(in_place);
- EXPECT_TRUE(!!opt_move);
- EXPECT_TRUE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- }
- {
- // explicitly constructing optional<Convert> from optional<Explicit>
- optional<Convert> empty(e_empty);
- EXPECT_FALSE(!!empty);
- optional<Convert> opt_copy(e);
- EXPECT_TRUE(!!opt_copy);
- EXPECT_FALSE(opt_copy->implicit);
- EXPECT_FALSE(opt_copy->move);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<Convert>>::value));
- optional<Convert> opt_move{optional<Explicit>(in_place)};
- EXPECT_TRUE(!!opt_move);
- EXPECT_FALSE(opt_move->implicit);
- EXPECT_TRUE(opt_move->move);
- EXPECT_FALSE(
- (std::is_convertible<optional<Explicit>&&, optional<Convert>>::value));
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Implicit> via ConvertFromOptional(optional<Implicit>&&)
- // check that ConvertFromOptional(Implicit&&) is NOT called
- static_assert(
- gtl::internal_optional::is_constructible_convertible_from_optional<
- ConvertFromOptional, Implicit>::value,
- "");
- optional<ConvertFromOptional> opt0 = i_empty;
- EXPECT_TRUE(!!opt0);
- EXPECT_TRUE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- optional<ConvertFromOptional> opt1 = optional<Implicit>();
- EXPECT_TRUE(!!opt1);
- EXPECT_TRUE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- }
- {
- // implicitly constructing optional<ConvertFromOptional> from
- // optional<Explicit> via ConvertFromOptional(optional<Explicit>&&)
- // check that ConvertFromOptional(Explicit&&) is NOT called
- optional<ConvertFromOptional> opt0(e_empty);
- EXPECT_TRUE(!!opt0);
- EXPECT_FALSE(opt0->implicit);
- EXPECT_FALSE(opt0->move);
- EXPECT_TRUE(opt0->from_optional);
- EXPECT_FALSE((std::is_convertible<const optional<Explicit>&,
- optional<ConvertFromOptional>>::value));
- optional<ConvertFromOptional> opt1{optional<Explicit>()};
- EXPECT_TRUE(!!opt1);
- EXPECT_FALSE(opt1->implicit);
- EXPECT_TRUE(opt1->move);
- EXPECT_TRUE(opt1->from_optional);
- EXPECT_FALSE((std::is_convertible<optional<Explicit>&&,
- optional<ConvertFromOptional>>::value));
- }
-}
-
-TEST(optionalTest, StructorBasic) {
- StructorListener listener;
- Listenable::listener = &listener;
- {
- optional<Listenable> empty;
- EXPECT_FALSE(!!empty);
- optional<Listenable> opt0(in_place);
- EXPECT_TRUE(!!opt0);
- optional<Listenable> opt1(in_place, 1);
- EXPECT_TRUE(!!opt1);
- optional<Listenable> opt2(in_place, 1, 2);
- EXPECT_TRUE(!!opt2);
- }
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(3, listener.destruct);
-}
-
-TEST(optionalTest, CopyMoveStructor) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> original(in_place);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> copy(original);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(0, listener.move);
- optional<Listenable> move(std::move(original));
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.copy);
- EXPECT_EQ(1, listener.move);
-}
-
-TEST(optionalTest, ListInit) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> listinit1(in_place, {1});
- optional<Listenable> listinit2(in_place, {1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, AssignFromNullopt) {
- optional<int> opt(1);
- opt = nullopt;
- EXPECT_FALSE(!!opt);
-
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt1(in_place);
- opt1 = nullopt;
- EXPECT_FALSE(opt1);
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.destruct);
-
- EXPECT_TRUE((std::is_nothrow_assignable<optional<int>, nullopt_t>::value));
- EXPECT_TRUE(
- (std::is_nothrow_assignable<optional<Listenable>, nullopt_t>::value));
-}
-
-TEST(optionalTest, CopyAssignment) {
- const optional<int> empty, opt1 = 1, opt2 = 2;
- optional<int> empty_to_opt1, opt1_to_opt2, opt2_to_empty;
-
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = empty;
- EXPECT_FALSE(!!empty_to_opt1);
- empty_to_opt1 = opt1;
- EXPECT_TRUE(!!empty_to_opt1);
- EXPECT_EQ(1, empty_to_opt1.value());
-
- EXPECT_FALSE(!!opt1_to_opt2);
- opt1_to_opt2 = opt1;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(1, opt1_to_opt2.value());
- opt1_to_opt2 = opt2;
- EXPECT_TRUE(!!opt1_to_opt2);
- EXPECT_EQ(2, opt1_to_opt2.value());
-
- EXPECT_FALSE(!!opt2_to_empty);
- opt2_to_empty = opt2;
- EXPECT_TRUE(!!opt2_to_empty);
- EXPECT_EQ(2, opt2_to_empty.value());
- opt2_to_empty = empty;
- EXPECT_FALSE(!!opt2_to_empty);
-
- EXPECT_TRUE(std::is_copy_assignable<optional<Copyable>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_copy_assignable<optional<NonMovable>>::value);
-}
-
-TEST(optionalTest, MoveAssignment) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> empty1, empty2, set1(in_place), set2(in_place);
- EXPECT_EQ(2, listener.construct0);
- optional<Listenable> empty_to_empty, empty_to_set, set_to_empty(in_place),
- set_to_set(in_place);
- EXPECT_EQ(4, listener.construct0);
- empty_to_empty = std::move(empty1);
- empty_to_set = std::move(set1);
- set_to_empty = std::move(empty2);
- set_to_set = std::move(set2);
- EXPECT_EQ(0, listener.copy);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(1, listener.destruct);
- EXPECT_EQ(1, listener.move_assign);
-
- EXPECT_TRUE(std::is_move_assignable<optional<Copyable>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(std::is_move_assignable<optional<MoveableNoThrow>>::value);
- EXPECT_FALSE(std::is_move_assignable<optional<NonMovable>>::value);
-
- EXPECT_FALSE(std::is_nothrow_move_assignable<optional<MoveableThrow>>::value);
- EXPECT_TRUE(
- std::is_nothrow_move_assignable<optional<MoveableNoThrow>>::value);
-}
-
-struct NoConvertToOptional {
- // disable implicit conversion from const NoConvertToOptional&
- // to optional<NoConvertToOptional>.
- NoConvertToOptional(const NoConvertToOptional&) = delete;
-};
-
-struct CopyConvert {
- CopyConvert(const NoConvertToOptional&);
- CopyConvert& operator=(const CopyConvert&) = delete;
- CopyConvert& operator=(const NoConvertToOptional&);
-};
-
-struct CopyConvertFromOptional {
- CopyConvertFromOptional(const NoConvertToOptional&);
- CopyConvertFromOptional(const optional<NoConvertToOptional>&);
- CopyConvertFromOptional& operator=(const CopyConvertFromOptional&) = delete;
- CopyConvertFromOptional& operator=(const NoConvertToOptional&);
- CopyConvertFromOptional& operator=(const optional<NoConvertToOptional>&);
-};
-
-struct MoveConvert {
- MoveConvert(NoConvertToOptional&&);
- MoveConvert& operator=(const MoveConvert&) = delete;
- MoveConvert& operator=(NoConvertToOptional&&);
-};
-
-struct MoveConvertFromOptional {
- MoveConvertFromOptional(NoConvertToOptional&&);
- MoveConvertFromOptional(optional<NoConvertToOptional>&&);
- MoveConvertFromOptional& operator=(const MoveConvertFromOptional&) = delete;
- MoveConvertFromOptional& operator=(NoConvertToOptional&&);
- MoveConvertFromOptional& operator=(optional<NoConvertToOptional>&&);
-};
-
-// template <class U = T> optional<T>& operator=(U&& v);
-TEST(optionalTest, ValueAssignment) {
- optional<int> opt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = nullopt;
- EXPECT_FALSE(!!opt);
- opt = 42;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(42, opt.value());
- opt = 43;
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(43, opt.value());
- opt = {}; // this should clear optional
- EXPECT_FALSE(!!opt);
-
- opt = {44};
- EXPECT_TRUE(!!opt);
- EXPECT_EQ(44, opt.value());
-
- // U = const NoConvertToOptional&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvert>&,
- const NoConvertToOptional&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE((std::is_assignable<optional<CopyConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvert>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- NoConvertToOptional&&>::value));
- // U = const NoConvertToOptional& triggers SFINAE because
- // std::is_constructible_v<MoveConvertFromOptional, const
- // NoConvertToOptional&> is false
- EXPECT_FALSE((std::is_assignable<optional<MoveConvertFromOptional>&,
- const NoConvertToOptional&>::value));
- // U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- NoConvertToOptional&&>::value));
- // U = const optional<NoConvertToOptional>&
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
- // U = optional<NoConvertToOptional>
- EXPECT_TRUE((std::is_assignable<optional<MoveConvertFromOptional>&,
- optional<NoConvertToOptional>&&>::value));
-}
-
-// template <class U> optional<T>& operator=(const optional<U>& rhs);
-// template <class U> optional<T>& operator=(optional<U>&& rhs);
-TEST(optionalTest, ConvertingAssignment) {
- optional<int> opt_i;
- optional<char> opt_c('c');
- opt_i = opt_c;
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ(*opt_c, *opt_i);
- opt_i = optional<char>();
- EXPECT_FALSE(!!opt_i);
- opt_i = optional<char>('d');
- EXPECT_TRUE(!!opt_i);
- EXPECT_EQ('d', *opt_i);
-
- optional<string> opt_str;
- optional<const char*> opt_cstr("abc");
- opt_str = opt_cstr;
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("abc"), *opt_str);
- opt_str = optional<const char*>();
- EXPECT_FALSE(!!opt_str);
- opt_str = optional<const char*>("def");
- EXPECT_TRUE(!!opt_str);
- EXPECT_EQ(string("def"), *opt_str);
-
- // operator=(const optional<U>&) with U = NoConvertToOptional
- EXPECT_TRUE(
- (std::is_assignable<optional<CopyConvert>,
- const optional<NoConvertToOptional>&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional
- // triggers SFINAE because
- // std::is_constructible_v<MoveConvert, const NoConvertToOptional&> is false
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvert>&,
- const optional<NoConvertToOptional>&>::value));
- // operator=(optional<U>&&) with U = NoConvertToOptional
- EXPECT_TRUE((std::is_assignable<optional<MoveConvert>&,
- optional<NoConvertToOptional>&&>::value));
- // operator=(const optional<U>&) with U = NoConvertToOptional triggers SFINAE
- // because std::is_constructible_v<MoveConvertFromOptional,
- // const NoConvertToOptional&> is false.
- // operator=(U&&) with U = const optional<NoConverToOptional>& triggers SFINAE
- // because std::is_constructible<MoveConvertFromOptional,
- // optional<NoConvertToOptional>&&> is true.
- EXPECT_FALSE(
- (std::is_assignable<optional<MoveConvertFromOptional>&,
- const optional<NoConvertToOptional>&>::value));
-}
-
-TEST(optionalTest, ResetAndHasValue) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- opt.emplace();
- EXPECT_TRUE(!!opt);
- EXPECT_TRUE(opt.has_value());
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
- EXPECT_EQ(1, listener.destruct);
- opt.reset();
- EXPECT_FALSE(!!opt);
- EXPECT_FALSE(opt.has_value());
-
- constexpr optional<int> empty;
- static_assert(!empty.has_value(), "");
- constexpr optional<int> nonempty(1);
- static_assert(nonempty.has_value(), "");
-}
-
-TEST(optionalTest, Emplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace(1);
- EXPECT_TRUE(!!opt);
- opt.emplace(1, 2);
- EXPECT_EQ(1, listener.construct1);
- EXPECT_EQ(1, listener.construct2);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, ListEmplace) {
- StructorListener listener;
- Listenable::listener = &listener;
- optional<Listenable> opt;
- EXPECT_FALSE(!!opt);
- opt.emplace({1});
- EXPECT_TRUE(!!opt);
- opt.emplace({1, 2});
- EXPECT_EQ(2, listener.listinit);
- EXPECT_EQ(1, listener.destruct);
-}
-
-TEST(optionalTest, Swap) {
- optional<int> opt_empty, opt1 = 1, opt2 = 2;
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt1);
- EXPECT_TRUE(!!opt_empty);
- EXPECT_EQ(1, opt_empty.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt_empty, opt1);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(1, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(2, opt2.value());
- swap(opt1, opt2);
- EXPECT_FALSE(!!opt_empty);
- EXPECT_TRUE(!!opt1);
- EXPECT_EQ(2, opt1.value());
- EXPECT_TRUE(!!opt2);
- EXPECT_EQ(1, opt2.value());
-
- EXPECT_TRUE(noexcept(opt1.swap(opt2)));
- EXPECT_TRUE(noexcept(swap(opt1, opt2)));
-}
-
-TEST(optionalTest, PointerStuff) {
- optional<string> opt(in_place, "foo");
- EXPECT_EQ("foo", *opt);
- const auto& opt_const = opt;
- EXPECT_EQ("foo", *opt_const);
- EXPECT_EQ(opt->size(), 3);
- EXPECT_EQ(opt_const->size(), 3);
-
- constexpr optional<ConstexprType> opt1(1);
- static_assert(opt1->x == 1, "");
-}
-
-// gcc has a bug pre 4.9 where it doesn't do correct overload resolution
-// between rvalue reference qualified member methods. Skip that test to make
-// the build green again when using the old compiler.
-#if defined(__GNUC__) && !defined(__clang__)
-#if __GNUC__ < 4 || (__GNUC__ == 4 && __GNUC_MINOR__ < 9)
-#define SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
-#endif
-#endif
-
-TEST(optionalTest, Value) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", lvalue.value());
- EXPECT_EQ("clvalue", clvalue.value());
- EXPECT_EQ("xvalue", O(in_place, "xvalue").value());
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", CO(in_place, "cxvalue").value());
- EXPECT_EQ("&", TypeQuals(lvalue.value()));
- EXPECT_EQ("c&", TypeQuals(clvalue.value()));
- EXPECT_EQ("&&", TypeQuals(O(in_place, "xvalue").value()));
- EXPECT_EQ("c&&", TypeQuals(CO(in_place, "cxvalue").value()));
-#endif
-}
-
-TEST(optionalTest, DerefOperator) {
- using O = optional<string>;
- using CO = const optional<string>;
- O lvalue(in_place, "lvalue");
- CO clvalue(in_place, "clvalue");
- EXPECT_EQ("lvalue", *lvalue);
- EXPECT_EQ("clvalue", *clvalue);
- EXPECT_EQ("xvalue", *O(in_place, "xvalue"));
-#ifndef SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG
- EXPECT_EQ("cxvalue", *CO(in_place, "cxvalue"));
- EXPECT_EQ("&", TypeQuals(*lvalue));
- EXPECT_EQ("c&", TypeQuals(*clvalue));
- EXPECT_EQ("&&", TypeQuals(*O(in_place, "xvalue")));
- EXPECT_EQ("c&&", TypeQuals(*CO(in_place, "cxvalue")));
-#endif
-
- constexpr optional<int> opt1(1);
- static_assert(*opt1 == 1, "");
-
-#if !defined(SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG) && \
- !defined(SKIP_OVERLOAD_TEST_DUE_TO_GCC_BUG)
- using COI = const optional<int>;
- static_assert(*COI(2) == 2, "");
-#endif
-}
-
-TEST(optionalTest, ValueOr) {
- optional<double> opt_empty, opt_set = 1.2;
- EXPECT_EQ(42.0, opt_empty.value_or(42));
- EXPECT_EQ(1.2, opt_set.value_or(42));
- EXPECT_EQ(42.0, optional<double>().value_or(42));
- EXPECT_EQ(1.2, optional<double>(1.2).value_or(42));
-
-#ifndef SKIP_CONSTEXPR_TEST_DUE_TO_CLANG_BUG
- constexpr optional<double> copt_empty;
- static_assert(42.0 == copt_empty.value_or(42), "");
-
- constexpr optional<double> copt_set = {1.2};
- static_assert(1.2 == copt_set.value_or(42), "");
-
- using COD = const optional<double>;
- static_assert(42.0 == COD().value_or(42), "");
- static_assert(1.2 == COD(1.2).value_or(42), "");
-#endif
-}
-
-// make_optional cannot be constexpr until C++17
-TEST(optionalTest, make_optional) {
- auto opt_int = make_optional(42);
- EXPECT_TRUE((std::is_same<decltype(opt_int), optional<int>>::value));
- EXPECT_EQ(42, opt_int);
-
- StructorListener listener;
- Listenable::listener = &listener;
-
- optional<Listenable> opt0 = make_optional<Listenable>();
- EXPECT_EQ(1, listener.construct0);
- optional<Listenable> opt1 = make_optional<Listenable>(1);
- EXPECT_EQ(1, listener.construct1);
- optional<Listenable> opt2 = make_optional<Listenable>(1, 2);
- EXPECT_EQ(1, listener.construct2);
- optional<Listenable> opt3 = make_optional<Listenable>({1});
- optional<Listenable> opt4 = make_optional<Listenable>({1, 2});
- EXPECT_EQ(2, listener.listinit);
-}
-
-TEST(optionalTest, Comparisons) {
- optional<int> ae, be, a2 = 2, b2 = 2, a4 = 4, b4 = 4;
-
-#define optionalTest_Comparisons_EXPECT_LESS(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_TRUE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_FALSE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_SAME(x, y) \
- EXPECT_TRUE((x) == (y)); \
- EXPECT_FALSE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_FALSE((x) > (y)); \
- EXPECT_TRUE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
-#define optionalTest_Comparisons_EXPECT_GREATER(x, y) \
- EXPECT_FALSE((x) == (y)); \
- EXPECT_TRUE((x) != (y)); \
- EXPECT_FALSE((x) < (y)); \
- EXPECT_TRUE((x) > (y)); \
- EXPECT_FALSE((x) <= (y)); \
- EXPECT_TRUE((x) >= (y));
-
- // LHS: nullopt, ae, a2, 3, a4
- // RHS: nullopt, be, b2, 3, b4
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,nullopt);
- optionalTest_Comparisons_EXPECT_SAME(nullopt, be);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b2);
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(nullopt,3);
- optionalTest_Comparisons_EXPECT_LESS(nullopt, b4);
-
- optionalTest_Comparisons_EXPECT_SAME(ae, nullopt);
- optionalTest_Comparisons_EXPECT_SAME(ae, be);
- optionalTest_Comparisons_EXPECT_LESS(ae, b2);
- optionalTest_Comparisons_EXPECT_LESS(ae, 3);
- optionalTest_Comparisons_EXPECT_LESS(ae, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a2, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a2, be);
- optionalTest_Comparisons_EXPECT_SAME(a2, b2);
- optionalTest_Comparisons_EXPECT_LESS(a2, 3);
- optionalTest_Comparisons_EXPECT_LESS(a2, b4);
-
- // optionalTest_Comparisons_EXPECT_NOT_TO_WORK(3,nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(3, be);
- optionalTest_Comparisons_EXPECT_GREATER(3, b2);
- optionalTest_Comparisons_EXPECT_SAME(3, 3);
- optionalTest_Comparisons_EXPECT_LESS(3, b4);
-
- optionalTest_Comparisons_EXPECT_GREATER(a4, nullopt);
- optionalTest_Comparisons_EXPECT_GREATER(a4, be);
- optionalTest_Comparisons_EXPECT_GREATER(a4, b2);
- optionalTest_Comparisons_EXPECT_GREATER(a4, 3);
- optionalTest_Comparisons_EXPECT_SAME(a4, b4);
-}
-
-TEST(optionalTest, SwapRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- optional<Listenable> b(in_place);
- a.swap(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-
- {
- optional<Listenable> a(in_place);
- optional<Listenable> b;
- a.swap(b);
- }
-
- EXPECT_EQ(2, listener.construct0);
- EXPECT_EQ(2, listener.move);
- EXPECT_EQ(4, listener.destruct);
-}
-
-TEST(optionalTest, BigStringLeakCheck) {
- constexpr size_t n = 1 << 16;
-
- using OS = optional<string>;
-
- OS a;
- OS b = nullopt;
- OS c = string(n, 'c');
- string sd(n, 'd');
- OS d = sd;
- OS e(in_place, n, 'e');
- OS f;
- f.emplace(n, 'f');
-
- OS ca(a);
- OS cb(b);
- OS cc(c);
- OS cd(d);
- OS ce(e);
-
- OS oa;
- OS ob = nullopt;
- OS oc = string(n, 'c');
- string sod(n, 'd');
- OS od = sod;
- OS oe(in_place, n, 'e');
- OS of;
- of.emplace(n, 'f');
-
- OS ma(std::move(oa));
- OS mb(std::move(ob));
- OS mc(std::move(oc));
- OS md(std::move(od));
- OS me(std::move(oe));
- OS mf(std::move(of));
-
- OS aa1;
- OS ab1 = nullopt;
- OS ac1 = string(n, 'c');
- string sad1(n, 'd');
- OS ad1 = sad1;
- OS ae1(in_place, n, 'e');
- OS af1;
- af1.emplace(n, 'f');
-
- OS aa2;
- OS ab2 = nullopt;
- OS ac2 = string(n, 'c');
- string sad2(n, 'd');
- OS ad2 = sad2;
- OS ae2(in_place, n, 'e');
- OS af2;
- af2.emplace(n, 'f');
-
- aa1 = af2;
- ab1 = ae2;
- ac1 = ad2;
- ad1 = ac2;
- ae1 = ab2;
- af1 = aa2;
-
- OS aa3;
- OS ab3 = nullopt;
- OS ac3 = string(n, 'c');
- string sad3(n, 'd');
- OS ad3 = sad3;
- OS ae3(in_place, n, 'e');
- OS af3;
- af3.emplace(n, 'f');
-
- aa3 = nullopt;
- ab3 = nullopt;
- ac3 = nullopt;
- ad3 = nullopt;
- ae3 = nullopt;
- af3 = nullopt;
-
- OS aa4;
- OS ab4 = nullopt;
- OS ac4 = string(n, 'c');
- string sad4(n, 'd');
- OS ad4 = sad4;
- OS ae4(in_place, n, 'e');
- OS af4;
- af4.emplace(n, 'f');
-
- aa4 = OS(in_place, n, 'a');
- ab4 = OS(in_place, n, 'b');
- ac4 = OS(in_place, n, 'c');
- ad4 = OS(in_place, n, 'd');
- ae4 = OS(in_place, n, 'e');
- af4 = OS(in_place, n, 'f');
-
- OS aa5;
- OS ab5 = nullopt;
- OS ac5 = string(n, 'c');
- string sad5(n, 'd');
- OS ad5 = sad5;
- OS ae5(in_place, n, 'e');
- OS af5;
- af5.emplace(n, 'f');
-
- string saa5(n, 'a');
- string sab5(n, 'a');
- string sac5(n, 'a');
- string sad52(n, 'a');
- string sae5(n, 'a');
- string saf5(n, 'a');
-
- aa5 = saa5;
- ab5 = sab5;
- ac5 = sac5;
- ad5 = sad52;
- ae5 = sae5;
- af5 = saf5;
-
- OS aa6;
- OS ab6 = nullopt;
- OS ac6 = string(n, 'c');
- string sad6(n, 'd');
- OS ad6 = sad6;
- OS ae6(in_place, n, 'e');
- OS af6;
- af6.emplace(n, 'f');
-
- aa6 = string(n, 'a');
- ab6 = string(n, 'b');
- ac6 = string(n, 'c');
- ad6 = string(n, 'd');
- ae6 = string(n, 'e');
- af6 = string(n, 'f');
-
- OS aa7;
- OS ab7 = nullopt;
- OS ac7 = string(n, 'c');
- string sad7(n, 'd');
- OS ad7 = sad7;
- OS ae7(in_place, n, 'e');
- OS af7;
- af7.emplace(n, 'f');
-
- aa7.emplace(n, 'A');
- ab7.emplace(n, 'B');
- ac7.emplace(n, 'C');
- ad7.emplace(n, 'D');
- ae7.emplace(n, 'E');
- af7.emplace(n, 'F');
-}
-
-TEST(optionalTest, MoveAssignRegression) {
- StructorListener listener;
- Listenable::listener = &listener;
-
- {
- optional<Listenable> a;
- Listenable b;
- a = std::move(b);
- }
-
- EXPECT_EQ(1, listener.construct0);
- EXPECT_EQ(1, listener.move);
- EXPECT_EQ(2, listener.destruct);
-}
-
-TEST(optionalTest, ValueType) {
- EXPECT_TRUE((std::is_same<optional<int>::value_type, int>::value));
- EXPECT_TRUE((std::is_same<optional<string>::value_type, string>::value));
- EXPECT_FALSE((std::is_same<optional<int>::value_type, nullopt_t>::value));
-}
-
-TEST(optionalTest, Hash) {
- std::hash<optional<int>> hash;
- std::set<size_t> hashcodes;
- hashcodes.insert(hash(nullopt));
- for (int i = 0; i < 100; ++i) {
- hashcodes.insert(hash(i));
- }
- EXPECT_GT(hashcodes.size(), 90);
-}
-
-struct MoveMeNoThrow {
- MoveMeNoThrow() : x(0) {}
- MoveMeNoThrow(const MoveMeNoThrow& other) : x(other.x) {
- LOG(FATAL) << "Should not be called.";
- }
- MoveMeNoThrow(MoveMeNoThrow&& other) noexcept : x(other.x) {}
- int x;
-};
-
-struct MoveMeThrow {
- MoveMeThrow() : x(0) {}
- MoveMeThrow(const MoveMeThrow& other) : x(other.x) {}
- MoveMeThrow(MoveMeThrow&& other) : x(other.x) {}
- int x;
-};
-
-TEST(optionalTest, NoExcept) {
- static_assert(
- std::is_nothrow_move_constructible<optional<MoveMeNoThrow>>::value, "");
- static_assert(
- !std::is_nothrow_move_constructible<optional<MoveMeThrow>>::value, "");
- std::vector<optional<MoveMeNoThrow>> v;
- v.reserve(10);
- for (int i = 0; i < 10; ++i) v.emplace_back();
-}
-
-} // namespace
-} // namespace tensorflow
diff --git a/tensorflow/core/lib/gtl/priority_queue_util.h b/tensorflow/core/lib/gtl/priority_queue_util.h
index 07311e3725..93bf3d3037 100644
--- a/tensorflow/core/lib/gtl/priority_queue_util.h
+++ b/tensorflow/core/lib/gtl/priority_queue_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
-#define TENSORFLOW_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
+#define TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
#include <algorithm>
#include <queue>
@@ -52,4 +52,4 @@ T ConsumeTop(std::priority_queue<T, Container, Comparator>* q) {
} // namespace gtl
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_GTL_PRIORITY_QUEUE_UTIL_H_
diff --git a/tensorflow/core/lib/hash/crc32c.h b/tensorflow/core/lib/hash/crc32c.h
index ee0bda93b1..2718cd31b3 100644
--- a/tensorflow/core/lib/hash/crc32c.h
+++ b/tensorflow/core/lib/hash/crc32c.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_HASH_CRC32C_H_
-#define TENSORFLOW_LIB_HASH_CRC32C_H_
+#ifndef TENSORFLOW_CORE_LIB_HASH_CRC32C_H_
+#define TENSORFLOW_CORE_LIB_HASH_CRC32C_H_
#include <stddef.h>
#include "tensorflow/core/platform/types.h"
@@ -51,4 +51,4 @@ inline uint32 Unmask(uint32 masked_crc) {
} // namespace crc32c
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_HASH_CRC32C_H_
+#endif // TENSORFLOW_CORE_LIB_HASH_CRC32C_H_
diff --git a/tensorflow/core/lib/hash/hash.h b/tensorflow/core/lib/hash/hash.h
index 737d23f699..675bab7191 100644
--- a/tensorflow/core/lib/hash/hash.h
+++ b/tensorflow/core/lib/hash/hash.h
@@ -15,8 +15,8 @@ limitations under the License.
// Simple hash functions used for internal data structures
-#ifndef TENSORFLOW_LIB_HASH_HASH_H_
-#define TENSORFLOW_LIB_HASH_HASH_H_
+#ifndef TENSORFLOW_CORE_LIB_HASH_HASH_H_
+#define TENSORFLOW_CORE_LIB_HASH_HASH_H_
#include <stddef.h>
#include <stdint.h>
@@ -110,4 +110,4 @@ struct hash<std::pair<T, U>> {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_HASH_HASH_H_
+#endif // TENSORFLOW_CORE_LIB_HASH_HASH_H_
diff --git a/tensorflow/core/lib/histogram/histogram.h b/tensorflow/core/lib/histogram/histogram.h
index 65ce10786d..f882ee9abe 100644
--- a/tensorflow/core/lib/histogram/histogram.h
+++ b/tensorflow/core/lib/histogram/histogram.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
-#define TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+#ifndef TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_
+#define TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_
#include <string>
#include <vector>
@@ -136,4 +136,4 @@ class ThreadSafeHistogram {
} // namespace histogram
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_HISTOGRAM_HISTOGRAM_H_
+#endif // TENSORFLOW_CORE_LIB_HISTOGRAM_HISTOGRAM_H_
diff --git a/tensorflow/core/lib/io/block_builder.h b/tensorflow/core/lib/io/block_builder.h
index e2927689d2..117b6a0bb8 100644
--- a/tensorflow/core/lib/io/block_builder.h
+++ b/tensorflow/core/lib/io/block_builder.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <stdint.h>
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace table {
diff --git a/tensorflow/core/lib/io/buffered_inputstream.h b/tensorflow/core/lib/io/buffered_inputstream.h
index 924619f40f..96a95b7ed9 100644
--- a/tensorflow/core/lib/io/buffered_inputstream.h
+++ b/tensorflow/core/lib/io/buffered_inputstream.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
-#define TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#define TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_
#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/file_system.h"
@@ -104,4 +104,4 @@ class BufferedInputStream : public InputStreamInterface {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_BUFFERED_INPUTSTREAM_H_
+#endif // TENSORFLOW_CORE_LIB_IO_BUFFERED_INPUTSTREAM_H_
diff --git a/tensorflow/core/lib/io/inputstream_interface.h b/tensorflow/core/lib/io/inputstream_interface.h
index 3083d20776..cbfc509d93 100644
--- a/tensorflow/core/lib/io/inputstream_interface.h
+++ b/tensorflow/core/lib/io/inputstream_interface.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_INPUTSTREAM_INTERFACE_H_
-#define TENSORFLOW_LIB_IO_INPUTSTREAM_INTERFACE_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_
+#define TENSORFLOW_CORE_LIB_IO_INPUTSTREAM_INTERFACE_H_
#include <string>
#include "tensorflow/core/lib/core/status.h"
diff --git a/tensorflow/core/lib/io/path.cc b/tensorflow/core/lib/io/path.cc
index b62206012c..b75dcecadf 100644
--- a/tensorflow/core/lib/io/path.cc
+++ b/tensorflow/core/lib/io/path.cc
@@ -42,7 +42,7 @@ string JoinPathImpl(std::initializer_list<StringPiece> paths) {
if (path.empty()) continue;
if (result.empty()) {
- result = std::string(path);
+ result = string(path);
continue;
}
@@ -124,7 +124,7 @@ StringPiece Extension(StringPiece path) {
}
string CleanPath(StringPiece unclean_path) {
- string path = std::string(unclean_path);
+ string path(unclean_path);
const char* src = path.c_str();
string::iterator dst = path.begin();
@@ -237,7 +237,7 @@ void ParseURI(StringPiece remaining, StringPiece* scheme, StringPiece* host,
string CreateURI(StringPiece scheme, StringPiece host, StringPiece path) {
if (scheme.empty()) {
- return std::string(path);
+ return string(path);
}
return strings::StrCat(scheme, "://", host, path);
}
diff --git a/tensorflow/core/lib/io/path.h b/tensorflow/core/lib/io/path.h
index 818ba99888..38fb0c5d86 100644
--- a/tensorflow/core/lib/io/path.h
+++ b/tensorflow/core/lib/io/path.h
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_PATH_H_
-#define TENSORFLOW_LIB_IO_PATH_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_PATH_H_
+#define TENSORFLOW_CORE_LIB_IO_PATH_H_
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace io {
@@ -94,4 +95,4 @@ string GetTempFilename(const string& extension);
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_PATH_H_
+#endif // TENSORFLOW_CORE_LIB_IO_PATH_H_
diff --git a/tensorflow/core/lib/io/path_test.cc b/tensorflow/core/lib/io/path_test.cc
index e3275b93b6..0090b9100c 100644
--- a/tensorflow/core/lib/io/path_test.cc
+++ b/tensorflow/core/lib/io/path_test.cc
@@ -104,9 +104,9 @@ TEST(PathTest, CleanPath) {
StringPiece u(uri); \
StringPiece s, h, p; \
ParseURI(u, &s, &h, &p); \
- EXPECT_EQ(scheme, s.ToString()); \
- EXPECT_EQ(host, h.ToString()); \
- EXPECT_EQ(path, p.ToString()); \
+ EXPECT_EQ(scheme, s); \
+ EXPECT_EQ(host, h); \
+ EXPECT_EQ(path, p); \
EXPECT_EQ(uri, CreateURI(scheme, host, path)); \
EXPECT_LE(u.begin(), s.begin()); \
EXPECT_GE(u.end(), s.begin()); \
diff --git a/tensorflow/core/lib/io/proto_encode_helper.h b/tensorflow/core/lib/io/proto_encode_helper.h
index f70e1cbaab..34905520f1 100644
--- a/tensorflow/core/lib/io/proto_encode_helper.h
+++ b/tensorflow/core/lib/io/proto_encode_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_PROTO_ENCODE_HELPER_H_
-#define TENSORFLOW_LIB_IO_PROTO_ENCODE_HELPER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_
+#define TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -95,4 +95,4 @@ class ProtoEncodeHelper {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_PROTO_ENCODE_HELPER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_PROTO_ENCODE_HELPER_H_
diff --git a/tensorflow/core/lib/io/random_inputstream.h b/tensorflow/core/lib/io/random_inputstream.h
index bdbdbd71ff..c822fe50e9 100644
--- a/tensorflow/core/lib/io/random_inputstream.h
+++ b/tensorflow/core/lib/io/random_inputstream.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_RANDOM_INPUTSTREAM_H_
-#define TENSORFLOW_LIB_IO_RANDOM_INPUTSTREAM_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_
+#define TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_
#include "tensorflow/core/lib/io/inputstream_interface.h"
#include "tensorflow/core/platform/file_system.h"
@@ -54,4 +54,4 @@ class RandomAccessInputStream : public InputStreamInterface {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_RANDOM_INPUTSTREAM_H_
+#endif // TENSORFLOW_CORE_LIB_IO_RANDOM_INPUTSTREAM_H_
diff --git a/tensorflow/core/lib/io/record_reader.cc b/tensorflow/core/lib/io/record_reader.cc
index c24628be57..e22adcd569 100644
--- a/tensorflow/core/lib/io/record_reader.cc
+++ b/tensorflow/core/lib/io/record_reader.cc
@@ -108,10 +108,60 @@ Status RecordReader::ReadChecksummed(uint64 offset, size_t n, string* result) {
return Status::OK();
}
-Status RecordReader::ReadRecord(uint64* offset, string* record) {
- static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
- static const size_t kFooterSize = sizeof(uint32);
+Status RecordReader::GetMetadata(Metadata* md) {
+ if (!md) {
+ return errors::InvalidArgument(
+ "Metadata object call to GetMetadata() was null");
+ }
+
+ // Compute the metadata of the TFRecord file if not cached.
+ if (!cached_metadata_) {
+ TF_RETURN_IF_ERROR(input_stream_->Reset());
+
+ int64 data_size = 0;
+ int64 entries = 0;
+
+ // Within the loop, we always increment offset positively, so this
+ // loop should be guaranteed to either return after reaching EOF
+ // or encountering an error.
+ uint64 offset = 0;
+ string record;
+ while (true) {
+ // Read header, containing size of data.
+ Status s = ReadChecksummed(offset, sizeof(uint64), &record);
+ if (!s.ok()) {
+ if (errors::IsOutOfRange(s)) {
+ // We should reach out of range when the record file is complete.
+ break;
+ }
+ return s;
+ }
+
+ // Read the length of the data.
+ const uint64 length = core::DecodeFixed64(record.data());
+
+ // Skip reading the actual data since we just want the number
+ // of records and the size of the data.
+ TF_RETURN_IF_ERROR(input_stream_->SkipNBytes(length + kFooterSize));
+ offset += kHeaderSize + length + kFooterSize;
+
+ // Increment running stats.
+ data_size += length;
+ ++entries;
+ }
+ cached_metadata_.reset(new Metadata());
+ cached_metadata_->stats.entries = entries;
+ cached_metadata_->stats.data_size = data_size;
+ cached_metadata_->stats.file_size =
+ data_size + (kHeaderSize + kFooterSize) * entries;
+ }
+
+ md->stats = cached_metadata_->stats;
+ return Status::OK();
+}
+
+Status RecordReader::ReadRecord(uint64* offset, string* record) {
// Position the input stream.
int64 curr_pos = input_stream_->Tell();
int64 desired_pos = static_cast<int64>(*offset);
diff --git a/tensorflow/core/lib/io/record_reader.h b/tensorflow/core/lib/io/record_reader.h
index f6d587dfa0..17444660d4 100644
--- a/tensorflow/core/lib/io/record_reader.h
+++ b/tensorflow/core/lib/io/record_reader.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_RECORD_READER_H_
-#define TENSORFLOW_LIB_IO_RECORD_READER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_
+#define TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -58,6 +58,26 @@ class RecordReaderOptions {
// Note: this class is not thread safe; external synchronization required.
class RecordReader {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
+ // Statistics (sizes are in units of bytes)
+ struct Stats {
+ int64 file_size = -1;
+ int64 data_size = -1;
+ int64 entries = -1; // Number of values
+ };
+
+ // Metadata for the TFRecord file.
+ struct Metadata {
+ Stats stats;
+ };
+
// Create a reader that will return log records from "*file".
// "*file" must remain live while this Reader is in use.
explicit RecordReader(
@@ -71,6 +91,17 @@ class RecordReader {
// OUT_OF_RANGE for end of file, or something else for an error.
Status ReadRecord(uint64* offset, string* record);
+ // Return the metadata of the Record file.
+ //
+ // The current implementation scans the file to completion,
+ // skipping over the data regions, to extract the metadata once
+ // on the first call to GetStats(). An improved implementation
+ // would change RecordWriter to write the metadata into TFRecord
+ // so that GetMetadata() could be a const method.
+ //
+ // 'metadata' must not be nullptr.
+ Status GetMetadata(Metadata* md);
+
private:
Status ReadChecksummed(uint64 offset, size_t n, string* result);
@@ -78,6 +109,8 @@ class RecordReader {
std::unique_ptr<InputStreamInterface> input_stream_;
bool last_read_failed_;
+ std::unique_ptr<Metadata> cached_metadata_;
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordReader);
};
@@ -122,4 +155,4 @@ class SequentialRecordReader {
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_RECORD_READER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_RECORD_READER_H_
diff --git a/tensorflow/core/lib/io/record_reader_writer_test.cc b/tensorflow/core/lib/io/record_reader_writer_test.cc
index 13bea1f8f1..a88d34d293 100644
--- a/tensorflow/core/lib/io/record_reader_writer_test.cc
+++ b/tensorflow/core/lib/io/record_reader_writer_test.cc
@@ -147,6 +147,13 @@ TEST(RecordReaderWriterTest, TestBasics) {
EXPECT_EQ("abc", record);
TF_CHECK_OK(reader.ReadRecord(&offset, &record));
EXPECT_EQ("defg", record);
+
+ io::RecordReader::Metadata md;
+ TF_ASSERT_OK(reader.GetMetadata(&md));
+ EXPECT_EQ(2, md.stats.entries);
+ EXPECT_EQ(7, md.stats.data_size);
+ // Two entries have 16 bytes of header/footer each.
+ EXPECT_EQ(39, md.stats.file_size);
}
}
}
diff --git a/tensorflow/core/lib/io/record_writer.cc b/tensorflow/core/lib/io/record_writer.cc
index 6e71d23e71..2c6db2487e 100644
--- a/tensorflow/core/lib/io/record_writer.cc
+++ b/tensorflow/core/lib/io/record_writer.cc
@@ -88,10 +88,6 @@ RecordWriter::~RecordWriter() {
}
}
-static uint32 MaskedCrc(const char* data, size_t n) {
- return crc32c::Mask(crc32c::Value(data, n));
-}
-
Status RecordWriter::WriteRecord(StringPiece data) {
if (dest_ == nullptr) {
return Status(::tensorflow::error::FAILED_PRECONDITION,
@@ -102,13 +98,10 @@ Status RecordWriter::WriteRecord(StringPiece data) {
// uint32 masked crc of length
// byte data[length]
// uint32 masked crc of data
- char header[sizeof(uint64) + sizeof(uint32)];
- core::EncodeFixed64(header + 0, data.size());
- core::EncodeFixed32(header + sizeof(uint64),
- MaskedCrc(header, sizeof(uint64)));
- char footer[sizeof(uint32)];
- core::EncodeFixed32(footer, MaskedCrc(data.data(), data.size()));
-
+ char header[kHeaderSize];
+ char footer[kFooterSize];
+ PopulateHeader(header, data.data(), data.size());
+ PopulateFooter(footer, data.data(), data.size());
TF_RETURN_IF_ERROR(dest_->Append(StringPiece(header, sizeof(header))));
TF_RETURN_IF_ERROR(dest_->Append(data));
return dest_->Append(StringPiece(footer, sizeof(footer)));
diff --git a/tensorflow/core/lib/io/record_writer.h b/tensorflow/core/lib/io/record_writer.h
index daed809af3..1212e1fafb 100644
--- a/tensorflow/core/lib/io/record_writer.h
+++ b/tensorflow/core/lib/io/record_writer.h
@@ -13,11 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_RECORD_WRITER_H_
-#define TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
+#define TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
+#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/hash/crc32c.h"
#if !defined(IS_SLIM_BUILD)
#include "tensorflow/core/lib/io/zlib_compression_options.h"
#include "tensorflow/core/lib/io/zlib_outputbuffer.h"
@@ -41,12 +43,20 @@ class RecordWriterOptions {
// Options specific to zlib compression.
#if !defined(IS_SLIM_BUILD)
- ZlibCompressionOptions zlib_options;
+ tensorflow::io::ZlibCompressionOptions zlib_options;
#endif // IS_SLIM_BUILD
};
class RecordWriter {
public:
+ // Format of a single record:
+ // uint64 length
+ // uint32 masked crc of length
+ // byte data[length]
+ // uint32 masked crc of data
+ static const size_t kHeaderSize = sizeof(uint64) + sizeof(uint32);
+ static const size_t kFooterSize = sizeof(uint32);
+
// Create a writer that will append data to "*dest".
// "*dest" must be initially empty.
// "*dest" must remain live while this Writer is in use.
@@ -72,14 +82,36 @@ class RecordWriter {
// are invalid.
Status Close();
+ // Utility method to populate TFRecord headers. Populates record-header in
+ // "header[0,kHeaderSize-1]". The record-header is based on data[0, n-1].
+ inline static void PopulateHeader(char* header, const char* data, size_t n);
+
+ // Utility method to populate TFRecord footers. Populates record-footer in
+ // "footer[0,kFooterSize-1]". The record-footer is based on data[0, n-1].
+ inline static void PopulateFooter(char* footer, const char* data, size_t n);
+
private:
WritableFile* dest_;
RecordWriterOptions options_;
+ inline static uint32 MaskedCrc(const char* data, size_t n) {
+ return crc32c::Mask(crc32c::Value(data, n));
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(RecordWriter);
};
+void RecordWriter::PopulateHeader(char* header, const char* data, size_t n) {
+ core::EncodeFixed64(header + 0, n);
+ core::EncodeFixed32(header + sizeof(uint64),
+ MaskedCrc(header, sizeof(uint64)));
+}
+
+void RecordWriter::PopulateFooter(char* footer, const char* data, size_t n) {
+ core::EncodeFixed32(footer, MaskedCrc(data, n));
+}
+
} // namespace io
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_RECORD_WRITER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_RECORD_WRITER_H_
diff --git a/tensorflow/core/lib/io/recordio_test.cc b/tensorflow/core/lib/io/recordio_test.cc
index da514bd21c..946d7188d3 100644
--- a/tensorflow/core/lib/io/recordio_test.cc
+++ b/tensorflow/core/lib/io/recordio_test.cc
@@ -58,7 +58,7 @@ class StringDest : public WritableFile {
Status Close() override { return Status::OK(); }
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& slice) override {
+ Status Append(StringPiece slice) override {
contents_->append(slice.data(), slice.size());
return Status::OK();
}
diff --git a/tensorflow/core/lib/io/table.h b/tensorflow/core/lib/io/table.h
index a1b78eae5b..b9c6b8d9d2 100644
--- a/tensorflow/core/lib/io/table.h
+++ b/tensorflow/core/lib/io/table.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_TABLE_H_
-#define TENSORFLOW_LIB_IO_TABLE_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_H_
+#define TENSORFLOW_CORE_LIB_IO_TABLE_H_
#include <stdint.h>
#include "tensorflow/core/lib/io/iterator.h"
@@ -84,4 +84,4 @@ class Table {
} // namespace table
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_TABLE_H_
+#endif // TENSORFLOW_CORE_LIB_IO_TABLE_H_
diff --git a/tensorflow/core/lib/io/table_builder.h b/tensorflow/core/lib/io/table_builder.h
index 0202f90446..0e37e0a77f 100644
--- a/tensorflow/core/lib/io/table_builder.h
+++ b/tensorflow/core/lib/io/table_builder.h
@@ -21,8 +21,8 @@ limitations under the License.
// non-const method, all threads accessing the same TableBuilder must use
// external synchronization.
-#ifndef TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
-#define TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_
+#define TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_
#include <stdint.h>
#include "tensorflow/core/lib/core/status.h"
@@ -96,4 +96,4 @@ class TableBuilder {
} // namespace table
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_TABLE_BUILDER_H_
+#endif // TENSORFLOW_CORE_LIB_IO_TABLE_BUILDER_H_
diff --git a/tensorflow/core/lib/io/table_options.h b/tensorflow/core/lib/io/table_options.h
index fd8a9d4a78..9a36bf1631 100644
--- a/tensorflow/core/lib/io/table_options.h
+++ b/tensorflow/core/lib/io/table_options.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
-#define TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+#ifndef TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_
+#define TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_
#include <stddef.h>
@@ -65,4 +65,4 @@ struct Options {
} // namespace table
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_IO_TABLE_OPTIONS_H_
+#endif // TENSORFLOW_CORE_LIB_IO_TABLE_OPTIONS_H_
diff --git a/tensorflow/core/lib/io/table_test.cc b/tensorflow/core/lib/io/table_test.cc
index 9e3309f0a7..9cebbf40c6 100644
--- a/tensorflow/core/lib/io/table_test.cc
+++ b/tensorflow/core/lib/io/table_test.cc
@@ -98,7 +98,7 @@ class StringSink : public WritableFile {
Status Flush() override { return Status::OK(); }
Status Sync() override { return Status::OK(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
contents_.append(data.data(), data.size());
return Status::OK();
}
@@ -147,7 +147,7 @@ class Constructor {
virtual ~Constructor() {}
void Add(const string& key, const StringPiece& value) {
- data_[key] = std::string(value);
+ data_[key] = string(value);
}
// Finish constructing the data structure with all the keys that have
@@ -188,7 +188,7 @@ class BlockConstructor : public Constructor {
builder.Add(it->first, it->second);
}
// Open the block
- data_ = std::string(builder.Finish());
+ data_ = string(builder.Finish());
BlockContents contents;
contents.data = data_;
contents.cachable = false;
@@ -515,7 +515,7 @@ TEST_F(Harness, Randomized) {
for (int e = 0; e < num_entries; e++) {
string v;
Add(test::RandomKey(&rnd, rnd.Skewed(4)),
- std::string(test::RandomString(&rnd, rnd.Skewed(5), &v)));
+ string(test::RandomString(&rnd, rnd.Skewed(5), &v)));
}
Test(&rnd);
}
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.cc b/tensorflow/core/lib/io/zlib_outputbuffer.cc
index 84b47c171f..cba139e6ad 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.cc
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.cc
@@ -143,7 +143,7 @@ Status ZlibOutputBuffer::FlushOutputBufferToFile() {
return Status::OK();
}
-Status ZlibOutputBuffer::Append(const StringPiece& data) {
+Status ZlibOutputBuffer::Append(StringPiece data) {
// If there is sufficient free space in z_stream_input_ to fit data we
// add it there and return.
// If there isn't enough space we deflate the existing contents of
diff --git a/tensorflow/core/lib/io/zlib_outputbuffer.h b/tensorflow/core/lib/io/zlib_outputbuffer.h
index 3d86d89a99..ccad2fda44 100644
--- a/tensorflow/core/lib/io/zlib_outputbuffer.h
+++ b/tensorflow/core/lib/io/zlib_outputbuffer.h
@@ -62,7 +62,7 @@ class ZlibOutputBuffer : public WritableFile {
// to file when the buffer is full.
//
// To immediately write contents to file call `Flush()`.
- Status Append(const StringPiece& data) override;
+ Status Append(StringPiece data) override;
// Deflates any cached input and writes all output to file.
Status Flush() override;
diff --git a/tensorflow/core/lib/jpeg/jpeg_handle.h b/tensorflow/core/lib/jpeg/jpeg_handle.h
index 7d86be51da..86fa3ac5c2 100644
--- a/tensorflow/core/lib/jpeg/jpeg_handle.h
+++ b/tensorflow/core/lib/jpeg/jpeg_handle.h
@@ -16,8 +16,8 @@ limitations under the License.
// This file declares the functions and structures for memory I/O with libjpeg
// These functions are not meant to be used directly, see jpeg_mem.h instead.
-#ifndef TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
-#define TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+#ifndef TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_
+#define TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_
#include "tensorflow/core/platform/jpeg.h"
#include "tensorflow/core/platform/types.h"
@@ -57,4 +57,4 @@ void SetDest(j_compress_ptr cinfo, void *buffer, int bufsize,
} // namespace jpeg
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_JPEG_JPEG_HANDLE_H_
+#endif // TENSORFLOW_CORE_LIB_JPEG_JPEG_HANDLE_H_
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.cc b/tensorflow/core/lib/jpeg/jpeg_mem.cc
index 50ed8bdb3b..f7a359eb5b 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.cc
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.cc
@@ -152,7 +152,9 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
cinfo.scale_denom = ratio;
cinfo.dct_method = flags.dct_method;
- jpeg_start_decompress(&cinfo);
+ // Determine the output image size before attempting decompress to prevent
+ // OOM'ing doing the decompress
+ jpeg_calc_output_dimensions(&cinfo);
int64 total_size = static_cast<int64>(cinfo.output_height) *
static_cast<int64>(cinfo.output_width);
@@ -170,6 +172,8 @@ uint8* UncompressLow(const void* srcdata, FewerArgsForCompiler* argball) {
return nullptr;
}
+ jpeg_start_decompress(&cinfo);
+
JDIMENSION target_output_width = cinfo.output_width;
JDIMENSION target_output_height = cinfo.output_height;
JDIMENSION skipped_scanlines = 0;
diff --git a/tensorflow/core/lib/jpeg/jpeg_mem.h b/tensorflow/core/lib/jpeg/jpeg_mem.h
index 59342d28c0..03437a4e78 100644
--- a/tensorflow/core/lib/jpeg/jpeg_mem.h
+++ b/tensorflow/core/lib/jpeg/jpeg_mem.h
@@ -18,8 +18,8 @@ limitations under the License.
// (data array and size fields).
// Direct manipulation of JPEG strings are supplied: Flip, Rotate, Crop..
-#ifndef TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
-#define TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+#ifndef TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_
+#define TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_
#include <functional>
#include <string>
@@ -159,4 +159,4 @@ bool Compress(const void* srcdata, int width, int height,
} // namespace jpeg
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_JPEG_JPEG_MEM_H_
+#endif // TENSORFLOW_CORE_LIB_JPEG_JPEG_MEM_H_
diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h
index 41d486f2bd..502d741512 100644
--- a/tensorflow/core/lib/math/math_util.h
+++ b/tensorflow/core/lib/math/math_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_MATH_MATH_UTIL_H_
-#define TENSORFLOW_LIB_MATH_MATH_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
+#define TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
#include <type_traits>
@@ -160,4 +160,4 @@ T MathUtil::IPow(T base, int exp) {
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_MATH_MATH_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_MATH_MATH_UTIL_H_
diff --git a/tensorflow/core/lib/monitoring/collection_registry.cc b/tensorflow/core/lib/monitoring/collection_registry.cc
index 8c28620ff9..fface033cb 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.cc
+++ b/tensorflow/core/lib/monitoring/collection_registry.cc
@@ -38,15 +38,15 @@ void Collector::CollectMetricDescriptor(
mutex_lock l(mu_);
return collected_metrics_->metric_descriptor_map
.insert(std::make_pair(
- std::string(metric_def->name()),
+ string(metric_def->name()),
std::unique_ptr<MetricDescriptor>(new MetricDescriptor())))
.first->second.get();
}();
- metric_descriptor->name = std::string(metric_def->name());
- metric_descriptor->description = std::string(metric_def->description());
+ metric_descriptor->name = string(metric_def->name());
+ metric_descriptor->description = string(metric_def->description());
for (const StringPiece label_name : metric_def->label_descriptions()) {
- metric_descriptor->label_names.push_back(std::string(label_name));
+ metric_descriptor->label_names.emplace_back(label_name);
}
metric_descriptor->metric_kind = metric_def->kind();
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 20f0444f8b..9e4e1989dd 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace monitoring {
@@ -72,7 +73,7 @@ class MetricCollector {
registration_time_millis_(registration_time_millis),
collector_(collector),
point_set_(point_set) {
- point_set_->metric_name = std::string(metric_def->name());
+ point_set_->metric_name = string(metric_def->name());
}
const MetricDef<metric_kind, Value, NumLabels>* const metric_def_;
@@ -261,7 +262,7 @@ class Collector {
auto* const point_set = [&]() {
mutex_lock l(mu_);
return collected_metrics_->point_set_map
- .insert(std::make_pair(std::string(metric_def->name()),
+ .insert(std::make_pair(string(metric_def->name()),
std::unique_ptr<PointSet>(new PointSet())))
.first->second.get();
}();
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index 6f94685665..bc4365e439 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace monitoring {
@@ -98,8 +99,8 @@ class AbstractMetricDef {
const std::vector<string>& label_descriptions)
: kind_(kind),
value_type_(value_type),
- name_(std::string(name)),
- description_(std::string(description)),
+ name_(name),
+ description_(description),
label_descriptions_(std::vector<string>(label_descriptions.begin(),
label_descriptions.end())) {}
diff --git a/tensorflow/core/lib/png/png_io.h b/tensorflow/core/lib/png/png_io.h
index bb5d20fb68..c876c5156a 100644
--- a/tensorflow/core/lib/png/png_io.h
+++ b/tensorflow/core/lib/png/png_io.h
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/png.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace png {
diff --git a/tensorflow/core/lib/random/distribution_sampler.h b/tensorflow/core/lib/random/distribution_sampler.h
index 25605d8ed4..7aa50ece03 100644
--- a/tensorflow/core/lib/random/distribution_sampler.h
+++ b/tensorflow/core/lib/random/distribution_sampler.h
@@ -28,8 +28,8 @@ limitations under the License.
//
// The algorithm used is Walker's Aliasing algorithm, described in Knuth, Vol 2.
-#ifndef TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
-#define TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
#include <memory>
#include <utility>
@@ -91,4 +91,4 @@ class DistributionSampler {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_DISTRIBUTION_SAMPLER_H_
diff --git a/tensorflow/core/lib/random/philox_random.h b/tensorflow/core/lib/random/philox_random.h
index b2adb4462b..058ed95ffb 100644
--- a/tensorflow/core/lib/random/philox_random.h
+++ b/tensorflow/core/lib/random/philox_random.h
@@ -17,8 +17,8 @@ limitations under the License.
// Salmon et al. SC 2011. Parallel random numbers: as easy as 1, 2, 3.
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
-#ifndef TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
-#define TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
#include <stdlib.h>
@@ -248,4 +248,4 @@ class PhiloxRandom {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_PHILOX_RANDOM_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/lib/random/random_distributions.h b/tensorflow/core/lib/random/random_distributions.h
index e963511f5c..c3801a0412 100644
--- a/tensorflow/core/lib/random/random_distributions.h
+++ b/tensorflow/core/lib/random/random_distributions.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
-#define TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
#define _USE_MATH_DEFINES
#include <math.h>
@@ -744,4 +744,4 @@ PHILOX_DEVICE_INLINE double Uint64ToDouble(uint32 x0, uint32 x1) {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_RANDOM_DISTRIBUTIONS_H_
diff --git a/tensorflow/core/lib/random/simple_philox.h b/tensorflow/core/lib/random/simple_philox.h
index d529e08913..6464036856 100644
--- a/tensorflow/core/lib/random/simple_philox.h
+++ b/tensorflow/core/lib/random/simple_philox.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
-#define TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#ifndef TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_
+#define TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_
#include <math.h>
#include <string.h>
@@ -73,4 +73,4 @@ class SimplePhilox {
} // namespace random
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_RANDOM_SIMPLE_PHILOX_H_
+#endif // TENSORFLOW_CORE_LIB_RANDOM_SIMPLE_PHILOX_H_
diff --git a/tensorflow/core/lib/strings/numbers.h b/tensorflow/core/lib/strings/numbers.h
index 1d5bacac93..959290ba8c 100644
--- a/tensorflow/core/lib/strings/numbers.h
+++ b/tensorflow/core/lib/strings/numbers.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_STRINGS_NUMBERS_H_
-#define TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_
#include <string>
@@ -140,11 +140,11 @@ inline bool ProtoParseNumeric(StringPiece s, uint64* value) {
}
inline bool ProtoParseNumeric(StringPiece s, float* value) {
- return safe_strtof(std::string(s).c_str(), value);
+ return safe_strtof(s, value);
}
inline bool ProtoParseNumeric(StringPiece s, double* value) {
- return safe_strtod(std::string(s).c_str(), value);
+ return safe_strtod(s, value);
}
// Convert strings to number of type T.
@@ -176,4 +176,4 @@ string HumanReadableElapsedTime(double seconds);
} // namespace strings
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_NUMBERS_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_NUMBERS_H_
diff --git a/tensorflow/core/lib/strings/str_util.cc b/tensorflow/core/lib/strings/str_util.cc
index cab8f81585..3aba5ec80e 100644
--- a/tensorflow/core/lib/strings/str_util.cc
+++ b/tensorflow/core/lib/strings/str_util.cc
@@ -332,7 +332,7 @@ string StringReplace(StringPiece s, StringPiece oldsub, StringPiece newsub,
bool replace_all) {
// TODO(jlebar): We could avoid having to shift data around in the string if
// we had a StringPiece::find() overload that searched for a StringPiece.
- string res = std::string(s);
+ string res(s);
size_t pos = 0;
while ((pos = res.find(oldsub.data(), pos, oldsub.size())) != string::npos) {
res.replace(pos, oldsub.size(), newsub.data(), newsub.size());
@@ -448,8 +448,7 @@ bool SplitAndParseAsFloats(StringPiece text, char delim,
std::vector<float>* result) {
return SplitAndParseAsInts<float>(text, delim,
[](StringPiece str, float* value) {
- return strings::safe_strtof(
- std::string(str).c_str(), value);
+ return strings::safe_strtof(str, value);
},
result);
}
diff --git a/tensorflow/core/lib/strings/str_util.h b/tensorflow/core/lib/strings/str_util.h
index c887db7eff..9f52cf29fc 100644
--- a/tensorflow/core/lib/strings/str_util.h
+++ b/tensorflow/core/lib/strings/str_util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
-#define TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_
#include <functional>
#include <string>
@@ -205,7 +205,7 @@ std::vector<string> Split(StringPiece text, StringPiece delims, Predicate p) {
if ((i == text.size()) || (delims.find(text[i]) != StringPiece::npos)) {
StringPiece token(text.data() + token_start, i - token_start);
if (p(token)) {
- result.push_back(std::string(token));
+ result.emplace_back(token);
}
token_start = i + 1;
}
@@ -231,4 +231,4 @@ size_t Strnlen(const char* str, const size_t string_max_len);
} // namespace str_util
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_STR_UTIL_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_STR_UTIL_H_
diff --git a/tensorflow/core/lib/strings/strcat.h b/tensorflow/core/lib/strings/strcat.h
index fb2cd5bc7e..a620f59447 100644
--- a/tensorflow/core/lib/strings/strcat.h
+++ b/tensorflow/core/lib/strings/strcat.h
@@ -17,8 +17,8 @@ limitations under the License.
// #category: operations on strings
// #summary: Merges strings or numbers with no delimiter.
//
-#ifndef TENSORFLOW_LIB_STRINGS_STRCAT_H_
-#define TENSORFLOW_LIB_STRINGS_STRCAT_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_
#include <string>
@@ -59,29 +59,29 @@ namespace tensorflow {
namespace strings {
enum PadSpec {
- NO_PAD = 1,
- ZERO_PAD_2,
- ZERO_PAD_3,
- ZERO_PAD_4,
- ZERO_PAD_5,
- ZERO_PAD_6,
- ZERO_PAD_7,
- ZERO_PAD_8,
- ZERO_PAD_9,
- ZERO_PAD_10,
- ZERO_PAD_11,
- ZERO_PAD_12,
- ZERO_PAD_13,
- ZERO_PAD_14,
- ZERO_PAD_15,
- ZERO_PAD_16,
+ kNoPad = 1,
+ kZeroPad2,
+ kZeroPad3,
+ kZeroPad4,
+ kZeroPad5,
+ kZeroPad6,
+ kZeroPad7,
+ kZeroPad8,
+ kZeroPad9,
+ kZeroPad10,
+ kZeroPad11,
+ kZeroPad12,
+ kZeroPad13,
+ kZeroPad14,
+ kZeroPad15,
+ kZeroPad16
};
struct Hex {
uint64 value;
enum PadSpec spec;
template <class Int>
- explicit Hex(Int v, PadSpec s = NO_PAD) : spec(s) {
+ explicit Hex(Int v, PadSpec s = kNoPad) : spec(s) {
// Prevent sign-extension by casting integers to
// their unsigned counterparts.
static_assert(
@@ -124,6 +124,9 @@ class AlphaNum {
AlphaNum(const StringPiece &pc) : piece_(pc) {} // NOLINT(runtime/explicit)
AlphaNum(const tensorflow::string &str) // NOLINT(runtime/explicit)
: piece_(str) {}
+ template <typename A>
+ AlphaNum(const std::basic_string<char, std::char_traits<char>, A> &str)
+ : piece_(str) {} // NOLINT(runtime/explicit)
StringPiece::size_type size() const { return piece_.size(); }
const char *data() const { return piece_.data(); }
@@ -233,4 +236,4 @@ inline void StrAppend(string *dest, const AlphaNum &a, const AlphaNum &b,
} // namespace strings
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_STRCAT_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_STRCAT_H_
diff --git a/tensorflow/core/lib/strings/strcat_test.cc b/tensorflow/core/lib/strings/strcat_test.cc
index 8cc64a6f0a..6c4e5526b1 100644
--- a/tensorflow/core/lib/strings/strcat_test.cc
+++ b/tensorflow/core/lib/strings/strcat_test.cc
@@ -308,11 +308,11 @@ TEST(StrAppend, Death) {
static void CheckHex64(uint64 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_16));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad16));
string expected = Printf("%016llx", static_cast<unsigned long long>(v));
EXPECT_EQ(expected, actual) << " decimal value " << v;
- actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
expected = Printf("%08llx", static_cast<unsigned long long>(v));
EXPECT_EQ(expected, actual) << " decimal value " << v;
@@ -323,7 +323,7 @@ static void CheckHex64(uint64 v) {
static void CheckHex32(uint32 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
string expected = Printf("%08x", v);
EXPECT_EQ(expected, actual) << " decimal value " << v;
@@ -334,7 +334,7 @@ static void CheckHex32(uint32 v) {
static void CheckHexSigned32(int32 v) {
using tensorflow::strings::Hex;
- string actual = StrCat(Hex(v, tensorflow::strings::ZERO_PAD_8));
+ string actual = StrCat(Hex(v, tensorflow::strings::kZeroPad8));
string expected = Printf("%08x", v);
EXPECT_EQ(expected, actual) << " decimal value " << v;
diff --git a/tensorflow/core/lib/strings/stringprintf.h b/tensorflow/core/lib/strings/stringprintf.h
index f7957252ea..52af410d42 100644
--- a/tensorflow/core/lib/strings/stringprintf.h
+++ b/tensorflow/core/lib/strings/stringprintf.h
@@ -20,8 +20,8 @@ limitations under the License.
// strings::SPrintf(&result, "%d %s\n", 10, "hello");
// strings::Appendf(&result, "%d %s\n", 20, "there");
-#ifndef TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
-#define TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+#ifndef TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
+#define TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
#include <stdarg.h>
#include <string>
@@ -49,4 +49,4 @@ extern void Appendv(string* dst, const char* format, va_list ap);
} // namespace strings
} // namespace tensorflow
-#endif // TENSORFLOW_LIB_STRINGS_STRINGPRINTF_H_
+#endif // TENSORFLOW_CORE_LIB_STRINGS_STRINGPRINTF_H_
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 36d939e061..c536b5688e 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -232,6 +232,11 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
"Bad audio format for WAV: Expected 1 (PCM), but got", audio_format);
}
TF_RETURN_IF_ERROR(ReadValue<uint16>(wav_string, channel_count, &offset));
+ if (*channel_count < 1) {
+ return errors::InvalidArgument(
+ "Bad number of channels for WAV: Expected at least 1, but got ",
+ *channel_count);
+ }
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, sample_rate, &offset));
uint32 bytes_per_second;
TF_RETURN_IF_ERROR(ReadValue<uint32>(wav_string, &bytes_per_second, &offset));
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 1f2e57e9a9..3d03bc1d5f 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -354,6 +354,27 @@ Status TransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Transpose", TransposeGrad);
+Status GatherNdGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ *g = FDH::Define(
+ // Arg defs
+ {"params: Tparams", "indices: Tindices", "doutput: Tparams"},
+ // Ret val defs
+ {"dparams: Tparams", "dindices: Tindices"},
+ // Attr defs
+ {"Tparams: type", "Tindices: type"},
+ // Nodes
+ {
+ {{"x_shape"}, "Shape", {"params"}, {{"T", "$Tparams"}}},
+ {{"dparams"}, "ScatterNd", {"indices", "doutput", "x_shape"},
+ {{"T", "$Tparams"}, {"Tindices", "$Tindices"}}},
+ {{"dindices"}, "ZerosLike", {"indices"}, {{"T", "$Tindices"}}},
+ });
+ // clang-format on
+ return Status::OK();
+}
+REGISTER_OP_GRADIENT("GatherNd", GatherNdGrad);
+
Status ConjugateTransposeGrad(const AttrSlice& attrs, FunctionDef* g) {
*g = FDH::Define(
// Arg defs
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index ef8ad7972c..f55562ec99 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -133,6 +133,14 @@ Status TransposeShapeFn(InferenceContext* c) {
} else {
rank = perm->NumElements();
}
+ if (!c->RankKnown(input) && rank < 2) {
+ // A permutation array containing a single element is ambiguous. It could
+ // indicate either a scalar or a 1-dimensional array, both of which the
+ // transpose op returns unchanged.
+ c->set_output(0, input);
+ return Status::OK();
+ }
+
std::vector<DimensionHandle> dims;
dims.resize(rank);
TF_RETURN_IF_ERROR(c->WithRank(input, rank, &input));
@@ -427,7 +435,19 @@ REGISTER_OP("UnravelIndex")
.Input("dims: Tidx")
.Output("output: Tidx")
.Attr("Tidx: {int32, int64} = DT_INT32")
- .SetShapeFn([](InferenceContext* c) { return Status::OK(); });
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle indices = c->input(0);
+ ShapeHandle dims;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &dims));
+ if (c->RankKnown(indices) && c->Rank(indices) == 0) {
+ c->set_output(0, c->Vector(c->Dim(dims, 0)));
+ } else if (c->RankKnown(indices)) {
+ c->set_output(0, c->Matrix(c->Dim(dims, 0), c->NumElements(indices)));
+ } else {
+ c->set_output(0, c->UnknownShape());
+ }
+ return Status::OK();
+ });
REGISTER_OP("BroadcastTo")
.Input("input: T")
@@ -690,6 +710,16 @@ REGISTER_OP("Const")
return Status::OK();
});
+// Returns a constant tensor on the host. Useful for writing C++ tests
+// and benchmarks which run on GPU but require arguments pinned to the host.
+// Used by test::graph::HostConstant.
+// value: Attr `value` is the tensor to return.
+REGISTER_OP("HostConst")
+ .Output("output: dtype")
+ .Attr("value: tensor")
+ .Attr("dtype: type")
+ .SetShapeFn(shape_inference::UnknownShape);
+
// --------------------------------------------------------------------------
// TODO(mgubin): Update the doc when the freeze_graph script supports converting
// into memmapped format.
@@ -1424,6 +1454,30 @@ REGISTER_OP("ShapeN")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(ShapeShapeFn);
+REGISTER_OP("EnsureShape")
+ .Input("input: T")
+ .Output("output: T")
+ .Attr("shape: shape")
+ .Attr("T: type")
+ .SetShapeFn([](InferenceContext* c) {
+ // Merges desired shape and statically known shape of input
+ PartialTensorShape desired_shape;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &desired_shape));
+
+ int rank = desired_shape.dims();
+ ShapeHandle input_shape_handle;
+ ShapeHandle desired_shape_handle;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape_handle));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ desired_shape, &desired_shape_handle));
+
+ ShapeHandle merged_shape;
+ TF_RETURN_IF_ERROR(
+ c->Merge(desired_shape_handle, input_shape_handle, &merged_shape));
+ c->set_output(0, merged_shape);
+ return Status::OK();
+ });
+
// --------------------------------------------------------------------------
REGISTER_OP("ReverseSequence")
.Input("input: T")
@@ -1485,37 +1539,6 @@ REGISTER_OP("Size")
.Attr("out_type: {int32, int64} = DT_INT32")
.SetShapeFn(shape_inference::ScalarShape);
-namespace {
-
-// This SliceHelper processes the output shape of the `slice`
-// when the tensor of `sizes` is available.
-template <typename T>
-Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
- const Tensor* sizes_value,
- std::vector<DimensionHandle>* dims) {
- auto sizes_vec = sizes_value->vec<T>();
- for (int i = 0; i < sizes_value->NumElements(); ++i) {
- DimensionHandle dim = c->Dim(c->input(0), i);
- if (sizes_vec(i) != -1) {
- auto dim_val = c->Value(dim);
- if (sizes_vec(i) < 0) {
- return errors::InvalidArgument(
- "Out of bounds slicing on dimension ", i, " of length ", dim_val,
- ": sizes vector cannot be < -1, but was ", sizes_vec(i));
- }
-
- dims->emplace_back(c->MakeDim(sizes_vec(i)));
- } else {
- DimensionHandle result;
- TF_RETURN_IF_ERROR(c->Subtract(dim, c->Dim(begin_value, i), &result));
- dims->emplace_back(result);
- }
- }
-
- return Status::OK();
-}
-} // namespace
-
// --------------------------------------------------------------------------
REGISTER_OP("Slice")
.Input("input: T")
@@ -1524,83 +1547,22 @@ REGISTER_OP("Slice")
.Output("output: T")
.Attr("T: type")
.Attr("Index: {int32,int64}")
- .SetShapeFn([](InferenceContext* c) {
- ShapeHandle input = c->input(0);
- ShapeHandle begin_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &begin_shape));
- ShapeHandle sizes_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &sizes_shape));
-
- // Merge to check compatibility of begin and sizes tensors.
- TF_RETURN_IF_ERROR(c->Merge(begin_shape, sizes_shape, &begin_shape));
-
- DimensionHandle ndims = c->Dim(begin_shape, 0);
- if (c->ValueKnown(ndims)) {
- TF_RETURN_IF_ERROR(c->WithRank(input, c->Value(ndims), &input));
- }
-
- // NOTE(mrry): Use MakeShapeFromShapeTensor to handle partially-known
- // values, even though the `begin` value does not represent a shape.
- ShapeHandle begin_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
-
- // We check the tensor value here and will only use
- // `MakeShapeFromShapeTensor` when `sizes_value` is null.
- // The reason is that `sizes`might contain -1, which can't
- // be represented (-1 in the ShapeHandle would mean "unknown".
- const Tensor* sizes_value = c->input_tensor(2);
-
- if (sizes_value != nullptr) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, sizes_value->NumElements(), &begin_value));
- std::vector<DimensionHandle> dims;
- // If the begin and sizes tensors are available, then
- // we can be precise about the shape of the output.
- if (sizes_value->dtype() == DT_INT64) {
- TF_RETURN_IF_ERROR(
- SliceHelper<int64>(c, begin_value, sizes_value, &dims));
- } else {
- TF_RETURN_IF_ERROR(
- SliceHelper<int32>(c, begin_value, sizes_value, &dims));
- }
+ .SetShapeFn(shape_inference::SliceShape);
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- } else {
- // In case `sizes` is not available (`sizes_value` is null),
- // we could try to use `MakeShapeFromShapeTensor` here.
- // If sizes contain -1, we will simply consider it as `Unknown`.
- // This is less than ideal but still an improvement of shape inference.
- // The following is an example that returns [None, 1, None] with this
- // code path:
- // z = tf.zeros((1, 2, 3))
- // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
- // m.get_shape().as_list()
- ShapeHandle sizes_value;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
- if (c->RankKnown(sizes_value)) {
- TF_RETURN_IF_ERROR(
- c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
- std::vector<DimensionHandle> dims;
- dims.reserve(c->Rank(sizes_value));
- for (int i = 0; i < c->Rank(sizes_value); ++i) {
- dims.emplace_back(c->Dim(sizes_value, i));
- }
- c->set_output(0, c->MakeShape(dims));
- return Status::OK();
- }
-
- // We might know the rank of the input.
- if (c->RankKnown(input)) {
- c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));
- return Status::OK();
- } else {
- return shape_inference::UnknownShape(c);
- }
- }
-
- return Status::OK();
- });
+#ifdef INTEL_MKL
+REGISTER_OP("_MklSlice")
+ .Input("input: T")
+ .Input("begin: Index")
+ .Input("size: Index")
+ .Input("mkl_input: uint8")
+ .Input("mkl_begin: uint8")
+ .Input("mkl_size: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: type")
+ .Attr("Index: {int32,int64}")
+ .SetShapeFn(shape_inference::SliceShape);
+#endif
REGISTER_OP("StridedSlice")
.Input("input: T")
@@ -2549,6 +2511,116 @@ REGISTER_OP("ExtractImagePatches")
// --------------------------------------------------------------------------
+// To enable rates, uncomment all lines commented below and use ksize_*_eff
+// as the second parameter of all GetWindowedOutputSizeVerbose calls instead
+// of ksize_*.
+REGISTER_OP("ExtractVolumePatches")
+ .Input("input: T")
+ .Output("patches: T")
+ .Attr("ksizes: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ /* .Attr("rates: list(int) >= 5") */
+ .Attr("T: realnumbertype")
+ .Attr(GetPaddingAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+
+ std::vector<int32> ksizes;
+ TF_RETURN_IF_ERROR(c->GetAttr("ksizes", &ksizes));
+ if (ksizes.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the ksizes attribute to contain 5 "
+ "values, but got: ",
+ ksizes.size());
+ }
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the stride attribute to contain 5 "
+ "values, but got: ",
+ strides.size());
+ }
+
+ /*
+ // TODO(hsgkim): Enable rates.
+ // See extract_volume_patches_op.cc for why rates are disabled now.
+
+ std::vector<int32> rates;
+ TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
+ if (rates.size() != 5) {
+ return errors::InvalidArgument(
+ "ExtractVolumePatches requires the rates attribute to contain 5 "
+ "values, but got: ",
+ rates.size());
+ }
+ */
+
+ int32 ksize_planes = ksizes[1];
+ int32 ksize_rows = ksizes[2];
+ int32 ksize_cols = ksizes[3];
+
+ int32 stride_planes = strides[1];
+ int32 stride_rows = strides[2];
+ int32 stride_cols = strides[3];
+
+ /*
+ int32 rate_planes = rates[1];
+ int32 rate_rows = rates[2];
+ int32 rate_cols = rates[3];
+
+ int32 ksize_planes_eff = ksize_planes +
+ (ksize_planes - 1) * (rate_planes - 1);
+ int32 ksize_rows_eff = ksize_rows + (ksize_rows - 1) * (rate_rows - 1);
+ int32 ksize_cols_eff = ksize_cols + (ksize_cols - 1) * (rate_cols - 1);
+ */
+
+ DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
+ DimensionHandle in_planes_dim = c->Dim(input_shape, 1);
+ DimensionHandle in_rows_dim = c->Dim(input_shape, 2);
+ DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
+ DimensionHandle output_depth_dim;
+ TF_RETURN_IF_ERROR(c->Multiply(c->Dim(input_shape, 4),
+ ksize_planes * ksize_rows * ksize_cols,
+ &output_depth_dim));
+
+ if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
+ !c->ValueKnown(in_cols_dim)) {
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
+ InferenceContext::kUnknownDim, output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ }
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, ksize_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, ksize_rows, stride_rows, padding, &output_rows,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, ksize_cols, stride_cols, padding, &output_cols,
+ &padding_before, &padding_after));
+ ShapeHandle output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+ });
+
+// --------------------------------------------------------------------------
+
REGISTER_OP("Bitcast")
.Input("input: T")
.Output("output: type")
@@ -2870,6 +2942,34 @@ Status ScatterNdShape(InferenceContext* c) {
} // namespace
+REGISTER_OP("UpperBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
+REGISTER_OP("LowerBound")
+ .Input("sorted_inputs: T")
+ .Input("values: T")
+ .Output("output: out_type")
+ .Attr("T: type")
+ .Attr("out_type: {int32, int64} = DT_INT32")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &unused_shape));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &unused_shape));
+ c->set_output(0, c->input(1));
+ return Status::OK();
+ });
+
REGISTER_OP("ScatterNd")
.Input("indices: Tindices")
.Input("updates: T")
diff --git a/tensorflow/core/ops/array_ops_test.cc b/tensorflow/core/ops/array_ops_test.cc
index b1463338fb..1c29cd2491 100644
--- a/tensorflow/core/ops/array_ops_test.cc
+++ b/tensorflow/core/ops/array_ops_test.cc
@@ -27,6 +27,21 @@ limitations under the License.
namespace tensorflow {
+TEST(ArrayOpsTest, UnravelIndex_ShapeFn) {
+ ShapeInferenceTestOp op("UnravelIndex");
+
+ INFER_OK(op, "?;?", "?");
+
+ INFER_OK(op, "[];[?]", "[d1_0]");
+
+ INFER_OK(op, "[4,5];[?]", "[d1_0,20]");
+ INFER_OK(op, "[2,3,4];[?]", "[d1_0,24]");
+ INFER_OK(op, "?;[?]", "?");
+ INFER_OK(op, "[?];[?]", "[d1_0,?]");
+
+ INFER_ERROR("Shape must be rank 1 but is rank 2", op, "?;[1,1]");
+}
+
TEST(ArrayOpsTest, Pack_ShapeFn) {
ShapeInferenceTestOp op("Pack");
auto set_axis = [&op](int axis) {
@@ -960,6 +975,7 @@ TEST(ArrayOpsTest, Transpose_ShapeFn) {
INFER_OK(op, "?;[2]", "[?,?]");
INFER_OK(op, "[?,?];[2]", "[d0_1,d0_0]");
INFER_OK(op, "[1,?];[2]", "[d0_1,d0_0]");
+ INFER_OK(op, "?;[0]", "in0");
// Invalid arguments.
perm = test::AsTensor<int32>({1, 2});
@@ -1605,6 +1621,24 @@ TEST(ArrayOpsTest, Slice_ShapeFn) {
INFER_ERROR("cannot be < -1", op, "[2,3,4,5];[4];[4]");
}
+TEST(ArrayOpsTest, StridedSlice_ShapeFn) {
+ ShapeInferenceTestOp op("StridedSlice");
+ TF_ASSERT_OK(NodeDefBuilder("test", "StridedSlice")
+ .Input("input", 0, DT_FLOAT)
+ .Input("begin", 1, DT_INT32)
+ .Input("end", 2, DT_INT32)
+ .Input("strides", 3, DT_INT32)
+ .Attr("shrink_axis_mask", 1)
+ .Finalize(&op.node_def));
+ op.input_tensors.resize(4);
+ Tensor strides = test::AsTensor<int32>({1});
+ op.input_tensors[3] = &strides;
+ // Slicing on the 0-th dimension.
+ INFER_OK(op, "[2,3,4,5];[1];[1];[1]", "[3,4,5]");
+ // Slicing on the 0-th dimension. This time some of the result dimension is 0.
+ INFER_OK(op, "[2,0,3,4];[1];[1];[1]", "[0,3,4]");
+}
+
TEST(ArrayOpsTest, StridedSliceGrad_ShapeFn) {
ShapeInferenceTestOp op("StridedSliceGrad");
op.input_tensors.resize(5);
diff --git a/tensorflow/core/ops/boosted_trees_ops.cc b/tensorflow/core/ops/boosted_trees_ops.cc
index 01452b3e85..b8cf538554 100644
--- a/tensorflow/core/ops/boosted_trees_ops.cc
+++ b/tensorflow/core/ops/boosted_trees_ops.cc
@@ -22,6 +22,10 @@ limitations under the License.
namespace tensorflow {
+using shape_inference::DimensionHandle;
+using shape_inference::InferenceContext;
+using shape_inference::ShapeHandle;
+
REGISTER_RESOURCE_HANDLE_OP(BoostedTreesEnsembleResource);
REGISTER_OP("IsBoostedTreesEnsembleInitialized")
@@ -176,6 +180,8 @@ REGISTER_OP("BoostedTreesMakeStatsSummary")
return Status::OK();
});
+// TODO(nponomareva): when/if creating the new op for unbucketized data, rename
+// bucketized_features to features.
REGISTER_OP("BoostedTreesPredict")
.Input("tree_ensemble_handle: resource")
.Input("bucketized_features: num_bucketized_features * int32")
@@ -354,4 +360,125 @@ REGISTER_OP("BoostedTreesCenterBias")
return Status::OK();
});
+REGISTER_RESOURCE_HANDLE_OP(BoostedTreesQuantileStreamResource);
+
+REGISTER_OP("IsBoostedTreesQuantileStreamResourceInitialized")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("is_initialized: bool")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesCreateQuantileStreamResource")
+ .Attr("max_elements: int = 1099511627776") // 1 << 40
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("epsilon: float")
+ .Input("num_streams: int64")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesMakeQuantileSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("example_weights: float")
+ .Input("epsilon: float")
+ .Output("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle example_weights_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features), 1, &example_weights_shape));
+ for (int i = 0; i < num_features; ++i) {
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(example_weights_shape, 0),
+ &unused_dim));
+ // the columns are value, weight, min_rank, max_rank.
+ c->set_output(i, c->MakeShape({c->UnknownDim(), 4}));
+ }
+ // epsilon must be a scalar.
+ ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(c->input(num_features + 1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceAddSummaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("summaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ // resource handle must be a scalar.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ // each summary must be rank 2.
+ for (int i = 1; i < num_features + 1; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &unused_input));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceFlush")
+ .Attr("generate_quantiles: bool = False")
+ .Input("quantile_stream_resource_handle: resource")
+ .Input("num_buckets: int64")
+ .SetShapeFn([](InferenceContext* c) {
+ // All the inputs are scalars.
+ shape_inference::ShapeHandle unused_input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused_input));
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesQuantileStreamResourceGetBucketBoundaries")
+ .Attr("num_features: int >= 0")
+ .Input("quantile_stream_resource_handle: resource")
+ .Output("bucket_boundaries: num_features * float")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ shape_inference::ShapeHandle unused_input;
+ // resource handle must be a scalar.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused_input));
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->Vector(c->UnknownDim()));
+ }
+ return Status::OK();
+ });
+
+REGISTER_OP("BoostedTreesBucketize")
+ .Attr("num_features: int >= 0")
+ .Input("float_values: num_features * float")
+ .Input("bucket_boundaries: num_features * float")
+ .Output("buckets: num_features * int32")
+ .SetShapeFn([](InferenceContext* c) {
+ int num_features;
+ TF_RETURN_IF_ERROR(c->GetAttr("num_features", &num_features));
+ ShapeHandle feature_shape;
+ DimensionHandle unused_dim;
+ for (int i = 0; i < num_features; i++) {
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 2, &feature_shape));
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(feature_shape, 0),
+ c->Dim(c->input(0), 0), &unused_dim));
+ }
+ // Bucketized result should have same dimension as input.
+ for (int i = 0; i < num_features; i++) {
+ c->set_output(i, c->MakeShape({c->Dim(c->input(i), 0), 1}));
+ }
+ return Status::OK();
+ });
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index bb48da86ca..780c6f6448 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -11360,6 +11360,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -11469,6 +11492,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -11562,6 +11608,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -11631,6 +11703,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -13070,6 +13219,71 @@ op {
is_stateful: true
}
op {
+ name: "ConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Conj"
input_arg {
name: "input"
@@ -20317,6 +20531,31 @@ op {
}
}
op {
+ name: "DivNoNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -20865,6 +21104,25 @@ op {
is_stateful: true
}
op {
+ name: "EnsureShape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "Enter"
input_arg {
name: "data"
@@ -21274,6 +21532,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -21644,6 +22317,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -22461,33 +23187,6 @@ op {
is_stateful: true
}
op {
- name: "FeatureStatsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "tag"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Fill"
input_arg {
name: "dims"
@@ -23821,6 +24520,158 @@ op {
}
}
op {
+ name: "FusedBatchNorm"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
+ name: "FusedBatchNormGrad"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormGrad"
input_arg {
name: "y_backprop"
@@ -23884,6 +24735,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -24061,6 +24918,96 @@ op {
}
}
op {
+ name: "FusedBatchNormGradV2"
+ input_arg {
+ name: "y_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "x_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "scale_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "offset_backprop"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_3"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_4"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedBatchNormV2"
input_arg {
name: "x"
@@ -24228,6 +25175,96 @@ op {
}
}
op {
+ name: "FusedBatchNormV2"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scale"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "offset"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "mean"
+ type_attr: "U"
+ }
+ input_arg {
+ name: "variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "batch_mean"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "batch_variance"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_1"
+ type_attr: "U"
+ }
+ output_arg {
+ name: "reserve_space_2"
+ type_attr: "U"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "U"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "epsilon"
+ type: "float"
+ default_value {
+ f: 0.0001
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "FusedPadConv2D"
input_arg {
name: "input"
@@ -25613,6 +26650,21 @@ op {
}
}
op {
+ name: "HostConst"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+}
+op {
name: "IFFT"
input_arg {
name: "input"
@@ -26018,6 +27070,52 @@ op {
is_stateful: true
}
op {
+ name: "If"
+ input_arg {
+ name: "cond"
+ type_attr: "Tcond"
+ }
+ input_arg {
+ name: "input"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tcond"
+ type: "type"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "then_branch"
+ type: "func"
+ }
+ attr {
+ name: "else_branch"
+ type: "func"
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Igamma"
input_arg {
name: "a"
@@ -27095,6 +28193,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -29130,6 +30240,38 @@ op {
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -29349,11 +30491,91 @@ op {
}
}
op {
+ name: "MapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
+ name: "MapDefun"
+ input_arg {
+ name: "arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+}
+op {
name: "MapDefun"
input_arg {
name: "arguments"
type_list_attr: "Targuments"
}
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
output_arg {
name: "output"
type_list_attr: "output_types"
@@ -29365,6 +30587,15 @@ op {
minimum: 1
}
attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
name: "output_types"
type: "list(type)"
has_minimum: true
@@ -29976,6 +31207,32 @@ op {
}
}
op {
+ name: "MatrixExponential"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_DOUBLE
+ type: DT_FLOAT
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ deprecation {
+ version: 27
+ }
+}
+op {
name: "MatrixInverse"
input_arg {
name: "input"
@@ -34784,6 +36041,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -34891,6 +36171,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -35624,6 +37032,42 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV2"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
@@ -35651,6 +37095,46 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV3"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+}
+op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
@@ -35689,6 +37173,57 @@ op {
}
}
op {
+ name: "NonMaxSuppressionV4"
+ input_arg {
+ name: "boxes"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "scores"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "max_output_size"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "iou_threshold"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "score_threshold"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "selected_indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "valid_outputs"
+ type: DT_INT32
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
+ name: "pad_to_max_output_size"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "NonMaxSuppressionWithOverlaps"
input_arg {
name: "overlaps"
@@ -36979,6 +38514,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -37060,6 +38643,53 @@ op {
}
}
op {
+ name: "ParallelMapDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ParameterizedTruncatedNormal"
input_arg {
name: "shape"
@@ -37269,6 +38899,271 @@ op {
}
}
op {
+ name: "ParseExampleDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "dense_defaults"
+ type_list_attr: "Tdense"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "sparse_types"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tdense"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "dense_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ParseSequenceExample"
+ input_arg {
+ name: "serialized"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "debug_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "context_dense_defaults"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "context_sparse_indices"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_sparse_values"
+ type_list_attr: "context_sparse_types"
+ }
+ output_arg {
+ name: "context_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_dense_values"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_dense_values"
+ type_list_attr: "feature_list_dense_types"
+ }
+ output_arg {
+ name: "feature_list_dense_lengths"
+ type: DT_INT64
+ number_attr: "Nfeature_list_dense"
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tcontext_dense"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "context_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+}
+op {
name: "ParseSingleExample"
input_arg {
name: "serialized"
@@ -38043,6 +39938,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -43444,6 +45363,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -43804,6 +45776,38 @@ op {
}
}
op {
+ name: "Relu"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ type: DT_QINT8
+ }
+ }
+ }
+}
+op {
name: "Relu6"
input_arg {
name: "features"
@@ -56310,6 +58314,125 @@ op {
}
}
op {
+ name: "SdcaOptimizer"
+ input_arg {
+ name: "sparse_example_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_feature_values"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features_with_values"
+ }
+ input_arg {
+ name: "dense_features"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "example_labels"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "sparse_indices"
+ type: DT_INT64
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ input_arg {
+ name: "dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ input_arg {
+ name: "example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_example_state_data"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "out_delta_sparse_weights"
+ type: DT_FLOAT
+ number_attr: "num_sparse_features"
+ }
+ output_arg {
+ name: "out_delta_dense_weights"
+ type: DT_FLOAT
+ number_attr: "num_dense_features"
+ }
+ attr {
+ name: "loss_type"
+ type: "string"
+ allowed_values {
+ list {
+ s: "logistic_loss"
+ s: "squared_loss"
+ s: "hinge_loss"
+ s: "smooth_hinge_loss"
+ s: "poisson_loss"
+ }
+ }
+ }
+ attr {
+ name: "adaptative"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "num_sparse_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_sparse_features_with_values"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "num_dense_features"
+ type: "int"
+ has_minimum: true
+ }
+ attr {
+ name: "l1"
+ type: "float"
+ }
+ attr {
+ name: "l2"
+ type: "float"
+ }
+ attr {
+ name: "num_loss_partitions"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "num_inner_iterations"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "SdcaShrinkL1"
input_arg {
name: "weights"
@@ -57708,6 +59831,14 @@ op {
name: "stats_aggregator"
type: DT_RESOURCE
}
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "counter_prefix"
+ type: DT_STRING
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -58860,6 +60991,29 @@ op {
}
}
op {
+ name: "Softplus"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftplusGrad"
input_arg {
name: "gradients"
@@ -58996,6 +61150,33 @@ op {
}
}
op {
+ name: "SoftplusGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "Softsign"
input_arg {
name: "features"
@@ -59116,6 +61297,29 @@ op {
}
}
op {
+ name: "Softsign"
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "activations"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SoftsignGrad"
input_arg {
name: "gradients"
@@ -59252,6 +61456,33 @@ op {
}
}
op {
+ name: "SoftsignGrad"
+ input_arg {
+ name: "gradients"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "features"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "backprops"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_BFLOAT16
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "SpaceToBatch"
input_arg {
name: "input"
@@ -64026,6 +66257,71 @@ op {
is_stateful: true
}
op {
+ name: "SparseConditionalAccumulator"
+ output_arg {
+ name: "handle"
+ type: DT_STRING
+ is_ref: true
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "SparseCross"
input_arg {
name: "indices"
@@ -68819,6 +71115,47 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
+ name: "StaticRegexReplace"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+ attr {
+ name: "rewrite"
+ type: "string"
+ }
+ attr {
+ name: "replace_global"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "StatsAggregatorHandle"
output_arg {
name: "handle"
@@ -69094,6 +71431,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -69119,6 +71493,41 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -69481,6 +71890,48 @@ op {
}
}
op {
+ name: "Substr"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "pos"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "len"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "Sum"
input_arg {
name: "input"
@@ -71643,6 +74094,25 @@ op {
}
}
op {
+ name: "TensorListGather"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListGetItem"
input_arg {
name: "input_handle"
@@ -71759,6 +74229,39 @@ op {
}
}
op {
+ name: "TensorListScatter"
+ input_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListSetItem"
input_arg {
name: "input_handle"
@@ -73018,6 +75521,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -74110,6 +76624,38 @@ op {
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
@@ -74394,6 +76940,39 @@ op {
is_stateful: true
}
op {
+ name: "While"
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "cond"
+ type: "func"
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "WholeFileReader"
output_arg {
name: "reader_handle"
@@ -74445,9 +77024,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "shift"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "stride"
type: DT_INT64
}
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -74684,6 +77275,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index f78f7a897a..f84142c992 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -37,7 +37,6 @@ using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
-
REGISTER_OP("CudnnRNNParamsSize")
.Input("num_layers: int32")
.Input("num_units: int32")
@@ -52,11 +51,16 @@ REGISTER_OP("CudnnRNNParamsSize")
.Attr("seed2: int = 0")
.Output("params_size: S")
.SetShapeFn([](InferenceContext* c) {
+ ShapeHandle unused;
+ // num_layers, num_units, and input_size should be scalars.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+
c->set_output(0, c->Vector(1));
return Status::OK();
});
-
REGISTER_OP("CudnnRNN")
.Input("input: T")
.Input("input_h: T")
@@ -248,7 +252,6 @@ REGISTER_OP("CudnnRNNParamsToCanonical")
return Status::OK();
});
-
REGISTER_OP("CudnnRNNCanonicalToParams")
.Input("num_layers: int32")
.Input("num_units: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 2dd867561b..13c3b933f4 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -26,7 +26,16 @@ namespace tensorflow {
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
- INFER_OK(op, "[1];[1];[1]", "[1]");
+ INFER_OK(op, "[];[];[]", "[1]");
+ INFER_OK(op, "?;[];[]", "[1]");
+ INFER_OK(op, "[];?;[]", "[1]");
+ INFER_OK(op, "[];[];?", "[1]");
+ INFER_OK(op, "[];?;?", "[1]");
+ INFER_OK(op, "?;?;?", "[1]");
+
+ INFER_ERROR("Shape must be rank 0 ", op, "[1,2];?;[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;[2];[]");
+ INFER_ERROR("Shape must be rank 0 ", op, "?;?;[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
diff --git a/tensorflow/core/ops/data_flow_ops.cc b/tensorflow/core/ops/data_flow_ops.cc
index eed0bce174..ffab8ad661 100644
--- a/tensorflow/core/ops/data_flow_ops.cc
+++ b/tensorflow/core/ops/data_flow_ops.cc
@@ -419,6 +419,7 @@ REGISTER_OP("ConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
@@ -456,6 +457,7 @@ REGISTER_OP("SparseConditionalAccumulator")
.Attr("shape: shape")
.Attr("container: string = ''")
.Attr("shared_name: string = ''")
+ .Attr("reduction_type: { 'MEAN', 'SUM' } = 'MEAN' ")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(2));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 13733d48f0..ec22eee874 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -166,21 +166,27 @@ REGISTER_OP("LatencyStatsDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("FeatureStatsDataset")
+REGISTER_OP("ParseExampleDataset")
.Input("input_dataset: variant")
- .Input("tag: string")
+ .Input("num_parallel_calls: int64")
+ .Input("dense_defaults: Tdense")
.Output("handle: variant")
+ .Attr("sparse_keys: list(string) >= 0")
+ .Attr("dense_keys: list(string) >= 0")
+ .Attr("sparse_types: list({float,int64,string}) >= 0")
+ .Attr("Tdense: list({float,int64,string}) >= 0")
+ .Attr("dense_shapes: list(shape) >= 0")
.Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle tag_shape;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &tag_shape));
- return shape_inference::ScalarShape(c);
- });
+ .Attr("output_shapes: list(shape) >= 1") // Output components will be
+ // sorted by key (dense_keys and
+ // sparse_keys combined) here.
+ .SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("SetStatsAggregatorDataset")
.Input("input_dataset: variant")
.Input("stats_aggregator: resource")
+ .Input("tag: string")
+ .Input("counter_prefix: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
@@ -194,6 +200,7 @@ REGISTER_OP("MapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("ParallelMapDataset")
@@ -205,6 +212,7 @@ REGISTER_OP("ParallelMapDataset")
.Attr("Targuments: list(type) >= 0")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
.SetShapeFn(shape_inference::ScalarShape);
REGISTER_OP("MapAndBatchDataset")
@@ -321,6 +329,19 @@ REGISTER_OP("ParallelInterleaveDataset")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ParallelInterleaveDatasetV2")
+ .Input("input_dataset: variant")
+ .Input("other_arguments: Targuments")
+ .Input("cycle_length: int64")
+ .Input("block_length: int64")
+ .Input("num_parallel_calls: int64")
+ .Output("handle: variant")
+ .Attr("f: func")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("GroupByReducerDataset")
.Input("input_dataset: variant")
.Input("key_func_other_arguments: Tkey_func_other_arguments")
@@ -377,14 +398,20 @@ REGISTER_OP("FilterByLastComponentDataset")
REGISTER_OP("WindowDataset")
.Input("input_dataset: variant")
- .Input("window_size: int64")
+ .Input("size: int64")
+ .Input("shift: int64")
+ .Input("stride: int64")
+ .Input("drop_remainder: bool")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn([](shape_inference::InferenceContext* c) {
shape_inference::ShapeHandle unused;
- // batch_size should be a scalar.
+ // size, shift, stride, and drop_remainder should be scalars.
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
return shape_inference::ScalarShape(c);
});
@@ -731,6 +758,19 @@ REGISTER_OP("DatasetToSingleElement")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(IteratorGetNextShapeFn);
+REGISTER_OP("ReduceDataset")
+ .Input("input_dataset: variant")
+ .Input("initial_state: Tstate")
+ .Input("other_arguments: Targuments")
+ .Output("components: output_types")
+ .Attr("f: func")
+ .Attr("Tstate: list(type) >= 1")
+ .Attr("Targuments: list(type) >= 0")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("use_inter_op_parallelism: bool = true")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
REGISTER_OP("IteratorToStringHandle")
.Input("resource_handle: resource")
.Output("string_handle: string")
@@ -854,16 +894,27 @@ REGISTER_OP("IteratorGetNextAsOptional")
.Attr("output_shapes: list(shape) >= 1")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ModelDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
REGISTER_OP("MapDefun")
.Input("arguments: Targuments")
+ .Input("captured_inputs: Tcaptured")
.Output("output: output_types")
.Attr("Targuments: list(type) >= 1")
+ .Attr("Tcaptured: list(type) >= 0 = []")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("f: func")
.SetShapeFn([](shape_inference::InferenceContext* c) {
- std::vector<TensorShape> output_shapes;
+ std::vector<PartialTensorShape> output_shapes;
TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ DataTypeVector t_args;
+ TF_RETURN_IF_ERROR(c->GetAttr("Targuments", &t_args));
if (output_shapes.size() != c->num_outputs()) {
return errors::InvalidArgument(
"`output_shapes` must be the same length as `output_types` (",
@@ -871,14 +922,19 @@ REGISTER_OP("MapDefun")
}
int64 dim_zero = -1;
- for (size_t i = 0; i < static_cast<size_t>(c->num_inputs()); ++i) {
+ for (size_t i = 0; i < t_args.size(); ++i) {
+ if (c->Rank(c->input(i)) == 0) {
+ return errors::InvalidArgument(
+ "Arguments must have rank at least 1. Input ", i,
+ " has rank of 0.");
+ }
auto dim_handle = c->Dim(c->input(i), 0);
if (c->ValueKnown(dim_handle)) {
if (dim_zero == -1) {
dim_zero = c->Value(dim_handle);
} else if (c->Value(dim_handle) != dim_zero) {
return errors::InvalidArgument(
- "Inputs must have the same dimension 0.");
+ "Arguments must have the same dimension 0.");
}
}
}
@@ -896,4 +952,41 @@ REGISTER_OP("MapDefun")
return Status::OK();
});
+REGISTER_OP("MultiDeviceIterator")
+ .Output("handle: resource")
+ .Attr("devices: list(string) >= 1")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorInit")
+ .Input("dataset: variant")
+ .Input("multi_device_iterator: resource")
+ .Input("max_buffer_size: int64")
+ .Output("incarnation_id: int64")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorGetNextFromShard")
+ .Input("multi_device_iterator: resource")
+ .Input("shard_num: int32")
+ .Input("incarnation_id: int64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(IteratorGetNextShapeFn);
+
+REGISTER_OP("MultiDeviceIteratorToStringHandle")
+ .Input("multi_device_iterator: resource")
+ .Output("string_handle: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("MultiDeviceIteratorFromStringHandle")
+ .Input("string_handle: string")
+ .Output("multi_device_iterator: resource")
+ .Attr("output_types: list(type) >= 0 = []")
+ .Attr("output_shapes: list(shape) >= 0 = []")
+ .SetShapeFn(shape_inference::ScalarShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/experimental_dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
new file mode 100644
index 0000000000..f6bd5dce26
--- /dev/null
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -0,0 +1,207 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+
+namespace tensorflow {
+
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
+ .Input("selector_input_dataset: variant")
+ .Input("data_input_datasets: N * variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .Attr("N: int >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalCSVDataset")
+ .Input("filenames: string")
+ .Input("compression_type: string")
+ .Input("buffer_size: int64")
+ .Input("header: bool")
+ .Input("field_delim: string")
+ .Input("use_quote_delim: bool")
+ .Input("na_value: string")
+ .Input("select_cols: int64")
+ .Input("record_defaults: output_types")
+ .Output("handle: variant")
+ .Attr("output_types: list({float,double,int32,int64,string}) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // `filenames` must be a scalar or a vector.
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 1, &unused));
+ // `compression_type`, `buffer_size`, `header`, `field_delim`,
+ // `use_quote_delim`, `na_value` must be scalars
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
+ // `select_cols` must be a vector
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 1, &unused));
+ // `record_defaults` must be lists of scalars
+ for (size_t i = 8; i < c->num_inputs(); ++i) {
+ shape_inference::ShapeHandle v;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
+ return errors::InvalidArgument(
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
+ }
+ }
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalUniqueDataset")
+ .Input("input_dataset: variant")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIteratorGetDevice")
+ .Input("resource: resource")
+ .Output("device: string")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResource")
+ .Input("string_arg: string")
+ .Input("target_device: string")
+ .Output("resource: resource")
+ .Attr("shared_name: string")
+ .Attr("container: string")
+ .Attr("f: func")
+ .Attr("buffer_size: int")
+ .Attr("output_types: list(type)")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
+ .Input("function_buffer_resource: resource")
+ .Attr("output_types: list(type)")
+ .Output("output: output_types")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
+ .Input("function_buffer_resource: resource")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalThreadPoolDataset")
+ .Input("input_dataset: variant")
+ .Input("thread_pool: resource")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalThreadPoolHandle")
+ .Output("handle: resource")
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Attr("num_threads: int")
+ .Attr("max_intra_op_parallelism: int = 1")
+ .Attr("display_name: string")
+ .Attr("container: string = ''")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
+ .Input("input_dataset: variant")
+ .Input("transformations: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ // transformations should be a vector.
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
+ return shape_inference::ScalarShape(c);
+ });
+
+REGISTER_OP("ExperimentalLMDBDataset")
+ .Input("filenames: string")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape);
+
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
+} // namespace tensorflow
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index bda4a75c5d..22b4b07eff 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -110,8 +110,27 @@ REGISTER_OP("If")
.Attr("Tout: list(type) >= 0")
.Attr("then_branch: func")
.Attr("else_branch: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape);
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else return unknown shapes.
+ if (output_shapes.empty()) return shape_inference::UnknownShape(c);
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+ });
// TODO(drpng): remove this.
REGISTER_OP("_While")
@@ -150,10 +169,29 @@ REGISTER_OP("While")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(i));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else use the input shapes.
+ if (!output_shapes.empty()) {
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ } else {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(i));
+ }
}
return Status::OK();
});
diff --git a/tensorflow/core/ops/image_ops.cc b/tensorflow/core/ops/image_ops.cc
index 81f324a3ef..5427275284 100644
--- a/tensorflow/core/ops/image_ops.cc
+++ b/tensorflow/core/ops/image_ops.cc
@@ -108,6 +108,29 @@ Status ColorspaceShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status NMSShapeFn(InferenceContext* c) {
+ // Get inputs and validate ranks.
+ ShapeHandle boxes;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
+ ShapeHandle scores;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
+ ShapeHandle max_output_size;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
+ ShapeHandle iou_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
+ ShapeHandle score_threshold;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
+ // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
+ DimensionHandle unused;
+ // The boxes[0] and scores[0] are both num_boxes.
+ TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
+ // The boxes[1] is 4.
+ TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
+
+ c->set_output(0, c->Vector(c->UnknownDim()));
+ return Status::OK();
+}
+
} // namespace
// --------------------------------------------------------------------------
@@ -660,11 +683,12 @@ REGISTER_OP("NonMaxSuppression")
});
REGISTER_OP("NonMaxSuppressionV2")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Output("selected_indices: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
// Get inputs and validate ranks.
ShapeHandle boxes;
@@ -688,66 +712,36 @@ REGISTER_OP("NonMaxSuppressionV2")
});
REGISTER_OP("NonMaxSuppressionV3")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
- .SetShapeFn([](InferenceContext* c) {
- // Get inputs and validate ranks.
- ShapeHandle boxes;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
- ShapeHandle scores;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
- ShapeHandle max_output_size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
- ShapeHandle iou_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
- ShapeHandle score_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
- // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
- DimensionHandle unused;
- // The boxes[0] and scores[0] are both num_boxes.
- TF_RETURN_IF_ERROR(
- c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
- // The boxes[1] is 4.
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
-
- c->set_output(0, c->Vector(c->UnknownDim()));
- return Status::OK();
- });
+ .Attr("T: {half, float} = DT_FLOAT")
+ .SetShapeFn(NMSShapeFn);
REGISTER_OP("NonMaxSuppressionV4")
- .Input("boxes: float")
- .Input("scores: float")
+ .Input("boxes: T")
+ .Input("scores: T")
.Input("max_output_size: int32")
.Input("iou_threshold: float")
.Input("score_threshold: float")
.Output("selected_indices: int32")
.Output("valid_outputs: int32")
+ .Attr("T: {half, float} = DT_FLOAT")
.Attr("pad_to_max_output_size: bool = false")
.SetShapeFn([](InferenceContext* c) {
- // Get inputs and validate ranks.
- ShapeHandle boxes;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
- ShapeHandle scores;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
- ShapeHandle max_output_size;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
- ShapeHandle iou_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
- ShapeHandle score_threshold;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
- // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
- DimensionHandle unused;
- // The boxes[0] and scores[0] are both num_boxes.
- TF_RETURN_IF_ERROR(
- c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
- // The boxes[1] is 4.
- TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
-
- c->set_output(0, c->Vector(c->UnknownDim()));
+ TF_RETURN_IF_ERROR(NMSShapeFn(c));
+
+ bool pad_to_max;
+ TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
+ if (pad_to_max) {
+ // If padded, overwrite the shape of the output to be static.
+ DimensionHandle output_dim;
+ TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
+ c->set_output(0, c->MakeShape({output_dim}));
+ }
c->set_output(1, c->MakeShape({}));
return Status::OK();
});
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index f37f79ddbf..1d4d51a25d 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -235,6 +235,8 @@ REGISTER_OP("MatrixInverse")
.SetShapeFn(BatchUnchangedSquareShapeFn);
REGISTER_OP("MatrixExponential")
+ .Deprecated(
+ 27, "Use Python implementation tf.linalg.matrix_exponential instead.")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float, complex64, complex128}")
diff --git a/tensorflow/core/ops/list_ops.cc b/tensorflow/core/ops/list_ops.cc
index b9f94ba1c5..7d79df9c1c 100644
--- a/tensorflow/core/ops/list_ops.cc
+++ b/tensorflow/core/ops/list_ops.cc
@@ -210,7 +210,8 @@ REGISTER_OP("TensorListFromTensor")
shape_inference::ShapeHandle o;
TF_RETURN_IF_ERROR(c->Subshape(s, 1, &o));
shape_inference::ShapeHandle element_shape;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &element_shape));
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(
+ 1, &element_shape));
TF_RETURN_IF_ERROR(c->Merge(o, element_shape, &o));
c->set_output_handle_shapes_and_types(
0, std::vector<shape_inference::ShapeAndType>{{element_shape, t}});
@@ -240,7 +241,8 @@ REGISTER_OP("TensorListReserve")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->Scalar());
shape_inference::ShapeHandle s;
- TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(0, &s));
DataType t;
TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
c->set_output_handle_shapes_and_types(
@@ -295,6 +297,51 @@ REGISTER_OP("TensorListSetItem")
return Status::OK();
});
+REGISTER_OP("TensorListGather")
+ .Input("input_handle: variant")
+ .Input("indices: int32")
+ .Output("values: element_dtype")
+ .Attr("element_dtype: type")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ shape_inference::ShapeHandle element_shape = c->UnknownShape();
+ if (handle_data != nullptr) {
+ const shape_inference::ShapeAndType& list_shape_type =
+ (*handle_data)[0];
+ element_shape = list_shape_type.shape;
+ if (list_shape_type.dtype != t) {
+ return errors::InvalidArgument("Expected list with element dtype ",
+ DataTypeString(t),
+ " but got list with element dtype ",
+ DataTypeString(list_shape_type.dtype));
+ }
+ }
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->Concatenate(c->input(1), element_shape, &out));
+ c->set_output(0, out);
+ return Status::OK();
+ });
+
+REGISTER_OP("TensorListScatter")
+ .Input("tensor: element_dtype")
+ .Input("indices: int32")
+ .Input("element_shape: shape_type")
+ .Output("output_handle: variant")
+ .Attr("element_dtype: type")
+ .Attr("shape_type: {int32, int64}")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("element_dtype", &t));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(
+ c->MakeShapeFromShapeTensorTreatScalarAsUnknownShape(2, &s));
+ c->set_output_handle_shapes_and_types(0, {{s, t}});
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("TensorListConcatLists")
.Input("input_a: variant")
.Input("input_b: variant")
diff --git a/tensorflow/core/ops/logging_ops.cc b/tensorflow/core/ops/logging_ops.cc
index fbde692e95..2034d3601b 100644
--- a/tensorflow/core/ops/logging_ops.cc
+++ b/tensorflow/core/ops/logging_ops.cc
@@ -14,11 +14,14 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/dataset_stateful_op_whitelist.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
+using shape_inference::InferenceContext;
+
REGISTER_OP("Assert")
.Input("condition: bool")
.Input("data: T")
@@ -27,6 +30,8 @@ REGISTER_OP("Assert")
.Attr("summarize: int = 3")
.SetShapeFn(shape_inference::NoOutputs);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Assert");
+
REGISTER_OP("Print")
.Input("input: T")
.Input("data: U")
@@ -39,6 +44,25 @@ REGISTER_OP("Print")
.Attr("summarize: int = 3")
.SetShapeFn(shape_inference::UnchangedShape);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Print");
+
+REGISTER_OP("PrintV2")
+ .Input("input: string")
+ .SetIsStateful()
+ .Attr(
+ "output_stream: {'stdout', 'stderr', 'log(info)', "
+ "'log(warning)', 'log(error)'} = 'stderr'")
+ .SetShapeFn([](InferenceContext* c) {
+ // Make sure that the input is a scalar.
+ if (c->Rank(c->input(0)) != 0) {
+ return errors::InvalidArgument("input must be a scalar, but has rank: ",
+ c->Rank(c->input(0)));
+ }
+ return Status::OK();
+ });
+
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("PrintV2");
+
// ----------------------------------------------------------------------------
// Operators that deal with SummaryProtos (encoded as DT_STRING tensors) as
// inputs or outputs in various ways.
@@ -116,4 +140,6 @@ REGISTER_OP("Timestamp")
.SetIsStateful()
.SetShapeFn(shape_inference::ScalarShape);
+WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("Timestamp");
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/lookup_ops.cc b/tensorflow/core/ops/lookup_ops.cc
index 2059741da9..72a77be70d 100644
--- a/tensorflow/core/ops/lookup_ops.cc
+++ b/tensorflow/core/ops/lookup_ops.cc
@@ -23,6 +23,7 @@ namespace tensorflow {
using shape_inference::DimensionHandle;
using shape_inference::InferenceContext;
+using shape_inference::ShapeAndType;
using shape_inference::ShapeHandle;
// --------------------------------------------------------------------------
@@ -86,6 +87,74 @@ REGISTER_OP("LookupTableFind")
return Status::OK();
});
+Status ValidateTableResourceHandle(InferenceContext* c, ShapeHandle keys,
+ const string& key_dtype_attr,
+ const string& value_dtype_attr,
+ bool is_lookup,
+ ShapeAndType* output_shape_and_type) {
+ auto* handle_data = c->input_handle_shapes_and_types(0);
+ if (handle_data == nullptr || handle_data->size() != 2) {
+ output_shape_and_type->shape = c->UnknownShape();
+ output_shape_and_type->dtype = DT_INVALID;
+ } else {
+ const ShapeAndType& key_shape_and_type = (*handle_data)[0];
+ const ShapeAndType& value_shape_and_type = (*handle_data)[1];
+ DataType key_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr(key_dtype_attr, &key_dtype));
+ if (key_shape_and_type.dtype != key_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read value with wrong dtype. "
+ "Expected ",
+ DataTypeString(key_shape_and_type.dtype), " got ",
+ DataTypeString(key_dtype));
+ }
+ DataType value_dtype;
+ TF_RETURN_IF_ERROR(c->GetAttr(value_dtype_attr, &value_dtype));
+ if (value_shape_and_type.dtype != value_dtype) {
+ return errors::InvalidArgument(
+ "Trying to read value with wrong dtype. "
+ "Expected ",
+ DataTypeString(value_shape_and_type.dtype), " got ",
+ DataTypeString(value_dtype));
+ }
+ output_shape_and_type->dtype = value_shape_and_type.dtype;
+
+ if (is_lookup) {
+ if (c->RankKnown(key_shape_and_type.shape) && c->RankKnown(keys)) {
+ int keys_rank = c->Rank(keys);
+ int key_suffix_rank = c->Rank(key_shape_and_type.shape);
+ if (keys_rank < key_suffix_rank) {
+ return errors::InvalidArgument(
+ "Expected keys to have suffix ",
+ c->DebugString(key_shape_and_type.shape),
+ " but saw shape: ", c->DebugString(keys));
+ }
+ for (int d = 0; d < key_suffix_rank; d++) {
+ // Ensure the suffix of keys match what's in the Table.
+ DimensionHandle dim = c->Dim(key_shape_and_type.shape, d);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(keys, keys_rank - key_suffix_rank + d, dim, &keys));
+ }
+ std::vector<DimensionHandle> keys_prefix_vec;
+ keys_prefix_vec.reserve(keys_rank - key_suffix_rank);
+ for (int d = 0; d < keys_rank - key_suffix_rank; ++d) {
+ keys_prefix_vec.push_back(c->Dim(keys, d));
+ }
+ ShapeHandle keys_prefix = c->MakeShape(keys_prefix_vec);
+ TF_RETURN_IF_ERROR(c->Concatenate(keys_prefix,
+ value_shape_and_type.shape,
+ &output_shape_and_type->shape));
+ } else {
+ output_shape_and_type->shape = c->UnknownShape();
+ }
+ } else {
+ TF_RETURN_IF_ERROR(c->Concatenate(keys, value_shape_and_type.shape,
+ &output_shape_and_type->shape));
+ }
+ }
+ return Status::OK();
+}
+
REGISTER_OP("LookupTableFindV2")
.Input("table_handle: resource")
.Input("keys: Tin")
@@ -98,9 +167,18 @@ REGISTER_OP("LookupTableFindV2")
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
// Default value must be scalar or vector.
- ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &unused));
- c->set_output(0, c->UnknownShape());
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(2), 1, &keys));
+
+ ShapeAndType value_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
+ c,
+ /*keys=*/c->input(1),
+ /*key_dtype_attr=*/"Tin",
+ /*value_dtype_attr=*/"Tout",
+ /*is_lookup=*/true, &value_shape_and_type));
+ c->set_output(0, value_shape_and_type.shape);
+
return Status::OK();
});
WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("LookupTableFindV2");
@@ -177,12 +255,16 @@ REGISTER_OP("LookupTableExportV2")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
-
- ShapeHandle values = c->UnknownShape();
- TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values));
- ShapeHandle keys = c->Vector(c->Dim(values, 0));
+ ShapeHandle keys = c->UnknownShapeOfRank(1);
+ ShapeAndType value_shape_and_type;
+ TF_RETURN_IF_ERROR(ValidateTableResourceHandle(
+ c,
+ /*keys=*/keys,
+ /*key_dtype_attr=*/"Tkeys",
+ /*value_dtype_attr=*/"Tvalues",
+ /*is_lookup=*/false, &value_shape_and_type));
c->set_output(0, keys);
- c->set_output(1, values);
+ c->set_output(1, value_shape_and_type.shape);
return Status::OK();
});
@@ -212,10 +294,32 @@ REGISTER_OP("LookupTableImportV2")
ShapeHandle handle;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &handle));
- // TODO: Validate keys and values shape.
+ ShapeHandle keys;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &keys));
+ TF_RETURN_IF_ERROR(c->Merge(keys, c->input(2), &keys));
return Status::OK();
});
+Status MutableHashTableShape(InferenceContext* c, const ShapeHandle& key,
+ const ShapeHandle& value) {
+ c->set_output(0, c->Scalar());
+
+ ShapeHandle key_s;
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(key, 1, &key_s));
+
+ DataType key_t;
+ TF_RETURN_IF_ERROR(c->GetAttr("key_dtype", &key_t));
+
+ DataType value_t;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_dtype", &value_t));
+
+ // ShapeAndType vector for {key, value}.
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<ShapeAndType>{{key_s, key_t}, {value, value_t}});
+
+ return Status::OK();
+}
+
REGISTER_OP("HashTable")
.Output("table_handle: Ref(string)")
.Attr("container: string = ''")
@@ -254,7 +358,10 @@ REGISTER_OP("MutableHashTableV2")
.Attr("key_dtype: type")
.Attr("value_dtype: type")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ return MutableHashTableShape(c, /*key=*/c->Scalar(),
+ /*value=*/c->Scalar());
+ });
REGISTER_OP("MutableHashTableOfTensors")
.Output("table_handle: Ref(string)")
@@ -276,7 +383,13 @@ REGISTER_OP("MutableHashTableOfTensorsV2")
.Attr("value_dtype: type")
.Attr("value_shape: shape = {}")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape value_p;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
+ ShapeHandle value_s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
+ return MutableHashTableShape(c, /*key=*/c->Scalar(), /*value=*/value_s);
+ });
REGISTER_OP("MutableDenseHashTable")
.Input("empty_key: key_dtype")
@@ -304,7 +417,13 @@ REGISTER_OP("MutableDenseHashTableV2")
.Attr("initial_num_buckets: int = 131072") // 2^17
.Attr("max_load_factor: float = 0.8")
.SetIsStateful()
- .SetShapeFn(ScalarOutput);
+ .SetShapeFn([](InferenceContext* c) {
+ PartialTensorShape value_p;
+ TF_RETURN_IF_ERROR(c->GetAttr("value_shape", &value_p));
+ ShapeHandle value_s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(value_p, &value_s));
+ return MutableHashTableShape(c, /*key=*/c->input(0), /*value=*/value_s);
+ });
REGISTER_OP("InitializeTable")
.Input("table_handle: Ref(string)")
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 783d292389..55dcc50325 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -495,6 +495,19 @@ Status RealDivGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("RealDiv", RealDivGrad);
+Status DivNoNanGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"gx"}, "DivNoNan", {"dz", "y"}},
+ {{"nx"}, "Neg", {"x"}, {}, {"dz"}},
+ {{"y2"}, "Square", {"y"}, {}, {"dz"}},
+ {{"nx_y2"}, "DivNoNan", {"nx", "y2"}},
+ {{"gy"}, "Mul", {"dz", "nx_y2"}}, // dz * (- x / y^2)
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("DivNoNan", DivNoNanGrad);
+
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
std::vector<FDH::Node> nodes = {
@@ -536,6 +549,40 @@ Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("Pow", PowGrad);
+Status XlogyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_logy"}, "Xlogy", {"is_zero_cast", "y"}},
+ {{"xlogygrad"}, "Xdivy", {"x", "y"}},
+ {{"gx"}, "Mul", {"safe_logy", "dz"}},
+ {{"gy"}, "Mul", {"xlogygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xlogy", XlogyGrad);
+
+Status XdivyGrad(const AttrSlice& attrs, FunctionDef* g) {
+ // clang-format off
+ return GradForBinaryCwise(g, {
+ {{"zeros"}, "ZerosLike", {"x"}},
+ {{"is_x_zero"}, "NotEqual", {"x", "zeros"}},
+ {{"is_zero_cast"}, "Cast", {"is_x_zero"},
+ {{"SrcT", DT_BOOL}, {"DstT", "$T"}}},
+ {{"safe_divy"}, "Xdivy", {"is_zero_cast", "y"}},
+ {{"y2"}, "Square", {"y"}},
+ {{"negy2"}, "Neg", {"y2"}},
+ {{"xdivygrad"}, "Xdivy", {"x", "negy2"}},
+ {{"gx"}, "Mul", {"safe_divy", "dz"}},
+ {{"gy"}, "Mul", {"xdivygrad", "dz"}},
+ });
+ // clang-format on
+}
+REGISTER_OP_GRADIENT("Xdivy", XdivyGrad);
+
Status MaximumMinimumGradHelper(const string& comparator,
const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 2a27ef3ddb..9fc6b34147 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -753,6 +753,78 @@ TEST_F(MathGradTest, Div) {
}
}
+TEST_F(MathGradTest, DivNoNan) {
+ auto x = test::AsTensor<float>(
+ {0.f, -3.f, -2.f, -1.f, 0.f, 1.f, 2.f, 3.f, 0.f}, TensorShape({3, 3}));
+ auto y = test::AsTensor<float>({-10.f, 0.f, 10.f}, TensorShape({3, 1}));
+ Tensor dx;
+ Tensor dy;
+ {
+ SymGrad("DivNoNan", x, y, &dx, &dy);
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return 1.f / y;
+ }
+ };
+ test::ExpectClose(dx, test::AsTensor<float>(
+ {g(0.f, -10.f), g(-3.f, -10.f), g(-2.f, -10.f),
+ g(-1.f, 0.f), g(0.f, 0.f), g(1.f, 0.f),
+ g(2.f, 10.f), g(3.f, 10.f), g(0.f, 10.f)},
+ TensorShape({3, 3})));
+ }
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return -x / (y * y);
+ }
+ };
+ test::ExpectClose(dy,
+ test::AsTensor<float>(
+ {g(0.f, -10.f) + g(-3.f, -10.f) + g(-2.f, -10.f),
+ g(-1.f, 0.f) + g(0.f, 0.f) + g(1.f, 0.f),
+ g(2.f, 10.f) + g(3.f, 10.f) + g(0.f, 10.f)},
+ TensorShape({3, 1})));
+ }
+ }
+ { // Swap x and y.
+ SymGrad("DivNoNan", y, x, &dy, &dx);
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return 1.f / y;
+ }
+ };
+ test::ExpectClose(dy,
+ test::AsTensor<float>(
+ {g(-10.f, 0.f) + g(-10.f, -3.f) + g(-10.f, -2.f),
+ g(0.f, -1.f) + g(0.f, 0.f) + g(0.f, 1.f),
+ g(10.f, 2.f) + g(10.f, 3.f) + g(10.f, 0.f)},
+ TensorShape({3, 1})));
+ }
+ {
+ auto g = [](float x, float y) {
+ if (y == 0.f) {
+ return 0.f;
+ } else {
+ return -x / (y * y);
+ }
+ };
+ test::ExpectClose(dx, test::AsTensor<float>(
+ {g(-10.f, 0.f), g(-10.f, -3.f), g(-10.f, -2.f),
+ g(0.f, -1.f), g(0.f, 0.f), g(0.f, 1.f),
+ g(10.f, 2.f), g(10.f, 3.f), g(10.f, 0.f)},
+ TensorShape({3, 3})));
+ }
+ }
+}
+
TEST_F(MathGradTest, Pow) {
auto x = test::AsTensor<float>({0.f, 1.f, 2.f, 3.f, 4.f, 5.f},
TensorShape({2, 3}));
@@ -837,6 +909,46 @@ TEST_F(MathGradTest, ComplexPow) {
}
#endif // TENSORFLOW_USE_SYCL
+TEST_F(MathGradTest, Xlogy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : std::log(y); };
+ auto h = [](float x, float y) -> float { return x == 0. ? 0. : x / y; };
+ SymGrad("Xlogy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
+TEST_F(MathGradTest, Xdivy) {
+ auto x = test::AsTensor<float>({0.f, 0.f, 2.f, 3.f, 4.f, 5.f},
+ TensorShape({2, 3}));
+ auto y = test::AsTensor<float>({.5f, 2.f}, TensorShape({2, 1}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](float x, float y) -> float { return x == 0. ? 0. : 1 / y; };
+ auto h = [](float x, float y) -> float {
+ return x == 0. ? 0. : -x / (y * y);
+ };
+ SymGrad("Xdivy", x, y, &dx, &dy);
+ test::ExpectClose(
+ dx, test::AsTensor<float>({g(0.f, .5f), g(0.f, 0.f), g(2.f, .5f),
+ g(3.f, 2.f), g(4.f, 2.f), g(5.f, 2.f)},
+ TensorShape({2, 3})));
+ test::ExpectClose(
+ dy, test::AsTensor<float>({h(0.f, .5f) + h(0.f, 0.f) + h(2.f, .5f),
+ h(3.f, 2.f) + h(4.f, 2.f) + h(5.f, 2.f)},
+ TensorShape({2, 1})));
+}
+
TEST_F(MathGradTest, Maximum) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index 0ba4a9a005..a9e5e7824d 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -392,6 +392,13 @@ Returns x * y element-wise.
REGISTER_OP("Div").BINARY_MORE().SetShapeFn(
shape_inference::BroadcastBinaryOpShapeFn);
+REGISTER_OP("DivNoNan")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {float, double}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
REGISTER_OP("FloorDiv")
.BINARY_MORE()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
@@ -422,6 +429,20 @@ Returns (x - y)(x - y) element-wise.
[here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
)doc");
+REGISTER_OP("Xlogy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
+REGISTER_OP("Xdivy")
+ .Input("x: T")
+ .Input("y: T")
+ .Output("z: T")
+ .Attr("T: {half, float, double, complex64, complex128}")
+ .SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn);
+
#undef BINARY_FEWER
#undef BINARY_MORE
diff --git a/tensorflow/core/ops/math_ops_test.cc b/tensorflow/core/ops/math_ops_test.cc
index 7bf7c476f4..05379a7d69 100644
--- a/tensorflow/core/ops/math_ops_test.cc
+++ b/tensorflow/core/ops/math_ops_test.cc
@@ -120,7 +120,8 @@ TEST(MathOpsTest, BroadcastBinaryOps_ShapeFn) {
"Maximum", "Minimum",
"Mod", "Mul",
"NotEqual", "Pow",
- "Sub", "SquaredDifference"}) {
+ "Sub", "SquaredDifference",
+ "DivNoNan"}) {
ShapeInferenceTestOp op(op_name);
INFER_OK(op, "?;?", "?");
INFER_OK(op, "[1,2];?", "?");
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index e0f25fb4ef..d1d81b27cc 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -178,7 +178,7 @@ REGISTER_OP("FusedBatchNorm")
.Output("reserve_space_2: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -196,7 +196,7 @@ REGISTER_OP("FusedBatchNormV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormShape);
@@ -213,7 +213,7 @@ REGISTER_OP("FusedBatchNormGrad")
.Output("reserve_space_4: T")
.Attr("T: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -231,7 +231,7 @@ REGISTER_OP("FusedBatchNormGradV2")
.Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
- .Attr("data_format: string = 'NHWC'")
+ .Attr(GetConvnetDataFormatAttrString())
.Attr("is_training: bool = true")
.SetShapeFn(shape_inference::FusedBatchNormGradShape);
@@ -960,7 +960,7 @@ REGISTER_OP("Dilation2DBackpropFilter")
REGISTER_OP("Relu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {realnumbertype, qint8}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("ReluGrad")
@@ -1012,27 +1012,27 @@ REGISTER_OP("SeluGrad")
REGISTER_OP("Softplus")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftplusGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
REGISTER_OP("Softsign")
.Input("features: T")
.Output("activations: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape);
REGISTER_OP("SoftsignGrad")
.Input("gradients: T")
.Input("features: T")
.Output("backprops: T")
- .Attr("T: realnumbertype")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn);
// --------------------------------------------------------------------------
@@ -1736,6 +1736,87 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklConv3D")
+ .Input("input: T")
+ .Input("filter: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter: uint8")
+ .Output("output: T")
+ .Output("filter_output: T")
+ .Output("mkl_output: uint8")
+ .Output("mkl_filter_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn(shape_inference::Conv3DShape)
+ .Doc(R"doc(
+MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropInputV2")
+ .Input("input_sizes: Tshape")
+ .Input("filter: T")
+ .Input("out_backprop: T")
+ .Input("mkl_input_sizes: uint8")
+ .Input("mkl_filter: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int) >= 5")
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .Attr("Tshape: {int32, int64} = DT_INT32")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklConv3DBackpropFilterV2")
+ .Input("input: T")
+ .Input("filter_sizes: int32")
+ .Input("out_backprop: T")
+ .Input("mkl_input: uint8")
+ .Input("mkl_filter_size: uint8")
+ .Input("mkl_out_backprop: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("T: {half, float, double}")
+ .Attr("strides: list(int)")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
+gradients of convolution with respect to the filter.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklRelu")
.Input("features: T")
.Input("mkl_features: uint8")
@@ -1943,6 +2024,104 @@ NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
expected to invoke these operators.
)doc");
+REGISTER_OP("_MklAvgPool3D")
+ .Input("value: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn(shape_inference::Pool3DShape)
+ .Doc(R"doc(
+MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+
+REGISTER_OP("_MklAvgPool3DGrad")
+ .Input("orig_input_shape: int32")
+ .Input("grad: T")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_grad: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {float, half, double}")
+ .SetShapeFn([](InferenceContext* c) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
+ TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
+ c->set_output(0, s);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients
+of AvgPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklMaxPool3D")
+ .Input("input: T")
+ .Input("mkl_input: uint8")
+ .Output("output: T")
+ .Output("workspace: uint8")
+ .Output("mkl_output: uint8")
+ .Output("mkl_workspace: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {half, bfloat16, float}")
+ .Attr("workspace_enabled: bool = false")
+ .SetShapeFn(shape_inference::Pool3DShape)
+ .Doc(R"doc(
+MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
+on the input.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
+REGISTER_OP("_MklMaxPool3DGrad")
+ .Input("orig_input: TInput")
+ .Input("orig_output: TInput")
+ .Input("grad: T")
+ .Input("workspace: uint8")
+ .Input("mkl_orig_input: uint8")
+ .Input("mkl_orig_output: uint8")
+ .Input("mkl_grad: uint8")
+ .Input("mkl_workspace: uint8")
+ .Output("output: T")
+ .Output("mkl_output: uint8")
+ .Attr("ksize: list(int) >= 5")
+ .Attr("strides: list(int) >= 5")
+ .Attr(GetPaddingAttrString())
+ .Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
+ .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
+ .Attr("workspace_enabled: bool = false")
+ .SetShapeFn([](InferenceContext* c) {
+ return UnchangedShapeWithRank(c, 5);
+ })
+ .Doc(R"doc(
+MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients
+of MklPool function.
+
+NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
+expected to invoke these operators.
+)doc");
+
REGISTER_OP("_MklLRN")
.Input("input: T")
.Input("mkl_input: uint8")
@@ -2161,7 +2340,7 @@ REGISTER_OP("_MklToTf")
.Input("mkl_input: uint8")
.Output("output: T")
.Attr("T: {half, float, double}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to convert a tensor from MKL layout to TensorFlow layout.
@@ -2183,7 +2362,7 @@ REGISTER_OP("_MklInputConversion")
.Attr(
"T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
"complex64, complex128}")
- .Attr(GetConvnetDataFormatAttrString())
+ .Attr(GetConvnetDataFormat2D3DAttrString())
.SetShapeFn(shape_inference::UnknownShape)
.Doc(R"doc(
MKL operator to process the inputs to an elementwise MKL op. Both inputs
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 59ba1c5306..0d8997c1bd 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4272,6 +4272,29 @@ op {
is_commutative: true
}
op {
+ name: "BoostedTreesBucketize"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ output_arg {
+ name: "buckets"
+ type: DT_INT32
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesCalculateBestGainsPerFeature"
input_arg {
name: "node_id_range"
@@ -4381,6 +4404,29 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesCreateQuantileStreamResource"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "num_streams"
+ type: DT_INT64
+ }
+ attr {
+ name: "max_elements"
+ type: "int"
+ default_value {
+ i: 1099511627776
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesDeserializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -4474,6 +4520,32 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesMakeQuantileSummaries"
+ input_arg {
+ name: "float_values"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ input_arg {
+ name: "example_weights"
+ type: DT_FLOAT
+ }
+ input_arg {
+ name: "epsilon"
+ type: DT_FLOAT
+ }
+ output_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+}
+op {
name: "BoostedTreesMakeStatsSummary"
input_arg {
name: "node_ids"
@@ -4543,6 +4615,83 @@ op {
is_stateful: true
}
op {
+ name: "BoostedTreesQuantileStreamResourceAddSummaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "summaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceFlush"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "num_buckets"
+ type: DT_INT64
+ }
+ attr {
+ name: "generate_quantiles"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceGetBucketBoundaries"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "bucket_boundaries"
+ type: DT_FLOAT
+ number_attr: "num_features"
+ }
+ attr {
+ name: "num_features"
+ type: "int"
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "BoostedTreesQuantileStreamResourceHandleOp"
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
name: "BoostedTreesSerializeEnsemble"
input_arg {
name: "tree_ensemble_handle"
@@ -5592,6 +5741,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -9190,6 +9352,31 @@ op {
}
}
op {
+ name: "DivNoNan"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "DrawBoundingBoxes"
input_arg {
name: "images"
@@ -9642,6 +9829,25 @@ op {
is_stateful: true
}
op {
+ name: "EnsureShape"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+}
+op {
name: "Enter"
input_arg {
name: "data"
@@ -9833,6 +10039,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -9981,6 +10602,59 @@ op {
}
}
op {
+ name: "ExtractVolumePatches"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "patches"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksizes"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 5
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_INT64
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+}
+op {
name: "FFT"
input_arg {
name: "input"
@@ -10415,33 +11089,6 @@ op {
is_stateful: true
}
op {
- name: "FeatureStatsDataset"
- input_arg {
- name: "input_dataset"
- type: DT_VARIANT
- }
- input_arg {
- name: "tag"
- type: DT_STRING
- }
- output_arg {
- name: "handle"
- type: DT_VARIANT
- }
- attr {
- name: "output_types"
- type: "list(type)"
- has_minimum: true
- minimum: 1
- }
- attr {
- name: "output_shapes"
- type: "list(shape)"
- has_minimum: true
- minimum: 1
- }
-}
-op {
name: "Fill"
input_arg {
name: "dims"
@@ -11227,6 +11874,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11300,6 +11953,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11384,6 +12043,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -11468,6 +12133,12 @@ op {
default_value {
s: "NHWC"
}
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
}
attr {
name: "is_training"
@@ -12257,6 +12928,21 @@ op {
}
}
op {
+ name: "HostConst"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+}
+op {
name: "IFFT"
input_arg {
name: "input"
@@ -12490,6 +13176,14 @@ op {
name: "else_branch"
type: "func"
}
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -13117,6 +13811,18 @@ op {
is_stateful: true
}
op {
+ name: "IsBoostedTreesQuantileStreamResourceInitialized"
+ input_arg {
+ name: "quantile_stream_resource_handle"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "is_initialized"
+ type: DT_BOOL
+ }
+ is_stateful: true
+}
+op {
name: "IsFinite"
input_arg {
name: "x"
@@ -14330,6 +15036,38 @@ op {
}
}
op {
+ name: "LowerBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "MakeIterator"
input_arg {
name: "dataset"
@@ -14510,6 +15248,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "MapDefun"
@@ -14517,6 +15262,10 @@ op {
name: "arguments"
type_list_attr: "Targuments"
}
+ input_arg {
+ name: "captured_inputs"
+ type_list_attr: "Tcaptured"
+ }
output_arg {
name: "output"
type_list_attr: "output_types"
@@ -14528,6 +15277,15 @@ op {
minimum: 1
}
attr {
+ name: "Tcaptured"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
name: "output_types"
type: "list(type)"
has_minimum: true
@@ -15006,6 +15764,10 @@ op {
}
}
}
+ deprecation {
+ version: 27
+ explanation: "Use Python implementation tf.linalg.matrix_exponential instead."
+ }
}
op {
name: "MatrixInverse"
@@ -16504,6 +17266,29 @@ op {
}
}
op {
+ name: "ModelDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Mul"
input_arg {
name: "x"
@@ -16540,6 +17325,134 @@ op {
is_commutative: true
}
op {
+ name: "MultiDeviceIterator"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "devices"
+ type: "list(string)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorFromStringHandle"
+ input_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorGetNextFromShard"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "shard_num"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorInit"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "max_buffer_size"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "incarnation_id"
+ type: DT_INT64
+ }
+ is_stateful: true
+}
+op {
+ name: "MultiDeviceIteratorToStringHandle"
+ input_arg {
+ name: "multi_device_iterator"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "string_handle"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
name: "Multinomial"
input_arg {
name: "logits"
@@ -17042,11 +17955,11 @@ op {
name: "NonMaxSuppressionV2"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17060,16 +17973,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV3"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17087,16 +18013,29 @@ op {
name: "selected_indices"
type: DT_INT32
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
}
op {
name: "NonMaxSuppressionV4"
input_arg {
name: "boxes"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "scores"
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "max_output_size"
@@ -17119,6 +18058,19 @@ op {
type: DT_INT32
}
attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ }
+ }
+ }
+ attr {
name: "pad_to_max_output_size"
type: "bool"
default_value {
@@ -18156,6 +19108,54 @@ op {
}
}
op {
+ name: "ParallelInterleaveDatasetV2"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ input_arg {
+ name: "cycle_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "block_length"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "ParallelMapDataset"
input_arg {
name: "input_dataset"
@@ -18194,6 +19194,13 @@ op {
has_minimum: true
minimum: 1
}
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "ParameterizedTruncatedNormal"
@@ -18342,6 +19349,271 @@ op {
}
}
op {
+ name: "ParseExampleDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "num_parallel_calls"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "dense_defaults"
+ type_list_attr: "Tdense"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "sparse_types"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tdense"
+ type: "list(type)"
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "dense_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ParseSequenceExample"
+ input_arg {
+ name: "serialized"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "debug_name"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "context_dense_defaults"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "context_sparse_indices"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_sparse_values"
+ type_list_attr: "context_sparse_types"
+ }
+ output_arg {
+ name: "context_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Ncontext_sparse"
+ }
+ output_arg {
+ name: "context_dense_values"
+ type_list_attr: "Tcontext_dense"
+ }
+ output_arg {
+ name: "feature_list_sparse_indices"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_sparse_values"
+ type_list_attr: "feature_list_sparse_types"
+ }
+ output_arg {
+ name: "feature_list_sparse_shapes"
+ type: DT_INT64
+ number_attr: "Nfeature_list_sparse"
+ }
+ output_arg {
+ name: "feature_list_dense_values"
+ type_list_attr: "feature_list_dense_types"
+ }
+ output_arg {
+ name: "feature_list_dense_lengths"
+ type: DT_INT64
+ number_attr: "Nfeature_list_dense"
+ }
+ attr {
+ name: "feature_list_dense_missing_assumed_empty"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "context_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_dense_keys"
+ type: "list(string)"
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Ncontext_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_sparse"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "Nfeature_list_dense"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "context_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "Tcontext_dense"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "context_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+ attr {
+ name: "feature_list_sparse_types"
+ type: "list(type)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "feature_list_dense_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ has_minimum: true
+ }
+}
+op {
name: "ParseSingleExample"
input_arg {
name: "serialized"
@@ -18922,6 +20194,30 @@ op {
is_stateful: true
}
op {
+ name: "PrintV2"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ attr {
+ name: "output_stream"
+ type: "string"
+ default_value {
+ s: "stderr"
+ }
+ allowed_values {
+ list {
+ s: "stdout"
+ s: "stderr"
+ s: "log(info)"
+ s: "log(warning)"
+ s: "log(error)"
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "PriorityQueue"
output_arg {
name: "handle"
@@ -22009,6 +23305,59 @@ op {
is_stateful: true
}
op {
+ name: "ReduceDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "initial_state"
+ type_list_attr: "Tstate"
+ }
+ input_arg {
+ name: "other_arguments"
+ type_list_attr: "Targuments"
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "Tstate"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "Targuments"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "use_inter_op_parallelism"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ReduceJoin"
input_arg {
name: "inputs"
@@ -22275,6 +23624,7 @@ op {
type: DT_HALF
type: DT_UINT32
type: DT_UINT64
+ type: DT_QINT8
}
}
}
@@ -26675,6 +28025,7 @@ op {
s: "squared_loss"
s: "hinge_loss"
s: "smooth_hinge_loss"
+ s: "poisson_loss"
}
}
}
@@ -27265,6 +28616,14 @@ op {
name: "stats_aggregator"
type: DT_RESOURCE
}
+ input_arg {
+ name: "tag"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "counter_prefix"
+ type: DT_STRING
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -27876,18 +29235,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -27911,18 +29262,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -27942,18 +29285,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -27977,18 +29312,10 @@ op {
type: "type"
allowed_values {
list {
+ type: DT_HALF
+ type: DT_BFLOAT16
type: DT_FLOAT
type: DT_DOUBLE
- type: DT_INT32
- type: DT_UINT8
- type: DT_INT16
- type: DT_INT8
- type: DT_INT64
- type: DT_BFLOAT16
- type: DT_UINT16
- type: DT_HALF
- type: DT_UINT32
- type: DT_UINT64
}
}
}
@@ -29307,6 +30634,19 @@ op {
s: ""
}
}
+ attr {
+ name: "reduction_type"
+ type: "string"
+ default_value {
+ s: "MEAN"
+ }
+ allowed_values {
+ list {
+ s: "MEAN"
+ s: "SUM"
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -31805,6 +33145,47 @@ op {
}
}
op {
+ name: "StaticRegexFullMatch"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_BOOL
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+}
+op {
+ name: "StaticRegexReplace"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "pattern"
+ type: "string"
+ }
+ attr {
+ name: "rewrite"
+ type: "string"
+ }
+ attr {
+ name: "replace_global"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "StatsAggregatorHandle"
output_arg {
name: "handle"
@@ -32080,6 +33461,43 @@ op {
}
}
op {
+ name: "StringFormat"
+ input_arg {
+ name: "inputs"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "template"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "placeholder"
+ type: "string"
+ default_value {
+ s: "%s"
+ }
+ }
+ attr {
+ name: "summarize"
+ type: "int"
+ default_value {
+ i: 3
+ }
+ }
+}
+op {
name: "StringJoin"
input_arg {
name: "inputs"
@@ -32105,6 +33523,30 @@ op {
}
}
op {
+ name: "StringLength"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
+}
+op {
name: "StringSplit"
input_arg {
name: "input"
@@ -32319,6 +33761,19 @@ op {
}
}
}
+ attr {
+ name: "unit"
+ type: "string"
+ default_value {
+ s: "BYTE"
+ }
+ allowed_values {
+ list {
+ s: "BYTE"
+ s: "UTF8_CHAR"
+ }
+ }
+ }
}
op {
name: "Sum"
@@ -33810,6 +35265,25 @@ op {
}
}
op {
+ name: "TensorListGather"
+ input_arg {
+ name: "input_handle"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "values"
+ type_attr: "element_dtype"
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+}
+op {
name: "TensorListGetItem"
input_arg {
name: "input_handle"
@@ -33926,6 +35400,39 @@ op {
}
}
op {
+ name: "TensorListScatter"
+ input_arg {
+ name: "tensor"
+ type_attr: "element_dtype"
+ }
+ input_arg {
+ name: "indices"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "element_shape"
+ type_attr: "shape_type"
+ }
+ output_arg {
+ name: "output_handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "element_dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape_type"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "TensorListSetItem"
input_arg {
name: "input_handle"
@@ -34652,6 +36159,17 @@ op {
}
}
op {
+ name: "UnicodeScript"
+ input_arg {
+ name: "input"
+ type: DT_INT32
+ }
+ output_arg {
+ name: "output"
+ type: DT_INT32
+ }
+}
+op {
name: "UniformCandidateSampler"
input_arg {
name: "true_classes"
@@ -35236,6 +36754,38 @@ op {
is_stateful: true
}
op {
+ name: "UpperBound"
+ input_arg {
+ name: "sorted_inputs"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "values"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "out_type"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "out_type"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "VarHandleOp"
output_arg {
name: "resource"
@@ -35427,6 +36977,14 @@ op {
name: "body"
type: "func"
}
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
is_stateful: true
}
op {
@@ -35481,9 +37039,21 @@ op {
type: DT_VARIANT
}
input_arg {
- name: "window_size"
+ name: "size"
type: DT_INT64
}
+ input_arg {
+ name: "shift"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "stride"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "drop_remainder"
+ type: DT_BOOL
+ }
output_arg {
name: "handle"
type: DT_VARIANT
@@ -35720,6 +37290,62 @@ op {
is_stateful: true
}
op {
+ name: "Xdivy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
+ name: "Xlogy"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+}
+op {
name: "ZerosLike"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/parsing_ops.cc b/tensorflow/core/ops/parsing_ops.cc
index ddb714b4e9..eff453241d 100644
--- a/tensorflow/core/ops/parsing_ops.cc
+++ b/tensorflow/core/ops/parsing_ops.cc
@@ -132,6 +132,99 @@ REGISTER_OP("ParseSingleExample")
return Status::OK();
});
+REGISTER_OP("ParseSequenceExample")
+ .Input("serialized: string")
+ .Input("debug_name: string")
+ .Input("context_dense_defaults: Tcontext_dense")
+ .Output("context_sparse_indices: Ncontext_sparse * int64")
+ .Output("context_sparse_values: context_sparse_types")
+ .Output("context_sparse_shapes: Ncontext_sparse * int64")
+ .Output("context_dense_values: Tcontext_dense")
+ .Output("feature_list_sparse_indices: Nfeature_list_sparse * int64")
+ .Output("feature_list_sparse_values: feature_list_sparse_types")
+ .Output("feature_list_sparse_shapes: Nfeature_list_sparse * int64")
+ .Output("feature_list_dense_values: feature_list_dense_types")
+ .Output("feature_list_dense_lengths: Nfeature_list_dense * int64")
+ .Attr("feature_list_dense_missing_assumed_empty: list(string) >= 0")
+ .Attr("context_sparse_keys: list(string) >= 0")
+ .Attr("context_dense_keys: list(string) >= 0")
+ .Attr("feature_list_sparse_keys: list(string) >= 0")
+ .Attr("feature_list_dense_keys: list(string) >= 0")
+ .Attr("Ncontext_sparse: int >= 0 = 0")
+ .Attr("Ncontext_dense: int >= 0 = 0")
+ .Attr("Nfeature_list_sparse: int >= 0 = 0")
+ .Attr("Nfeature_list_dense: int >= 0 = 0")
+ .Attr("context_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("Tcontext_dense: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_types: list({float,int64,string}) >= 0 = []")
+ .Attr("context_dense_shapes: list(shape) >= 0 = []")
+ .Attr("feature_list_sparse_types: list({float,int64,string}) >= 0 = []")
+ .Attr("feature_list_dense_shapes: list(shape) >= 0 = []")
+ .SetShapeFn([](InferenceContext* c) {
+ ParseSequenceExampleAttrs attrs;
+ TF_RETURN_IF_ERROR(attrs.Init(c));
+
+ // Verify that the input is a vector, and carry the shape if known.
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &input));
+ shape_inference::DimensionHandle num_examples = c->Dim(input, 0);
+
+ ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused)); // debug_name
+
+ int output_idx = 0;
+
+ // Output context_sparse_indices, context_sparse_values, and
+ // context_sparse_shapes.
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Matrix(c->UnknownDim(), 2));
+ }
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(c->UnknownDim()));
+ }
+ for (int i = 0; i < attrs.num_context_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(2));
+ }
+
+ // Output context_dense_values.
+ for (int i = 0; i < attrs.num_context_dense; ++i) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ attrs.context_dense_shapes[i], &s));
+ TF_RETURN_IF_ERROR(c->Concatenate(c->Vector(num_examples), s, &s));
+ c->set_output(output_idx++, s);
+ }
+
+ // Output feature_list_sparse_indices, feature_list_sparse_values,
+ // feature_list_sparse_shapes.
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Matrix(c->UnknownDim(), 3));
+ }
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(c->UnknownDim()));
+ }
+ for (int i = 0; i < attrs.num_feature_list_sparse; ++i) {
+ c->set_output(output_idx++, c->Vector(3));
+ }
+
+ // Output feature_list_dense_shapes.
+ for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ attrs.feature_list_dense_shapes[i], &s));
+ TF_RETURN_IF_ERROR(
+ c->Concatenate(c->Matrix(num_examples, c->UnknownDim()), s, &s));
+ c->set_output(output_idx++, s);
+ }
+
+ // Output feature_list_dense_lengths.
+ for (int i = 0; i < attrs.num_feature_list_dense; ++i) {
+ c->set_output(output_idx++, c->Vector(num_examples));
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ParseSingleSequenceExample")
.Input("serialized: string")
.Input("feature_list_dense_missing_assumed_empty: string")
@@ -250,10 +343,11 @@ REGISTER_OP("DecodeCSV")
// Validate the record_defaults inputs.
for (int i = 1; i < c->num_inputs(); ++i) {
ShapeHandle v;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &v));
- if (c->Value(c->Dim(v, 0)) > 1) {
+ TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(i), 1, &v));
+ if (c->Rank(c->input(i)) == 1 && c->Value(c->Dim(v, 0)) > 1) {
return errors::InvalidArgument(
- "Shape of a default must be a length-0 or length-1 vector");
+ "Shape of a default must be a length-0 or length-1 vector, or a "
+ "scalar.");
}
}
diff --git a/tensorflow/core/ops/parsing_ops_test.cc b/tensorflow/core/ops/parsing_ops_test.cc
index 9121d7ae92..ba594e400c 100644
--- a/tensorflow/core/ops/parsing_ops_test.cc
+++ b/tensorflow/core/ops/parsing_ops_test.cc
@@ -52,9 +52,12 @@ TEST(ParsingOpsTest, DecodeCSV_ShapeFn) {
INFER_OK(op, "[1,2,?,4];?;?", "in0;in0");
INFER_OK(op, "[1,2,?,4];[?];[?]", "in0;in0");
+ // Scalar defaults are ok
+ INFER_OK(op, "?;?;[]", "in0;in0");
+
// Check errors in the record_defaults inputs.
- INFER_ERROR("must be rank 1", op, "?;?;[]");
- INFER_ERROR("must be rank 1", op, "?;[];?");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;?;[1,2]");
+ INFER_ERROR("must be at most rank 1 but is rank 2", op, "?;[3,4];?");
INFER_ERROR("Shape of a default must be", op, "?;?;[2]");
INFER_ERROR("Shape of a default must be", op, "?;[2];?");
}
@@ -143,6 +146,88 @@ TEST(ParsingOpsTest, ParseExample_ShapeFn) {
"?;?;?;?;?;?;?;?;?;?");
}
+TEST(ParsingOpsTest, ParseSequenceExample_ShapeFn) {
+ ShapeInferenceTestOp op("ParseSequenceExample");
+ auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
+ int num_feature_list_sparse,
+ int num_feature_list_dense,
+ bool add_extra_shape = false) {
+ using NodeOutList = std::vector<NodeDefBuilder::NodeOut>;
+ using DataTypeList = std::vector<DataType>;
+ string string_in("test");
+ NodeDefBuilder::NodeOut node_in{"a", 0, DT_STRING};
+ TF_ASSERT_OK(
+ NodeDefBuilder("test", "ParseSequenceExample")
+ .Input("serialized", 0, DT_STRING)
+ .Input("debug_name", 0, DT_STRING)
+ .Input(NodeOutList(num_context_dense, node_in))
+ .Attr("Ncontext_sparse", num_context_sparse)
+ .Attr("Ncontext_dense", num_context_dense)
+ .Attr("Nfeature_list_sparse", num_feature_list_sparse)
+ .Attr("Nfeature_list_dense", num_feature_list_dense)
+ .Attr("feature_list_dense_missing_assumed_empty",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_keys",
+ std::vector<string>(num_context_sparse, string_in))
+ .Attr("context_dense_keys",
+ std::vector<string>(num_context_dense, string_in))
+ .Attr("feature_list_sparse_keys",
+ std::vector<string>(num_feature_list_sparse, string_in))
+ .Attr("feature_list_dense_keys",
+ std::vector<string>(num_feature_list_dense, string_in))
+ .Attr("context_sparse_types",
+ DataTypeList(num_context_sparse, DT_FLOAT))
+ .Attr("context_dense_types",
+ DataTypeList(num_context_dense, DT_FLOAT))
+ .Attr("context_dense_shapes",
+ MakeDenseShapes(num_context_dense, add_extra_shape, 0))
+ .Attr("feature_list_sparse_types",
+ DataTypeList(num_feature_list_sparse, DT_FLOAT))
+ .Attr("feature_list_dense_types",
+ DataTypeList(num_feature_list_dense, DT_FLOAT))
+ .Attr("feature_list_dense_shapes",
+ MakeDenseShapes(num_feature_list_dense, add_extra_shape, 0))
+ .Finalize(&op.node_def));
+ };
+
+ // Verify inputs 'serialized' and 'debug_name'.
+ set_outputs(0, 0, 0, 0);
+ INFER_OK(op, "[?];[?]", "");
+ INFER_OK(op, "[8];[8]", "");
+ INFER_ERROR("must be rank 1", op, "[];[?]");
+ INFER_ERROR("must be rank 1", op, "[?];[]");
+
+ // context inputs with no feature_list inputs.
+ set_outputs(2 /* num_context_sparse */, 3 /* num_context_dense */, 0, 0);
+ INFER_OK(op, "[?];[?];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3]")); // context dense
+
+ // feature_list inputs with no context inputs.
+ set_outputs(0, 0, 2 /* num_feature_list_sparse */,
+ 3 /* num_feature_list_dense */);
+ INFER_OK(op, "[?];[?]",
+ ("[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Combine previous two test cases.
+ set_outputs(2, 3, 2, 3);
+ INFER_OK(op, "[7];[7];?;?;?",
+ ("[?,2];[?,2];[?];[?];[2];[2];" // context sparse
+ "[d0_0,1];[d0_0,1,2];[d0_0,1,2,3];" // context dense
+ "[?,3];[?,3];[?];[?];[3];[3];" // feature_list sparse
+ "[d0_0,?,1];[d0_0,?,1,2];[d0_0,?,1,2,3];" // feature_list dense
+ "[d0_0];[d0_0];[d0_0]")); // feature_list length
+
+ // Confirm an error from ParseSequenceExampleAttrs.Init().
+ set_outputs(1, 1, 1, 1, true /* add_extra_shape */);
+ INFER_ERROR(
+ "num_context_dense (1) must match the size of context_dense_keys (1), "
+ "context_dense_types (1) and context_dense_shapes (2)",
+ op, "[?];[?];?");
+}
+
TEST(ParsingOpsTest, ParseSingleSequenceExample_ShapeFn) {
ShapeInferenceTestOp op("ParseSingleSequenceExample");
auto set_outputs = [&op](int num_context_sparse, int num_context_dense,
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index 26499540f1..adc9cd1486 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -19,6 +19,7 @@
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/lib/core/errors.h"
using ::tensorflow::shape_inference::InferenceContext;
using ::tensorflow::shape_inference::ShapeAndType;
@@ -56,6 +57,36 @@ Status ReadVariableShapeFn(InferenceContext* c) {
return Status::OK();
}
+Status ReadVariablesShapeFn(InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector value_dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &value_dtypes));
+ if (n != value_dtypes.size()) {
+ return errors::InvalidArgument(
+ "Mismatched number of arguments to ReadVariablesOp");
+ }
+ for (int i = 0; i < n; ++i) {
+ ShapeAndType shape_and_type;
+ auto* handle_data = c->input_handle_shapes_and_types(i);
+ if (handle_data == nullptr || handle_data->empty()) {
+ shape_and_type.shape = c->UnknownShape();
+ shape_and_type.dtype = DT_INVALID;
+ } else {
+ shape_and_type = (*handle_data)[0];
+ if (shape_and_type.dtype != value_dtypes[i]) {
+ return errors::InvalidArgument(
+ "Trying to read variable with wrong dtype. "
+ "Expected ",
+ DataTypeString(shape_and_type.dtype), " got ",
+ DataTypeString(value_dtypes[i]));
+ }
+ }
+ c->set_output(i, shape_and_type.shape);
+ }
+ return Status::OK();
+}
+
} // namespace
REGISTER_OP("VarHandleOp")
@@ -79,12 +110,53 @@ REGISTER_OP("VarHandleOp")
return Status::OK();
});
+REGISTER_OP("_VarHandlesOp")
+ .Attr("containers: list(string)")
+ .Attr("shared_names: list(string)")
+ .Attr("N: int >= 0")
+ .Attr("dtypes: list(type)")
+ .Attr("shapes: list(shape)")
+ .Output("resources: N * resource")
+ .SetIsStateful()
+ .SetShapeFn([](InferenceContext* c) {
+ int n;
+ TF_RETURN_IF_ERROR(c->GetAttr("N", &n));
+ DataTypeVector dtypes;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtypes", &dtypes));
+ std::vector<PartialTensorShape> shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("shapes", &shapes));
+ if (dtypes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of dtypes (n=", n,
+ ", num dtypes=", dtypes.size(), ")");
+ }
+ if (shapes.size() != n) {
+ return errors::InvalidArgument("Mismatched number of shapes (n=", n,
+ ", num shapes=", shapes.size(), ")");
+ }
+ for (int i = 0; i < n; ++i) {
+ c->set_output(i, c->Scalar());
+ ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(shapes[i], &s));
+ c->set_output_handle_shapes_and_types(
+ i, std::vector<ShapeAndType>{{s, dtypes[i]}});
+ }
+
+ return Status::OK();
+ });
+
REGISTER_OP("ReadVariableOp")
.Input("resource: resource")
.Output("value: dtype")
.Attr("dtype: type")
.SetShapeFn(ReadVariableShapeFn);
+REGISTER_OP("_ReadVariablesOp")
+ .Attr("N: int >= 0")
+ .Input("resources: N * resource")
+ .Output("values: dtypes")
+ .Attr("dtypes: list(type)")
+ .SetShapeFn(ReadVariablesShapeFn);
+
Status ReadGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FunctionDefHelper::Define(
diff --git a/tensorflow/core/ops/sdca_ops.cc b/tensorflow/core/ops/sdca_ops.cc
index 4025070adb..fdf53a55dd 100644
--- a/tensorflow/core/ops/sdca_ops.cc
+++ b/tensorflow/core/ops/sdca_ops.cc
@@ -41,7 +41,7 @@ static Status ApplySdcaOptimizerShapeFn(InferenceContext* c) {
REGISTER_OP("SdcaOptimizer")
.Attr(
"loss_type: {'logistic_loss', 'squared_loss', 'hinge_loss',"
- "'smooth_hinge_loss'}")
+ "'smooth_hinge_loss', 'poisson_loss'}")
.Attr("adaptative : bool=false")
.Attr("num_sparse_features: int >= 0")
.Attr("num_sparse_features_with_values: int >= 0")
diff --git a/tensorflow/core/ops/stateless_random_grad.cc b/tensorflow/core/ops/stateless_random_grad.cc
new file mode 100644
index 0000000000..331e1d0152
--- /dev/null
+++ b/tensorflow/core/ops/stateless_random_grad.cc
@@ -0,0 +1,23 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/function.h"
+
+namespace tensorflow {
+REGISTER_OP_NO_GRADIENT("StatelessRandomUniform");
+REGISTER_OP_NO_GRADIENT("StatelessRandomNormal");
+REGISTER_OP_NO_GRADIENT("StatelessTruncatedNormal");
+REGISTER_OP_NO_GRADIENT("StatelessMultinomial");
+} // end namespace tensorflow
diff --git a/tensorflow/core/ops/string_ops.cc b/tensorflow/core/ops/string_ops.cc
index 8c39d69157..94d71a4113 100644
--- a/tensorflow/core/ops/string_ops.cc
+++ b/tensorflow/core/ops/string_ops.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "absl/strings/str_split.h"
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
@@ -37,6 +38,14 @@ REGISTER_OP("RegexReplace")
return Status::OK();
});
+REGISTER_OP("StaticRegexReplace")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Attr("rewrite: string")
+ .Output("output: string")
+ .Attr("replace_global: bool = true")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("RegexFullMatch")
.Input("input: string")
.Input("pattern: string")
@@ -48,6 +57,12 @@ REGISTER_OP("RegexFullMatch")
return Status::OK();
});
+REGISTER_OP("StaticRegexFullMatch")
+ .Input("input: string")
+ .Attr("pattern: string")
+ .Output("output: bool")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("StringToHashBucketFast")
.Input("input: string")
.Output("output: int64")
@@ -88,6 +103,32 @@ REGISTER_OP("AsString")
.Attr("fill: string = ''")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringFormat")
+ .Input("inputs: T")
+ .Output("output: string")
+ .Attr("T: list(type) >= 0")
+ .Attr("template: string = '%s'")
+ .Attr("placeholder: string = '%s'")
+ .Attr("summarize: int = 3")
+ .SetShapeFn([](InferenceContext* c) {
+ string template_;
+ string placeholder;
+ TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
+ TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
+
+ std::vector<std::string> split_template;
+ split_template = absl::StrSplit(template_, placeholder);
+ int64 num_placeholders = split_template.size() - 1;
+ if (c->num_inputs() != num_placeholders) {
+ return errors::InvalidArgument(strings::StrCat(
+ "num placeholders in template and num inputs must match: ",
+ num_placeholders, " vs. ", c->num_inputs()));
+ }
+
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("StringJoin")
.Input("inputs: N * string")
.Attr("N: int")
@@ -159,6 +200,12 @@ REGISTER_OP("StringStrip")
.Output("output: string")
.SetShapeFn(shape_inference::UnchangedShape);
+REGISTER_OP("StringLength")
+ .Input("input: string")
+ .Output("output: int32")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
REGISTER_OP("EncodeBase64")
.Input("input: string")
.Output("output: string")
@@ -176,6 +223,7 @@ REGISTER_OP("Substr")
.Input("len: T")
.Output("output: string")
.Attr("T: {int32, int64}")
+ .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle pos_shape = c->input(1);
ShapeHandle len_shape = c->input(2);
@@ -197,4 +245,9 @@ REGISTER_OP("Substr")
return shape_inference::BroadcastBinaryOpShapeFn(c);
});
+REGISTER_OP("UnicodeScript")
+ .Input("input: int32")
+ .Output("output: int32")
+ .SetShapeFn(shape_inference::UnchangedShape);
+
} // namespace tensorflow
diff --git a/tensorflow/core/platform/abi.cc b/tensorflow/core/platform/abi.cc
index e597a490d6..d7a13a3528 100644
--- a/tensorflow/core/platform/abi.cc
+++ b/tensorflow/core/platform/abi.cc
@@ -37,13 +37,13 @@ extern "C" char* __unDName(char* output_string, const char* name,
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name) {
+string MaybeAbiDemangle(const char* name) {
#if defined(_MSC_VER)
std::unique_ptr<char> demangled{__unDName(nullptr, name, 0, std::malloc,
std::free,
static_cast<unsigned short>(0))};
- return std::string(demangled.get() != nullptr ? demangled.get() : name);
+ return string(demangled.get() != nullptr ? demangled.get() : name);
#else
int status = 0;
std::unique_ptr<char, void (*)(void*)> res{
diff --git a/tensorflow/core/platform/abi.h b/tensorflow/core/platform/abi.h
index 763d467457..d1498a6a64 100644
--- a/tensorflow/core/platform/abi.h
+++ b/tensorflow/core/platform/abi.h
@@ -13,17 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_ABI_H_
-#define TENSORFLOW_PLATFORM_ABI_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_ABI_H_
+#define TENSORFLOW_CORE_PLATFORM_ABI_H_
#include <string>
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace port {
-std::string MaybeAbiDemangle(const char* name);
+string MaybeAbiDemangle(const char* name);
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_ABI_H_
+#endif // TENSORFLOW_CORE_PLATFORM_ABI_H_
diff --git a/tensorflow/core/platform/cloud/auth_provider.h b/tensorflow/core/platform/cloud/auth_provider.h
index 465ff248d9..7347bc626d 100644
--- a/tensorflow/core/platform/cloud/auth_provider.h
+++ b/tensorflow/core/platform/cloud/auth_provider.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
-#define TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_
#include <string>
#include "tensorflow/core/lib/core/errors.h"
@@ -51,4 +51,4 @@ class EmptyAuthProvider : public AuthProvider {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_AUTH_PROVIDER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_AUTH_PROVIDER_H_
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
index f41b83ac34..affb68ebbb 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/platform/cloud/curl_http_request.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
namespace tensorflow {
@@ -25,21 +24,14 @@ namespace {
// The URL to retrieve metadata when running in Google Compute Engine.
constexpr char kGceMetadataBaseUrl[] = "http://metadata/computeMetadata/v1/";
-// The default initial delay between retries with exponential backoff.
-constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
} // namespace
ComputeEngineMetadataClient::ComputeEngineMetadataClient(
- std::shared_ptr<HttpRequest::Factory> http_request_factory)
- : ComputeEngineMetadataClient(std::move(http_request_factory),
- kInitialRetryDelayUsec) {}
-
-ComputeEngineMetadataClient::ComputeEngineMetadataClient(
std::shared_ptr<HttpRequest::Factory> http_request_factory,
- int64 initial_retry_delay_usec)
+ const RetryConfig& config)
: http_request_factory_(std::move(http_request_factory)),
- initial_retry_delay_usec_(initial_retry_delay_usec) {}
+ retry_config_(config) {}
Status ComputeEngineMetadataClient::GetMetadata(
const string& path, std::vector<char>* response_buffer) {
@@ -52,8 +44,7 @@ Status ComputeEngineMetadataClient::GetMetadata(
return Status::OK();
};
- return RetryingUtils::CallWithRetries(get_metadata_from_gce,
- initial_retry_delay_usec_);
+ return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
index 534ccf30b2..7f060327da 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/http_request.h"
+#include "tensorflow/core/platform/cloud/retrying_utils.h"
namespace tensorflow {
@@ -31,10 +32,11 @@ namespace tensorflow {
class ComputeEngineMetadataClient {
public:
explicit ComputeEngineMetadataClient(
- std::shared_ptr<HttpRequest::Factory> http_request_factory);
- ComputeEngineMetadataClient(
std::shared_ptr<HttpRequest::Factory> http_request_factory,
- int64 initial_retry_delay_usec);
+ const RetryConfig& config = RetryConfig(
+ 10000, /* init_delay_time_us = 1 ms */
+ 1000000 /* max_delay_time_us = 1 s */
+ ));
virtual ~ComputeEngineMetadataClient() {}
/// \brief Get the metadata value for a given attribute of the metadata
@@ -54,7 +56,7 @@ class ComputeEngineMetadataClient {
private:
std::shared_ptr<HttpRequest::Factory> http_request_factory_;
- const int64 initial_retry_delay_usec_;
+ const RetryConfig retry_config_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineMetadataClient);
};
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
index 4c41ccaa0e..e891b4a5e9 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
@@ -30,7 +30,8 @@ TEST(ComputeEngineMetadataClientTest, GetMetadata) {
std::shared_ptr<HttpRequest::Factory> http_factory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- ComputeEngineMetadataClient client(http_factory, 0);
+ ComputeEngineMetadataClient client(http_factory,
+ RetryConfig(0 /* init_delay_time_us */));
std::vector<char> result;
TF_EXPECT_OK(
@@ -56,7 +57,8 @@ TEST(ComputeEngineMetadataClientTest, RetryOnFailure) {
std::shared_ptr<HttpRequest::Factory> http_factory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- ComputeEngineMetadataClient client(http_factory, 0);
+ ComputeEngineMetadataClient client(http_factory,
+ RetryConfig(0 /* init_delay_time_us */));
std::vector<char> result;
TF_EXPECT_OK(
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
index dacf56187c..e147d88371 100644
--- a/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider.cc
@@ -43,7 +43,7 @@ Status ComputeEngineZoneProvider::GetZone(string* zone) {
*zone = cached_zone;
} else {
LOG(ERROR) << "Failed to parse the zone name from location: "
- << location.ToString();
+ << string(location);
}
return Status::OK();
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
index f7477eca23..476e4f9c1f 100644
--- a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
@@ -34,8 +34,8 @@ TEST_F(ComputeEngineZoneProviderTest, GetZone) {
auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadata_client =
- std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+ auto metadata_client = std::make_shared<ComputeEngineMetadataClient>(
+ httpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
ComputeEngineZoneProvider provider(metadata_client);
@@ -55,8 +55,8 @@ TEST_F(ComputeEngineZoneProviderTest, InvalidZoneString) {
auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadata_client =
- std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+ auto metadata_client = std::make_shared<ComputeEngineMetadataClient>(
+ httpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
ComputeEngineZoneProvider provider(metadata_client);
diff --git a/tensorflow/core/platform/cloud/curl_http_request.cc b/tensorflow/core/platform/cloud/curl_http_request.cc
index a1be4aacce..5e1eabee5b 100644
--- a/tensorflow/core/platform/cloud/curl_http_request.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request.cc
@@ -394,9 +394,9 @@ size_t CurlHttpRequest::HeaderCallback(const void* ptr, size_t size,
.StopCapture()
.OneLiteral(": ")
.GetResult(&value, &name)) {
- string str_value = std::string(value);
+ string str_value(value);
str_util::StripTrailingWhitespace(&str_value);
- that->response_headers_[std::string(name)] = str_value;
+ that->response_headers_[string(name)] = str_value;
}
return size * nmemb;
}
diff --git a/tensorflow/core/platform/cloud/gcs_dns_cache.h b/tensorflow/core/platform/cloud/gcs_dns_cache.h
index 40f16f1044..07d0e59fd5 100644
--- a/tensorflow/core/platform/cloud/gcs_dns_cache.h
+++ b/tensorflow/core/platform/cloud/gcs_dns_cache.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
-#define TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_
#include <random>
@@ -74,4 +74,4 @@ class GcsDnsCache {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATNFORM_CLOUD_DNS_CACHE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCS_DNS_CACHE_H_
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 9d33787bd5..c61b68aeeb 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -25,6 +25,7 @@ limitations under the License.
#ifdef _WIN32
#include <io.h> // for _mktemp
#endif
+#include "absl/base/macros.h"
#include "include/json/json.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
@@ -63,7 +64,7 @@ constexpr int kGetChildrenDefaultPageSize = 1000;
// The HTTP response code "308 Resume Incomplete".
constexpr uint64 HTTP_CODE_RESUME_INCOMPLETE = 308;
// The environment variable that overrides the size of the readahead buffer.
-// DEPRECATED. Use GCS_BLOCK_SIZE_MB instead.
+ABSL_DEPRECATED("Use GCS_BLOCK_SIZE_MB instead.")
constexpr char kReadaheadBufferSize[] = "GCS_READAHEAD_BUFFER_SIZE_BYTES";
// The environment variable that disables the GCS block cache for reads.
// This is the explicit alternative to setting BLOCK_SIZE or MAX_SIZE to 0, and
@@ -179,13 +180,13 @@ Status ParseGcsPath(StringPiece fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("GCS path doesn't start with 'gs://': ",
fname);
}
- *bucket = std::string(bucketp);
+ *bucket = string(bucketp);
if (bucket->empty() || *bucket == ".") {
return errors::InvalidArgument("GCS path doesn't contain a bucket name: ",
fname);
}
str_util::ConsumePrefix(&objectp, "/");
- *object = std::string(objectp);
+ *object = string(objectp);
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("GCS path doesn't contain an object name: ",
fname);
@@ -224,7 +225,7 @@ std::set<string> AddAllSubpaths(const std::vector<string>& paths) {
for (const string& path : paths) {
StringPiece subpath = io::Dirname(path);
while (!subpath.empty()) {
- result.emplace(std::string(subpath));
+ result.emplace(string(subpath));
subpath = io::Dirname(subpath);
}
}
@@ -332,14 +333,14 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem* filesystem,
GcsFileSystem::TimeoutConfig* timeouts,
std::function<void()> file_cache_erase,
- int64 initial_retry_delay_usec)
+ RetryConfig retry_config)
: bucket_(bucket),
object_(object),
filesystem_(filesystem),
timeouts_(timeouts),
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
- initial_retry_delay_usec_(initial_retry_delay_usec) {
+ retry_config_(retry_config) {
// TODO: to make it safer, outfile_ should be constructed from an FD
if (GetTmpFilename(&tmp_content_filename_).ok()) {
outfile_.open(tmp_content_filename_,
@@ -356,14 +357,14 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem* filesystem, const string& tmp_content_filename,
GcsFileSystem::TimeoutConfig* timeouts,
std::function<void()> file_cache_erase,
- int64 initial_retry_delay_usec)
+ RetryConfig retry_config)
: bucket_(bucket),
object_(object),
filesystem_(filesystem),
timeouts_(timeouts),
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
- initial_retry_delay_usec_(initial_retry_delay_usec) {
+ retry_config_(retry_config) {
tmp_content_filename_ = tmp_content_filename;
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
@@ -371,7 +372,7 @@ class GcsWritableFile : public WritableFile {
~GcsWritableFile() override { Close().IgnoreError(); }
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
TF_RETURN_IF_ERROR(CheckWritable());
sync_needed_ = true;
outfile_ << data;
@@ -440,7 +441,7 @@ class GcsWritableFile : public WritableFile {
first_attempt = false;
return UploadToSession(session_uri, already_uploaded);
},
- initial_retry_delay_usec_);
+ retry_config_);
if (upload_status.code() == errors::Code::NOT_FOUND) {
// GCS docs recommend retrying the whole upload. We're relying on the
// RetryingFileSystem to retry the Sync() call.
@@ -585,7 +586,7 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem::TimeoutConfig* timeouts_;
std::function<void()> file_cache_erase_;
bool sync_needed_; // whether there is buffered data that needs to be synced
- int64 initial_retry_delay_usec_;
+ RetryConfig retry_config_;
};
class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
@@ -723,7 +724,7 @@ GcsFileSystem::GcsFileSystem() {
if (!header_name.empty() && !header_value.empty()) {
additional_header_.reset(new std::pair<const string, const string>(
- std::string(header_name), std::string(header_value)));
+ string(header_name), string(header_value)));
VLOG(1) << "GCS additional header ENABLED. "
<< "Name: " << additional_header_->first << ", "
@@ -790,7 +791,7 @@ GcsFileSystem::GcsFileSystem(
std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age,
size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age,
- size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec,
+ size_t matching_paths_cache_max_entries, RetryConfig retry_config,
TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header)
: auth_provider_(std::move(auth_provider)),
@@ -805,7 +806,7 @@ GcsFileSystem::GcsFileSystem(
kCacheNeverExpire, kBucketLocationCacheMaxEntries)),
allowed_locations_(allowed_locations),
timeouts_(timeouts),
- initial_retry_delay_usec_(initial_retry_delay_usec),
+ retry_config_(retry_config),
additional_header_(additional_header) {}
Status GcsFileSystem::NewRandomAccessFile(
@@ -940,7 +941,7 @@ Status GcsFileSystem::NewWritableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(bucket, object, this, &timeouts_,
[this, fname]() { ClearFileCaches(fname); },
- initial_retry_delay_usec_));
+ retry_config_));
return Status::OK();
}
@@ -980,7 +981,7 @@ Status GcsFileSystem::NewAppendableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(
bucket, object, this, old_content_filename, &timeouts_,
- [this, fname]() { ClearFileCaches(fname); }, initial_retry_delay_usec_));
+ [this, fname]() { ClearFileCaches(fname); }, retry_config_));
return Status::OK();
}
@@ -1229,7 +1230,7 @@ Status GcsFileSystem::GetMatchingPaths(const string& pattern,
// Find the fixed prefix by looking for the first wildcard.
const string& fixed_prefix =
pattern.substr(0, pattern.find_first_of("*?[\\"));
- const string& dir = std::string(io::Dirname(fixed_prefix));
+ const string dir(io::Dirname(fixed_prefix));
if (dir.empty()) {
return errors::InvalidArgument(
"A GCS pattern doesn't have a bucket name: ", pattern);
@@ -1326,7 +1327,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
" doesn't match the prefix ", object_prefix));
}
if (!relative_path.empty() || include_self_directory_marker) {
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
}
if (++retrieved_results >= max_results) {
return Status::OK();
@@ -1354,7 +1355,7 @@ Status GcsFileSystem::GetChildrenBounded(const string& dirname,
"Unexpected response: the returned folder name ", prefix_str,
" doesn't match the prefix ", object_prefix);
}
- result->emplace_back(std::string(relative_path));
+ result->emplace_back(relative_path);
if (++retrieved_results >= max_results) {
return Status::OK();
}
@@ -1533,7 +1534,7 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) {
// on the server side, we can't just retry the whole RenameFile operation
// because the source object is already gone.
return RetryingUtils::DeleteWithRetries(
- [this, &src]() { return DeleteFile(src); }, initial_retry_delay_usec_);
+ [this, &src]() { return DeleteFile(src); }, retry_config_);
}
Status GcsFileSystem::IsDirectory(const string& fname) {
@@ -1589,8 +1590,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname,
// and therefore RetryingFileSystem won't pay attention to the failures,
// we need to make sure these failures are properly retried.
const auto& delete_file_status = RetryingUtils::DeleteWithRetries(
- [this, &full_path]() { return DeleteFile(full_path); },
- initial_retry_delay_usec_);
+ [this, &full_path]() { return DeleteFile(full_path); }, retry_config_);
if (!delete_file_status.ok()) {
if (IsDirectory(full_path).ok()) {
// The object is a directory marker.
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 71db707687..d0840a3046 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -93,7 +93,7 @@ class GcsFileSystem : public FileSystem {
uint64 stat_cache_max_age, size_t stat_cache_max_entries,
uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries,
- int64 initial_retry_delay_usec, TimeoutConfig timeouts,
+ RetryConfig retry_config, TimeoutConfig timeouts,
const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header);
@@ -332,7 +332,7 @@ class GcsFileSystem : public FileSystem {
GcsStatsInterface* stats_ = nullptr; // Not owned.
/// The initial delay for exponential backoffs when retrying failed calls.
- const int64 initial_retry_delay_usec_ = 1000000L;
+ RetryConfig retry_config_;
// Additional header material to be transmitted with all GCS requests
std::unique_ptr<std::pair<const string, const string>> additional_header_;
@@ -344,7 +344,8 @@ class GcsFileSystem : public FileSystem {
class RetryingGcsFileSystem : public RetryingFileSystem<GcsFileSystem> {
public:
RetryingGcsFileSystem()
- : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem)) {}
+ : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem),
+ RetryConfig(100000 /* init_delay_time_us */)) {}
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 14376ad339..702802b185 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -24,6 +24,8 @@ namespace tensorflow {
namespace {
static GcsFileSystem::TimeoutConfig kTestTimeoutConfig(5, 1, 10, 20, 30);
+static RetryConfig kTestRetryConfig(0 /* init_delay_time_us */);
+
// Default (empty) constraint config
static std::unordered_set<string>* kAllowedLocationsDefault =
new std::unordered_set<string>();
@@ -62,16 +64,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
"Range: 6-11\n"
"Timeouts: 5 1 20\n",
"6789")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -108,9 +110,9 @@ TEST(GcsFileSystemTest,
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -150,9 +152,9 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) {
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
@@ -191,9 +193,9 @@ TEST(GcsFileSystemTest,
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(tensorflow::errors::FailedPrecondition(
@@ -216,16 +218,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) {
"Range: 3-12\n"
"Timeouts: 5 1 20\n",
"3456789")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -283,7 +285,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -372,7 +374,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -414,17 +416,17 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
"Range: 8-15\n"
"Timeouts: 5 1 20\n",
"89abcdef")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
- 16 /* max bytes */, 3600 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 8 /* block size */, 16 /* max bytes */,
+ 3600 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
// There should only be two HTTP requests issued to GCS even though we iterate
@@ -492,7 +494,7 @@ TEST(GcsFileSystemTest,
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -513,17 +515,17 @@ TEST(GcsFileSystemTest,
TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
- 0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* read ahead bytes */, 0 /* max bytes */,
+ 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -547,16 +549,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) {
"012")});
// Set stat_cache_max_age to 1000s so that StatCache could work.
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 1e3 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 1e3 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Stat the file first so that the file stats are cached.
FileStatistics stat;
@@ -621,7 +623,7 @@ TEST(GcsFileSystemTest, NewWritableFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
8 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -703,16 +705,16 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) {
"Timeouts: 5 1 30\n"
"Put body: t2\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -773,17 +775,17 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) {
"Range: 0-7\n"
"Timeouts: 5 1 20\n",
"01234567")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
- 8 /* max bytes */, 3600 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 8 /* block size */, 8 /* max bytes */,
+ 3600 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Pull the file's first block into the cache. This will trigger the first
// HTTP request to GCS.
std::unique_ptr<RandomAccessFile> rfile;
@@ -867,9 +869,9 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 2 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */,
+ RetryConfig(2 /* .init_delay_time_us */), kTestTimeoutConfig,
+ *kAllowedLocationsDefault, nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -918,16 +920,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
"Timeouts: 5 1 30\n"
"Put body: content1,content2\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -948,16 +950,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1013,7 +1015,7 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 32 /* block size */,
32 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1041,16 +1043,16 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1075,16 +1077,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
"Range: 0-",
content.size() - 1, "\n", "Timeouts: 5 1 20\n"),
content)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
@@ -1096,16 +1098,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1120,16 +1122,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt"));
}
@@ -1150,16 +1152,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subfolder/\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder"));
}
@@ -1176,16 +1178,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"size\": \"100\"}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket1"));
TF_EXPECT_OK(fs.FileExists("gs://bucket1/"));
@@ -1206,16 +1208,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": []}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::NOT_FOUND,
fs.FileExists("gs://bucket/path/file1.txt").code());
@@ -1233,16 +1235,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.FileExists("gs://bucket2/").code());
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1279,7 +1281,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1306,7 +1308,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1322,16 +1324,16 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1350,16 +1352,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1379,16 +1381,16 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) {
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1407,16 +1409,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1432,16 +1434,16 @@ TEST(GcsFileSystemTest, GetChildren_Root) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children));
@@ -1457,16 +1459,16 @@ TEST(GcsFileSystemTest, GetChildren_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1498,16 +1500,16 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) {
" { \"name\": \"path/file4.txt\" },"
" { \"name\": \"path/file5.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1525,16 +1527,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(
@@ -1553,16 +1555,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result));
@@ -1582,16 +1584,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result));
@@ -1608,16 +1610,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) {
"{\"items\": [ "
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result));
@@ -1634,16 +1636,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result));
@@ -1652,16 +1654,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1686,16 +1688,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Repeated calls to fs.GetMatchingPaths on these patterns should not lead to
// any additional HTTP requests to GCS.
@@ -1729,16 +1731,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// This loop should trigger the first HTTP request to GCS.
for (int i = 0; i < 10; i++) {
@@ -1800,7 +1802,7 @@ TEST(GcsFileSystemTest, DeleteFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1821,16 +1823,16 @@ TEST(GcsFileSystemTest, DeleteFile) {
TEST(GcsFileSystemTest, DeleteFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.DeleteFile("gs://bucket/").code());
@@ -1871,7 +1873,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1894,16 +1896,16 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1923,16 +1925,16 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1943,16 +1945,16 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) {
"name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket"));
}
@@ -1965,16 +1967,16 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/file1.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.DeleteDir("gs://bucket/path/").code());
@@ -1988,16 +1990,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
@@ -2006,16 +2008,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
TEST(GcsFileSystemTest, GetFileSize_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -2092,16 +2094,16 @@ TEST(GcsFileSystemTest, RenameFile_Folder) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/"));
}
@@ -2191,7 +2193,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
64 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// Do an initial read of the source and destination files to load their
@@ -2272,7 +2274,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// Do an initial stat of the destination file to load their contents into the
@@ -2332,16 +2334,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(
fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt"));
@@ -2374,16 +2376,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) {
"Post: yes\n"
"Timeouts: 5 1 10\n",
"{\"done\": false}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(
errors::Code::UNIMPLEMENTED,
@@ -2399,16 +2401,16 @@ TEST(GcsFileSystemTest, Stat_Object) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
@@ -2433,16 +2435,16 @@ TEST(GcsFileSystemTest, Stat_Folder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"subfolder/\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat));
@@ -2466,16 +2468,16 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code());
@@ -2487,16 +2489,16 @@ TEST(GcsFileSystemTest, Stat_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat));
@@ -2511,16 +2513,16 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code());
@@ -2556,7 +2558,7 @@ TEST(GcsFileSystemTest, Stat_Cache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -2598,7 +2600,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// There should be a single HTTP request to GCS for fs.Stat in this loop.
@@ -2628,16 +2630,16 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"5\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat));
@@ -2660,16 +2662,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2691,16 +2693,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2722,16 +2724,16 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": [{\"name\": \"subfolder/\"}]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/"));
@@ -2749,16 +2751,16 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/"));
@@ -2770,16 +2772,16 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code());
}
@@ -2812,16 +2814,16 @@ TEST(GcsFileSystemTest, CreateDir_Folder) {
"Timeouts: 5 1 30\n"
"Put body: \n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/"));
@@ -2839,16 +2841,16 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket"));
@@ -2911,16 +2913,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -3004,16 +3006,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) {
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -3039,16 +3041,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
EXPECT_EQ(error::Code::NOT_FOUND,
@@ -3130,7 +3132,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
add_header /* gcs additional header */);
@@ -3199,16 +3201,16 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
"Auth Token: fake_token\n"
"Header Hello: world\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<HttpRequest> request;
TF_EXPECT_OK(fs.CreateHttpRequest(&request));
@@ -3262,16 +3264,16 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
@@ -3289,16 +3291,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) {
"Range: 0-5\n"
"Timeouts: 5 1 20\n",
"012345")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
diff --git a/tensorflow/core/platform/cloud/google_auth_provider.h b/tensorflow/core/platform/cloud/google_auth_provider.h
index 58a785fd60..3755b124a8 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider.h
+++ b/tensorflow/core/platform/cloud/google_auth_provider.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
-#define TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
#include <memory>
#include "tensorflow/core/platform/cloud/auth_provider.h"
@@ -65,4 +65,4 @@ class GoogleAuthProvider : public AuthProvider {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GOOGLE_AUTH_PROVIDER_H_
diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
index 07b88a880f..ec31c5ee8c 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
@@ -93,8 +93,8 @@ TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
oauth_client->return_token = "fake-token";
@@ -129,8 +129,8 @@ TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -178,8 +178,8 @@ TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -206,8 +206,8 @@ TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&empty_requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -228,8 +228,8 @@ TEST_F(GoogleAuthProviderTest, NothingAvailable) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
diff --git a/tensorflow/core/platform/cloud/http_request.h b/tensorflow/core/platform/cloud/http_request.h
index 2343bca608..e925eefb1f 100644
--- a/tensorflow/core/platform/cloud/http_request.h
+++ b/tensorflow/core/platform/cloud/http_request.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
-#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_
#include <string>
#include <unordered_map>
@@ -188,4 +188,4 @@ class HttpRequest {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_H_
diff --git a/tensorflow/core/platform/cloud/http_request_fake.h b/tensorflow/core/platform/cloud/http_request_fake.h
index 7711eaceb2..0a1164b64a 100644
--- a/tensorflow/core/platform/cloud/http_request_fake.h
+++ b/tensorflow/core/platform/cloud/http_request_fake.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
-#define TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
#include <algorithm>
#include <fstream>
@@ -212,4 +212,4 @@ class FakeHttpRequestFactory : public HttpRequest::Factory {
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_HTTP_REQUEST_FAKE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_HTTP_REQUEST_FAKE_H_
diff --git a/tensorflow/core/platform/cloud/oauth_client.cc b/tensorflow/core/platform/cloud/oauth_client.cc
index ee6ba7b041..9b85cae9b9 100644
--- a/tensorflow/core/platform/cloud/oauth_client.cc
+++ b/tensorflow/core/platform/cloud/oauth_client.cc
@@ -216,7 +216,7 @@ Status OAuthClient::GetTokenFromServiceAccountJson(
// Send the request to the Google OAuth 2.0 server to get the token.
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
@@ -248,7 +248,7 @@ Status OAuthClient::GetTokenFromRefreshTokenJson(
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
- request->SetUri(std::string(oauth_server_uri));
+ request->SetUri(string(oauth_server_uri));
request->SetPostFromBuffer(request_body.c_str(), request_body.size());
request->SetResultBuffer(&response_buffer);
TF_RETURN_IF_ERROR(request->Send());
diff --git a/tensorflow/core/platform/cloud/oauth_client_test.cc b/tensorflow/core/platform/cloud/oauth_client_test.cc
index 4ffa72288b..1cd0641cd3 100644
--- a/tensorflow/core/platform/cloud/oauth_client_test.cc
+++ b/tensorflow/core/platform/cloud/oauth_client_test.cc
@@ -126,9 +126,9 @@ TEST(OAuthClientTest, GetTokenFromServiceAccountJson) {
EXPECT_EQ("urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Ajwt-bearer",
grant_type);
- int last_dot = std::string(assertion).find_last_of(".");
- string header_dot_claim = std::string(assertion.substr(0, last_dot));
- string signature_encoded = std::string(assertion.substr(last_dot + 1));
+ int last_dot = assertion.rfind('.');
+ string header_dot_claim(assertion.substr(0, last_dot));
+ string signature_encoded(assertion.substr(last_dot + 1));
// Check that 'signature' signs 'header_dot_claim'.
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 92aa72be89..5ce6670dc7 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -34,9 +34,9 @@ template <typename Underlying>
class RetryingFileSystem : public FileSystem {
public:
RetryingFileSystem(std::unique_ptr<Underlying> base_file_system,
- int64 delay_microseconds = 1000000)
+ const RetryConfig& retry_config)
: base_file_system_(std::move(base_file_system)),
- initial_delay_microseconds_(delay_microseconds) {}
+ retry_config_(retry_config) {}
Status NewRandomAccessFile(
const string& filename,
@@ -55,7 +55,7 @@ class RetryingFileSystem : public FileSystem {
Status FileExists(const string& fname) override {
return RetryingUtils::CallWithRetries(
[this, &fname]() { return base_file_system_->FileExists(fname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetChildren(const string& dir, std::vector<string>* result) override {
@@ -63,7 +63,7 @@ class RetryingFileSystem : public FileSystem {
[this, &dir, result]() {
return base_file_system_->GetChildren(dir, result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetMatchingPaths(const string& pattern,
@@ -72,31 +72,31 @@ class RetryingFileSystem : public FileSystem {
[this, &pattern, result]() {
return base_file_system_->GetMatchingPaths(pattern, result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status Stat(const string& fname, FileStatistics* stat) override {
return RetryingUtils::CallWithRetries(
[this, &fname, stat]() { return base_file_system_->Stat(fname, stat); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteFile(const string& fname) override {
return RetryingUtils::DeleteWithRetries(
[this, &fname]() { return base_file_system_->DeleteFile(fname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status CreateDir(const string& dirname) override {
return RetryingUtils::CallWithRetries(
[this, &dirname]() { return base_file_system_->CreateDir(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteDir(const string& dirname) override {
return RetryingUtils::DeleteWithRetries(
[this, &dirname]() { return base_file_system_->DeleteDir(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetFileSize(const string& fname, uint64* file_size) override {
@@ -104,7 +104,7 @@ class RetryingFileSystem : public FileSystem {
[this, &fname, file_size]() {
return base_file_system_->GetFileSize(fname, file_size);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status RenameFile(const string& src, const string& target) override {
@@ -112,13 +112,13 @@ class RetryingFileSystem : public FileSystem {
[this, &src, &target]() {
return base_file_system_->RenameFile(src, target);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status IsDirectory(const string& dirname) override {
return RetryingUtils::CallWithRetries(
[this, &dirname]() { return base_file_system_->IsDirectory(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteRecursively(const string& dirname, int64* undeleted_files,
@@ -128,7 +128,7 @@ class RetryingFileSystem : public FileSystem {
return base_file_system_->DeleteRecursively(dirname, undeleted_files,
undeleted_dirs);
},
- initial_delay_microseconds_);
+ retry_config_);
}
void FlushCaches() override { base_file_system_->FlushCaches(); }
@@ -137,7 +137,7 @@ class RetryingFileSystem : public FileSystem {
private:
std::unique_ptr<Underlying> base_file_system_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
TF_DISALLOW_COPY_AND_ASSIGN(RetryingFileSystem);
};
@@ -147,9 +147,8 @@ namespace retrying_internals {
class RetryingRandomAccessFile : public RandomAccessFile {
public:
RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file,
- int64 delay_microseconds)
- : base_file_(std::move(base_file)),
- initial_delay_microseconds_(delay_microseconds) {}
+ const RetryConfig& retry_config)
+ : base_file_(std::move(base_file)), retry_config_(retry_config) {}
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
@@ -157,47 +156,45 @@ class RetryingRandomAccessFile : public RandomAccessFile {
[this, offset, n, result, scratch]() {
return base_file_->Read(offset, n, result, scratch);
},
- initial_delay_microseconds_);
+ retry_config_);
}
private:
std::unique_ptr<RandomAccessFile> base_file_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
};
class RetryingWritableFile : public WritableFile {
public:
RetryingWritableFile(std::unique_ptr<WritableFile> base_file,
- int64 delay_microseconds)
- : base_file_(std::move(base_file)),
- initial_delay_microseconds_(delay_microseconds) {}
+ const RetryConfig& retry_config)
+ : base_file_(std::move(base_file)), retry_config_(retry_config) {}
~RetryingWritableFile() override {
// Makes sure the retrying version of Close() is called in the destructor.
Close().IgnoreError();
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return RetryingUtils::CallWithRetries(
- [this, &data]() { return base_file_->Append(data); },
- initial_delay_microseconds_);
+ [this, &data]() { return base_file_->Append(data); }, retry_config_);
}
Status Close() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Close(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Close(); }, retry_config_);
}
Status Flush() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Flush(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Flush(); }, retry_config_);
}
Status Sync() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Sync(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Sync(); }, retry_config_);
}
private:
std::unique_ptr<WritableFile> base_file_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
};
} // namespace retrying_internals
@@ -210,9 +207,9 @@ Status RetryingFileSystem<Underlying>::NewRandomAccessFile(
[this, &filename, &base_file]() {
return base_file_system_->NewRandomAccessFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingRandomAccessFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -224,9 +221,9 @@ Status RetryingFileSystem<Underlying>::NewWritableFile(
[this, &filename, &base_file]() {
return base_file_system_->NewWritableFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingWritableFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -238,9 +235,9 @@ Status RetryingFileSystem<Underlying>::NewAppendableFile(
[this, &filename, &base_file]() {
return base_file_system_->NewAppendableFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingWritableFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -252,7 +249,7 @@ Status RetryingFileSystem<Underlying>::NewReadOnlyMemoryRegionFromFile(
return base_file_system_->NewReadOnlyMemoryRegionFromFile(filename,
result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index ec2c470db7..868eea096c 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -72,7 +72,7 @@ class MockRandomAccessFile : public RandomAccessFile {
class MockWritableFile : public WritableFile {
public:
explicit MockWritableFile(const ExpectedCalls& calls) : calls_(calls) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
return calls_.ConsumeNextCall("Append");
}
Status Close() override { return calls_.ConsumeNextCall("Close"); }
@@ -184,7 +184,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -211,7 +212,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -235,7 +237,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -265,7 +268,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_NoRetriesForSomeErrors) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -291,7 +295,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_ImmediateSuccess) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -317,7 +322,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -343,7 +349,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry_ViaDestructor) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -368,7 +375,8 @@ TEST(RetryingFileSystemTest, NewAppendableFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped appendable file.
std::unique_ptr<WritableFile> writable_file;
@@ -391,7 +399,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_AllRetriesFailed) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -412,7 +421,8 @@ TEST(RetryingFileSystemTest,
std::make_tuple("NewReadOnlyMemoryRegionFromFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::unique_ptr<ReadOnlyMemoryRegion> result;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile("filename.txt", &result));
@@ -423,7 +433,8 @@ TEST(RetryingFileSystemTest, NewReadOnlyMemoryRegionFromFile_AllRetriesFailed) {
CreateRetriableErrors("NewReadOnlyMemoryRegionFromFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::unique_ptr<ReadOnlyMemoryRegion> result;
const auto& status =
@@ -440,7 +451,8 @@ TEST(RetryingFileSystemTest, GetChildren_SuccessWith2ndTry) {
std::make_tuple("GetChildren", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.GetChildren("gs://path", &result));
@@ -450,7 +462,8 @@ TEST(RetryingFileSystemTest, GetChildren_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetChildren", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.GetChildren("gs://path", &result);
@@ -466,7 +479,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_SuccessWith2ndTry) {
std::make_tuple("GetMatchingPaths", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://path/dir", &result));
@@ -477,7 +491,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_AllRetriesFailed) {
CreateRetriableErrors("GetMatchingPaths", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.GetMatchingPaths("gs://path/dir", &result);
@@ -492,7 +507,8 @@ TEST(RetryingFileSystemTest, DeleteFile_SuccessWith2ndTry) {
std::make_tuple("DeleteFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.DeleteFile("gs://path/file.txt"));
@@ -502,7 +518,8 @@ TEST(RetryingFileSystemTest, DeleteFile_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.DeleteFile("gs://path/file.txt");
@@ -517,7 +534,8 @@ TEST(RetryingFileSystemTest, CreateDir_SuccessWith2ndTry) {
std::make_tuple("CreateDir", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.CreateDir("gs://path/newdir"));
@@ -527,7 +545,8 @@ TEST(RetryingFileSystemTest, CreateDir_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("CreateDir", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.CreateDir("gs://path/newdir");
@@ -542,7 +561,8 @@ TEST(RetryingFileSystemTest, DeleteDir_SuccessWith2ndTry) {
std::make_tuple("DeleteDir", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.DeleteDir("gs://path/dir"));
@@ -552,7 +572,8 @@ TEST(RetryingFileSystemTest, DeleteDir_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteDir", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.DeleteDir("gs://path/dir");
@@ -568,7 +589,8 @@ TEST(RetryingFileSystemTest, GetFileSize_SuccessWith2ndTry) {
std::make_tuple("GetFileSize", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://path/file.txt", &size));
@@ -578,7 +600,8 @@ TEST(RetryingFileSystemTest, GetFileSize_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetFileSize", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
uint64 size;
const auto& status = fs.GetFileSize("gs://path/file.txt", &size);
@@ -593,7 +616,8 @@ TEST(RetryingFileSystemTest, RenameFile_SuccessWith2ndTry) {
std::make_tuple("RenameFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.RenameFile("old_name", "new_name"));
}
@@ -602,7 +626,8 @@ TEST(RetryingFileSystemTest, RenameFile_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("RenameFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.RenameFile("old_name", "new_name");
EXPECT_TRUE(
@@ -616,7 +641,8 @@ TEST(RetryingFileSystemTest, Stat_SuccessWith2ndTry) {
std::make_tuple("Stat", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("file_name", &stat));
@@ -626,7 +652,8 @@ TEST(RetryingFileSystemTest, Stat_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("Stat", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
FileStatistics stat;
const auto& status = fs.Stat("file_name", &stat);
@@ -639,7 +666,8 @@ TEST(RetryingFileSystemTest, FileExists_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("FileExists", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.FileExists("file_name");
EXPECT_TRUE(
@@ -653,7 +681,8 @@ TEST(RetryingFileSystemTest, FileExists_SuccessWith2ndTry) {
std::make_tuple("FileExists", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.FileExists("gs://path/dir"));
}
@@ -665,7 +694,8 @@ TEST(RetryingFileSystemTest, IsDirectory_SuccessWith2ndTry) {
std::make_tuple("IsDirectory", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.IsDirectory("gs://path/dir"));
}
@@ -674,7 +704,8 @@ TEST(RetryingFileSystemTest, IsDirectory_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("IsDirectory", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.IsDirectory("gs://path/dir");
EXPECT_TRUE(
@@ -689,7 +720,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_SuccessWith2ndTry) {
std::make_tuple("DeleteRecursively", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(
@@ -701,7 +733,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) {
CreateRetriableErrors("DeleteRecursively", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
int64 undeleted_files, undeleted_dirs;
const auto& status =
@@ -715,7 +748,8 @@ TEST(RetryingFileSystemTest, FlushCaches) {
ExpectedCalls none;
bool flushed = false;
std::unique_ptr<MockFileSystem> base_fs(new MockFileSystem(none, &flushed));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
fs.FlushCaches();
EXPECT_TRUE(flushed);
}
diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc
index d2df422024..cb0aecdd35 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils.cc
@@ -23,11 +23,6 @@ namespace tensorflow {
namespace {
-// In case of failure, every call will be retried kMaxRetries times.
-constexpr int kMaxRetries = 10;
-// Maximum backoff time in microseconds.
-constexpr int64 kMaximumBackoffMicroseconds = 32000000; // 32 seconds.
-
bool IsRetriable(error::Code code) {
switch (code) {
case error::UNAVAILABLE:
@@ -43,40 +38,41 @@ bool IsRetriable(error::Code code) {
} // namespace
Status RetryingUtils::CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds) {
- return CallWithRetries(f, initial_delay_microseconds, [](int64 micros) {
- return Env::Default()->SleepForMicroseconds(micros);
- });
+ const RetryConfig& config) {
+ return CallWithRetries(
+ f,
+ [](int64 micros) { return Env::Default()->SleepForMicroseconds(micros); },
+ config);
}
Status RetryingUtils::CallWithRetries(
- const std::function<Status()>& f, const int64 initial_delay_microseconds,
- const std::function<void(int64)>& sleep_usec) {
+ const std::function<Status()>& f,
+ const std::function<void(int64)>& sleep_usec, const RetryConfig& config) {
int retries = 0;
while (true) {
auto status = f();
if (!IsRetriable(status.code())) {
return status;
}
- if (retries >= kMaxRetries) {
+ if (retries >= config.max_retries) {
// Return AbortedError, so that it doesn't get retried again somewhere
// at a higher level.
return Status(
error::ABORTED,
strings::StrCat(
- "All ", kMaxRetries,
+ "All ", config.max_retries,
" retry attempts failed. The last failure: ", status.ToString()));
}
int64 delay_micros = 0;
- if (initial_delay_microseconds > 0) {
+ if (config.init_delay_time_us > 0) {
const int64 random_micros = random::New64() % 1000000;
- delay_micros = std::min(initial_delay_microseconds << retries,
- kMaximumBackoffMicroseconds) +
+ delay_micros = std::min(config.init_delay_time_us << retries,
+ config.max_delay_time_us) +
random_micros;
}
LOG(INFO) << "The operation failed and will be automatically retried in "
<< (delay_micros / 1000000.0) << " seconds (attempt "
- << (retries + 1) << " out of " << kMaxRetries
+ << (retries + 1) << " out of " << config.max_retries
<< "), caused by: " << status.ToString();
sleep_usec(delay_micros);
retries++;
@@ -84,8 +80,7 @@ Status RetryingUtils::CallWithRetries(
}
Status RetryingUtils::DeleteWithRetries(
- const std::function<Status()>& delete_func,
- const int64 initial_delay_microseconds) {
+ const std::function<Status()>& delete_func, const RetryConfig& config) {
bool is_retried = false;
return RetryingUtils::CallWithRetries(
[delete_func, &is_retried]() {
@@ -96,7 +91,7 @@ Status RetryingUtils::DeleteWithRetries(
is_retried = true;
return status;
},
- initial_delay_microseconds);
+ config);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h
index 546b8d1c4a..1a7ce1b122 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.h
+++ b/tensorflow/core/platform/cloud/retrying_utils.h
@@ -21,6 +21,26 @@ limitations under the License.
namespace tensorflow {
+// Default time before reporting failure: ~100 seconds.
+struct RetryConfig {
+ RetryConfig(int64 init_delay_time_us = 100 * 1000,
+ int64 max_delay_time_us = 32 * 1000 * 1000,
+ int max_retries = 10) {
+ this->init_delay_time_us = init_delay_time_us;
+ this->max_delay_time_us = max_delay_time_us;
+ this->max_retries = max_retries;
+ }
+
+ // In case of failure, every call will be retried max_retries times.
+ int max_retries;
+
+ // Initial backoff time
+ int64 init_delay_time_us;
+
+ // Maximum backoff time in microseconds.
+ int64 max_delay_time_us;
+};
+
class RetryingUtils {
public:
/// \brief Retries the function in case of failure with exponential backoff.
@@ -31,18 +51,19 @@ class RetryingUtils {
/// retries.
/// If all retries failed, returns the last error status.
static Status CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds);
+ const RetryConfig& config);
+
/// sleep_usec is a function that sleeps for the given number of microseconds.
static Status CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds,
- const std::function<void(int64)>& sleep_usec);
+ const std::function<void(int64)>& sleep_usec,
+ const RetryConfig& config);
/// \brief A retrying wrapper for a function that deletes a resource.
///
/// The function takes care of the scenario when a delete operation
/// returns a failure but succeeds under the hood: if a retry returns
/// NOT_FOUND, the whole operation is considered a success.
static Status DeleteWithRetries(const std::function<Status()>& delete_func,
- const int64 initial_delay_microseconds);
+ const RetryConfig& config);
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc
index 1b6527618a..75fe8a98f4 100644
--- a/tensorflow/core/platform/cloud/retrying_utils_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc
@@ -30,7 +30,8 @@ TEST(RetryingUtilsTest, CallWithRetries_RetryDelays) {
};
std::function<Status()> f = []() { return errors::Unavailable("Failed."); };
- const auto& status = RetryingUtils::CallWithRetries(f, 500000L, sleep);
+ const auto& status = RetryingUtils::CallWithRetries(
+ f, sleep, RetryConfig(500000 /* init_delay_time_us */));
EXPECT_EQ(errors::Code::ABORTED, status.code());
EXPECT_TRUE(str_util::StrContains(
status.error_message(),
@@ -60,8 +61,10 @@ TEST(RetryingUtilsTest, CallWithRetries_NotFoundIsNotRetried) {
results.erase(results.begin());
return result;
};
- EXPECT_EQ(errors::Code::NOT_FOUND,
- RetryingUtils::CallWithRetries(f, 0).code());
+ EXPECT_EQ(
+ errors::Code::NOT_FOUND,
+ RetryingUtils::CallWithRetries(f, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) {
@@ -74,7 +77,8 @@ TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) {
results.erase(results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 1.0, sleep));
+ TF_EXPECT_OK(RetryingUtils::CallWithRetries(
+ f, sleep, RetryConfig(1L /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) {
@@ -86,7 +90,8 @@ TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) {
results.erase(results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 0));
+ TF_EXPECT_OK(RetryingUtils::CallWithRetries(
+ f, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) {
@@ -96,7 +101,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) {
@@ -106,7 +112,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) {
@@ -118,7 +125,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) {
return result;
};
EXPECT_EQ(errors::Code::PERMISSION_DENIED,
- RetryingUtils::DeleteWithRetries(delete_func, 0).code());
+ RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) {
@@ -129,7 +138,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) {
@@ -140,7 +150,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) {
return result;
};
EXPECT_EQ(error::NOT_FOUND,
- RetryingUtils::DeleteWithRetries(delete_func, 0).code());
+ RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
} // namespace
diff --git a/tensorflow/core/platform/context.h b/tensorflow/core/platform/context.h
index 728ef91631..9f7beb7a68 100644
--- a/tensorflow/core/platform/context.h
+++ b/tensorflow/core/platform/context.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_CONTEXT_H_
-#define TENSORFLOW_PLATFORM_CONTEXT_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CONTEXT_H_
+#define TENSORFLOW_CORE_PLATFORM_CONTEXT_H_
namespace tensorflow {
@@ -42,4 +42,4 @@ class WithContext;
#include "tensorflow/core/platform/default/context.h"
#endif
-#endif // TENSORFLOW_PLATFORM_CONTEXT_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CONTEXT_H_
diff --git a/tensorflow/core/kernels/warn_about_ints.h b/tensorflow/core/platform/cord.h
index 20666b230e..7c5c6655be 100644
--- a/tensorflow/core/kernels/warn_about_ints.h
+++ b/tensorflow/core/platform/cord.h
@@ -13,17 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
-#define TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_CORD_H_
-#include "tensorflow/core/framework/op_kernel.h"
+// Include appropriate platform-dependent implementations
+#if defined(PLATFORM_GOOGLE)
+#include "tensorflow/core/platform/google/cord.h"
+#else
+#include "tensorflow/core/platform/default/cord.h"
+#endif
-namespace tensorflow {
-
-// Warn if a kernel is being created using ints
-// TODO(irving): Remove in TF 2.0 along with the bad op registrations.
-void WarnAboutInts(OpKernelConstruction* context);
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_KERNELS_WARN_ABOUT_INTS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CORD_H_
diff --git a/tensorflow/core/platform/cpu_feature_guard.h b/tensorflow/core/platform/cpu_feature_guard.h
index 586a6be55e..3d7bfe95b1 100644
--- a/tensorflow/core/platform/cpu_feature_guard.h
+++ b/tensorflow/core/platform/cpu_feature_guard.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_CPU_FEATURE_GUARD_H_
-#define TENSORFLOW_PLATFORM_CPU_FEATURE_GUARD_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_
+#define TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_
namespace tensorflow {
namespace port {
@@ -29,4 +29,4 @@ void InfoAboutUnusedCPUFeatures();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_CPU_FEATURE_GUARD_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CPU_FEATURE_GUARD_H_
diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h
index 175c9ae8b1..6eba83224a 100644
--- a/tensorflow/core/platform/cpu_info.h
+++ b/tensorflow/core/platform/cpu_info.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_CPU_INFO_H_
-#define TENSORFLOW_PLATFORM_CPU_INFO_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_
+#define TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_
#include <string>
@@ -117,4 +117,4 @@ int CPUIDNumSMT();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_CPU_INFO_H_
+#endif // TENSORFLOW_CORE_PLATFORM_CPU_INFO_H_
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 28891320c4..d884c1aa7c 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -8,224 +8,229 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load(
"//third_party/mkl:build_defs.bzl",
- "if_mkl",
+ "if_mkl_ml",
)
# Appends a suffix to a list of deps.
def tf_deps(deps, suffix):
- tf_deps = []
+ tf_deps = []
- # If the package name is in shorthand form (ie: does not contain a ':'),
- # expand it to the full name.
- for dep in deps:
- tf_dep = dep
+ # If the package name is in shorthand form (ie: does not contain a ':'),
+ # expand it to the full name.
+ for dep in deps:
+ tf_dep = dep
- if not ":" in dep:
- dep_pieces = dep.split("/")
- tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
+ if not ":" in dep:
+ dep_pieces = dep.split("/")
+ tf_dep += ":" + dep_pieces[len(dep_pieces) - 1]
- tf_deps += [tf_dep + suffix]
+ tf_deps += [tf_dep + suffix]
- return tf_deps
+ return tf_deps
# Modified from @cython//:Tools/rules.bzl
def pyx_library(
- name,
- deps=[],
- py_deps=[],
- srcs=[],
- **kwargs):
- """Compiles a group of .pyx / .pxd / .py files.
-
- First runs Cython to create .cpp files for each input .pyx or .py + .pxd
- pair. Then builds a shared object for each, passing "deps" to each cc_binary
- rule (includes Python headers by default). Finally, creates a py_library rule
- with the shared objects and any pure Python "srcs", with py_deps as its
- dependencies; the shared objects can be imported like normal Python files.
-
- Args:
- name: Name for the rule.
- deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
- py_deps: Pure Python dependencies of the final library.
- srcs: .py, .pyx, or .pxd files to either compile or pass through.
- **kwargs: Extra keyword arguments passed to the py_library.
- """
- # First filter out files that should be run compiled vs. passed through.
- py_srcs = []
- pyx_srcs = []
- pxd_srcs = []
- for src in srcs:
- if src.endswith(".pyx") or (src.endswith(".py")
- and src[:-3] + ".pxd" in srcs):
- pyx_srcs.append(src)
- elif src.endswith(".py"):
- py_srcs.append(src)
- else:
- pxd_srcs.append(src)
- if src.endswith("__init__.py"):
- pxd_srcs.append(src)
-
- # Invoke cython to produce the shared object libraries.
- for filename in pyx_srcs:
- native.genrule(
- name = filename + "_cython_translation",
- srcs = [filename],
- outs = [filename.split(".")[0] + ".cpp"],
- # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
- # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
- cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
- tools = ["@cython//:cython_binary"] + pxd_srcs,
+ name,
+ deps = [],
+ py_deps = [],
+ srcs = [],
+ **kwargs):
+ """Compiles a group of .pyx / .pxd / .py files.
+
+ First runs Cython to create .cpp files for each input .pyx or .py + .pxd
+ pair. Then builds a shared object for each, passing "deps" to each cc_binary
+ rule (includes Python headers by default). Finally, creates a py_library rule
+ with the shared objects and any pure Python "srcs", with py_deps as its
+ dependencies; the shared objects can be imported like normal Python files.
+
+ Args:
+ name: Name for the rule.
+ deps: C/C++ dependencies of the Cython (e.g. Numpy headers).
+ py_deps: Pure Python dependencies of the final library.
+ srcs: .py, .pyx, or .pxd files to either compile or pass through.
+ **kwargs: Extra keyword arguments passed to the py_library.
+ """
+
+ # First filter out files that should be run compiled vs. passed through.
+ py_srcs = []
+ pyx_srcs = []
+ pxd_srcs = []
+ for src in srcs:
+ if src.endswith(".pyx") or (src.endswith(".py") and
+ src[:-3] + ".pxd" in srcs):
+ pyx_srcs.append(src)
+ elif src.endswith(".py"):
+ py_srcs.append(src)
+ else:
+ pxd_srcs.append(src)
+ if src.endswith("__init__.py"):
+ pxd_srcs.append(src)
+
+ # Invoke cython to produce the shared object libraries.
+ for filename in pyx_srcs:
+ native.genrule(
+ name = filename + "_cython_translation",
+ srcs = [filename],
+ outs = [filename.split(".")[0] + ".cpp"],
+ # Optionally use PYTHON_BIN_PATH on Linux platforms so that python 3
+ # works. Windows has issues with cython_binary so skip PYTHON_BIN_PATH.
+ cmd = "PYTHONHASHSEED=0 $(location @cython//:cython_binary) --cplus $(SRCS) --output-file $(OUTS)",
+ tools = ["@cython//:cython_binary"] + pxd_srcs,
+ )
+
+ shared_objects = []
+ for src in pyx_srcs:
+ stem = src.split(".")[0]
+ shared_object_name = stem + ".so"
+ native.cc_binary(
+ name = shared_object_name,
+ srcs = [stem + ".cpp"],
+ deps = deps + ["//third_party/python_runtime:headers"],
+ linkshared = 1,
+ )
+ shared_objects.append(shared_object_name)
+
+ # Now create a py_library with these shared objects as data.
+ native.py_library(
+ name = name,
+ srcs = py_srcs,
+ deps = py_deps,
+ srcs_version = "PY2AND3",
+ data = shared_objects,
+ **kwargs
)
- shared_objects = []
- for src in pyx_srcs:
- stem = src.split(".")[0]
- shared_object_name = stem + ".so"
- native.cc_binary(
- name=shared_object_name,
- srcs=[stem + ".cpp"],
- deps=deps + ["//third_party/python_runtime:headers"],
- linkshared = 1,
- )
- shared_objects.append(shared_object_name)
-
- # Now create a py_library with these shared objects as data.
- native.py_library(
- name=name,
- srcs=py_srcs,
- deps=py_deps,
- srcs_version = "PY2AND3",
- data=shared_objects,
- **kwargs
- )
-
-def _proto_cc_hdrs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
- return ret
-
-def _proto_cc_srcs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
- return ret
-
-def _proto_py_outs(srcs, use_grpc_plugin=False):
- ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
- if use_grpc_plugin:
- ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
- return ret
+def _proto_cc_hdrs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + ".pb.h" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + ".grpc.pb.h" for s in srcs]
+ return ret
+
+def _proto_cc_srcs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + ".pb.cc" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + ".grpc.pb.cc" for s in srcs]
+ return ret
+
+def _proto_py_outs(srcs, use_grpc_plugin = False):
+ ret = [s[:-len(".proto")] + "_pb2.py" for s in srcs]
+ if use_grpc_plugin:
+ ret += [s[:-len(".proto")] + "_pb2_grpc.py" for s in srcs]
+ return ret
# Re-defined protocol buffer rule to allow building "header only" protocol
# buffers, to avoid duplicate registrations. Also allows non-iterable cc_libs
# containing select() statements.
def cc_proto_library(
- name,
- srcs=[],
- deps=[],
- cc_libs=[],
- include=None,
- protoc="@protobuf_archive//:protoc",
- internal_bootstrap_hack=False,
- use_grpc_plugin=False,
- use_grpc_namespace=False,
- default_header=False,
- **kargs):
- """Bazel rule to create a C++ protobuf library from proto source files.
-
- Args:
- name: the name of the cc_proto_library.
- srcs: the .proto files of the cc_proto_library.
- deps: a list of dependency labels; must be cc_proto_library.
- cc_libs: a list of other cc_library targets depended by the generated
- cc_library.
- include: a string indicating the include path of the .proto files.
- protoc: the label of the protocol compiler to generate the sources.
- internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
- for bootstraping. When it is set to True, no files will be generated.
- The rule will simply be a provider for .proto files, so that other
- cc_proto_library can depend on it.
- use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
- when processing the proto files.
- default_header: Controls the naming of generated rules. If True, the `name`
- rule will be header-only, and an _impl rule will contain the
- implementation. Otherwise the header-only rule (name + "_headers_only")
- must be referred to explicitly.
- **kargs: other keyword arguments that are passed to cc_library.
- """
-
- includes = []
- if include != None:
- includes = [include]
-
- if internal_bootstrap_hack:
- # For pre-checked-in generated files, we add the internal_bootstrap_hack
- # which will skip the codegen action.
+ name,
+ srcs = [],
+ deps = [],
+ cc_libs = [],
+ include = None,
+ protoc = "@protobuf_archive//:protoc",
+ internal_bootstrap_hack = False,
+ use_grpc_plugin = False,
+ use_grpc_namespace = False,
+ default_header = False,
+ **kargs):
+ """Bazel rule to create a C++ protobuf library from proto source files.
+
+ Args:
+ name: the name of the cc_proto_library.
+ srcs: the .proto files of the cc_proto_library.
+ deps: a list of dependency labels; must be cc_proto_library.
+ cc_libs: a list of other cc_library targets depended by the generated
+ cc_library.
+ include: a string indicating the include path of the .proto files.
+ protoc: the label of the protocol compiler to generate the sources.
+ internal_bootstrap_hack: a flag indicate the cc_proto_library is used only
+ for bootstraping. When it is set to True, no files will be generated.
+ The rule will simply be a provider for .proto files, so that other
+ cc_proto_library can depend on it.
+ use_grpc_plugin: a flag to indicate whether to call the grpc C++ plugin
+ when processing the proto files.
+ default_header: Controls the naming of generated rules. If True, the `name`
+ rule will be header-only, and an _impl rule will contain the
+ implementation. Otherwise the header-only rule (name + "_headers_only")
+ must be referred to explicitly.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ if internal_bootstrap_hack:
+ # For pre-checked-in generated files, we add the internal_bootstrap_hack
+ # which will skip the codegen action.
+ proto_gen(
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ visibility = ["//visibility:public"],
+ )
+
+ # An empty cc_library to make rule dependency consistent.
+ native.cc_library(
+ name = name,
+ **kargs
+ )
+ return
+
+ grpc_cpp_plugin = None
+ plugin_options = []
+ if use_grpc_plugin:
+ grpc_cpp_plugin = "//external:grpc_cpp_plugin"
+ if use_grpc_namespace:
+ plugin_options = ["services_namespace=grpc"]
+
+ gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
+ gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
+ outs = gen_srcs + gen_hdrs
+
proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- visibility=["//visibility:public"],
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ plugin = grpc_cpp_plugin,
+ plugin_language = "grpc",
+ plugin_options = plugin_options,
+ gen_cc = 1,
+ outs = outs,
+ visibility = ["//visibility:public"],
)
- # An empty cc_library to make rule dependency consistent.
- native.cc_library(
- name=name,
- **kargs)
- return
-
- grpc_cpp_plugin = None
- plugin_options = []
- if use_grpc_plugin:
- grpc_cpp_plugin = "//external:grpc_cpp_plugin"
- if use_grpc_namespace:
- plugin_options = ["services_namespace=grpc"]
-
- gen_srcs = _proto_cc_srcs(srcs, use_grpc_plugin)
- gen_hdrs = _proto_cc_hdrs(srcs, use_grpc_plugin)
- outs = gen_srcs + gen_hdrs
-
- proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- plugin=grpc_cpp_plugin,
- plugin_language="grpc",
- plugin_options=plugin_options,
- gen_cc=1,
- outs=outs,
- visibility=["//visibility:public"],
- )
-
- if use_grpc_plugin:
- cc_libs += select({
- "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
- "//conditions:default": ["//external:grpc_lib"],
- })
- if default_header:
- header_only_name = name
- impl_name = name + "_impl"
- else:
- header_only_name = name + "_headers_only"
- impl_name = name
-
- native.cc_library(
- name=impl_name,
- srcs=gen_srcs,
- hdrs=gen_hdrs,
- deps=cc_libs + deps,
- includes=includes,
- **kargs)
- native.cc_library(
- name=header_only_name,
- deps=["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
- hdrs=gen_hdrs,
- **kargs)
+ if use_grpc_plugin:
+ cc_libs += select({
+ "//tensorflow:linux_s390x": ["//external:grpc_lib_unsecure"],
+ "//conditions:default": ["//external:grpc_lib"],
+ })
+
+ if default_header:
+ header_only_name = name
+ impl_name = name + "_impl"
+ else:
+ header_only_name = name + "_headers_only"
+ impl_name = name
+
+ native.cc_library(
+ name = impl_name,
+ srcs = gen_srcs,
+ hdrs = gen_hdrs,
+ deps = cc_libs + deps,
+ includes = includes,
+ **kargs
+ )
+ native.cc_library(
+ name = header_only_name,
+ deps = ["@protobuf_archive//:protobuf_headers"] + if_static([impl_name]),
+ hdrs = gen_hdrs,
+ **kargs
+ )
# Re-defined protocol buffer rule to bring in the change introduced in commit
# https://github.com/google/protobuf/commit/294b5758c373cbab4b72f35f4cb62dc1d8332b68
@@ -234,477 +239,490 @@ def cc_proto_library(
# to include the above commit.
def py_proto_library(
name,
- srcs=[],
- deps=[],
- py_libs=[],
- py_extra_srcs=[],
- include=None,
- default_runtime="@protobuf_archive//:protobuf_python",
- protoc="@protobuf_archive//:protoc",
- use_grpc_plugin=False,
+ srcs = [],
+ deps = [],
+ py_libs = [],
+ py_extra_srcs = [],
+ include = None,
+ default_runtime = "@protobuf_archive//:protobuf_python",
+ protoc = "@protobuf_archive//:protoc",
+ use_grpc_plugin = False,
**kargs):
- """Bazel rule to create a Python protobuf library from proto source files
-
- NOTE: the rule is only an internal workaround to generate protos. The
- interface may change and the rule may be removed when bazel has introduced
- the native rule.
-
- Args:
- name: the name of the py_proto_library.
- srcs: the .proto files of the py_proto_library.
- deps: a list of dependency labels; must be py_proto_library.
- py_libs: a list of other py_library targets depended by the generated
- py_library.
- py_extra_srcs: extra source files that will be added to the output
- py_library. This attribute is used for internal bootstrapping.
- include: a string indicating the include path of the .proto files.
- default_runtime: the implicitly default runtime which will be depended on by
- the generated py_library target.
- protoc: the label of the protocol compiler to generate the sources.
- use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
- when processing the proto files.
- **kargs: other keyword arguments that are passed to cc_library.
- """
- outs = _proto_py_outs(srcs, use_grpc_plugin)
-
- includes = []
- if include != None:
- includes = [include]
-
- grpc_python_plugin = None
- if use_grpc_plugin:
- grpc_python_plugin = "//external:grpc_python_plugin"
- # Note: Generated grpc code depends on Python grpc module. This dependency
- # is not explicitly listed in py_libs. Instead, host system is assumed to
- # have grpc installed.
-
- proto_gen(
- name=name + "_genproto",
- srcs=srcs,
- deps=[s + "_genproto" for s in deps],
- includes=includes,
- protoc=protoc,
- gen_py=1,
- outs=outs,
- visibility=["//visibility:public"],
- plugin=grpc_python_plugin,
- plugin_language="grpc"
- )
-
- if default_runtime and not default_runtime in py_libs + deps:
- py_libs = py_libs + [default_runtime]
-
- native.py_library(
- name=name,
- srcs=outs+py_extra_srcs,
- deps=py_libs+deps,
- imports=includes,
- **kargs)
-
-def tf_proto_library_cc(name, srcs = [], has_services = None,
- protodeps = [],
- visibility = [], testonly = 0,
- cc_libs = [],
- cc_stubby_versions = None,
- cc_grpc_version = None,
- j2objc_api_version = 1,
- cc_api_version = 2,
- dart_api_version = 2,
- java_api_version = 2, py_api_version = 2,
- js_api_version = 2, js_codegen = "jspb",
- default_header = False):
- js_codegen = js_codegen # unused argument
- js_api_version = js_api_version # unused argument
- native.filegroup(
- name = name + "_proto_srcs",
- srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
- testonly = testonly,
- visibility = visibility,
- )
-
- use_grpc_plugin = None
- if cc_grpc_version:
- use_grpc_plugin = True
-
- cc_deps = tf_deps(protodeps, "_cc")
- cc_name = name + "_cc"
- if not srcs:
- # This is a collection of sub-libraries. Build header-only and impl
- # libraries containing all the sources.
+ """Bazel rule to create a Python protobuf library from proto source files
+
+ NOTE: the rule is only an internal workaround to generate protos. The
+ interface may change and the rule may be removed when bazel has introduced
+ the native rule.
+
+ Args:
+ name: the name of the py_proto_library.
+ srcs: the .proto files of the py_proto_library.
+ deps: a list of dependency labels; must be py_proto_library.
+ py_libs: a list of other py_library targets depended by the generated
+ py_library.
+ py_extra_srcs: extra source files that will be added to the output
+ py_library. This attribute is used for internal bootstrapping.
+ include: a string indicating the include path of the .proto files.
+ default_runtime: the implicitly default runtime which will be depended on by
+ the generated py_library target.
+ protoc: the label of the protocol compiler to generate the sources.
+ use_grpc_plugin: a flag to indicate whether to call the Python C++ plugin
+ when processing the proto files.
+ **kargs: other keyword arguments that are passed to cc_library.
+ """
+ outs = _proto_py_outs(srcs, use_grpc_plugin)
+
+ includes = []
+ if include != None:
+ includes = [include]
+
+ grpc_python_plugin = None
+ if use_grpc_plugin:
+ grpc_python_plugin = "//external:grpc_python_plugin"
+ # Note: Generated grpc code depends on Python grpc module. This dependency
+ # is not explicitly listed in py_libs. Instead, host system is assumed to
+ # have grpc installed.
+
proto_gen(
- name = cc_name + "_genproto",
- deps = [s + "_genproto" for s in cc_deps],
- protoc = "@protobuf_archive//:protoc",
- visibility=["//visibility:public"],
+ name = name + "_genproto",
+ srcs = srcs,
+ deps = [s + "_genproto" for s in deps],
+ includes = includes,
+ protoc = protoc,
+ gen_py = 1,
+ outs = outs,
+ visibility = ["//visibility:public"],
+ plugin = grpc_python_plugin,
+ plugin_language = "grpc",
)
- native.cc_library(
- name = cc_name,
- deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
- if_static([name + "_cc_impl"]),
+
+ if default_runtime and not default_runtime in py_libs + deps:
+ py_libs = py_libs + [default_runtime]
+
+ native.py_library(
+ name = name,
+ srcs = outs + py_extra_srcs,
+ deps = py_libs + deps,
+ imports = includes,
+ **kargs
+ )
+
+def tf_proto_library_cc(
+ name,
+ srcs = [],
+ has_services = None,
+ protodeps = [],
+ visibility = [],
+ testonly = 0,
+ cc_libs = [],
+ cc_stubby_versions = None,
+ cc_grpc_version = None,
+ j2objc_api_version = 1,
+ cc_api_version = 2,
+ dart_api_version = 2,
+ java_api_version = 2,
+ py_api_version = 2,
+ js_api_version = 2,
+ js_codegen = "jspb",
+ default_header = False):
+ js_codegen = js_codegen # unused argument
+ js_api_version = js_api_version # unused argument
+ native.filegroup(
+ name = name + "_proto_srcs",
+ srcs = srcs + tf_deps(protodeps, "_proto_srcs"),
testonly = testonly,
visibility = visibility,
)
- native.cc_library(
- name = cc_name + "_impl",
- deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
- )
- return
-
- cc_proto_library(
- name = cc_name,
- srcs = srcs,
- deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
- cc_libs = cc_libs + if_static(
- ["@protobuf_archive//:protobuf"],
- ["@protobuf_archive//:protobuf_headers"]
- ),
- copts = if_not_windows([
- "-Wno-unknown-warning-option",
- "-Wno-unused-but-set-variable",
- "-Wno-sign-compare",
- ]),
- protoc = "@protobuf_archive//:protoc",
- use_grpc_plugin = use_grpc_plugin,
- testonly = testonly,
- visibility = visibility,
- default_header = default_header,
- )
-
-def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
- testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
- py_deps = tf_deps(protodeps, "_py")
- py_name = name + "_py"
- if not srcs:
- # This is a collection of sub-libraries. Build header-only and impl
- # libraries containing all the sources.
- proto_gen(
- name = py_name + "_genproto",
- deps = [s + "_genproto" for s in py_deps],
+ use_grpc_plugin = None
+ if cc_grpc_version:
+ use_grpc_plugin = True
+
+ cc_deps = tf_deps(protodeps, "_cc")
+ cc_name = name + "_cc"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = cc_name + "_genproto",
+ deps = [s + "_genproto" for s in cc_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility = ["//visibility:public"],
+ )
+ native.cc_library(
+ name = cc_name,
+ deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
+ if_static([name + "_cc_impl"]),
+ testonly = testonly,
+ visibility = visibility,
+ )
+ native.cc_library(
+ name = cc_name + "_impl",
+ deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
+ )
+
+ return
+
+ cc_proto_library(
+ name = cc_name,
+ srcs = srcs,
+ deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
+ cc_libs = cc_libs + if_static(
+ ["@protobuf_archive//:protobuf"],
+ ["@protobuf_archive//:protobuf_headers"],
+ ),
+ copts = if_not_windows([
+ "-Wno-unknown-warning-option",
+ "-Wno-unused-but-set-variable",
+ "-Wno-sign-compare",
+ ]),
protoc = "@protobuf_archive//:protoc",
- visibility=["//visibility:public"],
+ use_grpc_plugin = use_grpc_plugin,
+ testonly = testonly,
+ visibility = visibility,
+ default_header = default_header,
)
- native.py_library(
+
+def tf_proto_library_py(
+ name,
+ srcs = [],
+ protodeps = [],
+ deps = [],
+ visibility = [],
+ testonly = 0,
+ srcs_version = "PY2AND3",
+ use_grpc_plugin = False):
+ py_deps = tf_deps(protodeps, "_py")
+ py_name = name + "_py"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = py_name + "_genproto",
+ deps = [s + "_genproto" for s in py_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility = ["//visibility:public"],
+ )
+ native.py_library(
+ name = py_name,
+ deps = py_deps + ["@protobuf_archive//:protobuf_python"],
+ testonly = testonly,
+ visibility = visibility,
+ )
+ return
+
+ py_proto_library(
name = py_name,
- deps = py_deps + ["@protobuf_archive//:protobuf_python"],
- testonly = testonly,
+ srcs = srcs,
+ srcs_version = srcs_version,
+ deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
+ protoc = "@protobuf_archive//:protoc",
+ default_runtime = "@protobuf_archive//:protobuf_python",
visibility = visibility,
+ testonly = testonly,
+ use_grpc_plugin = use_grpc_plugin,
)
- return
-
- py_proto_library(
- name = py_name,
- srcs = srcs,
- srcs_version = srcs_version,
- deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
- protoc = "@protobuf_archive//:protoc",
- default_runtime = "@protobuf_archive//:protobuf_python",
- visibility = visibility,
- testonly = testonly,
- use_grpc_plugin = use_grpc_plugin,
- )
def tf_jspb_proto_library(**kwargs):
- pass
+ pass
def tf_nano_proto_library(**kwargs):
- pass
-
-def tf_proto_library(name, srcs = [], has_services = None,
- protodeps = [],
- visibility = [], testonly = 0,
- cc_libs = [],
- cc_api_version = 2, cc_grpc_version = None,
- dart_api_version = 2, j2objc_api_version = 1,
- java_api_version = 2, py_api_version = 2,
- js_api_version = 2, js_codegen = "jspb",
- provide_cc_alias = False,
- default_header = False):
- """Make a proto library, possibly depending on other proto libraries."""
- _ignore = (js_api_version, js_codegen, provide_cc_alias)
-
- tf_proto_library_cc(
- name = name,
- srcs = srcs,
- protodeps = protodeps,
- cc_grpc_version = cc_grpc_version,
- cc_libs = cc_libs,
- testonly = testonly,
- visibility = visibility,
- default_header = default_header,
- )
-
- tf_proto_library_py(
- name = name,
- srcs = srcs,
- protodeps = protodeps,
- srcs_version = "PY2AND3",
- testonly = testonly,
- visibility = visibility,
- use_grpc_plugin = has_services,
- )
+ pass
+
+def tf_proto_library(
+ name,
+ srcs = [],
+ has_services = None,
+ protodeps = [],
+ visibility = [],
+ testonly = 0,
+ cc_libs = [],
+ cc_api_version = 2,
+ cc_grpc_version = None,
+ dart_api_version = 2,
+ j2objc_api_version = 1,
+ java_api_version = 2,
+ py_api_version = 2,
+ js_api_version = 2,
+ js_codegen = "jspb",
+ provide_cc_alias = False,
+ default_header = False):
+ """Make a proto library, possibly depending on other proto libraries."""
+ _ignore = (js_api_version, js_codegen, provide_cc_alias)
+
+ tf_proto_library_cc(
+ name = name,
+ srcs = srcs,
+ protodeps = protodeps,
+ cc_grpc_version = cc_grpc_version,
+ cc_libs = cc_libs,
+ testonly = testonly,
+ visibility = visibility,
+ default_header = default_header,
+ )
+
+ tf_proto_library_py(
+ name = name,
+ srcs = srcs,
+ protodeps = protodeps,
+ srcs_version = "PY2AND3",
+ testonly = testonly,
+ visibility = visibility,
+ use_grpc_plugin = has_services,
+ )
# A list of all files under platform matching the pattern in 'files'. In
# contrast with 'tf_platform_srcs' below, which seletive collects files that
# must be compiled in the 'default' platform, this is a list of all headers
# mentioned in the platform/* files.
def tf_platform_hdrs(files):
- return native.glob(["platform/*/" + f for f in files])
+ return native.glob(["platform/*/" + f for f in files])
def tf_platform_srcs(files):
- base_set = ["platform/default/" + f for f in files]
- windows_set = base_set + ["platform/windows/" + f for f in files]
- posix_set = base_set + ["platform/posix/" + f for f in files]
-
- # Handle cases where we must also bring the posix file in. Usually, the list
- # of files to build on windows builds is just all the stuff in the
- # windows_set. However, in some cases the implementations in 'posix/' are
- # just what is necessary and historically we choose to simply use the posix
- # file instead of making a copy in 'windows'.
- for f in files:
- if f == "error.cc":
- windows_set.append("platform/posix/" + f)
-
- return select({
- "//tensorflow:windows" : native.glob(windows_set),
- "//tensorflow:windows_msvc" : native.glob(windows_set),
- "//conditions:default" : native.glob(posix_set),
- })
+ base_set = ["platform/default/" + f for f in files]
+ windows_set = base_set + ["platform/windows/" + f for f in files]
+ posix_set = base_set + ["platform/posix/" + f for f in files]
+
+ # Handle cases where we must also bring the posix file in. Usually, the list
+ # of files to build on windows builds is just all the stuff in the
+ # windows_set. However, in some cases the implementations in 'posix/' are
+ # just what is necessary and historically we choose to simply use the posix
+ # file instead of making a copy in 'windows'.
+ for f in files:
+ if f == "error.cc":
+ windows_set.append("platform/posix/" + f)
+
+ return select({
+ "//tensorflow:windows": native.glob(windows_set),
+ "//conditions:default": native.glob(posix_set),
+ })
def tf_additional_lib_hdrs(exclude = []):
- windows_hdrs = native.glob([
- "platform/default/*.h",
- "platform/windows/*.h",
- "platform/posix/error.h",
- ], exclude = exclude)
- return select({
- "//tensorflow:windows" : windows_hdrs,
- "//tensorflow:windows_msvc" : windows_hdrs,
- "//conditions:default" : native.glob([
+ windows_hdrs = native.glob([
"platform/default/*.h",
- "platform/posix/*.h",
- ], exclude = exclude),
- })
+ "platform/windows/*.h",
+ "platform/posix/error.h",
+ ], exclude = exclude)
+ return select({
+ "//tensorflow:windows": windows_hdrs,
+ "//conditions:default": native.glob([
+ "platform/default/*.h",
+ "platform/posix/*.h",
+ ], exclude = exclude),
+ })
def tf_additional_lib_srcs(exclude = []):
- windows_srcs = native.glob([
- "platform/default/*.cc",
- "platform/windows/*.cc",
- "platform/posix/error.cc",
- ], exclude = exclude)
- return select({
- "//tensorflow:windows" : windows_srcs,
- "//tensorflow:windows_msvc" : windows_srcs,
- "//conditions:default" : native.glob([
+ windows_srcs = native.glob([
"platform/default/*.cc",
- "platform/posix/*.cc",
- ], exclude = exclude),
- })
+ "platform/windows/*.cc",
+ "platform/posix/error.cc",
+ ], exclude = exclude)
+ return select({
+ "//tensorflow:windows": windows_srcs,
+ "//conditions:default": native.glob([
+ "platform/default/*.cc",
+ "platform/posix/*.cc",
+ ], exclude = exclude),
+ })
def tf_additional_minimal_lib_srcs():
- return [
- "platform/default/integral_types.h",
- "platform/default/mutex.h",
- ]
+ return [
+ "platform/default/integral_types.h",
+ "platform/default/mutex.h",
+ ]
def tf_additional_proto_hdrs():
- return [
- "platform/default/integral_types.h",
- "platform/default/logging.h",
- "platform/default/protobuf.h"
- ] + if_windows([
- "platform/windows/integral_types.h",
- ])
+ return [
+ "platform/default/integral_types.h",
+ "platform/default/logging.h",
+ "platform/default/protobuf.h",
+ ] + if_windows([
+ "platform/windows/integral_types.h",
+ ])
+
+def tf_additional_proto_compiler_hdrs():
+ return [
+ "platform/default/protobuf_compiler.h",
+ ]
def tf_additional_proto_srcs():
- return [
- "platform/default/protobuf.cc",
- ]
+ return [
+ "platform/default/protobuf.cc",
+ ]
def tf_additional_human_readable_json_deps():
- return []
+ return []
def tf_additional_all_protos():
- return ["//tensorflow/core:protos_all"]
+ return ["//tensorflow/core:protos_all"]
def tf_protos_all_impl():
- return ["//tensorflow/core:protos_all_cc_impl"]
+ return ["//tensorflow/core:protos_all_cc_impl"]
def tf_protos_all():
- return if_static(
- extra_deps=tf_protos_all_impl(),
- otherwise=["//tensorflow/core:protos_all_cc"])
+ return if_static(
+ extra_deps = tf_protos_all_impl(),
+ otherwise = ["//tensorflow/core:protos_all_cc"],
+ )
def tf_protos_grappler_impl():
- return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]
+ return ["//tensorflow/core/grappler/costs:op_performance_data_cc_impl"]
def tf_protos_grappler():
- return if_static(
- extra_deps=tf_protos_grappler_impl(),
- otherwise=["//tensorflow/core/grappler/costs:op_performance_data_cc"])
+ return if_static(
+ extra_deps = tf_protos_grappler_impl(),
+ otherwise = ["//tensorflow/core/grappler/costs:op_performance_data_cc"],
+ )
def tf_additional_cupti_wrapper_deps():
- return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
+ return ["//tensorflow/core/platform/default/gpu:cupti_wrapper"]
def tf_additional_device_tracer_srcs():
- return ["platform/default/device_tracer.cc"]
+ return ["platform/default/device_tracer.cc"]
def tf_additional_device_tracer_cuda_deps():
- return []
+ return []
def tf_additional_device_tracer_deps():
- return []
+ return []
def tf_additional_libdevice_data():
- return []
+ return []
def tf_additional_libdevice_deps():
- return ["@local_config_cuda//cuda:cuda_headers"]
+ return ["@local_config_cuda//cuda:cuda_headers"]
def tf_additional_libdevice_srcs():
- return ["platform/default/cuda_libdevice_path.cc"]
+ return ["platform/default/cuda_libdevice_path.cc"]
def tf_additional_test_deps():
- return []
+ return []
def tf_additional_test_srcs():
- return [
- "platform/default/test_benchmark.cc",
- ] + select({
- "//tensorflow:windows" : [
- "platform/windows/test.cc"
+ return [
+ "platform/default/test_benchmark.cc",
+ ] + select({
+ "//tensorflow:windows": [
+ "platform/windows/test.cc",
],
- "//conditions:default" : [
- "platform/posix/test.cc",
+ "//conditions:default": [
+ "platform/posix/test.cc",
],
})
def tf_kernel_tests_linkstatic():
- return 0
+ return 0
def tf_additional_lib_defines():
- """Additional defines needed to build TF libraries."""
- return select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
- "//tensorflow:with_jemalloc_linux_ppc64le":["TENSORFLOW_USE_JEMALLOC"],
- "//conditions:default": [],
- }) + if_not_mobile(["TENSORFLOW_USE_ABSL"])
+ """Additional defines needed to build TF libraries."""
+ return []
def tf_additional_lib_deps():
- """Additional dependencies needed to build TF libraries."""
- return if_not_mobile(["@com_google_absl//absl/base:base"]) + if_static(
- ["@nsync//:nsync_cpp"],
- ["@nsync//:nsync_headers"]
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- })
+ """Additional dependencies needed to build TF libraries."""
+ return [
+ "@com_google_absl//absl/base:base",
+ "@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/types:span",
+ "@com_google_absl//absl/types:optional",
+ ] + if_static(
+ ["@nsync//:nsync_cpp"],
+ ["@nsync//:nsync_headers"],
+ )
def tf_additional_core_deps():
- return select({
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/core/platform/cloud:gcs_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_hdfs_support_windows_override": [],
- "//tensorflow:with_hdfs_support_android_override": [],
- "//tensorflow:with_hdfs_support_ios_override": [],
- "//tensorflow:with_hdfs_support": [
- "//tensorflow/core/platform/hadoop:hadoop_file_system",
- ],
- "//conditions:default": [],
- }) + select({
- "//tensorflow:with_aws_support_windows_override": [],
- "//tensorflow:with_aws_support_android_override": [],
- "//tensorflow:with_aws_support_ios_override": [],
- "//tensorflow:with_aws_support": [
- "//tensorflow/core/platform/s3:s3_file_system",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
+ "//tensorflow/core/platform/cloud:gcs_file_system",
+ "//tensorflow/core/platform/s3:s3_file_system",
+ "//tensorflow/core/platform/hadoop:hadoop_file_system",
+ ],
+ })
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_op_deps():
- return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
- "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
+ "//tensorflow/contrib/cloud:bigquery_reader_ops_op_lib",
+ "//tensorflow/contrib/cloud:gcs_config_ops_op_lib",
+ ],
+ })
# TODO(jart, jhseu): Delete when GCP is default on.
def tf_additional_cloud_kernel_deps():
- return select({
- "//tensorflow:with_gcp_support_windows_override": [],
- "//tensorflow:with_gcp_support_android_override": [],
- "//tensorflow:with_gcp_support_ios_override": [],
- "//tensorflow:with_gcp_support": [
- "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
- "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
- ],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:android": [],
+ "//tensorflow:windows": [],
+ "//tensorflow:ios": [],
+ "//tensorflow:linux_s390x": [],
+ "//conditions:default": [
+ "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops",
+ "//tensorflow/contrib/cloud/kernels:gcs_config_ops",
+ ],
+ })
def tf_lib_proto_parsing_deps():
- return [
- ":protos_all_cc",
- "//third_party/eigen3",
- "//tensorflow/core/platform/default/build_config:proto_parsing",
- ]
+ return [
+ ":protos_all_cc",
+ "//third_party/eigen3",
+ "//tensorflow/core/platform/default/build_config:proto_parsing",
+ ]
+
+def tf_lib_proto_compiler_deps():
+ return [
+ "@protobuf_archive//:protoc_lib",
+ ]
def tf_additional_verbs_lib_defines():
- return select({
- "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_verbs_support": ["TENSORFLOW_USE_VERBS"],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_lib_defines():
- return select({
- "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_mpi_support": ["TENSORFLOW_USE_MPI"],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_lib_defines():
- return select({
- "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
- "//conditions:default": [],
- })
+ return select({
+ "//tensorflow:with_gdr_support": ["TENSORFLOW_USE_GDR"],
+ "//conditions:default": [],
+ })
-def tf_py_clif_cc(name, visibility=None, **kwargs):
- pass
+def tf_py_clif_cc(name, visibility = None, **kwargs):
+ pass
-def tf_pyclif_proto_library(name, proto_lib, proto_srcfile="", visibility=None,
- **kwargs):
- pass
+def tf_pyclif_proto_library(
+ name,
+ proto_lib,
+ proto_srcfile = "",
+ visibility = None,
+ **kwargs):
+ pass
def tf_additional_binary_deps():
- return ["@nsync//:nsync_cpp"] + if_cuda(
- [
- "//tensorflow/stream_executor:cuda_platform",
- "//tensorflow/core/platform/default/build_config:cuda",
- ],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- }) + [
- # TODO(allenl): Split these out into their own shared objects (they are
- # here because they are shared between contrib/ op shared objects and
- # core).
- "//tensorflow/core/kernels:lookup_util",
- "//tensorflow/core/util/tensor_bundle",
- ] + if_mkl(
- [
- "//third_party/mkl:intel_binary_blob",
- ],
- )
+ return ["@nsync//:nsync_cpp"] + if_cuda(
+ [
+ "//tensorflow/stream_executor:cuda_platform",
+ "//tensorflow/core/platform/default/build_config:cuda",
+ ],
+ ) + [
+ # TODO(allenl): Split these out into their own shared objects (they are
+ # here because they are shared between contrib/ op shared objects and
+ # core).
+ "//tensorflow/core/kernels:lookup_util",
+ "//tensorflow/core/util/tensor_bundle",
+ ] + if_mkl_ml(
+ [
+ "//third_party/mkl:intel_binary_blob",
+ ],
+ )
diff --git a/tensorflow/core/platform/default/build_config_root.bzl b/tensorflow/core/platform/default/build_config_root.bzl
index 3a012c23fd..37475feebe 100644
--- a/tensorflow/core/platform/default/build_config_root.bzl
+++ b/tensorflow/core/platform/default/build_config_root.bzl
@@ -3,64 +3,64 @@
# be separate to avoid cyclic references.
def tf_cuda_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_sycl_tests_tags():
- return ["requires-gpu"]
+ return ["requires-gpu", "local", "gpu"]
def tf_additional_plugin_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): [
- str(Label("//tensorflow/compiler/jit"))
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): [
+ str(Label("//tensorflow/compiler/jit")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_xla_deps_py():
- return []
+ return []
def tf_additional_grpc_deps_py():
- return []
+ return []
def tf_additional_license_deps():
- return select({
- str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_xla_support")): ["@llvm//:LICENSE.TXT"],
+ "//conditions:default": [],
+ })
def tf_additional_verbs_deps():
- return select({
- str(Label("//tensorflow:with_verbs_support")): [
- str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
- str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_verbs_support")): [
+ str(Label("//tensorflow/contrib/verbs:verbs_server_lib")),
+ str(Label("//tensorflow/contrib/verbs:grpc_verbs_client")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_mpi_deps():
- return select({
- str(Label("//tensorflow:with_mpi_support")): [
- str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_mpi_support")): [
+ str(Label("//tensorflow/contrib/mpi:mpi_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
def tf_additional_gdr_deps():
- return select({
- str(Label("//tensorflow:with_gdr_support")): [
- str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
- ],
- "//conditions:default": [],
- })
+ return select({
+ str(Label("//tensorflow:with_gdr_support")): [
+ str(Label("//tensorflow/contrib/gdr:gdr_server_lib")),
+ ],
+ "//conditions:default": [],
+ })
-def if_static(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:framework_shared_object")): otherwise,
- "//conditions:default": extra_deps,
- })
+def if_static(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:framework_shared_object")): otherwise,
+ "//conditions:default": extra_deps,
+ })
-def if_dynamic_kernels(extra_deps, otherwise=[]):
- return select({
- str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
- "//conditions:default": otherwise,
- })
+def if_dynamic_kernels(extra_deps, otherwise = []):
+ return select({
+ str(Label("//tensorflow:dynamic_loaded_kernels")): extra_deps,
+ "//conditions:default": otherwise,
+ })
diff --git a/tensorflow/core/platform/default/cord.h b/tensorflow/core/platform/default/cord.h
new file mode 100644
index 0000000000..5823374d1a
--- /dev/null
+++ b/tensorflow/core/platform/default/cord.h
@@ -0,0 +1,21 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
+
+// TODO(ebrevdo): Fill this in.
+
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_CORD_H_
diff --git a/tensorflow/core/platform/default/device_tracer.cc b/tensorflow/core/platform/default/device_tracer.cc
index ccddf1eafc..83c65dbfa9 100644
--- a/tensorflow/core/platform/default/device_tracer.cc
+++ b/tensorflow/core/platform/default/device_tracer.cc
@@ -321,6 +321,16 @@ class DeviceTracerImpl : public DeviceTracer,
return nullptr;
}
+ bool IsEnabledForAnnotations() const override {
+ // We are always enabled for 'Annotations'.
+ return true;
+ }
+
+ bool IsEnabledForActivities(bool is_expensive) const override {
+ // We don't do anything with 'Activities' so we are never 'enabled'.
+ return false;
+ }
+
protected:
// This callback is used exclusively by CUPTIManager.
friend class CUPTIManager;
diff --git a/tensorflow/core/platform/default/integral_types.h b/tensorflow/core/platform/default/integral_types.h
index 7cbe7d62f7..92186bc912 100644
--- a/tensorflow/core/platform/default/integral_types.h
+++ b/tensorflow/core/platform/default/integral_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/types.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/types.h
@@ -33,4 +33,4 @@ typedef unsigned long long uint64;
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_INTEGRAL_TYPES_H_
diff --git a/tensorflow/core/platform/default/logging.h b/tensorflow/core/platform/default/logging.h
index 2c134f1be9..08a692fff7 100644
--- a/tensorflow/core/platform/default/logging.h
+++ b/tensorflow/core/platform/default/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/logging.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/logging.h
@@ -314,4 +314,4 @@ int64 MinVLogLevelFromEnv();
} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_LOGGING_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_LOGGING_H_
diff --git a/tensorflow/core/platform/default/mutex.h b/tensorflow/core/platform/default/mutex.h
index 48d90779e1..bef7801037 100644
--- a/tensorflow/core/platform/default/mutex.h
+++ b/tensorflow/core/platform/default/mutex.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_MUTEX_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_MUTEX_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/mutex.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/mutex.h
@@ -173,4 +173,4 @@ inline ConditionResult WaitForMilliseconds(mutex_lock* mu,
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_MUTEX_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_MUTEX_H_
diff --git a/tensorflow/core/platform/default/protobuf.h b/tensorflow/core/platform/default/protobuf.h
index c732c76ff7..bd9d41c62b 100644
--- a/tensorflow/core/platform/default/protobuf.h
+++ b/tensorflow/core/platform/default/protobuf.h
@@ -20,8 +20,8 @@ limitations under the License.
// IWYU pragma: friend third_party/tensorflow/core/platform/protobuf.h
#include "google/protobuf/arena.h"
-#include "google/protobuf/compiler/importer.h"
#include "google/protobuf/descriptor.h"
+#include "google/protobuf/descriptor.pb.h"
#include "google/protobuf/dynamic_message.h"
#include "google/protobuf/io/coded_stream.h"
#include "google/protobuf/io/zero_copy_stream.h"
diff --git a/tensorflow/core/kernels/warn_about_ints.cc b/tensorflow/core/platform/default/protobuf_compiler.h
index 75ecdf2ae4..a93d7a184b 100644
--- a/tensorflow/core/kernels/warn_about_ints.cc
+++ b/tensorflow/core/platform/default/protobuf_compiler.h
@@ -13,21 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/core/kernels/warn_about_ints.h"
-#include "tensorflow/core/framework/node_def.pb.h"
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_COMPILER_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_COMPILER_H_
-namespace tensorflow {
+// IWYU pragma: private, include "third_party/tensorflow/core/platform/protobuf_compiler.h"
+// IWYU pragma: friend third_party/tensorflow/core/platform/protobuf_compiler.h
-void WarnAboutInts(OpKernelConstruction* context) {
- DataType dtype;
- OP_REQUIRES_OK(context, context->GetAttr("T", &dtype));
- if (DataTypeIsInteger(dtype)) {
- LOG(WARNING) << "Op " << context->def().name() << " of type "
- << context->def().op() << " used with integer dtype "
- << DataTypeString(dtype)
- << ". This op was registered with integer support "
- << "accidentally, and you won't like the result.";
- }
-}
+#include "google/protobuf/compiler/importer.h"
+#include "tensorflow/core/platform/default/protobuf.h"
-} // namespace tensorflow
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_PROTOBUF_H_
diff --git a/tensorflow/core/platform/default/thread_annotations.h b/tensorflow/core/platform/default/thread_annotations.h
index a6aa5b1b5e..d21d60ab0b 100644
--- a/tensorflow/core/platform/default/thread_annotations.h
+++ b/tensorflow/core/platform/default/thread_annotations.h
@@ -32,8 +32,8 @@ limitations under the License.
// (e.g. &MyClass::mutex_) to refer to a mutex in some (unknown) object.
//
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
// IWYU pragma: private, include "third_party/tensorflow/core/platform/thread_annotations.h"
// IWYU pragma: friend third_party/tensorflow/core/platform/thread_annotations.h
@@ -174,4 +174,4 @@ inline T& ts_unchecked_read(T& v) NO_THREAD_SAFETY_ANALYSIS {
} // namespace thread_safety_analysis
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/default/tracing_impl.h b/tensorflow/core/platform/default/tracing_impl.h
index b161378405..b7a5f1386c 100644
--- a/tensorflow/core/platform/default/tracing_impl.h
+++ b/tensorflow/core/platform/default/tracing_impl.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#define TENSORFLOW_CORE_PLATFORM_DEFAULT_TRACING_IMPL_H_
// Stub implementations of tracing functionality.
@@ -43,4 +43,4 @@ inline bool EventCollector::IsEnabled() { return false; }
} // namespace tracing
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_TRACING_IMPL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DEFAULT_TRACING_IMPL_H_
diff --git a/tensorflow/core/platform/denormal.h b/tensorflow/core/platform/denormal.h
index 09bb0352a2..555ac023db 100644
--- a/tensorflow/core/platform/denormal.h
+++ b/tensorflow/core/platform/denormal.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DENORMAL_H_
-#define TENSORFLOW_PLATFORM_DENORMAL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DENORMAL_H_
+#define TENSORFLOW_CORE_PLATFORM_DENORMAL_H_
#include "tensorflow/core/platform/macros.h"
@@ -59,4 +59,4 @@ class ScopedDontFlushDenormal {
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DENORMAL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DENORMAL_H_
diff --git a/tensorflow/core/platform/dynamic_annotations.h b/tensorflow/core/platform/dynamic_annotations.h
index f51f3f33a3..dad0d0f4e4 100644
--- a/tensorflow/core/platform/dynamic_annotations.h
+++ b/tensorflow/core/platform/dynamic_annotations.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DYNAMIC_ANNOTATIONS_H_
-#define TENSORFLOW_PLATFORM_DYNAMIC_ANNOTATIONS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_
+#define TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_
#include "tensorflow/core/platform/platform.h"
@@ -28,4 +28,4 @@ limitations under the License.
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_DYNAMIC_ANNOTATIONS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_DYNAMIC_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/env.cc b/tensorflow/core/platform/env.cc
index 47c59d435b..afc4201e53 100644
--- a/tensorflow/core/platform/env.cc
+++ b/tensorflow/core/platform/env.cc
@@ -92,7 +92,7 @@ Env::Env() : file_system_registry_(new FileSystemRegistryImpl) {}
Status Env::GetFileSystemForFile(const string& fname, FileSystem** result) {
StringPiece scheme, host, path;
io::ParseURI(fname, &scheme, &host, &path);
- FileSystem* file_system = file_system_registry_->Lookup(std::string(scheme));
+ FileSystem* file_system = file_system_registry_->Lookup(string(scheme));
if (!file_system) {
if (scheme.empty()) {
scheme = "[local]";
@@ -166,7 +166,7 @@ bool Env::FilesExist(const std::vector<string>& files,
for (const auto& file : files) {
StringPiece scheme, host, path;
io::ParseURI(file, &scheme, &host, &path);
- files_per_fs[std::string(scheme)].push_back(file);
+ files_per_fs[string(scheme)].push_back(file);
}
std::unordered_map<string, Status> per_file_status;
diff --git a/tensorflow/core/platform/env.h b/tensorflow/core/platform/env.h
index 5b237c4736..5732271f15 100644
--- a/tensorflow/core/platform/env.h
+++ b/tensorflow/core/platform/env.h
@@ -228,6 +228,10 @@ class Env {
/// |suffix|. Returns true if success.
bool CreateUniqueFileName(string* prefix, const string& suffix);
+ /// \brief Return the runfiles directory if running under bazel. Returns
+ /// the directory the executable is located in if not running under bazel.
+ virtual string GetRunfilesDir() = 0;
+
// TODO(jeff,sanjay): Add back thread/thread-pool support if needed.
// TODO(jeff,sanjay): if needed, tighten spec so relative to epoch, or
// provide a routine to get the absolute time.
@@ -360,6 +364,8 @@ class EnvWrapper : public Env {
return target_->FormatLibraryFileName(name, version);
}
+ string GetRunfilesDir() override { return target_->GetRunfilesDir(); }
+
private:
void GetLocalTempDirectories(std::vector<string>* list) override {
target_->GetLocalTempDirectories(list);
diff --git a/tensorflow/core/platform/env_test.cc b/tensorflow/core/platform/env_test.cc
index 305a9a682f..2e32abdffb 100644
--- a/tensorflow/core/platform/env_test.cc
+++ b/tensorflow/core/platform/env_test.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/null_file_system.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/test.h"
@@ -345,7 +346,13 @@ TEST_F(DefaultEnvTest, LocalTempFilename) {
// Write something to the temporary file.
std::unique_ptr<WritableFile> file_to_write;
TF_CHECK_OK(env->NewWritableFile(filename, &file_to_write));
+#if defined(PLATFORM_GOOGLE)
+ TF_CHECK_OK(file_to_write->Append("Nu"));
+ TF_CHECK_OK(file_to_write->Append(absl::Cord("ll")));
+#else
+ // TODO(ebrevdo): Remove this version.
TF_CHECK_OK(file_to_write->Append("Null"));
+#endif
TF_CHECK_OK(file_to_write->Close());
TF_CHECK_OK(env->FileExists(filename));
diff --git a/tensorflow/core/platform/file_system.cc b/tensorflow/core/platform/file_system.cc
index 922773684b..3ab542a5d8 100644
--- a/tensorflow/core/platform/file_system.cc
+++ b/tensorflow/core/platform/file_system.cc
@@ -158,7 +158,7 @@ Status FileSystem::RecursivelyCreateDir(const string& dirname) {
std::reverse(sub_dirs.begin(), sub_dirs.end());
// Now create the directories.
- string built_path = std::string(remaining_dir);
+ string built_path(remaining_dir);
for (const StringPiece sub_dir : sub_dirs) {
built_path = io::JoinPath(built_path, sub_dir);
Status status = CreateDir(io::CreateURI(scheme, host, built_path));
diff --git a/tensorflow/core/platform/file_system.h b/tensorflow/core/platform/file_system.h
index 077b1d79cf..156af6cdea 100644
--- a/tensorflow/core/platform/file_system.h
+++ b/tensorflow/core/platform/file_system.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/cord.h"
#include "tensorflow/core/platform/file_statistics.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/platform.h"
@@ -252,7 +253,15 @@ class WritableFile {
virtual ~WritableFile();
/// \brief Append 'data' to the file.
- virtual Status Append(const StringPiece& data) = 0;
+ virtual Status Append(StringPiece data) = 0;
+
+ // TODO(ebrevdo): Remove this ifdef when absl is updated.
+#if defined(PLATFORM_GOOGLE)
+ // \brief Append 'data' to the file.
+ virtual Status Append(const absl::Cord& cord) {
+ return errors::Unimplemented("Append(absl::Cord) is not implemented");
+ }
+#endif
/// \brief Close the file.
///
diff --git a/tensorflow/core/platform/file_system_helper.cc b/tensorflow/core/platform/file_system_helper.cc
index 0ba0e6304f..342cf28e38 100644
--- a/tensorflow/core/platform/file_system_helper.cc
+++ b/tensorflow/core/platform/file_system_helper.cc
@@ -59,7 +59,7 @@ Status GetMatchingPaths(FileSystem* fs, Env* env, const string& pattern,
string fixed_prefix = pattern.substr(0, pattern.find_first_of("*?[\\"));
string eval_pattern = pattern;
std::vector<string> all_files;
- string dir = std::string(io::Dirname(fixed_prefix));
+ string dir(io::Dirname(fixed_prefix));
// If dir is empty then we need to fix up fixed_prefix and eval_pattern to
// include . as the top level directory.
if (dir.empty()) {
diff --git a/tensorflow/core/platform/file_system_test.cc b/tensorflow/core/platform/file_system_test.cc
index c0a16c95f9..a637d42a92 100644
--- a/tensorflow/core/platform/file_system_test.cc
+++ b/tensorflow/core/platform/file_system_test.cc
@@ -125,7 +125,7 @@ class InterPlanetaryFileSystem : public NullFileSystem {
ASSERT_EQ(scheme, "ipfs");
ASSERT_EQ(host, "solarsystem");
str_util::ConsumePrefix(&path, "/");
- *parsed_path = std::string(path);
+ *parsed_path = string(path);
}
std::map<string, std::set<string>> celestial_bodies_ = {
diff --git a/tensorflow/core/platform/hadoop/hadoop_file_system.cc b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
index ff4b4436bb..eb35531e9f 100644
--- a/tensorflow/core/platform/hadoop/hadoop_file_system.cc
+++ b/tensorflow/core/platform/hadoop/hadoop_file_system.cc
@@ -144,7 +144,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
StringPiece scheme, namenode, path;
io::ParseURI(fname, &scheme, &namenode, &path);
- const string nn = namenode.ToString();
+ const string nn(namenode);
hdfsBuilder* builder = hdfs_->hdfsNewBuilder();
if (scheme == "file") {
@@ -183,7 +183,7 @@ Status HadoopFileSystem::Connect(StringPiece fname, hdfsFS* fs) {
string HadoopFileSystem::TranslateName(const string& name) const {
StringPiece scheme, namenode, path;
io::ParseURI(name, &scheme, &namenode, &path);
- return path.ToString();
+ return string(path);
}
class HDFSRandomAccessFile : public RandomAccessFile {
@@ -282,7 +282,7 @@ class HDFSWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (hdfs_->hdfsWrite(fs_, file_, data.data(),
static_cast<tSize>(data.size())) == -1) {
return IOError(filename_, errno);
@@ -392,7 +392,7 @@ Status HadoopFileSystem::GetChildren(const string& dir,
return IOError(dir, errno);
}
for (int i = 0; i < entries; i++) {
- result->push_back(io::Basename(info[i].mName).ToString());
+ result->push_back(string(io::Basename(info[i].mName)));
}
hdfs_->hdfsFreeFileInfo(info, entries);
return Status::OK();
diff --git a/tensorflow/core/platform/host_info.h b/tensorflow/core/platform/host_info.h
index 6124c95923..e76b83adf3 100644
--- a/tensorflow/core/platform/host_info.h
+++ b/tensorflow/core/platform/host_info.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_HOST_INFO_H_
-#define TENSORFLOW_PLATFORM_HOST_INFO_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_
+#define TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_
#include "tensorflow/core/platform/types.h"
@@ -27,4 +27,4 @@ string Hostname();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_HOST_INFO_H_
+#endif // TENSORFLOW_CORE_PLATFORM_HOST_INFO_H_
diff --git a/tensorflow/core/platform/init_main.h b/tensorflow/core/platform/init_main.h
index 20cbc615b1..834c529816 100644
--- a/tensorflow/core/platform/init_main.h
+++ b/tensorflow/core/platform/init_main.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_INIT_MAIN_H_
-#define TENSORFLOW_PLATFORM_INIT_MAIN_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_
+#define TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_
namespace tensorflow {
namespace port {
@@ -28,4 +28,4 @@ void InitMain(const char* usage, int* argc, char*** argv);
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_INIT_MAIN_H_
+#endif // TENSORFLOW_CORE_PLATFORM_INIT_MAIN_H_
diff --git a/tensorflow/core/platform/load_library.h b/tensorflow/core/platform/load_library.h
index 9038de25f3..c7eeb2918c 100644
--- a/tensorflow/core/platform/load_library.h
+++ b/tensorflow/core/platform/load_library.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
-#define TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_
+#define TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_
#include "tensorflow/core/lib/core/status.h"
@@ -31,4 +31,4 @@ string FormatLibraryFileName(const string& name, const string& version);
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_LOAD_LIBRARY_H_
+#endif // TENSORFLOW_CORE_PLATFORM_LOAD_LIBRARY_H_
diff --git a/tensorflow/core/platform/logging.h b/tensorflow/core/platform/logging.h
index 985c061676..17a5d5fb5b 100644
--- a/tensorflow/core/platform/logging.h
+++ b/tensorflow/core/platform/logging.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_LOGGING_H_
-#define TENSORFLOW_PLATFORM_LOGGING_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_LOGGING_H_
+#define TENSORFLOW_CORE_PLATFORM_LOGGING_H_
#include "tensorflow/core/platform/platform.h" // To pick up PLATFORM_define
@@ -36,4 +36,4 @@ void LogString(const char* fname, int line, int severity,
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_LOGGING_H_
+#endif // TENSORFLOW_CORE_PLATFORM_LOGGING_H_
diff --git a/tensorflow/core/platform/macros.h b/tensorflow/core/platform/macros.h
index b65eb43146..e1d83e18ac 100644
--- a/tensorflow/core/platform/macros.h
+++ b/tensorflow/core/platform/macros.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_MACROS_H_
-#define TENSORFLOW_PLATFORM_MACROS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_MACROS_H_
+#define TENSORFLOW_CORE_PLATFORM_MACROS_H_
// Compiler attributes
#if (defined(__GNUC__) || defined(__APPLE__)) && !defined(SWIG)
@@ -125,4 +125,4 @@ limitations under the License.
} while (0)
#endif
-#endif // TENSORFLOW_PLATFORM_MACROS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_MACROS_H_
diff --git a/tensorflow/core/platform/mem.h b/tensorflow/core/platform/mem.h
index fca3a2332d..e8150f7322 100644
--- a/tensorflow/core/platform/mem.h
+++ b/tensorflow/core/platform/mem.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_MEM_H_
-#define TENSORFLOW_PLATFORM_MEM_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_MEM_H_
+#define TENSORFLOW_CORE_PLATFORM_MEM_H_
// TODO(cwhipkey): remove this when callers use annotations directly.
#include "tensorflow/core/platform/dynamic_annotations.h"
@@ -65,4 +65,4 @@ int64 AvailableRam();
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_MEM_H_
+#endif // TENSORFLOW_CORE_PLATFORM_MEM_H_
diff --git a/tensorflow/core/platform/mutex.h b/tensorflow/core/platform/mutex.h
index 42d46ceb5b..66b20da95a 100644
--- a/tensorflow/core/platform/mutex.h
+++ b/tensorflow/core/platform/mutex.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_MUTEX_H_
-#define TENSORFLOW_PLATFORM_MUTEX_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_MUTEX_H_
+#define TENSORFLOW_CORE_PLATFORM_MUTEX_H_
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
@@ -50,4 +50,4 @@ ConditionResult WaitForMilliseconds(mutex_lock* mu, condition_variable* cv,
int64 ms);
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_MUTEX_H_
+#endif // TENSORFLOW_CORE_PLATFORM_MUTEX_H_
diff --git a/tensorflow/core/platform/net.h b/tensorflow/core/platform/net.h
index 9e7851728d..7dbc92f058 100644
--- a/tensorflow/core/platform/net.h
+++ b/tensorflow/core/platform/net.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_NET_H_
-#define TENSORFLOW_PLATFORM_NET_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_NET_H_
+#define TENSORFLOW_CORE_PLATFORM_NET_H_
namespace tensorflow {
namespace internal {
@@ -24,4 +24,4 @@ int PickUnusedPortOrDie();
} // namespace internal
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_NET_H_
+#endif // TENSORFLOW_CORE_PLATFORM_NET_H_
diff --git a/tensorflow/core/platform/png.h b/tensorflow/core/platform/png.h
index b110d63aba..93b1425f7a 100644
--- a/tensorflow/core/platform/png.h
+++ b/tensorflow/core/platform/png.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PNG_H_
-#define TENSORFLOW_PLATFORM_PNG_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PNG_H_
+#define TENSORFLOW_CORE_PLATFORM_PNG_H_
#include "tensorflow/core/platform/platform.h"
@@ -27,4 +27,4 @@ limitations under the License.
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_PNG_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PNG_H_
diff --git a/tensorflow/core/platform/posix/env.cc b/tensorflow/core/platform/posix/env.cc
index 418874d340..af95d8201e 100644
--- a/tensorflow/core/platform/posix/env.cc
+++ b/tensorflow/core/platform/posix/env.cc
@@ -119,6 +119,17 @@ class PosixEnv : public Env {
return tensorflow::internal::FormatLibraryFileName(name, version);
}
+ string GetRunfilesDir() override {
+ string bin_path = this->GetExecutablePath();
+ string runfiles_path = bin_path + ".runfiles/org_tensorflow";
+ Status s = this->IsDirectory(runfiles_path);
+ if (!s.ok()) {
+ return runfiles_path;
+ } else {
+ return bin_path.substr(0, bin_path.find_last_of("/\\"));
+ }
+ }
+
private:
void GetLocalTempDirectories(std::vector<string>* list) override;
};
diff --git a/tensorflow/core/platform/posix/error.h b/tensorflow/core/platform/posix/error.h
index 9b614d0f70..9df5f2daa1 100644
--- a/tensorflow/core/platform/posix/error.h
+++ b/tensorflow/core/platform/posix/error.h
@@ -24,4 +24,4 @@ Status IOError(const string& context, int err_number);
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_POSIX_POSIX_FILE_SYSTEM_H_
+#endif // TENSORFLOW_CORE_PLATFORM_POSIX_ERROR_H_
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index 1939cf72fb..acdd7798ea 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -13,13 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef TENSORFLOW_USE_JEMALLOC
-#include "jemalloc/jemalloc.h"
-#endif
-
-#ifdef TENSORFLOW_USE_ABSL
#include "absl/base/internal/sysinfo.h"
-#endif
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
@@ -103,11 +97,7 @@ void* AlignedMalloc(size_t size, int minimum_alignment) {
// memory aligned to at least the size of a pointer.
const int required_alignment = sizeof(void*);
if (minimum_alignment < required_alignment) return Malloc(size);
-#ifdef TENSORFLOW_USE_JEMALLOC
- int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
-#else
int err = posix_memalign(&ptr, minimum_alignment, size);
-#endif
if (err != 0) {
return nullptr;
} else {
@@ -118,29 +108,11 @@ void* AlignedMalloc(size_t size, int minimum_alignment) {
void AlignedFree(void* aligned_memory) { Free(aligned_memory); }
-void* Malloc(size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_malloc(size);
-#else
- return malloc(size);
-#endif
-}
+void* Malloc(size_t size) { return malloc(size); }
-void* Realloc(void* ptr, size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_realloc(ptr, size);
-#else
- return realloc(ptr, size);
-#endif
-}
+void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); }
-void Free(void* ptr) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- jemalloc_free(ptr);
-#else
- free(ptr);
-#endif
-}
+void Free(void* ptr) { free(ptr); }
void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
return AlignedMalloc(size, minimum_alignment);
@@ -148,9 +120,7 @@ void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
void NUMAFree(void* ptr, size_t size) { Free(ptr); }
-int NUMAGetMemAffinity(const void* addr) {
- return kNUMANoAffinity;
-}
+int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; }
void MallocExtension_ReleaseToSystem(std::size_t num_bytes) {
// No-op.
@@ -194,11 +164,7 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output) {
string Demangle(const char* mangled) { return mangled; }
double NominalCPUFrequency() {
-#ifdef TENSORFLOW_USE_ABSL
return absl::base_internal::NominalCPUFrequency();
-#else
- return 1.0;
-#endif
}
int64 AvailableRam() {
diff --git a/tensorflow/core/platform/posix/posix_file_system.cc b/tensorflow/core/platform/posix/posix_file_system.cc
index 47bfa020ce..c7afab9583 100644
--- a/tensorflow/core/platform/posix/posix_file_system.cc
+++ b/tensorflow/core/platform/posix/posix_file_system.cc
@@ -91,7 +91,7 @@ class PosixWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
size_t r = fwrite(data.data(), 1, data.size(), file_);
if (r != data.size()) {
return IOError(filename_, errno);
diff --git a/tensorflow/core/platform/posix/posix_file_system.h b/tensorflow/core/platform/posix/posix_file_system.h
index e8898d0a97..752eccea66 100644
--- a/tensorflow/core/platform/posix/posix_file_system.h
+++ b/tensorflow/core/platform/posix/posix_file_system.h
@@ -70,7 +70,7 @@ class LocalPosixFileSystem : public PosixFileSystem {
string TranslateName(const string& name) const override {
StringPiece scheme, host, path;
io::ParseURI(name, &scheme, &host, &path);
- return path.ToString();
+ return string(path);
}
};
diff --git a/tensorflow/core/platform/posix/subprocess.h b/tensorflow/core/platform/posix/subprocess.h
index 53f95f3c14..9740d75595 100644
--- a/tensorflow/core/platform/posix/subprocess.h
+++ b/tensorflow/core/platform/posix/subprocess.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_SUBPROCESS_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_SUBPROCESS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_POSIX_SUBPROCESS_H_
+#define TENSORFLOW_CORE_PLATFORM_POSIX_SUBPROCESS_H_
#include <errno.h>
#include <unistd.h>
@@ -128,4 +128,4 @@ class SubProcess {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_DEFAULT_SUBPROCESS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_POSIX_SUBPROCESS_H_
diff --git a/tensorflow/core/platform/prefetch.h b/tensorflow/core/platform/prefetch.h
index 81e1a5210a..9cefab3c1b 100644
--- a/tensorflow/core/platform/prefetch.h
+++ b/tensorflow/core/platform/prefetch.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PREFETCH_H_
-#define TENSORFLOW_PLATFORM_PREFETCH_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PREFETCH_H_
+#define TENSORFLOW_CORE_PLATFORM_PREFETCH_H_
#include "tensorflow/core/platform/platform.h"
@@ -56,4 +56,4 @@ inline void prefetch(const void* x) {
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PREFETCH_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PREFETCH_H_
diff --git a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h
index ce2069b004..2d94736c97 100644
--- a/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h
+++ b/tensorflow/core/platform/profile_utils/android_armv7a_cpu_utils_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROFILEUTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H__
-#define TENSORFLOW_PLATFORM_PROFILEUTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H__
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_
#include <sys/types.h>
@@ -64,4 +64,4 @@ class AndroidArmV7ACpuUtilsHelper : public ICpuUtilsHelper {
#endif // defined(__ANDROID__) && (__ANDROID_API__ >= 21) &&
// (defined(__ARM_ARCH_7A__) || defined(__aarch64__))
-#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H__
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_ANDROID_ARMV7A_CPU_UTILS_HELPER_H_
diff --git a/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h b/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h
index de4eec28e3..e25456374c 100644
--- a/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h
+++ b/tensorflow/core/platform/profile_utils/clock_cycle_profiler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
-#define TENSORFLOW_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
#include <algorithm>
@@ -103,4 +103,4 @@ class ClockCycleProfiler {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CLOCK_CYCLE_PROFILER_H_
diff --git a/tensorflow/core/platform/profile_utils/cpu_utils.h b/tensorflow/core/platform/profile_utils/cpu_utils.h
index 8f06290303..b0b1ef0363 100644
--- a/tensorflow/core/platform/profile_utils/cpu_utils.h
+++ b/tensorflow/core/platform/profile_utils/cpu_utils.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// This class is designed to get accurate profile for programs.
-#ifndef TENSORFLOW_PLATFORM_PROFILEUTILS_CPU_UTILS_H__
-#define TENSORFLOW_PLATFORM_PROFILEUTILS_CPU_UTILS_H__
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_
#include <chrono>
#include <memory>
@@ -164,4 +164,4 @@ class CpuUtils {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_CPU_UTILS_H__
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_CPU_UTILS_H_
diff --git a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h
index 11b739c009..cab7618a70 100644
--- a/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h
+++ b/tensorflow/core/platform/profile_utils/i_cpu_utils_helper.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__
-#define TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__
+#ifndef TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_
+#define TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -50,4 +50,4 @@ class ICpuUtilsHelper {
} // namespace profile_utils
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROFILEUTILS_I_CPU_UTILS_HELPER_H__
+#endif // TENSORFLOW_CORE_PLATFORM_PROFILE_UTILS_I_CPU_UTILS_HELPER_H_
diff --git a/tensorflow/core/platform/protobuf.h b/tensorflow/core/platform/protobuf.h
index 288d091624..fcbf1fc8c5 100644
--- a/tensorflow/core/platform/protobuf.h
+++ b/tensorflow/core/platform/protobuf.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROTOBUF_H_
-#define TENSORFLOW_PLATFORM_PROTOBUF_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
+#define TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
#include "tensorflow/core/platform/platform.h"
#include "tensorflow/core/platform/types.h"
@@ -52,4 +52,4 @@ inline void SetProtobufStringSwapAllowed(string* src, string* dest) {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROTOBUF_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_H_
diff --git a/tensorflow/core/platform/protobuf_compiler.h b/tensorflow/core/platform/protobuf_compiler.h
new file mode 100644
index 0000000000..29679e0089
--- /dev/null
+++ b/tensorflow/core/platform/protobuf_compiler.h
@@ -0,0 +1,25 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PLATFORM_PROTOBUF_COMPILER_H_
+#define TENSORFLOW_PLATFORM_PROTOBUF_COMPILER_H_
+
+#if defined(PLATFORM_GOOGLE) && !defined(USE_DEFAULT_PROTOBUF)
+#include "tensorflow/core/platform/google/protobuf_compiler.h"
+#else
+#include "tensorflow/core/platform/default/protobuf_compiler.h"
+#endif
+
+#endif // TENSORFLOW_PLATFORM_PROTOBUF_COMPILER_H_
diff --git a/tensorflow/core/platform/protobuf_internal.h b/tensorflow/core/platform/protobuf_internal.h
index 2f151a5aee..d0cfde09bc 100644
--- a/tensorflow/core/platform/protobuf_internal.h
+++ b/tensorflow/core/platform/protobuf_internal.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_PROTOBUF_INTERNAL_H_
-#define TENSORFLOW_PLATFORM_PROTOBUF_INTERNAL_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
+#define TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
#include "google/protobuf/any.pb.h"
#include "tensorflow/core/lib/core/errors.h"
@@ -69,4 +69,4 @@ Status ParseAny(const google::protobuf::Any& any, T* message,
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_PROTOBUF_INTERNAL_H_
+#endif // TENSORFLOW_CORE_PLATFORM_PROTOBUF_INTERNAL_H_
diff --git a/tensorflow/core/platform/s3/s3_file_system.cc b/tensorflow/core/platform/s3/s3_file_system.cc
index 462113f9bb..e0b8e37745 100644
--- a/tensorflow/core/platform/s3/s3_file_system.cc
+++ b/tensorflow/core/platform/s3/s3_file_system.cc
@@ -150,13 +150,13 @@ Status ParseS3Path(const string& fname, bool empty_object_ok, string* bucket,
return errors::InvalidArgument("S3 path doesn't start with 's3://': ",
fname);
}
- *bucket = bucketp.ToString();
+ *bucket = string(bucketp);
if (bucket->empty() || *bucket == ".") {
return errors::InvalidArgument("S3 path doesn't contain a bucket name: ",
fname);
}
str_util::ConsumePrefix(&objectp, "/");
- *object = objectp.ToString();
+ *object = string(objectp);
if (!empty_object_ok && object->empty()) {
return errors::InvalidArgument("S3 path doesn't contain an object name: ",
fname);
@@ -211,7 +211,7 @@ class S3WritableFile : public WritableFile {
std::ios_base::binary | std::ios_base::trunc | std::ios_base::in |
std::ios_base::out)) {}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
if (!outfile_) {
return errors::FailedPrecondition(
"The internal temporary file is not writable.");
diff --git a/tensorflow/core/platform/setround.h b/tensorflow/core/platform/setround.h
index d076e7acc6..ded00b23b1 100644
--- a/tensorflow/core/platform/setround.h
+++ b/tensorflow/core/platform/setround.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_SETROUND_H_
-#define TENSORFLOW_PLATFORM_SETROUND_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_SETROUND_H_
+#define TENSORFLOW_CORE_PLATFORM_SETROUND_H_
#include <cfenv>
@@ -42,4 +42,4 @@ class ScopedSetRound {
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_SETROUND_H_
+#endif // TENSORFLOW_CORE_PLATFORM_SETROUND_H_
diff --git a/tensorflow/core/platform/snappy.h b/tensorflow/core/platform/snappy.h
index 62c208ffb4..5477b097ef 100644
--- a/tensorflow/core/platform/snappy.h
+++ b/tensorflow/core/platform/snappy.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_SNAPPY_H_
-#define TENSORFLOW_PLATFORM_SNAPPY_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_SNAPPY_H_
+#define TENSORFLOW_CORE_PLATFORM_SNAPPY_H_
#include "tensorflow/core/platform/types.h"
@@ -31,4 +31,4 @@ bool Snappy_Uncompress(const char* input, size_t length, char* output);
} // namespace port
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_SNAPPY_H_
+#endif // TENSORFLOW_CORE_PLATFORM_SNAPPY_H_
diff --git a/tensorflow/core/platform/stacktrace_handler.h b/tensorflow/core/platform/stacktrace_handler.h
index a52970fdaa..9f118b91b8 100644
--- a/tensorflow/core/platform/stacktrace_handler.h
+++ b/tensorflow/core/platform/stacktrace_handler.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
-#define TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_
+#define TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_
namespace tensorflow {
namespace testing {
@@ -25,4 +25,4 @@ void InstallStacktraceHandler();
} // namespace testing
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PLATFORM_BACKTRACE_H_
+#endif // TENSORFLOW_CORE_PLATFORM_STACKTRACE_HANDLER_H_
diff --git a/tensorflow/core/platform/subprocess.h b/tensorflow/core/platform/subprocess.h
index dcc0c1a4ee..7c11e6232f 100644
--- a/tensorflow/core/platform/subprocess.h
+++ b/tensorflow/core/platform/subprocess.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_SUBPROCESS_H_
-#define TENSORFLOW_PLATFORM_SUBPROCESS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_
+#define TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_
#include <memory>
#include <vector>
@@ -67,4 +67,4 @@ std::unique_ptr<SubProcess> CreateSubProcess(const std::vector<string>& argv);
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_SUBPROCESS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_SUBPROCESS_H_
diff --git a/tensorflow/core/platform/test.h b/tensorflow/core/platform/test.h
index 99bae63edf..f5d3282f57 100644
--- a/tensorflow/core/platform/test.h
+++ b/tensorflow/core/platform/test.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_TEST_H_
-#define TENSORFLOW_PLATFORM_TEST_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TEST_H_
+#define TENSORFLOW_CORE_PLATFORM_TEST_H_
#include <memory>
#include <vector>
@@ -55,4 +55,4 @@ int PickUnusedPortOrDie();
} // namespace testing
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_TEST_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TEST_H_
diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h
index 9b8726d98f..61fcd0d372 100644
--- a/tensorflow/core/platform/test_benchmark.h
+++ b/tensorflow/core/platform/test_benchmark.h
@@ -14,8 +14,8 @@ limitations under the License.
==============================================================================*/
// Simple benchmarking facility.
-#ifndef TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
-#define TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
+#define TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
#include <utility>
#include <vector>
@@ -115,4 +115,4 @@ void UseRealTime();
} // namespace testing
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TEST_BENCHMARK_H_
diff --git a/tensorflow/core/platform/thread_annotations.h b/tensorflow/core/platform/thread_annotations.h
index 50195cbbc7..aec34df8a1 100644
--- a/tensorflow/core/platform/thread_annotations.h
+++ b/tensorflow/core/platform/thread_annotations.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
-#define TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_
+#define TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_
#include "tensorflow/core/platform/types.h"
@@ -27,4 +27,4 @@ limitations under the License.
#error Define the appropriate PLATFORM_<foo> macro for this platform
#endif
-#endif // TENSORFLOW_PLATFORM_THREAD_ANNOTATIONS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_THREAD_ANNOTATIONS_H_
diff --git a/tensorflow/core/platform/tracing.h b/tensorflow/core/platform/tracing.h
index c322777705..aefbe64425 100644
--- a/tensorflow/core/platform/tracing.h
+++ b/tensorflow/core/platform/tracing.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_TRACING_H_
-#define TENSORFLOW_PLATFORM_TRACING_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TRACING_H_
+#define TENSORFLOW_CORE_PLATFORM_TRACING_H_
// Tracing interface
@@ -155,6 +155,13 @@ class TraceCollector {
StringPiece name_part1, StringPiece name_part2,
bool is_expensive) const = 0;
+ // Returns true if this annotation tracing is enabled for any op.
+ virtual bool IsEnabledForAnnotations() const = 0;
+
+ // Returns true if this activity handle tracking is enabled for an op of the
+ // given expensiveness.
+ virtual bool IsEnabledForActivities(bool is_expensive) const = 0;
+
protected:
static string ConcatenateNames(StringPiece first, StringPiece second);
@@ -238,4 +245,4 @@ const char* GetLogDir();
#include "tensorflow/core/platform/default/tracing_impl.h"
#endif
-#endif // TENSORFLOW_PLATFORM_TRACING_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TRACING_H_
diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h
index 68897ac423..a4fa790317 100644
--- a/tensorflow/core/platform/types.h
+++ b/tensorflow/core/platform/types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_TYPES_H_
-#define TENSORFLOW_PLATFORM_TYPES_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_TYPES_H_
+#define TENSORFLOW_CORE_PLATFORM_TYPES_H_
#include <string>
#include "tensorflow/core/platform/platform.h"
@@ -66,4 +66,4 @@ namespace tensorflow {
namespace se = ::stream_executor;
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_TYPES_H_
+#endif // TENSORFLOW_CORE_PLATFORM_TYPES_H_
diff --git a/tensorflow/core/platform/windows/cpu_info.h b/tensorflow/core/platform/windows/cpu_info.h
index ba2126abcf..8b42cbec7a 100644
--- a/tensorflow/core/platform/windows/cpu_info.h
+++ b/tensorflow/core/platform/windows/cpu_info.h
@@ -13,10 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
-#define TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
// included so __cpuidex function is available for GETCPUID on Windows
#include <intrin.h>
-#endif // TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_CPU_INFO_H_
diff --git a/tensorflow/core/platform/windows/env.cc b/tensorflow/core/platform/windows/env.cc
index 68ee3595a2..f26ccd1662 100644
--- a/tensorflow/core/platform/windows/env.cc
+++ b/tensorflow/core/platform/windows/env.cc
@@ -160,6 +160,17 @@ class WindowsEnv : public Env {
return filename;
}
+ string GetRunfilesDir() override {
+ string bin_path = this->GetExecutablePath();
+ string runfiles_path = bin_path + ".runfiles\\org_tensorflow";
+ Status s = this->IsDirectory(runfiles_path);
+ if (!s.ok()) {
+ return runfiles_path;
+ } else {
+ return bin_path.substr(0, bin_path.find_last_of("/\\"));
+ }
+ }
+
private:
void GetLocalTempDirectories(std::vector<string>* list) override;
diff --git a/tensorflow/core/platform/windows/integral_types.h b/tensorflow/core/platform/windows/integral_types.h
index 46338a536d..283af49f20 100644
--- a/tensorflow/core/platform/windows/integral_types.h
+++ b/tensorflow/core/platform/windows/integral_types.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
-#define TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
#include "tensorflow/core/platform/default/integral_types.h"
@@ -22,4 +22,4 @@ limitations under the License.
typedef std::ptrdiff_t ssize_t;
-#endif // TENSORFLOW_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_INTEGRAL_TYPES_H_
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 5375f56372..911ea1902f 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef TENSORFLOW_USE_JEMALLOC
-#include "jemalloc/jemalloc.h"
-#endif
-
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -70,55 +66,16 @@ void NUMASetThreadNodeAffinity(int node) {}
int NUMAGetThreadNodeAffinity() { return kNUMANoAffinity; }
void* AlignedMalloc(size_t size, int minimum_alignment) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- void* ptr = NULL;
- // posix_memalign requires that the requested alignment be at least
- // sizeof(void*). In this case, fall back on malloc which should return
- // memory aligned to at least the size of a pointer.
- const int required_alignment = sizeof(void*);
- if (minimum_alignment < required_alignment) return Malloc(size);
- int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
- if (err != 0) {
- return NULL;
- } else {
- return ptr;
- }
-#else
return _aligned_malloc(size, minimum_alignment);
-#endif
}
-void AlignedFree(void* aligned_memory) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- jemalloc_free(aligned_memory);
-#else
- _aligned_free(aligned_memory);
-#endif
-}
+void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); }
-void* Malloc(size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_malloc(size);
-#else
- return malloc(size);
-#endif
-}
+void* Malloc(size_t size) { return malloc(size); }
-void* Realloc(void* ptr, size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_realloc(ptr, size);
-#else
- return realloc(ptr, size);
-#endif
-}
+void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); }
-void Free(void* ptr) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_free(ptr);
-#else
- return free(ptr);
-#endif
-}
+void Free(void* ptr) { return free(ptr); }
void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
return AlignedMalloc(size, minimum_alignment);
diff --git a/tensorflow/core/platform/windows/subprocess.h b/tensorflow/core/platform/windows/subprocess.h
index f00471d484..9084ff5a92 100644
--- a/tensorflow/core/platform/windows/subprocess.h
+++ b/tensorflow/core/platform/windows/subprocess.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PLATFORM_WINDOWS_SUBPROCESS_H_
-#define TENSORFLOW_PLATFORM_WINDOWS_SUBPROCESS_H_
+#ifndef TENSORFLOW_CORE_PLATFORM_WINDOWS_SUBPROCESS_H_
+#define TENSORFLOW_CORE_PLATFORM_WINDOWS_SUBPROCESS_H_
#include <memory>
#include <vector>
@@ -33,4 +33,4 @@ std::unique_ptr<SubProcess> CreateSubProcess(const std::vector<string>& argv) {
} // namespace tensorflow
-#endif // TENSORFLOW_PLATFORM_WINDOWS_SUBPROCESS_H_
+#endif // TENSORFLOW_CORE_PLATFORM_WINDOWS_SUBPROCESS_H_
diff --git a/tensorflow/core/platform/windows/windows_file_system.cc b/tensorflow/core/platform/windows/windows_file_system.cc
index 9079a5ccaa..6cf79634d7 100644
--- a/tensorflow/core/platform/windows/windows_file_system.cc
+++ b/tensorflow/core/platform/windows/windows_file_system.cc
@@ -150,7 +150,7 @@ class WindowsWritableFile : public WritableFile {
}
}
- Status Append(const StringPiece& data) override {
+ Status Append(StringPiece data) override {
DWORD bytes_written = 0;
DWORD data_size = static_cast<DWORD>(data.size());
BOOL write_result =
diff --git a/tensorflow/core/platform/windows/windows_file_system.h b/tensorflow/core/platform/windows/windows_file_system.h
index 6b04720c68..1f4c535f24 100644
--- a/tensorflow/core/platform/windows/windows_file_system.h
+++ b/tensorflow/core/platform/windows/windows_file_system.h
@@ -71,7 +71,7 @@ class LocalWinFileSystem : public WindowsFileSystem {
string TranslateName(const string& name) const override {
StringPiece scheme, host, path;
io::ParseURI(name, &scheme, &host, &path);
- return path.ToString();
+ return string(path);
}
};
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index af034bdd7d..2bf371276e 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -40,7 +40,6 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
index f5ac5c9c5a..0d1c92eb08 100644
--- a/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
+++ b/tensorflow/core/profiler/internal/advisor/expensive_operation_checker.h
@@ -137,4 +137,4 @@ class ExpensiveOperationChecker : public Checker {
} // namespace tfprof
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OP_CHECKER_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_EXPENSIVE_OPERATION_CHECKER_H_
diff --git a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
index 270662bd4a..e1533f882f 100644
--- a/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
+++ b/tensorflow/core/profiler/internal/advisor/tfprof_advisor.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
-#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_
+#define TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_
#include "tensorflow/core/profiler/internal/advisor/accelerator_utilization_checker.h"
#include "tensorflow/core/profiler/internal/advisor/checker.h"
@@ -78,4 +78,4 @@ class Advisor {
} // namespace tfprof
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVICE_H_
+#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_ADVISOR_TFPROF_ADVISOR_H_
diff --git a/tensorflow/core/profiler/internal/tfprof_code.cc b/tensorflow/core/profiler/internal/tfprof_code.cc
index 2c4f52e3ad..744e1e95de 100644
--- a/tensorflow/core/profiler/internal/tfprof_code.cc
+++ b/tensorflow/core/profiler/internal/tfprof_code.cc
@@ -37,7 +37,7 @@ const char* const kGradientSuffix = " (gradient)";
// Convert to Trace proto into a short readable string.
string GetTraceString(const CallStack::Trace& trace) {
- string ntrace = io::Basename(trace.file()).ToString();
+ string ntrace(io::Basename(trace.file()));
ntrace += strings::StrCat(":", trace.lineno());
if (trace.function().length() < 20) {
ntrace += ":" + trace.function();
@@ -113,7 +113,7 @@ class FunctionTable {
// function index should start from 1.
func_pb->set_id(function_table_.size());
- string file_base = io::Basename(file_path).ToString();
+ string file_base(io::Basename(file_path));
file_base = file_base.substr(0, file_base.find_last_of("."));
func_pb->set_name(
string_table_->GetIndex(strings::StrCat(file_base, ":", func_name)));
diff --git a/tensorflow/core/profiler/tfprof_options.h b/tensorflow/core/profiler/tfprof_options.h
index d61deb72ac..57c7e11fa2 100644
--- a/tensorflow/core/profiler/tfprof_options.h
+++ b/tensorflow/core/profiler/tfprof_options.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
-#define TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#ifndef TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_
+#define TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_
#include <set>
#include <string>
@@ -183,4 +183,4 @@ tensorflow::Status ParseOutput(const string& output_opt, string* output_type,
} // namespace tfprof
} // namespace tensorflow
-#endif // TENSORFLOW_CORE_PROFILER_INTERNAL_TFPROF_OPTIONS_H_
+#endif // TENSORFLOW_CORE_PROFILER_TFPROF_OPTIONS_H_
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index da3a99565e..104ab039cb 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -68,7 +68,7 @@ message GPUOptions {
// after the process starts. Users are required to use vendor
// specific mechanisms (e.g., CUDA_VISIBLE_DEVICES) to control the
// physical to visible device mapping prior to invoking TensorFlow.
- // 2. In the code, the ids in this list are also called "CUDA GPU id"s,
+ // 2. In the code, the ids in this list are also called "platform GPU id"s,
// and the 'virtual' ids of GPU devices (i.e. the ids in the device
// name "/device:GPU:<id>") are also called "TF GPU id"s. Please
// refer to third_party/tensorflow/core/common_runtime/gpu/gpu_id.h
@@ -390,9 +390,12 @@ message ConfigProto {
message Experimental {
// Task name for group resolution.
string collective_group_leader = 1;
- // Whether the client will format templated errors. For example, the string:
- // "The node was defined on ^^node:Foo:${file}:${line}^^".
- bool client_handles_error_formatting = 2;
+
+ // We removed the flag client_handles_error_formatting. Marking the tag
+ // number as reserved.
+ // TODO(shikharagarwal): Should we just remove this tag so that it can be
+ // used in future for other purpose?
+ reserved 2;
// Which executor to use, the default executor will be used
// if it is an empty string or "DEFAULT"
@@ -450,6 +453,11 @@ message RunOptions {
// same group_key value (in a distributed computation where tasks
// run disjoint graphs).
int64 collective_graph_key = 1;
+ // If true, then operations (using the inter-op pool) across all
+ // session::run() calls will be centrally scheduled, optimizing for (median
+ // and tail) latency.
+ // Consider using this option for CPU-bound workloads like inference.
+ bool use_run_handler_pool = 2;
};
Experimental experimental = 8;
diff --git a/tensorflow/core/protobuf/debug.proto b/tensorflow/core/protobuf/debug.proto
index 811cf406b9..8ca76c44c0 100644
--- a/tensorflow/core/protobuf/debug.proto
+++ b/tensorflow/core/protobuf/debug.proto
@@ -60,6 +60,12 @@ message DebugOptions {
// Note that this is distinct from the session run count and the executor
// step count.
int64 global_step = 10;
+
+ // Whether the total disk usage of tfdbg is to be reset to zero
+ // in this Session.run call. This is used by wrappers and hooks
+ // such as the local CLI ones to indicate that the dumped tensors
+ // are cleaned up from the disk after each Session.run.
+ bool reset_disk_byte_usage = 11;
}
message DebuggedSourceFile {
diff --git a/tensorflow/core/protobuf/replay_log.proto b/tensorflow/core/protobuf/replay_log.proto
new file mode 100644
index 0000000000..7644314fc9
--- /dev/null
+++ b/tensorflow/core/protobuf/replay_log.proto
@@ -0,0 +1,47 @@
+syntax = "proto3";
+
+option cc_enable_arenas = true;
+package tensorflow;
+
+import "tensorflow/core/framework/graph.proto";
+import "tensorflow/core/protobuf/cluster.proto";
+import "tensorflow/core/protobuf/master.proto";
+
+// Records the creation of a new replay session. We record the device listing
+// here to capture the state of the cluster.
+message NewReplaySession {
+ ListDevicesResponse devices = 1;
+ string session_handle = 2;
+}
+
+message ReplayOp {
+ double start_time_us = 31;
+ double end_time_us = 32;
+
+ oneof op {
+ CreateSessionRequest create_session = 1;
+ ExtendSessionRequest extend_session = 2;
+ PartialRunSetupRequest partial_run_setup = 3;
+ RunStepRequest run_step = 4;
+ CloseSessionRequest close_session = 5;
+ ListDevicesRequest list_devices = 6;
+ ResetRequest reset_request = 7;
+ MakeCallableRequest make_callable = 8;
+ RunCallableRequest run_callable = 9;
+ ReleaseCallableRequest release_callable = 10;
+ NewReplaySession new_replay_session = 11;
+ }
+
+ oneof response {
+ CreateSessionResponse create_session_response = 21;
+ ExtendSessionResponse extend_session_response = 22;
+ PartialRunSetupResponse partial_run_setup_response = 23;
+ RunStepResponse run_step_response = 24;
+ CloseSessionResponse close_session_response = 25;
+ ListDevicesResponse list_devices_response = 26;
+ ResetResponse reset_request_response = 27;
+ MakeCallableResponse make_callable_response = 28;
+ RunCallableResponse run_callable_response = 29;
+ ReleaseCallableResponse release_callable_response = 30;
+ }
+}
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 07f984ceea..8c31468ff5 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -75,6 +75,10 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
+ // Force small ops onto the CPU (default is OFF).
+ Toggle pin_to_host_optimization = 18;
+ // Disable the entire meta optimizer (off by default).
+ bool disable_meta_optimizer = 19;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
@@ -141,8 +145,8 @@ message RewriterConfig {
// not configurable (in contrast to memory optimization passes through the
// meta-optimizer) and act only on manual op annotations.
//
- // Custom registered optimizers will be run after the base optimizers, in
- // the order that they are specified.
+ // Custom optimizers (see custom_optimizers) that are not part of this
+ // schedule will be run after - in the order that they were specified.
repeated string optimizers = 100;
// Message to describe custom graph optimizer and its parameters
diff --git a/tensorflow/core/public/session.h b/tensorflow/core/public/session.h
index cc8596ef3d..536a07c413 100644
--- a/tensorflow/core/public/session.h
+++ b/tensorflow/core/public/session.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_PUBLIC_SESSION_H_
-#define TENSORFLOW_PUBLIC_SESSION_H_
+#ifndef TENSORFLOW_CORE_PUBLIC_SESSION_H_
+#define TENSORFLOW_CORE_PUBLIC_SESSION_H_
#include <string>
#include <vector>
@@ -279,4 +279,4 @@ Session* NewSession(const SessionOptions& options);
} // end namespace tensorflow
-#endif // TENSORFLOW_PUBLIC_SESSION_H_
+#endif // TENSORFLOW_CORE_PUBLIC_SESSION_H_
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index 563564119f..b043a69431 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -19,12 +19,12 @@ limitations under the License.
// TensorFlow uses semantic versioning, see http://semver.org/.
#define TF_MAJOR_VERSION 1
-#define TF_MINOR_VERSION 10
+#define TF_MINOR_VERSION 11
#define TF_PATCH_VERSION 0
// TF_VERSION_SUFFIX is non-empty for pre-releases (e.g. "-alpha", "-alpha.1",
// "-beta", "-rc", "-rc.1")
-#define TF_VERSION_SUFFIX ""
+#define TF_VERSION_SUFFIX "-rc1"
#define TF_STR_HELPER(x) #x
#define TF_STR(x) TF_STR_HELPER(x)
@@ -96,10 +96,12 @@ limitations under the License.
// GraphDef. (7dec2017)
// 27. Deprecate TensorArray ops v2 in favor of v3 and deprecated io_ops
// deprecated in favor of V2 ops. (2018/01/23)
+// 28. Deprecate MatrixExponential op in favor of Python implementation.
+// (2018/08/21).
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 26
+#define TF_GRAPH_DEF_VERSION 27
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/core/util/activation_mode.h b/tensorflow/core/util/activation_mode.h
index 2e03ccd5c8..2f7820fb47 100644
--- a/tensorflow/core/util/activation_mode.h
+++ b/tensorflow/core/util/activation_mode.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_ACTIVATION_MODE_H_
-#define TENSORFLOW_UTIL_ACTIVATION_MODE_H_
+#ifndef TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_
+#define TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_
// This file contains helper routines to deal with activation mode in various
// ops and kernels.
@@ -43,4 +43,4 @@ Status GetActivationModeFromString(const string& str_value,
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_ACTIVATION_MODE_H_
+#endif // TENSORFLOW_CORE_UTIL_ACTIVATION_MODE_H_
diff --git a/tensorflow/core/util/bcast.h b/tensorflow/core/util/bcast.h
index 81d64e5676..6d73c38e3c 100644
--- a/tensorflow/core/util/bcast.h
+++ b/tensorflow/core/util/bcast.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_BCAST_H_
-#define TENSORFLOW_UTIL_BCAST_H_
+#ifndef TENSORFLOW_CORE_UTIL_BCAST_H_
+#define TENSORFLOW_CORE_UTIL_BCAST_H_
#include <algorithm>
@@ -132,4 +132,4 @@ class BCast {
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_BCAST_H_
+#endif // TENSORFLOW_CORE_UTIL_BCAST_H_
diff --git a/tensorflow/core/util/command_line_flags.cc b/tensorflow/core/util/command_line_flags.cc
index b281acb2b0..55f1e30880 100644
--- a/tensorflow/core/util/command_line_flags.cc
+++ b/tensorflow/core/util/command_line_flags.cc
@@ -32,7 +32,7 @@ bool ParseStringFlag(tensorflow::StringPiece arg, tensorflow::StringPiece flag,
if (str_util::ConsumePrefix(&arg, "--") &&
str_util::ConsumePrefix(&arg, flag) &&
str_util::ConsumePrefix(&arg, "=")) {
- *value_parsing_ok = hook(std::string(arg));
+ *value_parsing_ok = hook(string(arg));
return true;
}
diff --git a/tensorflow/core/util/ctc/ctc_beam_entry.h b/tensorflow/core/util/ctc/ctc_beam_entry.h
index 973e315f09..24002e72a0 100644
--- a/tensorflow/core/util/ctc/ctc_beam_entry.h
+++ b/tensorflow/core/util/ctc/ctc_beam_entry.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_ENTRY_H_
diff --git a/tensorflow/core/util/ctc/ctc_beam_scorer.h b/tensorflow/core/util/ctc/ctc_beam_scorer.h
index 1a622babe1..1e45a8abd3 100644
--- a/tensorflow/core/util/ctc/ctc_beam_scorer.h
+++ b/tensorflow/core/util/ctc/ctc_beam_scorer.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
// Collection of scoring classes that can be extended and provided to the
// CTCBeamSearchDecoder to incorporate additional scoring logic (such as a
diff --git a/tensorflow/core/util/ctc/ctc_beam_search.h b/tensorflow/core/util/ctc/ctc_beam_search.h
index aee647a1b3..6fbb1ed0da 100644
--- a/tensorflow/core/util/ctc/ctc_beam_search.h
+++ b/tensorflow/core/util/ctc/ctc_beam_search.h
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_BEAM_SEARCH_H_
@@ -259,6 +260,16 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
} else {
max_coeff = raw_input.maxCoeff();
}
+
+ // Get normalization term of softmax: log(sum(exp(logit[j]-max_coeff))).
+ float logsumexp = 0.0;
+ for (int j = 0; j < raw_input.size(); ++j) {
+ logsumexp += Eigen::numext::exp(raw_input(j) - max_coeff);
+ }
+ logsumexp = Eigen::numext::log(logsumexp);
+ // Final normalization offset to get correct log probabilities.
+ float norm_offset = max_coeff + logsumexp;
+
const float label_selection_input_min =
(label_selection_margin_ >= 0) ? (max_coeff - label_selection_margin_)
: -std::numeric_limits<float>::infinity();
@@ -290,10 +301,10 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
beam_scorer_->GetStateExpansionScore(b->state, previous));
}
// Plabel(l=abc @ t=6) *= P(c @ 6)
- b->newp.label += raw_input(b->label) - max_coeff;
+ b->newp.label += raw_input(b->label) - norm_offset;
}
// Pblank(l=abc @ t=6) = P(l=abc @ t=5) * P(- @ 6)
- b->newp.blank = b->oldp.total + raw_input(blank_index_) - max_coeff;
+ b->newp.blank = b->oldp.total + raw_input(blank_index_) - norm_offset;
// P(l=abc @ t=6) = Plabel(l=abc @ t=6) + Pblank(l=abc @ t=6)
b->newp.total = LogSumExp(b->newp.blank, b->newp.label);
@@ -328,6 +339,8 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
const float logit = top_k ? top_k_logits[ind] : raw_input(ind);
// Perform label selection: if input for this label looks very
// unpromising, never evaluate it with a scorer.
+ // We may compare logits instead of log probabilities,
+ // since the difference is the same in both cases.
if (logit < label_selection_input_min) {
continue;
}
@@ -341,7 +354,7 @@ void CTCBeamSearchDecoder<CTCBeamState, CTCBeamComparer>::Step(
// Plabel(l=abcd @ t=6) = P(l=abc @ t=5) * P(d @ 6)
beam_scorer_->ExpandState(b->state, b->label, &c.state, c.label);
float previous = (c.label == b->label) ? b->oldp.blank : b->oldp.total;
- c.newp.label = logit - max_coeff +
+ c.newp.label = logit - norm_offset +
beam_scorer_->GetStateExpansionScore(c.state, previous);
// P(l=abcd @ t=6) = Plabel(l=abcd @ t=6)
c.newp.total = c.newp.label;
diff --git a/tensorflow/core/util/ctc/ctc_decoder.h b/tensorflow/core/util/ctc/ctc_decoder.h
index 3be36822e5..b55d7d77ac 100644
--- a/tensorflow/core/util/ctc/ctc_decoder.h
+++ b/tensorflow/core/util/ctc/ctc_decoder.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_DECODER_H_
diff --git a/tensorflow/core/util/ctc/ctc_loss_util.h b/tensorflow/core/util/ctc/ctc_loss_util.h
index 36be9e92ef..054412d388 100644
--- a/tensorflow/core/util/ctc/ctc_loss_util.h
+++ b/tensorflow/core/util/ctc/ctc_loss_util.h
@@ -1,4 +1,3 @@
-// LINT.IfChange
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -13,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// LINT.IfChange
#ifndef TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
#define TENSORFLOW_CORE_UTIL_CTC_CTC_LOSS_UTIL_H_
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index 540adb58d4..f6f0408ccc 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -93,11 +93,11 @@ __device__ EIGEN_ALWAYS_INLINE Eigen::half CudaShuffleXorSync(
}
namespace cuda_helper {
-template <typename IntType>
-__device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
- IntType* orig = first;
- IntType* it = nullptr;
- IntType step = 0;
+template <typename T, typename OutType = int32>
+__device__ OutType upper_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
while (count > 0) {
it = first;
step = count / 2;
@@ -112,6 +112,27 @@ __device__ IntType upper_bound(IntType* first, IntType count, IntType val) {
return first - orig;
}
+
+template <typename T, typename OutType = int32>
+__device__ OutType lower_bound(const T* first, OutType count, T val) {
+ const T* orig = first;
+ const T* it = nullptr;
+ OutType step = 0;
+ while (count > 0) {
+ it = first;
+ step = count / 2;
+ it += step;
+ if (*it < val) {
+ first = ++it;
+ count -= step + 1;
+ } else {
+ count = step;
+ }
+ }
+
+ return first - orig;
+}
+
} // namespace cuda_helper
} // namespace tensorflow
diff --git a/tensorflow/core/util/device_name_utils.h b/tensorflow/core/util/device_name_utils.h
index 4071a70836..3f0bc60562 100644
--- a/tensorflow/core/util/device_name_utils.h
+++ b/tensorflow/core/util/device_name_utils.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
-#define TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+#ifndef TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_
+#define TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_
#include <string>
@@ -173,4 +173,4 @@ class DeviceNameUtils {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_DEVICE_NAME_UTILS_H_
+#endif // TENSORFLOW_CORE_UTIL_DEVICE_NAME_UTILS_H_
diff --git a/tensorflow/core/util/env_var.cc b/tensorflow/core/util/env_var.cc
index 8d43bcc927..2604a5d66a 100644
--- a/tensorflow/core/util/env_var.cc
+++ b/tensorflow/core/util/env_var.cc
@@ -28,7 +28,7 @@ namespace tensorflow {
Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val,
bool* value) {
*value = default_val;
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val == nullptr) {
return Status::OK();
}
@@ -48,7 +48,7 @@ Status ReadBoolFromEnvVar(StringPiece env_var_name, bool default_val,
Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val,
int64* value) {
*value = default_val;
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val == nullptr) {
return Status::OK();
}
@@ -62,11 +62,11 @@ Status ReadInt64FromEnvVar(StringPiece env_var_name, int64 default_val,
Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
string* value) {
- const char* tf_env_var_val = getenv(std::string(env_var_name).c_str());
+ const char* tf_env_var_val = getenv(string(env_var_name).c_str());
if (tf_env_var_val != nullptr) {
*value = tf_env_var_val;
} else {
- *value = std::string(default_val);
+ *value = string(default_val);
}
return Status::OK();
}
diff --git a/tensorflow/core/util/env_var.h b/tensorflow/core/util/env_var.h
index 47f9ff3a3b..724ca35729 100644
--- a/tensorflow/core/util/env_var.h
+++ b/tensorflow/core/util/env_var.h
@@ -13,7 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_ENV_VAR_H_
+#ifndef TENSORFLOW_CORE_UTIL_ENV_VAR_H_
+#define TENSORFLOW_CORE_UTIL_ENV_VAR_H_
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -42,4 +43,4 @@ Status ReadStringFromEnvVar(StringPiece env_var_name, StringPiece default_val,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_ENV_VAR_H_
+#endif // TENSORFLOW_CORE_UTIL_ENV_VAR_H_
diff --git a/tensorflow/core/util/events_writer.h b/tensorflow/core/util/events_writer.h
index 5dbaf97af4..d5952c3cbd 100644
--- a/tensorflow/core/util/events_writer.h
+++ b/tensorflow/core/util/events_writer.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_EVENTS_WRITER_H_
-#define TENSORFLOW_UTIL_EVENTS_WRITER_H_
+#ifndef TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_
+#define TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_
#include <memory>
#include <string>
@@ -95,4 +95,4 @@ class EventsWriter {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_EVENTS_WRITER_H_
+#endif // TENSORFLOW_CORE_UTIL_EVENTS_WRITER_H_
diff --git a/tensorflow/core/util/example_proto_fast_parsing.cc b/tensorflow/core/util/example_proto_fast_parsing.cc
index 1fec0010a1..e52d55e2ff 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing.cc
@@ -353,7 +353,7 @@ bool TestFastParse(const string& serialized, Example* example) {
// I.e. last entry in the map overwrites all the previous ones.
parsed::FeatureMapEntry& name_and_feature =
parsed_example[parsed_example_size - i - 1];
- string name = std::string(name_and_feature.first);
+ string name(name_and_feature.first);
if ((*features.mutable_feature()).count(name) > 0) continue;
auto& value = (*features.mutable_feature())[name];
@@ -1722,10 +1722,11 @@ Status FastParseSequenceExample(
const FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, Result* context_result,
- Result* feature_list_result) {
+ Result* feature_list_result, std::vector<Tensor>* dense_feature_lengths) {
int num_examples = serialized.size();
DCHECK(context_result != nullptr);
DCHECK(feature_list_result != nullptr);
+ DCHECK(dense_feature_lengths != nullptr);
std::map<StringPiece, bool> context_is_sparse;
std::map<StringPiece, std::pair<DataType, size_t>>
context_feature_type_and_lengths;
@@ -1740,9 +1741,22 @@ Status FastParseSequenceExample(
context_is_sparse[c.feature_name] = true;
}
for (auto& c : context_config.dense) {
+ if (context_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Context feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
context_feature_type_and_lengths[c.feature_name] =
- std::make_pair(c.dtype, 0);
+ std::make_pair(c.dtype, c.default_value.NumElements());
+ if (c.default_value.NumElements() > 0) {
+ if (!c.shape.IsCompatibleWith(c.default_value.shape())) {
+ return errors::InvalidArgument("Default value for context feature ",
+ c.feature_name,
+ " has an incorrect shape: saw ",
+ c.default_value.shape().DebugString(),
+ " but expected ", c.shape.DebugString());
+ }
+ }
context_is_sparse[c.feature_name] = false;
}
std::map<StringPiece, bool> sequence_is_sparse;
@@ -1755,6 +1769,10 @@ Status FastParseSequenceExample(
sequence_is_sparse[c.feature_name] = true;
}
for (auto& c : feature_list_config.dense) {
+ if (sequence_is_sparse[c.feature_name]) {
+ return errors::InvalidArgument("Sequence feature " + c.feature_name +
+ " cannot be both dense and sparse");
+ }
TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype));
sequence_feature_type_and_lengths[c.feature_name] =
std::make_pair(c.dtype, 0);
@@ -1792,14 +1810,14 @@ Status FastParseSequenceExample(
features = sequence_features;
config = &sequence_feature_type_and_lengths;
} else if (!SkipExtraneousTag(&stream)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
if (features != nullptr) {
uint32 length;
if (!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
while (!stream.ExpectAtEnd()) {
@@ -1807,16 +1825,16 @@ Status FastParseSequenceExample(
uint32 length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&length)) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
auto limit = stream.PushLimit(length);
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!ParseString(&stream, &key) ||
!stream.ExpectTag(kDelimitedTag(2)) ||
!ParseString(&stream, &value) || !stream.ExpectAtEnd()) {
- return errors::InvalidArgument(strings::StrCat(
- "Invalid protocol message input, example id: ", example_name));
+ return errors::InvalidArgument(
+ "Invalid protocol message input, example id: ", example_name);
}
stream.PopLimit(limit);
// Only save if this feature was requested.
@@ -1851,9 +1869,8 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in context feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in context feature ", c.first,
+ " in example ", example_name);
}
num_elements += num;
}
@@ -1876,9 +1893,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -1898,22 +1915,22 @@ Status FastParseSequenceExample(
break;
}
if (num == -1) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
num_elements += num;
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.first,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.first, " in example ",
+ example_name);
}
}
}
@@ -1936,15 +1953,19 @@ Status FastParseSequenceExample(
feature_list_result->sparse_indices.resize(feature_list_config.sparse.size());
feature_list_result->sparse_shapes.resize(feature_list_config.sparse.size());
feature_list_result->dense_values.resize(feature_list_config.dense.size());
+ dense_feature_lengths->resize(feature_list_config.dense.size());
+
int t = 0;
for (const auto& c : context_config.dense) {
- TensorShape dense_shape;
+ TensorShape dense_shape, example_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
context_feature_type_and_lengths[c.feature_name].second;
- if (expected_max_elements != dense_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Inconsistent number of elements for feature ", c.feature_name));
+ if (!c.shape.AsTensorShape(&example_shape) ||
+ expected_max_elements != example_shape.num_elements()) {
+ return errors::InvalidArgument(
+ "Inconsistent number of elements for feature ", c.feature_name, ": ",
+ expected_max_elements, " vs ", dense_shape.num_elements());
}
dense_shape.AddDim(num_examples);
for (const int dim : c.shape.dim_sizes()) {
@@ -1968,18 +1989,58 @@ Status FastParseSequenceExample(
out_int64 = context_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
size_t num_elements = 0;
- const auto& feature = all_context_features[e][c.feature_name];
+ const auto feature_iter = all_context_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_context_features[e].end()) {
+ // Copy the default value, if present. If not, return an error.
+ if (c.default_value.NumElements() == 0) {
+ return errors::InvalidArgument(
+ "Feature: ", c.feature_name,
+ " (data type: ", DataTypeString(c.dtype), ")",
+ " is required but could not be found.");
+ }
+ const string* in_bytes = nullptr;
+ const float* in_float = nullptr;
+ const int64* in_int64 = nullptr;
+ size_t num = 0;
+ switch (dtype) {
+ case DT_STRING:
+ in_bytes = c.default_value.flat<string>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_bytes++ = *in_bytes++;
+ }
+ break;
+ case DT_FLOAT:
+ in_float = c.default_value.flat<float>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_float++ = *in_float++;
+ }
+ break;
+ case DT_INT64:
+ in_int64 = c.default_value.flat<int64>().data();
+ num = c.default_value.NumElements();
+ for (int p = 0; p < num; p++) {
+ *out_int64++ = *in_int64++;
+ }
+ break;
+ default:
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
+ }
+ num_elements += num;
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -1998,14 +2059,14 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
}
if (num_elements != expected_max_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in example ", example_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in example ", example_name);
}
}
}
@@ -2037,8 +2098,8 @@ Status FastParseSequenceExample(
out_int64 = context_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices = context_result->sparse_indices[t].flat<int64>().data();
auto out_shape = context_result->sparse_shapes[t].vec<int64>();
@@ -2070,8 +2131,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2082,30 +2143,35 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected total number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected total number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_cols;
}
t = 0;
+ TensorShape dense_length_shape({num_examples});
for (const auto& c : feature_list_config.dense) {
TensorShape dense_shape, row_shape;
DataType dtype = c.dtype;
- size_t expected_max_elements =
+ const size_t expected_max_elements =
sequence_feature_type_and_lengths[c.feature_name].second;
- int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
if (!c.shape.AsTensorShape(&row_shape) ||
- expected_max_elements != expected_max_rows * row_shape.num_elements()) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected shape error in feature ", c.feature_name));
+ expected_max_elements !=
+ (expected_max_elements / row_shape.num_elements()) *
+ row_shape.num_elements()) {
+ return errors::InvalidArgument("Unexpected shape error in feature ",
+ c.feature_name);
}
+ int64 expected_max_rows = expected_max_elements / row_shape.num_elements();
dense_shape.AddDim(num_examples);
dense_shape.AddDim(expected_max_rows);
for (const int dim : feature_list_config.dense[t].shape.dim_sizes()) {
dense_shape.AddDim(dim);
}
feature_list_result->dense_values[t] = Tensor(dtype, dense_shape);
+ (*dense_feature_lengths)[t] = Tensor(DT_INT64, dense_length_shape);
+ int64* out_lengths = (*dense_feature_lengths)[t].flat<int64>().data();
string* out_bytes = nullptr;
float* out_float = nullptr;
@@ -2121,18 +2187,26 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->dense_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
t++;
// Fill in the values.
for (int e = 0; e < num_examples; e++) {
- size_t num_elements = 0;
- const auto& feature = all_sequence_features[e][c.feature_name];
+ size_t num_elements = 0, num_rows = 0;
+ const auto feature_iter = all_sequence_features[e].find(c.feature_name);
const string& example_name =
example_names.empty() ? kUnknown : example_names[e];
- if (!feature.empty()) {
+ if (feature_iter == all_sequence_features[e].end()) {
+ // Return an error if this feature was not allowed to be missing.
+ // Otherwise, we'll pad as needed below.
+ if (!c.variable_length) {
+ return errors::InvalidArgument("Missing feature ", c.feature_name,
+ " in example ", example_name);
+ }
+ } else if (!feature_iter->second.empty()) {
+ const auto& feature = feature_iter->second;
protobuf::io::CodedInputStream stream(
reinterpret_cast<const uint8*>(feature.data()), feature.size());
EnableAliasing(&stream);
@@ -2140,9 +2214,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
auto limit = stream.PushLimit(feature_length);
size_t num_added;
@@ -2160,10 +2234,11 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
+ num_rows++;
if (num_added != row_shape.num_elements()) {
return errors::InvalidArgument(
"Unexpected number of elements in feature ", c.feature_name,
@@ -2172,6 +2247,7 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
}
}
+ *out_lengths++ = num_rows;
// Pad as necessary.
int num_to_pad = expected_max_elements - num_elements;
switch (dtype) {
@@ -2187,8 +2263,8 @@ Status FastParseSequenceExample(
out_int64 += num_to_pad;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
}
}
@@ -2219,8 +2295,8 @@ Status FastParseSequenceExample(
out_int64 = feature_list_result->sparse_values[t].flat<int64>().data();
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in feature ", c.feature_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in feature ", c.feature_name);
}
int64* out_indices =
feature_list_result->sparse_indices[t].flat<int64>().data();
@@ -2244,9 +2320,9 @@ Status FastParseSequenceExample(
uint32 feature_length;
if (!stream.ExpectTag(kDelimitedTag(1)) ||
!stream.ReadVarint32(&feature_length)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
if (feature_length > 2) {
auto limit = stream.PushLimit(feature_length);
@@ -2265,8 +2341,8 @@ Status FastParseSequenceExample(
out_int64 += num_added;
break;
default:
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected dtype ", dtype, " in example ", example_name));
+ return errors::InvalidArgument("Unexpected dtype ", dtype,
+ " in example ", example_name);
}
num_elements += num_added;
max_num_cols = std::max(max_num_cols, num_added);
@@ -2278,14 +2354,14 @@ Status FastParseSequenceExample(
stream.PopLimit(limit);
} else if (feature_length == 2) {
if (!SkipEmptyFeature(&stream, dtype)) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
} else if (feature_length != 0) {
- return errors::InvalidArgument(
- strings::StrCat("Error in sequence feature ", c.feature_name,
- " in example ", example_name));
+ return errors::InvalidArgument("Error in sequence feature ",
+ c.feature_name, " in example ",
+ example_name);
}
num_rows++;
}
@@ -2293,8 +2369,8 @@ Status FastParseSequenceExample(
}
}
if (num_elements != expected_num_elements) {
- return errors::InvalidArgument(strings::StrCat(
- "Unexpected number of elements in feature ", c.feature_name));
+ return errors::InvalidArgument(
+ "Unexpected number of elements in feature ", c.feature_name);
}
out_shape(0) = num_examples;
out_shape(1) = max_num_rows;
diff --git a/tensorflow/core/util/example_proto_fast_parsing.h b/tensorflow/core/util/example_proto_fast_parsing.h
index db5b5ff929..055d9c2c30 100644
--- a/tensorflow/core/util/example_proto_fast_parsing.h
+++ b/tensorflow/core/util/example_proto_fast_parsing.h
@@ -118,7 +118,8 @@ Status FastParseSequenceExample(
const example::FastParseExampleConfig& feature_list_config,
gtl::ArraySlice<string> serialized, gtl::ArraySlice<string> example_names,
thread::ThreadPool* thread_pool, example::Result* context_result,
- example::Result* feature_list_result);
+ example::Result* feature_list_result,
+ std::vector<Tensor>* dense_feature_lengths);
// This function parses serialized Example and populates given example.
// It uses the same specialized parser as FastParseExample which is efficient.
diff --git a/tensorflow/core/util/example_proto_fast_parsing_test.cc b/tensorflow/core/util/example_proto_fast_parsing_test.cc
index 37faa927bf..6c5f80a535 100644
--- a/tensorflow/core/util/example_proto_fast_parsing_test.cc
+++ b/tensorflow/core/util/example_proto_fast_parsing_test.cc
@@ -42,7 +42,7 @@ string SerializedToReadable(string serialized) {
string result;
result += '"';
for (char c : serialized)
- result += strings::StrCat("\\x", strings::Hex(c, strings::ZERO_PAD_2));
+ result += strings::StrCat("\\x", strings::Hex(c, strings::kZeroPad2));
result += '"';
return result;
}
diff --git a/tensorflow/core/util/example_proto_helper.cc b/tensorflow/core/util/example_proto_helper.cc
index e156a3bc8f..41fb20c00a 100644
--- a/tensorflow/core/util/example_proto_helper.cc
+++ b/tensorflow/core/util/example_proto_helper.cc
@@ -443,6 +443,59 @@ Status ParseSingleExampleAttrs::FinishInit() {
return Status::OK();
}
+Status ParseSequenceExampleAttrs::FinishInit() {
+ if (num_context_sparse != context_sparse_keys.size() ||
+ num_context_sparse != context_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_context_sparse (", num_context_sparse,
+ ") must match the size of context_sparse_keys (",
+ context_sparse_keys.size(), ") and context_sparse_types (",
+ context_sparse_types.size(), ")");
+ }
+ if (num_context_dense != context_dense_keys.size() ||
+ num_context_dense != context_dense_types.size() ||
+ num_context_dense != context_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_context_dense (", num_context_dense,
+ ") must match the size of context_dense_keys (",
+ context_dense_keys.size(), "), context_dense_types (",
+ context_dense_types.size(), ") and context_dense_shapes (",
+ context_dense_shapes.size(), ")");
+ }
+ if (num_feature_list_sparse != feature_list_sparse_keys.size() ||
+ num_feature_list_sparse != feature_list_sparse_types.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_sparse (", num_feature_list_sparse,
+ ") must match the size of feature_list_sparse_keys (",
+ feature_list_sparse_keys.size(), ") and feature_list_sparse_types (",
+ feature_list_sparse_types.size(), ")");
+ }
+ if (num_feature_list_dense != feature_list_dense_keys.size() ||
+ num_feature_list_dense != feature_list_dense_types.size() ||
+ num_feature_list_dense != feature_list_dense_shapes.size()) {
+ return errors::InvalidArgument(
+ "num_feature_list_dense (", num_feature_list_dense,
+ ") must match the size of feature_list_dense_keys (",
+ feature_list_dense_keys.size(), "), feature_list_dense_types (",
+ feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
+ feature_list_dense_shapes.size(), ")");
+ }
+ for (const DataType& type : context_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : context_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_dense_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+ for (const DataType& type : feature_list_sparse_types) {
+ TF_RETURN_IF_ERROR(CheckValidType(type));
+ }
+
+ return Status::OK();
+}
+
Status ParseSingleSequenceExampleAttrs::FinishInit() {
if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
return errors::InvalidArgument(
diff --git a/tensorflow/core/util/example_proto_helper.h b/tensorflow/core/util/example_proto_helper.h
index e511704962..c183ee4d96 100644
--- a/tensorflow/core/util/example_proto_helper.h
+++ b/tensorflow/core/util/example_proto_helper.h
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/sparse/sparse_tensor.h"
@@ -271,6 +272,66 @@ class ParseSingleExampleAttrs {
Status FinishInit(); // for context-independent parts of Init.
};
+// Parses the attributes passed to ParseSequenceExample.
+// REQUIRES: Init must be called after construction.
+class ParseSequenceExampleAttrs {
+ public:
+ template <typename ContextType>
+ Status Init(ContextType* ctx) {
+ std::vector<string> feature_list_dense_missing_assumed_empty_tmp;
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_missing_assumed_empty",
+ &feature_list_dense_missing_assumed_empty_tmp));
+ for (const string& feature : feature_list_dense_missing_assumed_empty_tmp) {
+ feature_list_dense_missing_assumed_empty.insert(feature);
+ }
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_keys", &context_sparse_keys));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("context_dense_keys", &context_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_keys", &feature_list_sparse_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_keys", &feature_list_dense_keys));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_sparse_types", &context_sparse_types));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense", &num_context_dense));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_dense", &num_feature_list_dense));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse", &num_context_sparse));
+ TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense", &context_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_sparse_types", &feature_list_sparse_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_types", &feature_list_dense_types));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("Nfeature_list_sparse", &num_feature_list_sparse));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("context_dense_shapes", &context_dense_shapes));
+ TF_RETURN_IF_ERROR(
+ ctx->GetAttr("feature_list_dense_shapes", &feature_list_dense_shapes));
+ return FinishInit();
+ }
+
+ std::unordered_set<string> feature_list_dense_missing_assumed_empty;
+ int64 num_context_sparse;
+ int64 num_context_dense;
+ int64 num_feature_list_sparse;
+ int64 num_feature_list_dense;
+ std::vector<string> context_sparse_keys;
+ std::vector<string> context_dense_keys;
+ std::vector<string> feature_list_sparse_keys;
+ std::vector<string> feature_list_dense_keys;
+ std::vector<DataType> context_sparse_types;
+ std::vector<DataType> context_dense_types;
+ std::vector<TensorShape> context_dense_shapes;
+ std::vector<DataType> feature_list_sparse_types;
+ std::vector<DataType> feature_list_dense_types;
+ std::vector<TensorShape> feature_list_dense_shapes;
+
+ private:
+ Status FinishInit(); // for context-independent parts of Init.
+};
+
// Parses the attributes passed to ParseSingleSequenceExample.
// REQUIRES: Init must be called after construction.
class ParseSingleSequenceExampleAttrs {
diff --git a/tensorflow/core/util/guarded_philox_random.h b/tensorflow/core/util/guarded_philox_random.h
index 44970eb949..8be7a374f0 100644
--- a/tensorflow/core/util/guarded_philox_random.h
+++ b/tensorflow/core/util/guarded_philox_random.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
-#define TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+#ifndef TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_
+#define TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/random/philox_random.h"
@@ -79,4 +79,4 @@ class GuardedPhiloxRandom {
} // namespace tensorflow
-#endif // TENSORFLOW_KERNELS_GUARDED_PHILOX_RANDOM_H_
+#endif // TENSORFLOW_CORE_UTIL_GUARDED_PHILOX_RANDOM_H_
diff --git a/tensorflow/core/util/mirror_pad_mode.h b/tensorflow/core/util/mirror_pad_mode.h
index f703d47ab1..ceee9b06b0 100644
--- a/tensorflow/core/util/mirror_pad_mode.h
+++ b/tensorflow/core/util/mirror_pad_mode.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_MIRROR_PAD_MODE_H_
-#define TENSORFLOW_UTIL_MIRROR_PAD_MODE_H_
+#ifndef TENSORFLOW_CORE_UTIL_MIRROR_PAD_MODE_H_
+#define TENSORFLOW_CORE_UTIL_MIRROR_PAD_MODE_H_
// This file contains helper routines to deal with padding in various ops and
// kernels.
@@ -49,4 +49,4 @@ Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_MIRROR_PAD_MODE_H_
+#endif // TENSORFLOW_CORE_UTIL_MIRROR_PAD_MODE_H_
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index 159a787d05..04aaea4f89 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
#ifdef INTEL_MKL
+#include <string>
#include <memory>
#include <unordered_map>
#include <utility>
@@ -33,6 +34,11 @@ limitations under the License.
#endif
#ifdef INTEL_MKL_ML_ONLY
+#error \
+ "Compiling for INTEL MKL ML only is no longer supported.Please use MKL DNN (the default option for --config=mkl)"
+#endif
+
+#ifdef INTEL_MKL_ML_ONLY
#include "mkl_dnn.h"
#include "mkl_dnn_types.h"
#include "mkl_service.h"
@@ -50,6 +56,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
+#include "tensorflow/core/util/env_var.h"
#ifndef INTEL_MKL_ML_ONLY
#include "mkldnn.hpp"
@@ -66,7 +73,6 @@ using mkldnn::reorder;
typedef unsigned int uint;
#endif
-
namespace tensorflow {
// The file contains a number of utility classes and functions used by MKL
@@ -87,6 +93,18 @@ typedef enum {
Dim_I = 1
} MklDnnDims;
+typedef enum {
+ Dim3d_N = 0,
+ Dim3d_C = 1,
+ Dim3d_D = 2,
+ Dim3d_H = 3,
+ Dim3d_W = 4,
+ Dim3d_O = 0,
+ Dim3d_I = 1
+} MklDnnDims3D;
+
+static const int kSmallBatchSize = 32;
+
#ifdef INTEL_MKL_ML_ONLY
class MklShape {
public:
@@ -351,6 +369,7 @@ class MklShape {
#else
// Forward decl
+TensorFormat MklDnn3DDataFormatToTFDataFormat(memory::format format);
TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format);
memory::dims CalculateTFStrides(const memory::dims& dims_tf_order);
memory::desc CreateBlockedMemDescHelper(const memory::dims& dim,
@@ -453,6 +472,13 @@ class MklDnnShape {
return this->DimSize(index);
}
+ inline size_t GetDimension3D(char dimension) const {
+ int index = GetMklDnnTensor3DDimIndex(dimension);
+ CHECK(index >= 0 && index < this->GetDimension())
+ << "Invalid index from the dimension: " << index << ", " << dimension;
+ return this->DimSize(index);
+ }
+
inline int32 GetMklDnnTensorDimIndex(char dimension) const {
switch (dimension) {
case 'N':
@@ -469,6 +495,24 @@ class MklDnnShape {
}
}
+ inline int32 GetMklDnnTensor3DDimIndex(char dimension) const {
+ switch (dimension) {
+ case 'N':
+ return MklDnnDims3D::Dim3d_N;
+ case 'C':
+ return MklDnnDims3D::Dim3d_C;
+ case 'D':
+ return MklDnnDims3D::Dim3d_D;
+ case 'H':
+ return MklDnnDims3D::Dim3d_H;
+ case 'W':
+ return MklDnnDims3D::Dim3d_W;
+ default:
+ LOG(FATAL) << "Invalid dimension: " << dimension;
+ return -1; // Avoid compiler warning about missing return value
+ }
+ }
+
inline size_t GetDimension() const { return data_.dimension_; }
inline const int* GetSizes() const {
return reinterpret_cast<const int*>(&data_.sizes_[0]);
@@ -587,15 +631,29 @@ class MklDnnShape {
}
inline void SetTfDimOrder(const size_t dimension, TensorFormat data_format) {
- // TODO(nhasabni): Why do we restrict this to 4D?
- CHECK_EQ(dimension, 4);
- CHECK(dimension == data_.dimension_);
- data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
- data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
- data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
- data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ if (dimension == 5) {
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<3>(data_format, '0')] =
+ MklDnnDims3D::Dim3d_D;
+ data_.map_[GetTensorDimIndex<3>(data_format, '1')] =
+ MklDnnDims3D::Dim3d_H;
+ data_.map_[GetTensorDimIndex<3>(data_format, '2')] =
+ MklDnnDims3D::Dim3d_W;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'C')] =
+ MklDnnDims3D::Dim3d_C;
+ data_.map_[GetTensorDimIndex<3>(data_format, 'N')] =
+ MklDnnDims3D::Dim3d_N;
+ } else {
+ CHECK_EQ(dimension, 4);
+ CHECK(dimension == data_.dimension_);
+ data_.map_[GetTensorDimIndex<2>(data_format, 'W')] = MklDnnDims::Dim_W;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'H')] = MklDnnDims::Dim_H;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'C')] = MklDnnDims::Dim_C;
+ data_.map_[GetTensorDimIndex<2>(data_format, 'N')] = MklDnnDims::Dim_N;
+ }
}
+
inline void SetTfDimOrder(const size_t dimension, memory::format format) {
TensorFormat data_format = MklDnnDataFormatToTFDataFormat(format);
SetTfDimOrder(dimension, data_format);
@@ -1329,6 +1387,19 @@ memory::data_type MklDnnType<float>() {
return memory::data_type::f32;
}
+/// Map TensorFlow's data format into MKL-DNN 3D data format
+/// @input: TensorFlow data format
+/// @return: memory::format corresponding to TensorFlow data format;
+/// Fails with an error if invalid data format.
+inline memory::format TFDataFormatToMklDnn3DDataFormat(TensorFormat format) {
+ if (format == FORMAT_NHWC)
+ return memory::format::ndhwc;
+ else if (format == FORMAT_NCHW)
+ return memory::format::ncdhw;
+ TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
+ return memory::format::format_undef;
+}
+
/// Map TensorFlow's data format into MKL-DNN data format
///
/// @input: TensorFlow data format
@@ -1340,7 +1411,6 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
else if (format == FORMAT_NCHW)
return memory::format::nchw;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
- // Return to get rid of compiler warning
return memory::format::format_undef;
}
@@ -1350,9 +1420,9 @@ inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
/// @return: Tensorflow data format corresponding to memory::format
/// Fails with an error if invalid data format.
inline TensorFormat MklDnnDataFormatToTFDataFormat(memory::format format) {
- if (format == memory::format::nhwc)
+ if (format == memory::format::nhwc || format == memory::format::ndhwc)
return FORMAT_NHWC;
- else if (format == memory::format::nchw)
+ else if (format == memory::format::nchw || format == memory::format::ncdhw)
return FORMAT_NCHW;
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT, "Unsupported data format"));
@@ -1402,6 +1472,22 @@ inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
return memory::dims({n, c, h, w});
}
+inline memory::dims TFShapeToMklDnnDimsInNCDHW(const TensorShape& shape,
+ TensorFormat format) {
+ // Check validity of format.
+ CHECK_NE(TFDataFormatToMklDnn3DDataFormat(format),
+ memory::format::format_undef);
+
+ int n = shape.dim_size(GetTensorDimIndex<3>(format, 'N'));
+ int c = shape.dim_size(GetTensorDimIndex<3>(format, 'C'));
+ int d = shape.dim_size(GetTensorDimIndex<3>(format, '0'));
+ int h = shape.dim_size(GetTensorDimIndex<3>(format, '1'));
+ int w = shape.dim_size(GetTensorDimIndex<3>(format, '2'));
+
+ // MKL-DNN requires dimensions in NCDHW format.
+ return memory::dims({n, c, d, h, w});
+}
+
/// Overloaded version of function above. Input parameters are
/// self-explanatory.
inline memory::dims MklDnnDimsInNCHW(const memory::dims& in_dims,
@@ -1514,6 +1600,8 @@ class MklDnnData {
/// Operations memory descriptor
memory::desc* op_md_;
+ // flat to indicate if data is 3D or not.
+ bool bIs3D;
/// Operations temp buffer
void* allocated_buffer_;
/// CPU engine on which operation will be executed
@@ -1540,6 +1628,10 @@ class MklDnnData {
static_cast<const void*>(tensor->flat<T>().data()));
}
+ void SetIs3DData(bool bIs3D_) { bIs3D = bIs3D_; }
+
+ bool GetIs3D() { return bIs3D; }
+
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
@@ -1911,7 +2003,9 @@ const mkldnn::memory::dims NONE_DIMS = {};
template <typename T>
class MklPrimitiveFactory {
public:
- MklPrimitiveFactory() {}
+ MklPrimitiveFactory() {
+ }
+
~MklPrimitiveFactory() {}
MklPrimitive* GetOp(const string& key) {
@@ -1934,6 +2028,22 @@ class MklPrimitiveFactory {
map[key] = op;
}
+ /// Function to decide whether HW has AVX512 or AVX2
+ /// For those legacy device(w/o AVX512 and AVX2),
+ /// MKL-DNN GEMM will be used.
+ static inline bool IsLegacyPlatform() {
+ return (!port::TestCPUFeature(port::CPUFeature::AVX512F)
+ && !port::TestCPUFeature(port::CPUFeature::AVX2));
+ }
+
+ /// Fuction to check whether primitive memory optimization is enabled
+ static inline bool IsPrimitiveMemOptEnabled() {
+ bool is_primitive_mem_opt_enabled = true;
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
+ return is_primitive_mem_opt_enabled;
+ }
+
private:
static inline std::unordered_map<string, MklPrimitive*>& GetHashMap() {
static thread_local std::unordered_map<string, MklPrimitive*> map_;
@@ -1971,21 +2081,24 @@ class FactoryKeyCreator {
const char delimiter = 'x';
const int kMaxKeyLength = 256;
void Append(StringPiece s) {
- key_.append(s.ToString());
+ key_.append(string(s));
key_.append(1, delimiter);
}
};
-static inline memory::format get_desired_format(int channel) {
+
+static inline memory::format get_desired_format(int channel,
+ bool is_2d = true) {
memory::format fmt_desired = memory::format::any;
- if (port::TestCPUFeature(port::CPUFeature::AVX512F) && (channel % 16) == 0) {
- fmt_desired = memory::format::nChw16c;
+ if (port::TestCPUFeature(port::CPUFeature::AVX512F)) {
+ fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = memory::format::nChw8c;
+ fmt_desired = is_2d ? memory::format::nChw8c
+ : memory::format::ncdhw; // no avx2 support for 3d yet.
} else {
- fmt_desired = memory::format::nchw;
+ fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
return fmt_desired;
}
@@ -2006,7 +2119,7 @@ class MklReorderPrimitive : public MklPrimitive {
context_.dst_mem->set_data_handle(to->get_data_handle());
}
- private:
+ private:
struct ReorderContext {
std::shared_ptr<mkldnn::memory> src_mem;
std::shared_ptr<mkldnn::memory> dst_mem;
@@ -2048,7 +2161,7 @@ class MklReorderPrimitiveFactory : public MklPrimitiveFactory<T> {
return instance_;
}
- private:
+ private:
MklReorderPrimitiveFactory() {}
~MklReorderPrimitiveFactory() {}
@@ -2093,6 +2206,16 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
return *reorder_prim->GetPrimitive();
}
+// utility function to determine if it is conv 1x1 and stride != 1
+// for purpose of temporarily disabling primitive reuse
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
+ memory::dims strides) {
+ if (filter_dims.size() != 4 || strides.size() != 2) return false;
+
+ return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
+ ((strides[0] != 1) || (strides[1] != 1)));
+}
+
#endif // INTEL_MKL_DNN
} // namespace tensorflow
diff --git a/tensorflow/core/util/padding.h b/tensorflow/core/util/padding.h
index a4278ff2b4..76f9b4dd9a 100644
--- a/tensorflow/core/util/padding.h
+++ b/tensorflow/core/util/padding.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_PADDING_H_
-#define TENSORFLOW_UTIL_PADDING_H_
+#ifndef TENSORFLOW_CORE_UTIL_PADDING_H_
+#define TENSORFLOW_CORE_UTIL_PADDING_H_
// This file contains helper routines to deal with padding in various ops and
// kernels.
@@ -50,4 +50,4 @@ Status GetNodeAttr(const NodeDef& node_def, StringPiece attr_name,
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_PADDING_H_
+#endif // TENSORFLOW_CORE_UTIL_PADDING_H_
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index c081ceae57..e01058dff6 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -38,10 +38,10 @@ bool CudaSupportsHalfMatMulAndConv() {
}
bool IsMklEnabled() {
-#ifdef INTEL_MKL
+#if defined(INTEL_MKL) && defined(ENABLE_MKL)
return true;
#else
return false;
-#endif
+#endif // INTEL_MKL && ENABLE_MKL
}
} // end namespace tensorflow
diff --git a/tensorflow/core/util/port.h b/tensorflow/core/util/port.h
index 981def9d22..e9b9cb1cd2 100644
--- a/tensorflow/core/util/port.h
+++ b/tensorflow/core/util/port.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_PORT_H_
-#define TENSORFLOW_UTIL_PORT_H_
+#ifndef TENSORFLOW_CORE_UTIL_PORT_H_
+#define TENSORFLOW_CORE_UTIL_PORT_H_
namespace tensorflow {
@@ -30,4 +30,4 @@ bool IsMklEnabled();
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_PORT_H_
+#endif // TENSORFLOW_CORE_UTIL_PORT_H_
diff --git a/tensorflow/core/util/saved_tensor_slice_util.h b/tensorflow/core/util/saved_tensor_slice_util.h
index 90672a10a8..7c9cfa35f7 100644
--- a/tensorflow/core/util/saved_tensor_slice_util.h
+++ b/tensorflow/core/util/saved_tensor_slice_util.h
@@ -15,8 +15,8 @@ limitations under the License.
// Utilities for saving/restoring tensor slice checkpoints.
-#ifndef TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
-#define TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#ifndef TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
#include <string> // for string
#include "tensorflow/core/framework/tensor.pb.h"
@@ -210,4 +210,4 @@ inline void Fill(const string* data, size_t n, TensorProto* t) {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
+#endif // TENSORFLOW_CORE_UTIL_SAVED_TENSOR_SLICE_UTIL_H_
diff --git a/tensorflow/core/util/sparse/group_iterator.cc b/tensorflow/core/util/sparse/group_iterator.cc
index 204b933051..546b0a833c 100644
--- a/tensorflow/core/util/sparse/group_iterator.cc
+++ b/tensorflow/core/util/sparse/group_iterator.cc
@@ -21,8 +21,8 @@ namespace sparse {
void GroupIterable::IteratorStep::UpdateEndOfGroup() {
++next_loc_;
- int64 N = iter_->ix_.dim_size(0);
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
+ const int64 N = ix_t.dimension(0);
while (next_loc_ < N && iter_->GroupMatches(ix_t, loc_, next_loc_)) {
++next_loc_;
}
@@ -54,7 +54,7 @@ GroupIterable::IteratorStep GroupIterable::IteratorStep::operator++(
std::vector<int64> Group::group() const {
std::vector<int64> g;
- auto ix_t = iter_->ix_.template matrix<int64>();
+ const auto& ix_t = iter_->ix_matrix_;
for (const int d : iter_->group_dims_) {
g.push_back(ix_t(loc_, d));
}
@@ -62,8 +62,8 @@ std::vector<int64> Group::group() const {
}
TTypes<int64>::UnalignedConstMatrix Group::indices() const {
- return TTypes<int64>::UnalignedConstMatrix(
- &(iter_->ix_.matrix<int64>()(loc_, 0)), next_loc_ - loc_, iter_->dims_);
+ return TTypes<int64>::UnalignedConstMatrix(&(iter_->ix_matrix_(loc_, 0)),
+ next_loc_ - loc_, iter_->dims_);
}
} // namespace sparse
diff --git a/tensorflow/core/util/sparse/group_iterator.h b/tensorflow/core/util/sparse/group_iterator.h
index 3fa8cb6116..14610c61d9 100644
--- a/tensorflow/core/util/sparse/group_iterator.h
+++ b/tensorflow/core/util/sparse/group_iterator.h
@@ -79,6 +79,7 @@ class GroupIterable {
GroupIterable(Tensor ix, Tensor vals, int dims, const VarDimArray& group_dims)
: ix_(ix),
+ ix_matrix_(ix_.matrix<int64>()),
vals_(vals),
dims_(dims),
group_dims_(group_dims.begin(), group_dims.end()) {}
@@ -127,7 +128,8 @@ class GroupIterable {
private:
friend class Group;
- Tensor ix_;
+ const Tensor ix_;
+ const TTypes<int64>::ConstMatrix ix_matrix_;
Tensor vals_;
const int dims_;
const gtl::InlinedVector<int64, 8> group_dims_;
diff --git a/tensorflow/core/util/sparse/sparse_tensor.h b/tensorflow/core/util/sparse/sparse_tensor.h
index 0f04b65f60..b9ca8ab395 100644
--- a/tensorflow/core/util/sparse/sparse_tensor.h
+++ b/tensorflow/core/util/sparse/sparse_tensor.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <numeric>
#include <vector>
+#include "absl/base/macros.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_types.h"
@@ -95,21 +96,21 @@ class SparseTensor {
SparseTensor() : dims_(0) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape)
: SparseTensor(ix, vals, TensorShapeToVector(shape),
UndefinedOrder(TensorShapeToVector(shape))) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape)
: SparseTensor(ix, vals, shape, UndefinedOrder(shape)) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const TensorShape& shape,
const VarDimArray order)
: SparseTensor(ix, vals, TensorShapeToVector(shape), order) {}
- // DEPRECATED: use Create() functions instead of constructors directly.
+ ABSL_DEPRECATED("Use Create() functions instead of constructors directly.")
SparseTensor(Tensor ix, Tensor vals, const VarDimArray shape,
const VarDimArray order)
: ix_(ix),
@@ -237,9 +238,10 @@ class SparseTensor {
static Status Split(const SparseTensor& tensor, const int split_dim,
const int num_split, std::vector<SparseTensor>* result);
- // DEPRECATED: use the form of Split() that takes an output pointer and
- // returns a status instead.
template <typename T>
+ ABSL_DEPRECATED(
+ "Use the form of Split() that takes an output pointer and returns a "
+ "status instead.")
static std::vector<SparseTensor> Split(const SparseTensor& tensor,
const int split_dim,
const int num_split,
diff --git a/tensorflow/core/util/strided_slice_op.cc b/tensorflow/core/util/strided_slice_op.cc
index aca60b942d..ad8a44a518 100644
--- a/tensorflow/core/util/strided_slice_op.cc
+++ b/tensorflow/core/util/strided_slice_op.cc
@@ -326,7 +326,7 @@ Status ValidateStridedSliceOp(
// Even if we don't have values for begin or end, we do know that this
// dimension covers the whole interval. If we have shape information for
// this dimension, that tells us the interval length.
- if (dim_i > 0) {
+ if (dim_i >= 0) {
if (stride_i < 0) {
interval_length = -dim_i;
} else {
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 648358606c..f40ec9b752 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -64,6 +64,11 @@ cc_library(
tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
+ data = glob(["testdata/**"]),
+ tags = [
+ "nomsan",
+ "notsan",
+ ],
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/util/tensor_bundle/naming.h b/tensorflow/core/util/tensor_bundle/naming.h
index 3d21570c74..7b101971a8 100644
--- a/tensorflow/core/util/tensor_bundle/naming.h
+++ b/tensorflow/core/util/tensor_bundle/naming.h
@@ -31,10 +31,11 @@ limitations under the License.
//
// Regexp can also be used: e.g. R"<prefix>.data-\d{5}-of-\d{5}" for data files.
-#ifndef TENSORFLOW_UTIL_TENSOR_BUNDLE_NAMING_H_
-#define TENSORFLOW_UTIL_TENSOR_BUNDLE_NAMING_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -43,4 +44,4 @@ string DataFilename(StringPiece prefix, int32 shard_id, int32 num_shards);
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_BUNDLE_NAMING_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_NAMING_H_
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
index 7190614706..2dcb57a1f9 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.cc
@@ -64,27 +64,36 @@ namespace {
// Reads "num_elements" string elements from file[offset, offset+size) into the
// length-N "destination". Discards the original content of "destination".
//
-// Checksums the string lengths (as restored uint32, not varint32 bytes) and
-// string bytes, and stores it into "actual_crc32c".
+// Checksums the string lengths (as restored uint32 or uint64, not varint64
+// bytes) and string bytes, and stores it into "actual_crc32c".
Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
size_t offset, size_t size, string* destination,
uint32* actual_crc32c) {
if (size == 0) return Status::OK();
CHECK_GT(size, 0);
- // Reads "num_elements" varint32's from "buffered_file".
+ // Reads "num_elements" varint64's from "buffered_file".
TF_RETURN_IF_ERROR(buffered_file->Seek(offset));
- std::vector<uint32> string_lengths(num_elements);
+ std::vector<uint64> string_lengths(num_elements);
for (size_t i = 0; i < num_elements; ++i) {
- TF_RETURN_IF_ERROR(buffered_file->ReadVarint32(&string_lengths[i]));
+ TF_RETURN_IF_ERROR(buffered_file->ReadVarint64(&string_lengths[i]));
+ if (string_lengths[i] <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(string_lengths[i]);
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *actual_crc32c = crc32c::Extend(
+ *actual_crc32c, reinterpret_cast<const char*>(&string_lengths[i]),
+ sizeof(uint64));
+ }
}
if (offset + size < buffered_file->Tell()) {
return errors::DataLoss("String lengths longer than expected offset ",
offset + size);
}
- *actual_crc32c =
- crc32c::Value(reinterpret_cast<const char*>(string_lengths.data()),
- sizeof(uint32) * num_elements);
// Reads the length-checksum.
uint32 length_checksum = 0;
@@ -104,7 +113,7 @@ Status ReadStringTensor(io::InputBuffer* buffered_file, size_t num_elements,
// Reads the actual string bytes.
for (size_t i = 0; i < num_elements; ++i) {
- const uint32 string_length = string_lengths[i];
+ const uint64 string_length = string_lengths[i];
string* buffer = &destination[i];
buffer->resize(string_length);
@@ -218,8 +227,8 @@ Status WriteTensor(const Tensor& val, FileOutputBuffer* out,
Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
size_t* bytes_written, uint32* crc32c) {
// On-disk format:
- // [varint32 len0]..[varint32 lenL][4 byte cksum on lengths][string bytes]
- // Var "crc32c" checksums the string lengths (as uint32, not varint32 bytes),
+ // [varint64 len0]..[varint64 lenL][4 byte cksum on lengths][string bytes]
+ // Var "crc32c" checksums the string lengths (as uint64, not varint64 bytes),
// the length-checksum, and all the string bytes.
DCHECK_EQ(val.dtype(), DT_STRING);
const string* strings = GetStringBackingBuffer(val);
@@ -230,12 +239,21 @@ Status WriteStringTensor(const Tensor& val, FileOutputBuffer* out,
*crc32c = 0;
for (int64 i = 0; i < val.NumElements(); ++i) {
const string* elem = &strings[i];
- DCHECK_EQ(elem->size(), static_cast<uint32>(elem->size()));
- const uint32 elem_size = static_cast<uint32>(elem->size());
-
- core::PutVarint32(&lengths, elem_size);
- *crc32c = crc32c::Extend(*crc32c, reinterpret_cast<const char*>(&elem_size),
- sizeof(uint32));
+ DCHECK_EQ(elem->size(), static_cast<uint64>(elem->size()));
+ const uint64 elem_size = static_cast<uint64>(elem->size());
+
+ core::PutVarint64(&lengths, elem_size);
+ if (elem_size <= UINT32_MAX) {
+ // We need to do this because older checkpoints only used uint32s and we
+ // should still support them.
+ const uint32 elem_size_uint32 = static_cast<uint32>(elem_size);
+ *crc32c = crc32c::Extend(*crc32c,
+ reinterpret_cast<const char*>(&elem_size_uint32),
+ sizeof(uint32));
+ } else {
+ *crc32c = crc32c::Extend(
+ *crc32c, reinterpret_cast<const char*>(&elem_size), sizeof(uint64));
+ }
}
TF_RETURN_IF_ERROR(out->Append(lengths));
*bytes_written = lengths.size();
@@ -370,14 +388,14 @@ Status PadAlignment(FileOutputBuffer* out, int alignment, int64* size) {
BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
: env_(env),
options_(options),
- prefix_(std::string(prefix)),
+ prefix_(prefix),
tmp_metadata_path_(strings::StrCat(MetaFilename(prefix_), ".tempstate",
random::New64())),
tmp_data_path_(strings::StrCat(DataFilename(prefix_, 0, 1), ".tempstate",
random::New64())),
out_(nullptr),
size_(0) {
- status_ = env_->CreateDir(std::string(io::Dirname(prefix_)));
+ status_ = env_->CreateDir(string(io::Dirname(prefix_)));
if (!status_.ok() && !errors::IsAlreadyExists(status_)) {
return;
}
@@ -394,7 +412,7 @@ BundleWriter::BundleWriter(Env* env, StringPiece prefix, const Options& options)
Status BundleWriter::Add(StringPiece key, const Tensor& val) {
if (!status_.ok()) return status_;
CHECK_NE(key, kHeaderEntryKey);
- const string key_string = std::string(key);
+ const string key_string(key);
if (entries_.find(key_string) != entries_.end()) {
status_ = errors::InvalidArgument("Adding duplicate key: ", key);
return status_;
@@ -445,7 +463,7 @@ Status BundleWriter::AddSlice(StringPiece full_tensor_key,
// In the case of a sharded save, MergeBundles() is responsible for merging
// the "slices" field of multiple metadata entries corresponding to the same
// full tensor.
- const string full_tensor_key_string = std::string(full_tensor_key);
+ const string full_tensor_key_string(full_tensor_key);
BundleEntryProto* full_entry = &entries_[full_tensor_key_string];
if (full_entry->dtype() != DT_INVALID) {
CHECK_EQ(full_entry->dtype(), slice_tensor.dtype());
@@ -600,7 +618,7 @@ static Status MergeOneBundle(Env* env, StringPiece prefix,
// Loops through the non-header to-merge entries.
BundleEntryProto to_merge_entry;
for (; iter->Valid(); iter->Next()) {
- const string key = std::string(iter->key());
+ const string key(iter->key());
const auto entry_iter = merge_state->entries.find(key);
// Illegal: the duplicated entry is a non-slice tensor.
@@ -649,7 +667,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
// Merges all metadata tables.
// TODO(zhifengc): KeyValue sorter if it becomes too big.
MergeState merge;
- Status status = env->CreateDir(std::string(io::Dirname(merged_prefix)));
+ Status status = env->CreateDir(string(io::Dirname(merged_prefix)));
if (!status.ok() && !errors::IsAlreadyExists(status)) return status;
for (int i = 0; i < prefixes.size(); ++i) {
TF_RETURN_IF_ERROR(MergeOneBundle(env, prefixes[i], &merge));
@@ -697,7 +715,7 @@ Status MergeBundles(Env* env, gtl::ArraySlice<string> prefixes,
BundleReader::BundleReader(Env* env, StringPiece prefix)
: env_(env),
- prefix_(std::string(prefix)),
+ prefix_(prefix),
metadata_(nullptr),
table_(nullptr),
iter_(nullptr) {
@@ -919,7 +937,7 @@ Status BundleReader::GetSliceValue(StringPiece full_tensor_key,
const TensorShape full_shape(TensorShape(full_tensor_entry.shape()));
std::vector<std::pair<TensorSlice, string>> details;
- const string full_tensor_key_string = std::string(full_tensor_key);
+ const string full_tensor_key_string(full_tensor_key);
const TensorSliceSet* tss =
gtl::FindPtrOrNull(tensor_slices_, full_tensor_key_string);
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle.h b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
index d30ce3f0cf..3a2ffbb495 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle.h
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle.h
@@ -58,8 +58,8 @@ limitations under the License.
// "/fs/model/train/ckpt-step/ckpt" /* merged prefix */);
//
-#ifndef TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
-#define TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
#include "tensorflow/core/protobuf/tensor_bundle.pb.h"
@@ -346,4 +346,4 @@ class FileOutputBuffer {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_BUNDLE_TENSOR_BUNDLE_H_
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index 92ce8ae00e..9567e4750b 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -39,6 +39,11 @@ string Prefix(const string& prefix) {
return strings::StrCat(testing::TmpDir(), "/", prefix);
}
+string TestdataPrefix(const string& prefix) {
+ return strings::StrCat(testing::TensorFlowSrcRoot(),
+ "/core/util/tensor_bundle/testdata/", prefix);
+}
+
template <typename T>
Tensor Constant(T v, TensorShape shape) {
Tensor ret(DataTypeToEnum<T>::value, shape);
@@ -107,7 +112,7 @@ std::vector<string> AllTensorKeys(BundleReader* reader) {
reader->Seek(kHeaderEntryKey);
reader->Next();
for (; reader->Valid(); reader->Next()) {
- ret.push_back(std::string(reader->key()));
+ ret.emplace_back(reader->key());
}
return ret;
}
@@ -458,7 +463,26 @@ TEST(TensorBundleTest, NonStandardShapes) {
TestNonStandardShapes<qint8>();
}
+TEST(TensorBundleTest, StringTensorsOldFormat) {
+ // Test string tensor bundle made with previous version of code that use
+ // varint32s to store string lengths (we now use varint64s).
+ BundleReader reader(Env::Default(), TestdataPrefix("old_string_tensors/foo"));
+ TF_ASSERT_OK(reader.status());
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+
+ Expect<string>(&reader, "string_tensor", Tensor(DT_STRING, TensorShape({1})));
+ Expect<string>(&reader, "scalar", test::AsTensor<string>({"hello"}));
+ Expect<string>(
+ &reader, "strs",
+ test::AsTensor<string>({"hello", "", "x01", string(1 << 10, 'c')}));
+ Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+}
+
TEST(TensorBundleTest, StringTensors) {
+ constexpr size_t kLongLength = static_cast<size_t>(UINT32_MAX) + 1;
+ Tensor long_string_tensor(DT_STRING, TensorShape({1}));
+
{
BundleWriter writer(Env::Default(), Prefix("foo"));
TF_EXPECT_OK(writer.Add("string_tensor",
@@ -467,6 +491,12 @@ TEST(TensorBundleTest, StringTensors) {
TF_EXPECT_OK(writer.Add(
"strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')})));
+
+ // Requires a 64-bit length.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign(kLongLength, 'd');
+ TF_EXPECT_OK(writer.Add("long_scalar", long_string_tensor));
+
// Mixes in some floats.
TF_EXPECT_OK(writer.Add("floats", Constant_2x3<float>(16.18)));
TF_ASSERT_OK(writer.Finish());
@@ -474,9 +504,9 @@ TEST(TensorBundleTest, StringTensors) {
{
BundleReader reader(Env::Default(), Prefix("foo"));
TF_ASSERT_OK(reader.status());
- EXPECT_EQ(
- AllTensorKeys(&reader),
- std::vector<string>({"floats", "scalar", "string_tensor", "strs"}));
+ EXPECT_EQ(AllTensorKeys(&reader),
+ std::vector<string>({"floats", "long_scalar", "scalar",
+ "string_tensor", "strs"}));
Expect<string>(&reader, "string_tensor",
Tensor(DT_STRING, TensorShape({1})));
@@ -484,7 +514,35 @@ TEST(TensorBundleTest, StringTensors) {
Expect<string>(
&reader, "strs",
test::AsTensor<string>({"hello", "", "x01", string(1 << 25, 'c')}));
+
Expect<float>(&reader, "floats", Constant_2x3<float>(16.18));
+
+ // We don't use the Expect function so we can re-use the
+ // `long_string_tensor` buffer for reading out long_scalar to keep memory
+ // usage reasonable.
+ EXPECT_TRUE(reader.Contains("long_scalar"));
+ DataType dtype;
+ TensorShape shape;
+ TF_ASSERT_OK(reader.LookupDtypeAndShape("long_scalar", &dtype, &shape));
+ EXPECT_EQ(DT_STRING, dtype);
+ EXPECT_EQ(TensorShape({1}), shape);
+
+ // Zero-out the string so that we can be sure the new one is read in.
+ string* backing_string = long_string_tensor.flat<string>().data();
+ backing_string->assign("");
+
+ // Read long_scalar and check it contains kLongLength 'd's.
+ TF_ASSERT_OK(reader.Lookup("long_scalar", &long_string_tensor));
+ ASSERT_EQ(backing_string, long_string_tensor.flat<string>().data());
+ EXPECT_EQ(kLongLength, backing_string->length());
+ for (char c : *backing_string) {
+ // Not using ASSERT_EQ('d', c) because this way is twice as fast due to
+ // compiler optimizations.
+ if (c != 'd') {
+ FAIL() << "long_scalar is not full of 'd's as expected.";
+ break;
+ }
+ }
}
}
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
new file mode 100644
index 0000000000..428d3ef79e
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/README
@@ -0,0 +1,3 @@
+This tensor bundle was generated from cl/214343133, before string tensor
+lengths were written as varint64s. This is here to check backwards
+compatibility between the new code and old checkpoints.
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001 b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
new file mode 100644
index 0000000000..23b488e5fe
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
new file mode 100644
index 0000000000..a22a69e6e1
--- /dev/null
+++ b/tensorflow/core/util/tensor_bundle/testdata/old_string_tensors/foo.index
Binary files differ
diff --git a/tensorflow/core/util/tensor_format.cc b/tensorflow/core/util/tensor_format.cc
index a5f7ecf0d1..f331973f5c 100644
--- a/tensorflow/core/util/tensor_format.cc
+++ b/tensorflow/core/util/tensor_format.cc
@@ -25,6 +25,10 @@ string GetConvnet3dDataFormatAttrString() {
return "data_format: { 'NDHWC', 'NCDHW' } = 'NDHWC' ";
}
+string GetConvnetDataFormat2D3DAttrString() {
+ return "data_format: { 'NHWC', 'NCHW', 'NDHWC', 'NCDHW' } = 'NHWC' ";
+}
+
string GetConvnetFilterFormatAttrString() {
return "filter_format: { 'HWIO', 'OIHW' } = 'HWIO' ";
}
diff --git a/tensorflow/core/util/tensor_format.h b/tensorflow/core/util/tensor_format.h
index 918835e1fb..b0c349dd90 100644
--- a/tensorflow/core/util/tensor_format.h
+++ b/tensorflow/core/util/tensor_format.h
@@ -483,6 +483,7 @@ string GetConvnet3dDataFormatAttrString();
// Return the string that specifies the filter format for convnet operations.
string GetConvnetFilterFormatAttrString();
string GetConvnet3dFilterFormatAttrString();
+string GetConvnetDataFormat2D3DAttrString();
// Returns a tensor shape for the specified format and dimension sizes.
// Works for both 2D and 3D operations. The output shapes are as follows:
diff --git a/tensorflow/core/util/tensor_slice_reader.h b/tensorflow/core/util/tensor_slice_reader.h
index 263f56c7fc..4aa9a4708e 100644
--- a/tensorflow/core/util/tensor_slice_reader.h
+++ b/tensorflow/core/util/tensor_slice_reader.h
@@ -16,8 +16,8 @@ limitations under the License.
// The utility to read checkpoints for google brain tensor ops and v3
// checkpoints for dist_belief.
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
#include <unordered_map>
@@ -192,4 +192,4 @@ bool TensorSliceReader::CopySliceData(const string& name,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_H_
diff --git a/tensorflow/core/util/tensor_slice_reader_cache.h b/tensorflow/core/util/tensor_slice_reader_cache.h
index 63a8d0b068..9f1919df4e 100644
--- a/tensorflow/core/util/tensor_slice_reader_cache.h
+++ b/tensorflow/core/util/tensor_slice_reader_cache.h
@@ -16,8 +16,8 @@ limitations under the License.
// The utility to read checkpoints for google brain tensor ops and v3
// checkpoints for dist_belief.
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
#include <unordered_map>
@@ -85,4 +85,4 @@ class TensorSliceReaderCache {
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_SLICE_READER_CACHE_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_READER_CACHE_H_
diff --git a/tensorflow/core/util/tensor_slice_writer.h b/tensorflow/core/util/tensor_slice_writer.h
index 2888c66d10..0db2fb4804 100644
--- a/tensorflow/core/util/tensor_slice_writer.h
+++ b/tensorflow/core/util/tensor_slice_writer.h
@@ -16,8 +16,8 @@ limitations under the License.
// The utility to write checkpoints for google brain tensor ops and v3
// checkpoints for dist_belief.
-#ifndef TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
-#define TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+#ifndef TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_
+#define TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_
#include <unordered_map>
@@ -192,4 +192,4 @@ Status CreateTableTensorSliceBuilder(const string& filename,
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_TENSOR_SLICE_WRITER_H_
+#endif // TENSORFLOW_CORE_UTIL_TENSOR_SLICE_WRITER_H_
diff --git a/tensorflow/core/util/util.cc b/tensorflow/core/util/util.cc
index 1e5a9c5712..489999d1e8 100644
--- a/tensorflow/core/util/util.cc
+++ b/tensorflow/core/util/util.cc
@@ -120,4 +120,20 @@ string SliceDebugString(const TensorShape& shape, const int64 flat) {
return result;
}
+#ifdef INTEL_MKL
+bool DisableMKL() {
+ enum MklStatus { MKL_DEFAULT = 0, MKL_ON = 1, MKL_OFF = 2 };
+ static MklStatus status = MKL_DEFAULT;
+ if (status == MKL_DEFAULT) {
+ char* tf_disable_mkl = getenv("TF_DISABLE_MKL");
+ if ((tf_disable_mkl != NULL) && (std::stoi(tf_disable_mkl) == 1)) {
+ VLOG(2) << "TF-MKL: Disabling MKL";
+ status = MKL_OFF;
+ } else {
+ status = MKL_ON;
+ }
+ }
+ return status == MKL_OFF ? true : false;
+}
+#endif // INTEL_MKL
} // namespace tensorflow
diff --git a/tensorflow/core/util/util.h b/tensorflow/core/util/util.h
index 4adf2f14dc..4aa47aa48a 100644
--- a/tensorflow/core/util/util.h
+++ b/tensorflow/core/util/util.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_UTIL_H_
-#define TENSORFLOW_UTIL_UTIL_H_
+#ifndef TENSORFLOW_CORE_UTIL_UTIL_H_
+#define TENSORFLOW_CORE_UTIL_UTIL_H_
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -56,6 +56,11 @@ string PrintMemory(const char* ptr, size_t n);
// "tensor", "tensor[i]", "tensor[i, j]", etc.
string SliceDebugString(const TensorShape& shape, const int64 flat);
+// disable MKL in runtime
+#ifdef INTEL_MKL
+bool DisableMKL();
+#endif // INTEL_MKL
+
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_UTIL_H_
+#endif // TENSORFLOW_CORE_UTIL_UTIL_H_
diff --git a/tensorflow/core/util/work_sharder.cc b/tensorflow/core/util/work_sharder.cc
index f4bd2950e9..74f0713a61 100644
--- a/tensorflow/core/util/work_sharder.cc
+++ b/tensorflow/core/util/work_sharder.cc
@@ -50,6 +50,8 @@ void Shard(int max_parallelism, thread::ThreadPool* workers, int64 total,
max_parallelism);
}
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size.
void Sharder::Do(int64 total, int64 cost_per_unit, const Work& work,
const Runner& runner, int max_parallelism) {
cost_per_unit = std::max(int64{1}, cost_per_unit);
diff --git a/tensorflow/core/util/work_sharder.h b/tensorflow/core/util/work_sharder.h
index 72ce493c1b..9db85a54c6 100644
--- a/tensorflow/core/util/work_sharder.h
+++ b/tensorflow/core/util/work_sharder.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_UTIL_WORK_SHARDER_H_
-#define TENSORFLOW_UTIL_WORK_SHARDER_H_
+#ifndef TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
+#define TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_
#include <functional>
@@ -23,6 +23,9 @@ limitations under the License.
namespace tensorflow {
+// DEPRECATED: Prefer threadpool->TransformRangeConcurrently, which allows you
+// to directly specify the shard size. Use this function only if you want to
+// manually cap parallelism.
// Shards the "total" unit of work assuming each unit of work having
// roughly "cost_per_unit". Each unit of work is indexed 0, 1, ...,
// total - 1. Each shard contains 1 or more units of work and the
@@ -95,4 +98,4 @@ class Sharder {
} // end namespace tensorflow
-#endif // TENSORFLOW_UTIL_WORK_SHARDER_H_
+#endif // TENSORFLOW_CORE_UTIL_WORK_SHARDER_H_