diff options
author | 2015-12-08 09:58:59 -0800 | |
---|---|---|
committer | 2015-12-08 09:58:59 -0800 | |
commit | ddd4aaf5286de24ba70402ee0ec8b836d3aed8c7 (patch) | |
tree | 4efdf6cf4d69b45041fd2a02cd2b7327ea9f1f58 /tensorflow/stream_executor | |
parent | cd53f3c3302c9312c1840389a9988a879b8b9dd5 (diff) |
TensorFlow: upstream changes to git.
Change 109695551
Update FAQ
Change 109694725
Add a gradient for resize_bilinear op.
Change 109694505
Don't mention variables module in docs
variables.Variable should be tf.Variable.
Change 109658848
Adding an option to create a new thread-pool for each session.
Change 109640570
Take the snapshot of stream-executor.
+ Expose an interface for scratch space allocation in the interface.
Change 109638559
Let image_summary accept uint8 input
This allows users to do their own normalization / scaling if the default
(very weird) behavior of image_summary is undesired.
This required a slight tweak to fake_input.cc to make polymorphically typed
fake inputs infer if their type attr is not set but has a default.
Unfortunately, adding a second valid type to image_summary *disables* automatic
implicit conversion from np.float64 to tf.float32, so this change is slightly
backwards incompatible.
Change 109636969
Add serialization operations for SparseTensor.
Change 109636644
Update generated Op docs.
Change 109634899
TensorFlow: add a markdown file for producing release notes for our
releases. Seed with 0.5.0 with a boring but accurate description.
Change 109634502
Let histogram_summary take any realnumbertype
It used to take only floats, not it understands ints.
Change 109634434
TensorFlow: update locations where we mention python 3 support, update
them to current truth.
Change 109632108
Move HSV <> RGB conversions, grayscale conversions, and adjust_* ops back to tensorflow
- make GPU-capable version of RGBToHSV and HSVToRGB, allows only float input/output
- change docs to reflect new size constraints
- change HSV format to be [0,1] for all components
- add automatic dtype conversion for all adjust_* and grayscale conversion ops
- fix up docs
Change 109631077
Improve optimizer exceptions
1. grads_and_vars is now a tuple, so must be wrapped when passed to format.
2. Use '%r' instead of '%s' for dtype formatting
Base CL: 109697989
Diffstat (limited to 'tensorflow/stream_executor')
48 files changed, 1240 insertions, 541 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h index be81683166..94475817e0 100644 --- a/tensorflow/stream_executor/blas.h +++ b/tensorflow/stream_executor/blas.h @@ -1140,7 +1140,7 @@ class BlasSupport { // Macro used to quickly declare overrides for abstract virtuals in the // BlasSupport base class. -#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \ bool DoBlasAsum(Stream *stream, uint64 elem_count, \ const DeviceMemory<float> &x, int incx, \ DeviceMemory<float> *result) override; \ diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc index d2d84f4ea0..19ad12d28b 100644 --- a/tensorflow/stream_executor/cuda/cuda_blas.cc +++ b/tensorflow/stream_executor/cuda/cuda_blas.cc @@ -22,7 +22,8 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_helpers.h" -#include "tensorflow/stream_executor/cuda/cuda_platform.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dso_loader.h" #include "tensorflow/stream_executor/lib/initialize.h" diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index a3f1820094..c2ae035c96 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -29,16 +29,17 @@ limitations under the License. #include <memory> #include <vector> -#include "tensorflow/stream_executor/lib/error.h" -#include "tensorflow/stream_executor/lib/inlined_vector.h" -#include "tensorflow/stream_executor/lib/numbers.h" #include "tensorflow/stream_executor/lib/process_state.h" +#include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/status.h" #include "tensorflow/stream_executor/lib/str_util.h" #include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/lib/stringpiece.h" #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/platform/logging.h" +#include "tensorflow/stream_executor/lib/numbers.h" +#include "tensorflow/stream_executor/lib/str_util.h" +#include "tensorflow/stream_executor/lib/inlined_vector.h" namespace perftools { namespace gputools { @@ -113,7 +114,6 @@ void Diagnostician::LogDiagnosticInformation() { LOG(INFO) << "retrieving CUDA diagnostic information for host: " << port::Hostname(); - LogDriverVersionInformation(); } diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.h b/tensorflow/stream_executor/cuda/cuda_diagnostics.h index 15387c4161..42336c337f 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.h +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_ +#include "tensorflow/stream_executor/platform/port.h" #include <tuple> #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index b73a9f9ce2..880d52b589 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -18,6 +18,12 @@ limitations under the License. #include <dlfcn.h> #include <functional> +#include "tensorflow/stream_executor/cuda/cuda_activation.h" +#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" +#include "tensorflow/stream_executor/cuda/cuda_driver.h" +#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/dnn.h" #include "tensorflow/stream_executor/dso_loader.h" #include "tensorflow/stream_executor/lib/env.h" @@ -27,13 +33,9 @@ limitations under the License. #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/plugin_registry.h" +#include "tensorflow/stream_executor/scratch_allocator.h" #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor_pimpl.h" -#include "tensorflow/stream_executor/cuda/cuda_activation.h" -#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" -#include "tensorflow/stream_executor/cuda/cuda_driver.h" -#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" -#include "tensorflow/stream_executor/cuda/cuda_platform.h" #include "third_party/gpus/cuda/include/cudnn.h" namespace { @@ -62,8 +64,6 @@ namespace cuda { PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin); -extern CUstream AsCUDAStreamValue(Stream* stream); - string ToString(cudnnStatus_t status) { switch (status) { case CUDNN_STATUS_SUCCESS: @@ -117,14 +117,34 @@ static port::ThreadPool* GetCudaThreadpool() { return cudnn_threadpool; } +// Retrieves the CUDNN DSO, dies on failure. +void* GetDsoHandle() { + static auto result = internal::CachedDsoLoader::GetCudnnDsoHandle(); + return result.ValueOrDie(); +} + +// Calls cudnnGetVersion in the loaded DSO. +size_t cudnnGetVersion() { + static void* f = dlsym(GetDsoHandle(), "cudnnGetVersion"); + if (f == nullptr) { + LOG(FATAL) << "could not find cudnnGetVersion in cudnn DSO; dlerror: " + << dlerror(); + } + auto callable = reinterpret_cast<size_t (*)(void)>(f); + return callable(); +} + +// Returns whether the currently loaded cuDNN version is R2. +bool IsCudnnR2() { + static auto version = cudnnGetVersion(); + DCHECK_GE(version, 2000); + return version < 3000; +} + #define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \ struct DynLoadShim__##__name { \ static const char* kName; \ typedef std::add_pointer<decltype(::__name)>::type FuncPointerT; \ - static void* GetDsoHandle() { \ - static auto result = internal::CachedDsoLoader::GetCudnnDsoHandle(); \ - return result.ValueOrDie(); \ - } \ static FuncPointerT DynLoad() { \ static void* f = dlsym(GetDsoHandle(), kName); \ if (f == nullptr) { \ @@ -154,41 +174,53 @@ static port::ThreadPool* GetCudaThreadpool() { } __name; \ const char* DynLoadShim__##__name::kName = #__name; -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetTensor4dDescriptor) __macro( \ - cudnnGetConvolutionNdForwardOutputDim) \ - __macro(cudnnGetConvolutionForwardAlgorithm) __macro( \ - cudnnCreateTensorDescriptor) __macro(cudnnDestroyTensorDescriptor) \ - __macro(cudnnCreateFilterDescriptor) \ - __macro(cudnnSetFilter4dDescriptor) \ - __macro(cudnnSetPooling2dDescriptor) \ - __macro(cudnnDestroyFilterDescriptor) \ - __macro(cudnnCreateConvolutionDescriptor) \ - __macro(cudnnCreatePoolingDescriptor) \ - __macro(cudnnAddTensor) \ - __macro(cudnnDestroyPoolingDescriptor) +// clang-format off +#define CUDNN_DNN_ROUTINE_EACH(__macro) \ + __macro(cudnnSetTensor4dDescriptor) \ + __macro(cudnnGetConvolutionNdForwardOutputDim) \ + __macro(cudnnGetConvolutionForwardAlgorithm) \ + __macro(cudnnCreateTensorDescriptor) \ + __macro(cudnnDestroyTensorDescriptor) \ + __macro(cudnnCreateFilterDescriptor) \ + __macro(cudnnSetFilter4dDescriptor) \ + __macro(cudnnSetPooling2dDescriptor) \ + __macro(cudnnDestroyFilterDescriptor) \ + __macro(cudnnCreateConvolutionDescriptor) \ + __macro(cudnnCreatePoolingDescriptor) \ + __macro(cudnnAddTensor) \ + __macro(cudnnDestroyPoolingDescriptor) \ + __macro(cudnnSetConvolution2dDescriptor) \ + __macro(cudnnDestroyConvolutionDescriptor) \ + __macro(cudnnCreate) \ + __macro(cudnnDestroy) \ + __macro(cudnnSetStream) \ + __macro(cudnnActivationForward) \ + __macro(cudnnConvolutionForward) \ + __macro(cudnnConvolutionBackwardData) \ + __macro(cudnnConvolutionBackwardFilter) \ + __macro(cudnnGetConvolutionForwardWorkspaceSize) \ + __macro(cudnnTransformTensor) \ + __macro(cudnnPoolingForward) \ + __macro(cudnnPoolingBackward) +// clang-format on CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) -#undef CUDNN_DNN_ROUTINE_EACH // clang-format off -#define CUDNN_DNN_ROUTINE_EACH(__macro) \ - __macro(cudnnSetConvolution2dDescriptor) \ - __macro(cudnnDestroyConvolutionDescriptor) \ - __macro(cudnnCreate) \ - __macro(cudnnDestroy) \ - __macro(cudnnSetStream) \ - __macro(cudnnActivationForward) \ - __macro(cudnnConvolutionForward) \ - __macro(cudnnConvolutionBackwardData) \ - __macro(cudnnConvolutionBackwardFilter) \ - __macro(cudnnGetConvolutionForwardWorkspaceSize) \ - __macro(cudnnTransformTensor) \ - __macro(cudnnPoolingForward) \ - __macro(cudnnPoolingBackward) +#if CUDNN_VERSION >= 3000 +#define CUDNN_DNN_ROUTINE_EACH_R3(__macro) \ + __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \ + __macro(cudnnGetConvolutionBackwardDataAlgorithm) \ + __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \ + __macro(cudnnAddTensor_v3) \ + __macro(cudnnConvolutionBackwardData_v3) \ + __macro(cudnnConvolutionBackwardFilter_v3) \ + __macro(cudnnGetConvolutionBackwardDataWorkspaceSize) // clang-format on -CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP) +#undef CUDNN_DNN_ROUTINE_EACH_R3 +#endif #undef CUDNN_DNN_ROUTINE_EACH } // namespace dynload @@ -467,8 +499,8 @@ bool CudnnSupport::DoConvolve( const FilterDescriptor& filter_descriptor, const DeviceMemory<float>& filter_data, const ConvolutionDescriptor& convolution_descriptor, - const BatchDescriptor& output_descriptor, - DeviceMemory<float>* output_data) { + const BatchDescriptor& output_descriptor, DeviceMemory<float>* output_data, + ScratchAllocator* scratch_allocator) { ScopedTensorDescriptor input_4d{parent_, batch_descriptor, CUDNN_DATA_FLOAT}; ScopedTensorDescriptor output_4d{parent_, output_descriptor, CUDNN_DATA_FLOAT}; @@ -486,23 +518,62 @@ bool CudnnSupport::DoConvolve( // Beta is the scaling factor for output. float beta = 0.0; - // The NO_WORKSPACE versions are possibly slower for certain shapes, but - // not so for the shapes currently used by Brain. Also, it seems prudent to - // keep cuMemAlloc off the critical path. - cudnnConvolutionFwdAlgo_t algo; - status = dynload::cudnnGetConvolutionForwardAlgorithm( - parent_, ToHandle(dnn_handle_), input_4d.handle(), filter.handle(), - conv.handle(), output_4d.handle(), CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0, - &algo); + auto get_algorithm = [&](bool specify_limit) + SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) { + cudnnConvolutionFwdPreference_t preference = + specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE; + + auto memory_limit_bytes = + scratch_allocator == nullptr + ? 0 + : scratch_allocator->GetMemoryLimitInBytes(stream); + if (memory_limit_bytes < 0) { + memory_limit_bytes = 0; + } + + cudnnConvolutionFwdAlgo_t algo; + status = dynload::cudnnGetConvolutionForwardAlgorithm( + parent_, ToHandle(dnn_handle_), input_4d.handle(), filter.handle(), + conv.handle(), output_4d.handle(), + /*preference=*/preference, + /*memoryLimitInBytes=*/memory_limit_bytes, /*algo=*/&algo); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " + "algorithm for doing forward " + "convolution"; + return algo; + }; + + auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); + + DeviceMemory<uint8> scratch; + if (scratch_allocator != nullptr) { + size_t size_in_bytes; + status = dynload::cudnnGetConvolutionForwardWorkspaceSize( + parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_4d.handle(), + /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(), + /*destDesc=*/output_4d.handle(), /*algo=*/algo, + /*sizeInBytes=*/&size_in_bytes); + if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) { + scratch = + scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie(); + } + } - CHECK_EQ(status, CUDNN_STATUS_SUCCESS) - << "Unable to find a suitable algorithm for doing forward convolution"; + // If we didn't allocate any scratch space (perhaps because of failed + // allocation), we force a switch back to the "no workspace" algorithm. + if (scratch == nullptr) { + algo = get_algorithm(/*specify_limit=*/false); + } status = dynload::cudnnConvolutionForward( - parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(), - input_data.opaque(), filter.handle(), filter_data.opaque(), conv.handle(), - algo, nullptr /* workspace ptr */, 0 /* workspace size */, &beta, - output_4d.handle(), output_data->opaque()); + parent_, ToHandle(dnn_handle_), + /*alpha=*/&alpha, /*srcDesc=*/input_4d.handle(), + /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(), + /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(), + /*algo=*/algo, /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta, + /*destDesc=*/output_4d.handle(), /*destData=*/output_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "failed to enqueue convolution on stream: " @@ -565,7 +636,8 @@ bool CudnnSupport::DoConvolveBackwardData( DeviceMemory<float> backward_output_data, const ConvolutionDescriptor& convolution_descriptor, const BatchDescriptor& input_descriptor, - DeviceMemory<float>* backward_input_data) { + DeviceMemory<float>* backward_input_data, + ScratchAllocator* scratch_allocator) { mutex_lock lock{dnn_handle_mutex_}; auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_), AsCUDAStreamValue(stream)); @@ -592,16 +664,101 @@ bool CudnnSupport::DoConvolveBackwardData( ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT}; ScopedConvolutionDescriptor conv{parent_, convolution_descriptor}; - status = dynload::cudnnConvolutionBackwardData( - parent_, ToHandle(dnn_handle_), &alpha, filter.handle(), - filter_data.opaque(), out_back_4d.handle(), backward_output_data.opaque(), - conv.handle(), &beta, in_back_4d.handle(), backward_input_data->opaque()); +#if CUDNN_VERSION >= 3000 + if (dynload::IsCudnnR2()) { +#endif + status = dynload::cudnnConvolutionBackwardData( + parent_, ToHandle(dnn_handle_), &alpha, filter.handle(), + filter_data.opaque(), out_back_4d.handle(), + backward_output_data.opaque(), conv.handle(), &beta, + in_back_4d.handle(), backward_input_data->opaque()); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "failed to enqueue convolution on stream: " + << ToString(status); + return false; + } + return true; +#if CUDNN_VERSION >= 3000 + } +#endif + +#if CUDNN_VERSION >= 3000 + auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED( + dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t { + cudnnConvolutionBwdDataPreference_t preference = + specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE; + + auto memory_limit_bytes = + scratch_allocator == nullptr + ? 0 + : scratch_allocator->GetMemoryLimitInBytes(stream); + if (memory_limit_bytes < 0) { + memory_limit_bytes = 0; + } + + cudnnConvolutionBwdDataAlgo_t algo; + cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardDataAlgorithm( + parent_, ToHandle(dnn_handle_), + /*filterDesc=*/filter.handle(), + /*diffDesc=*/out_back_4d.handle(), + /*convDesc=*/conv.handle(), + /*gradDesc=*/in_back_4d.handle(), + /*preference=*/preference, + /*memoryLimitInBytes=*/memory_limit_bytes, + /*algo=*/&algo); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " + "algorithm for doing backward " + "filter convolution"; + return algo; + }; + + auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); + + DeviceMemory<uint8> scratch; + if (scratch_allocator != nullptr) { + size_t size_in_bytes; + status = dynload::cudnnGetConvolutionBackwardDataWorkspaceSize( + parent_, ToHandle(dnn_handle_), + /*filterDesc=*/filter.handle(), + /*diffDesc=*/out_back_4d.handle(), + /*convDesc=*/conv.handle(), + /*gradDesc=*/in_back_4d.handle(), + /*algo=*/algo, + /*sizeInBytes=*/&size_in_bytes); + if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) { + scratch = + scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie(); + } + } + + // If we didn't allocate any scratch space (perhaps because of failed + // allocation), we force a switch back to the "no workspace" algorithm. + if (scratch == nullptr) { + algo = get_algorithm(/*specify_limit=*/false); + } + + status = dynload::cudnnConvolutionBackwardData_v3( + parent_, ToHandle(dnn_handle_), + /*alpha=*/&alpha, + /*filterDesc=*/filter.handle(), + /*filterData=*/filter_data.opaque(), + /*diffDesc=*/out_back_4d.handle(), + /*diffData=*/backward_output_data.opaque(), + /*convDesc=*/conv.handle(), + /*algo=*/algo, + /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), + /*beta=*/&beta, + /*gradDesc=*/in_back_4d.handle(), + /*gradData=*/backward_input_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "failed to enqueue convolution on stream: " << ToString(status); return false; } return true; +#endif } bool CudnnSupport::DoConvolveBackwardFilter( @@ -611,7 +768,8 @@ bool CudnnSupport::DoConvolveBackwardFilter( DeviceMemory<float> backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemory<float>* backward_filter_data) { + DeviceMemory<float>* backward_filter_data, + ScratchAllocator* scratch_allocator) { mutex_lock lock{dnn_handle_mutex_}; auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_), AsCUDAStreamValue(stream)); @@ -637,16 +795,100 @@ bool CudnnSupport::DoConvolveBackwardFilter( ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT}; ScopedConvolutionDescriptor conv{parent_, convolution_descriptor}; - status = dynload::cudnnConvolutionBackwardFilter( - parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(), - input_data.opaque(), out_back_4d.handle(), backward_output_data.opaque(), - conv.handle(), &beta, filter.handle(), backward_filter_data->opaque()); +#if CUDNN_VERSION >= 3000 + if (dynload::IsCudnnR2()) { +#endif + status = dynload::cudnnConvolutionBackwardFilter( + parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(), + input_data.opaque(), out_back_4d.handle(), + backward_output_data.opaque(), conv.handle(), &beta, filter.handle(), + backward_filter_data->opaque()); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(FATAL) << "failed to enqueue convolution on stream: " + << ToString(status); + return false; + } + return true; +#if CUDNN_VERSION >= 3000 + } +#endif + +#if CUDNN_VERSION >= 3000 + // Lambda that retrieves the algorithm. + // specify_limit will occur when we have a scratch allocator and it succeeds + // in allocating; otherwise, we'll fall back to the "no workspace" version. + auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED( + dnn_handle_mutex_) { + cudnnConvolutionBwdFilterPreference_t preference = + specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT + : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE; + + auto memory_limit_bytes = + scratch_allocator == nullptr + ? 0 + : scratch_allocator->GetMemoryLimitInBytes(stream); + if (memory_limit_bytes < 0) { + memory_limit_bytes = 0; + } + + cudnnConvolutionBwdFilterAlgo_t algo; + cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardFilterAlgorithm( + parent_, ToHandle(dnn_handle_), + /*srcDesc=*/input_4d.handle(), + /*diffDesc=*/out_back_4d.handle(), + /*convDesc=*/conv.handle(), + /*gradDesc=*/filter.handle(), + /*preference=*/preference, + /*memoryLimitInBytes=*/memory_limit_bytes, + /*algo=*/&algo); + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable " + "algorithm for doing backward " + "filter convolution"; + return algo; + }; + + auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr); + + DeviceMemory<uint8> scratch; + if (scratch_allocator != nullptr) { + size_t size_in_bytes; + status = dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize( + parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_4d.handle(), + /*diffDesc=*/out_back_4d.handle(), /*convDesc=*/conv.handle(), + /*gradDesc=*/filter.handle(), /*algo=*/algo, + /*sizeInBytes=*/&size_in_bytes); + if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) { + scratch = + scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie(); + } + } + + // If we didn't allocate any scratch space (perhaps because of failed + // allocation), we force a switch back to the "no workspace" algorithm. + if (scratch == nullptr) { + algo = get_algorithm(/*specify_limit=*/false); + } + + status = dynload::cudnnConvolutionBackwardFilter_v3( + parent_, ToHandle(dnn_handle_), /*alpha=*/&alpha, + /*srcDesc=*/input_4d.handle(), + /*srcData=*/input_data.opaque(), + /*diffDesc=*/out_back_4d.handle(), + /*diffData=*/backward_output_data.opaque(), + /*convDesc=*/conv.handle(), + /*algo=*/algo, + /*workSpace=*/scratch.opaque(), + /*workSpaceSizeInBytes=*/scratch.size(), + /*beta=*/&beta, + /*gradDesc=*/filter.handle(), + /*gradData=*/backward_filter_data->opaque()); if (status != CUDNN_STATUS_SUCCESS) { LOG(FATAL) << "failed to enqueue convolution on stream: " << ToString(status); return false; } return true; +#endif } bool CudnnSupport::DoMatMul(Stream* stream, @@ -800,7 +1042,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, ScopedTensorDescriptor bias_descriptor{parent_, bias_dimensions, CUDNN_DATA_FLOAT}; - // cudnnAddTensor is in-place, so we need to copy input_data to + // cudnnAddTensor_v3 is in-place, so we need to copy input_data to // output_data before doing the addition, unless the input and // output are at the same address. if (input_data.opaque() != output_data->opaque()) { @@ -815,13 +1057,30 @@ bool CudnnSupport::DoBiasAdd(Stream* stream, } mutex_lock lock{dnn_handle_mutex_}; + auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_), + AsCUDAStreamValue(stream)); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); + return false; + } const float alpha = 1.0f; const float beta = 1.0f; - auto status = dynload::cudnnAddTensor( - parent_, ToHandle(dnn_handle_), CUDNN_ADD_SAME_C, &alpha, - bias_descriptor.handle(), biases.opaque(), &beta, - input_descriptor.handle(), output_data->opaque()); +#if CUDNN_VERSION >= 3000 + if (dynload::IsCudnnR2()) { +#endif + status = dynload::cudnnAddTensor( + parent_, ToHandle(dnn_handle_), CUDNN_ADD_SAME_C, &alpha, + bias_descriptor.handle(), biases.opaque(), &beta, + input_descriptor.handle(), output_data->opaque()); +#if CUDNN_VERSION >= 3000 + } else { + status = dynload::cudnnAddTensor_v3( + parent_, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(), + biases.opaque(), &beta, input_descriptor.handle(), + output_data->opaque()); + } +#endif if (status != CUDNN_STATUS_SUCCESS) { LOG(ERROR) << "stream " << stream << " could not enqueue bias addition."; @@ -970,7 +1229,52 @@ bool CudnnSupport::DoDepthConcatenate( Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions, port::ArraySlice<const DeviceMemory<float>*> input_data, DeviceMemory<float>* output_data) { - LOG(FATAL) << "not yet implemented"; // TODO(leary) + CHECK_EQ(input_dimensions.size(), input_data.size()); + + for (const auto& dimensions : input_dimensions) { + if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) { + LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only " + "supports the kBatchDepthYX layout."; + return false; + } + } + + if (input_dimensions.empty()) { + return true; // Nothing to do. + } + + dnn::BatchDescriptor output_dimensions = + dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions); + + const int64 area = output_dimensions.width() * output_dimensions.height(); + const auto index = [area](int64 batch, int64 depth, int64 yx, + int64 max_depth) { + return (batch * max_depth + depth) * area + yx; + }; + + std::vector<float> output_host(output_dimensions.ElementCount()); + std::vector<float> tmp; + int64 depth_sum = 0; + for (size_t i = 0; i < input_data.size(); ++i) { + const auto& dimensions = input_dimensions[i]; + tmp.resize(dimensions.ElementCount()); + stream->ThenMemcpyD2H<float>(*input_data[i], &tmp).BlockHostUntilDone(); + + for (int64 batch = 0; batch < output_dimensions.count(); ++batch) { + for (int64 yx = 0; yx < area; ++yx) { + for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) { + LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' ' + << yx << ' ' << depth; + output_host[index(batch, depth + depth_sum, yx, + output_dimensions.feature_map_count())] = + tmp[index(batch, depth, yx, dimensions.feature_map_count())]; + } + } + } + depth_sum += dimensions.feature_map_count(); + } + stream->ThenMemcpyH2D<float>(output_host, output_data); + return true; } bool CudnnSupport::DoElementwiseOperate( @@ -982,29 +1286,30 @@ bool CudnnSupport::DoElementwiseOperate( LOG(FATAL) << "not yet implemented"; // TODO(leary) } -bool CudnnSupport::DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory<float>& gpu_unquantized_src, - port::MutableArraySlice<uint8> host_dst) { - LOG(ERROR) << "quantized memcpy not supported by cuDNN"; - return false; +bool CudnnSupport::DoXYPad( + Stream* stream, const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, int64 left_pad, int64 right_pad, + int64 top_pad, int64 bottom_pad, DeviceMemory<float> *output_data) { + LOG(FATAL) << "not yet implemented"; // TODO(leary) } -bool CudnnSupport::DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory<float>& device_unquantized_src, - port::MutableArraySlice<uint16> host_dst) { - LOG(ERROR) << "quantized memcpy not supported by cuDNN"; - return false; +bool CudnnSupport::DoXYSlice( + Stream* stream, const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, int64 left_trim, int64 right_trim, + int64 top_trim, int64 bottom_trim, DeviceMemory<float> *output_data) { + LOG(FATAL) << "not yet implemented"; // TODO(leary) } bool CudnnSupport::DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory<float>& device_unquantized_src, - port::MutableArraySlice<int32> host_dst) { + Stream* stream, const DeviceMemory<float>& gpu_unquantized_src, + dnn::QuantizedActivationMode mode, void* host_dst, int64 size) { LOG(ERROR) << "quantized memcpy not supported by cuDNN"; return false; } bool CudnnSupport::DoMemcpyH2DQuantized( - Stream* stream, port::ArraySlice<uint8> host_src, + Stream* stream, const void* host_src, int64 size, + dnn::QuantizedActivationMode mode, DeviceMemory<float>* gpu_unquantized_dst) { LOG(ERROR) << "quantized memcpy not supported by cuDNN"; return false; @@ -1074,7 +1379,7 @@ void initialize_cudnn() { // Prime the cuDNN DSO. The loader will log more information. auto statusor = gpu::internal::CachedDsoLoader::GetCudnnDsoHandle(); if (!statusor.ok()) { - LOG(INFO) << "Unable to load cuDNN DSO."; + LOG(INFO) << "Unable to load cuDNN DSO"; } gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId, diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 4103cf8e06..37be60ec63 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -50,7 +50,8 @@ class CudnnSupport : public dnn::DnnSupport { const DeviceMemory<float>& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory<float>* output_data) override; + DeviceMemory<float>* output_data, + ScratchAllocator* scratch_allocator) override; bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor, const DeviceMemory<double>& input_data, @@ -80,7 +81,8 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory<float> backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::BatchDescriptor& input_descriptor, - DeviceMemory<float>* backward_input_data) override; + DeviceMemory<float>* backward_input_data, + ScratchAllocator* scratch_allocator) override; bool DoConvolveBackwardFilter( Stream* stream, const dnn::BatchDescriptor& input_descriptor, @@ -89,7 +91,8 @@ class CudnnSupport : public dnn::DnnSupport { DeviceMemory<float> backward_output_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::FilterDescriptor& filter_descriptor, - DeviceMemory<float>* backward_filter_data) override; + DeviceMemory<float>* backward_filter_data, + ScratchAllocator* scratch_allocator) override; bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data, const DeviceMemory<float>& weights, @@ -160,20 +163,24 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, DeviceMemory<float>* output_data) override; - bool DoMemcpyD2HQuantized(Stream* stream, - const DeviceMemory<float>& device_unquantized_src, - port::MutableArraySlice<uint8> host_dst) override; + bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, + int64 left_pad, int64 right_pad, int64 top_pad, + int64 bottom_pad, DeviceMemory<float> *output_data) override; - bool DoMemcpyD2HQuantized(Stream* stream, - const DeviceMemory<float>& device_unquantized_src, - port::MutableArraySlice<uint16> host_dst) override; + bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, + int64 left_trim, int64 right_trim, int64 top_trim, + int64 bottom_trim, DeviceMemory<float> *output_data) override; bool DoMemcpyD2HQuantized(Stream* stream, const DeviceMemory<float>& device_unquantized_src, - port::MutableArraySlice<int32> host_dst) override; + dnn::QuantizedActivationMode mode, void* host_dst, + int64 size) override; bool DoMemcpyH2DQuantized( - Stream* stream, port::ArraySlice<uint8> host_src, + Stream* stream, const void* host_src, int64 size, + dnn::QuantizedActivationMode mode, DeviceMemory<float>* device_unquantized_dst) override; // Derives an output batch descriptor from an input batch and convolution diff --git a/tensorflow/stream_executor/cuda/cuda_event.cc b/tensorflow/stream_executor/cuda/cuda_event.cc index c952257aa5..ad7c0bbf3b 100644 --- a/tensorflow/stream_executor/cuda/cuda_event.cc +++ b/tensorflow/stream_executor/cuda/cuda_event.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_event.h" +#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/lib/statusor.h" diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc index 4ee549ab03..b5740a1d7b 100644 --- a/tensorflow/stream_executor/cuda/cuda_fft.cc +++ b/tensorflow/stream_executor/cuda/cuda_fft.cc @@ -22,7 +22,8 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_helpers.h" -#include "tensorflow/stream_executor/cuda/cuda_platform.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dso_loader.h" #include "tensorflow/stream_executor/lib/initialize.h" diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index 93dc90635e..2565078bb2 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" #include "tensorflow/stream_executor/cuda/cuda_driver.h" #include "tensorflow/stream_executor/cuda/cuda_event.h" -#include "tensorflow/stream_executor/cuda/cuda_platform.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/cuda/cuda_timer.h" #include "tensorflow/stream_executor/dso_loader.h" @@ -88,20 +88,6 @@ static CUDAEvent *AsCUDAEvent(Event *event) { return static_cast<CUDAEvent *>(event->implementation()); } -// Given a platform-independent stream datatype, returns the internal CUDA -// platform implementation pointer. -static CUDAStream *AsCUDAStream(Stream *stream) { - DCHECK(stream != nullptr); - return static_cast<CUDAStream *>(stream->implementation()); -} - -// Given a platform-independent stream datatype, returns the platform -// implementation's internal value, suitable for passing directly to libcuda -// APIs. -CUstream AsCUDAStreamValue(Stream *stream) { - DCHECK(stream != nullptr); - return AsCUDAStream(stream)->cuda_stream(); -} // Given a platform-independent timer datatype, returns the internal CUDA // platform implementation pointer. @@ -861,6 +847,26 @@ bool CUDAExecutor::SupportsFft() const { return true; } bool CUDAExecutor::SupportsRng() const { return true; } +std::unique_ptr<internal::EventInterface> +CUDAExecutor::CreateEventImplementation() { + return std::unique_ptr<internal::EventInterface>(new CUDAEvent(this)); +} + +std::unique_ptr<internal::KernelInterface> +CUDAExecutor::CreateKernelImplementation() { + return std::unique_ptr<internal::KernelInterface>(new CUDAKernel()); +} + +std::unique_ptr<internal::StreamInterface> +CUDAExecutor::GetStreamImplementation() { + return std::unique_ptr<internal::StreamInterface>(new CUDAStream(this)); +} + +std::unique_ptr<internal::TimerInterface> +CUDAExecutor::GetTimerImplementation() { + return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this)); +} + void *CUDAExecutor::CudaContextHack() { return context_; } CUcontext CUDAExecutor::cuda_context() { return context_; } @@ -1064,30 +1070,6 @@ void initialize_cuda_gpu_executor() { const gpu::PluginConfig &config) { return new gpu::cuda::CUDAExecutor{config}; }; - - *gpu::internal::MakeCUDAKernelImplementation() = []() { - return new gpu::cuda::CUDAKernel; - }; - - *gpu::internal::MakeCUDAEventImplementation() = []( - gpu::StreamExecutor *parent) { - gpu::cuda::CUDAExecutor *cuda_executor = - static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation()); - return new gpu::cuda::CUDAEvent{cuda_executor}; - }; - - *gpu::internal::MakeCUDAStreamImplementation() = []( - gpu::StreamExecutor *parent) { - gpu::cuda::CUDAExecutor *cuda_executor = - static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation()); - return new gpu::cuda::CUDAStream{cuda_executor}; - }; - *gpu::internal::MakeCUDATimerImplementation() = []( - gpu::StreamExecutor *parent) { - gpu::cuda::CUDAExecutor *cuda_executor = - static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation()); - return new gpu::cuda::CUDATimer{cuda_executor}; - }; } } // namespace gputools diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 01ccf82ec6..2a0c6dc456 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -203,6 +203,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface { dnn::DnnSupport *CreateDnn() override; + std::unique_ptr<internal::EventInterface> CreateEventImplementation() + override; + + std::unique_ptr<internal::KernelInterface> CreateKernelImplementation() + override; + + std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override; + + std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override; + void *CudaContextHack() override; CUcontext cuda_context(); diff --git a/tensorflow/stream_executor/cuda/cuda_helpers.h b/tensorflow/stream_executor/cuda/cuda_helpers.h index dad62900f1..c52516c589 100644 --- a/tensorflow/stream_executor/cuda/cuda_helpers.h +++ b/tensorflow/stream_executor/cuda/cuda_helpers.h @@ -30,7 +30,6 @@ limitations under the License. namespace perftools { namespace gputools { -class Stream; template <typename ElemT> class DeviceMemory; @@ -51,8 +50,6 @@ T *CUDAMemoryMutable(DeviceMemory<T> *mem) { return static_cast<T *>(mem->opaque()); } -CUstream AsCUDAStreamValue(Stream *stream); - static_assert(sizeof(std::complex<float>) == sizeof(cuComplex), "std::complex<float> and cuComplex should have the same size"); static_assert(offsetof(cuComplex, x) == 0, diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc index 172f9e2cf8..f658ce216d 100644 --- a/tensorflow/stream_executor/cuda/cuda_platform.cc +++ b/tensorflow/stream_executor/cuda/cuda_platform.cc @@ -16,6 +16,8 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_platform.h" #include "tensorflow/stream_executor/cuda/cuda_driver.h" +#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/initialize.h" #include "tensorflow/stream_executor/lib/ptr_util.h" @@ -26,8 +28,6 @@ namespace perftools { namespace gputools { namespace cuda { -PLATFORM_DEFINE_ID(kCudaPlatformId); - CudaPlatform::CudaPlatform() : name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {} @@ -147,8 +147,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor( port::StatusOr<std::unique_ptr<StreamExecutor>> CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) { - auto executor = port::MakeUnique<StreamExecutor>(PlatformKind::kCuda, - config.plugin_config); + auto executor = port::MakeUnique<StreamExecutor>( + this, new CUDAExecutor(config.plugin_config)); auto init_status = executor->Init(config.ordinal, config.device_options); if (!init_status.ok()) { return port::Status{ diff --git a/tensorflow/stream_executor/cuda/cuda_platform_id.cc b/tensorflow/stream_executor/cuda/cuda_platform_id.cc new file mode 100644 index 0000000000..09ece156d2 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_platform_id.cc @@ -0,0 +1,26 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" + +namespace perftools { +namespace gputools { +namespace cuda { + +PLATFORM_DEFINE_ID(kCudaPlatformId); + +} // namespace cuda +} // namespace gputools +} // namespace perftools diff --git a/tensorflow/stream_executor/cuda/cuda_platform_id.h b/tensorflow/stream_executor/cuda/cuda_platform_id.h new file mode 100644 index 0000000000..c91ccc0e44 --- /dev/null +++ b/tensorflow/stream_executor/cuda/cuda_platform_id.h @@ -0,0 +1,36 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_ +#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_ + +#include "tensorflow/stream_executor/platform.h" + +namespace perftools { +namespace gputools { +namespace cuda { + +// Opaque and unique identifier for the cuda platform. +// This is needed so that plugins can refer to/identify this platform without +// instantiating a CudaPlatform object. +// This is broken out here to avoid a circular dependency between CudaPlatform +// and CudaExecutor. +extern const Platform::Id kCudaPlatformId; + +} // namespace cuda +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_ diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc index 3d244ed2e7..220ca85df9 100644 --- a/tensorflow/stream_executor/cuda/cuda_rng.cc +++ b/tensorflow/stream_executor/cuda/cuda_rng.cc @@ -20,7 +20,8 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_activation.h" #include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/cuda/cuda_helpers.h" -#include "tensorflow/stream_executor/cuda/cuda_platform.h" +#include "tensorflow/stream_executor/cuda/cuda_platform_id.h" +#include "tensorflow/stream_executor/cuda/cuda_stream.h" #include "tensorflow/stream_executor/device_memory.h" #include "tensorflow/stream_executor/dso_loader.h" #include "tensorflow/stream_executor/lib/initialize.h" diff --git a/tensorflow/stream_executor/cuda/cuda_stream.cc b/tensorflow/stream_executor/cuda/cuda_stream.cc index c2d8f95e9c..3bc42982e0 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.cc +++ b/tensorflow/stream_executor/cuda/cuda_stream.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/stream_executor/cuda/cuda_stream.h" +#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" #include "tensorflow/stream_executor/lib/status.h" +#include "tensorflow/stream_executor/stream.h" namespace perftools { namespace gputools { @@ -61,6 +63,16 @@ bool CUDAStream::GetOrCreateCompletedEvent(CUevent *completed_event) { return true; } +CUDAStream *AsCUDAStream(Stream *stream) { + DCHECK(stream != nullptr); + return static_cast<CUDAStream *>(stream->implementation()); +} + +CUstream AsCUDAStreamValue(Stream *stream) { + DCHECK(stream != nullptr); + return AsCUDAStream(stream)->cuda_stream(); +} + } // namespace cuda } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h index fda93a1a01..cb12a094e6 100644 --- a/tensorflow/stream_executor/cuda/cuda_stream.h +++ b/tensorflow/stream_executor/cuda/cuda_stream.h @@ -20,7 +20,7 @@ limitations under the License. #define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_ #include "tensorflow/stream_executor/cuda/cuda_driver.h" -#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h" +#include "tensorflow/stream_executor/platform/thread_annotations.h" #include "tensorflow/stream_executor/stream_executor_internal.h" namespace perftools { @@ -82,6 +82,13 @@ class CUDAStream : public internal::StreamInterface { CUevent completed_event_ GUARDED_BY(mu_); }; +// Helper functions to simplify extremely common flows. +// Converts a Stream to the underlying CUDAStream implementation. +CUDAStream *AsCUDAStream(Stream *stream); + +// Extracts a CUstream from a CUDAStream-backed Stream object. +CUstream AsCUDAStreamValue(Stream *stream); + } // namespace cuda } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/device_memory.h b/tensorflow/stream_executor/device_memory.h index b5125c27e1..54f348c094 100644 --- a/tensorflow/stream_executor/device_memory.h +++ b/tensorflow/stream_executor/device_memory.h @@ -111,6 +111,7 @@ class DeviceMemory final : public DeviceMemoryBase { public: // Default constructor instantiates a null-pointed, zero-sized memory region. DeviceMemory() : DeviceMemoryBase(nullptr, 0) {} + DeviceMemory(std::nullptr_t) : DeviceMemory() {} // Typed device memory regions may be constructed from untyped device memory // regions, this effectively amounts to a cast from a void*. diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc index 97a15dd263..fbc9342081 100644 --- a/tensorflow/stream_executor/dnn.cc +++ b/tensorflow/stream_executor/dnn.cc @@ -22,6 +22,20 @@ namespace perftools { namespace gputools { namespace dnn { +string QuantizedActivationModeString(QuantizedActivationMode mode) { + switch (mode) { + case dnn::QuantizedActivationMode::k8Bit: + return "uint8"; + case dnn::QuantizedActivationMode::k16Bit: + return "uint16"; + case dnn::QuantizedActivationMode::k32Bit: + return "int32"; + default: + LOG(FATAL) << "Unknown quantized_activation_mode " + << static_cast<int32>(mode); + } +} + string ActivationModeString(ActivationMode mode) { switch (mode) { case ActivationMode::kSigmoid: @@ -78,6 +92,17 @@ string FilterLayoutString(FilterLayout layout) { } } +string ShortPoolingModeString(PoolingMode mode) { + switch (mode) { + case PoolingMode::kMaximum: + return "Max"; + case PoolingMode::kAverage: + return "Avg"; + default: + LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode); + } +} + // -- BatchDescriptor BatchDescriptor::BatchDescriptor() @@ -137,7 +162,6 @@ string BatchDescriptor::ToShortString() const { return port::StrCat(batch, depth, y, x, suffix); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout()); - return ""; // Avoid lack-of-return warning } } @@ -160,6 +184,20 @@ int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) { return output.NodesAcrossFeatureMaps(); } +BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor( + port::ArraySlice<dnn::BatchDescriptor> inputs) { + if (inputs.empty()) { + return BatchDescriptor(); + } + int feature_map_count = 0; + for (const auto& dimensions : inputs) { + feature_map_count += dimensions.feature_map_count(); + } + BatchDescriptor output = inputs[0]; + output.set_feature_map_count(feature_map_count); + return output; +} + // -- FilterDescriptor FilterDescriptor::FilterDescriptor() @@ -205,7 +243,6 @@ string FilterDescriptor::ToShortString() const { return port::StrCat(y, x, id, od); default: LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_); - return ""; // Avoid lack-of-return warning } } diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index 7d1dfe3d0e..237f60c6ca 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -32,6 +32,7 @@ namespace perftools { namespace gputools { class Stream; +class ScratchAllocator; namespace dnn { @@ -55,6 +56,9 @@ enum class QuantizedActivationMode { k32Bit = 4, }; +// Returns a string representation of the given quantization mode. +string QuantizedActivationModeString(QuantizedActivationMode mode); + // Describes the dimensions that a layer consumes/produces. // // This is a matrix (height, width), its "depth" (feature_map_count), @@ -175,6 +179,13 @@ class BatchDescriptor { // with dimensions given the 'output' descriptor. static int64 FullyConnectedBiasCount(const BatchDescriptor& output); + // Return a BatchDescriptor for the output of a depth concatenation + // with the given input descriptors. The inputs should have the same + // dimensions, except possibly for feature_map_count(), though this + // function does not verify that. + static BatchDescriptor DepthConcatenateOutputDescriptor( + port::ArraySlice<dnn::BatchDescriptor> inputs); + private: int64 count_; int64 feature_map_count_; @@ -280,8 +291,6 @@ class FilterDescriptor { int64 input_filter_height_; int64 input_filter_width_; FilterLayout layout_; - - SE_DISALLOW_COPY_AND_ASSIGN(FilterDescriptor); }; // Describes a convolution. @@ -356,6 +365,9 @@ enum class PoolingMode : int64 { kAverage, }; +// Returns a short name for the pooling mode, e.g. "Avg". +string ShortPoolingModeString(PoolingMode mode); + // Describes a pooling operation to be enqueued onto a stream via a platform's // DnnSupport. // @@ -423,18 +435,31 @@ class PoolingDescriptor { int64 horizontal_padding_; int64 vertical_stride_; int64 horizontal_stride_; - - SE_DISALLOW_COPY_AND_ASSIGN(PoolingDescriptor); }; -// Describes a dist_belief local response normalization. -// The normalization equation is: -// y_i = x_i / (bias + alpha * (sum_j_{i - range}^{i + range} x_j^2)) ^ beta -// where x_i is the input in feature map i, y_i is the output. -// Each feature map is split into segment_size segments for performing the -// sum_j_. If wrap_around is true, the sum_j_ for y_i on the left and right of -// a segment wrap around at the edges of the segment, if wrap_around is false -// zeros are inserted instead. +// Describes a local response normalization (LRN). LRN is used e.g. in +// dist_belief. +// +// Let V be the vector of feature maps at some (batch, y, x) +// coordinate. LRN applies independently to each vector V in the +// input, across all coordinates (batch, y, x), by mapping each V to +// another vector U of the same size using the formula +// +// V_i = U_i / ((bias + alpha * (sum_j U_j^2)) ^ beta) +// +// where the sum is taken for j in the inclusive range [i - range, i + range]. +// +// When calculating V_i the j in the sum can extend beyond the bounds +// of U. If wrap_around is true, then U_j = U_{j mod F} where F is the +// size of U, which is the number of feature maps. If wrap_around is +// false, then U_j = 0 for j outside [0, F-1]. +// +// If segment_size <= F, where F is the number of feature_maps, then +// segment_size has no effect. Otherwise, each consecutive segment of +// segment_size entries in V are normalized separately. +// +// Not all StreamExecutors allow wrap_around == true or segment_size +// != 64. Some do not implement normalization at all. class NormalizeDescriptor { public: NormalizeDescriptor(); @@ -488,8 +513,6 @@ class NormalizeDescriptor { float beta_; bool wrap_around_; int32 segment_size_; - - SE_DISALLOW_COPY_AND_ASSIGN(NormalizeDescriptor); }; // Describes a kind of non-linearity (threshold-like mathematical function). @@ -503,6 +526,8 @@ enum class ActivationMode { // BatchDescriptor::value_max(). kReluX, kTanh, + // Like ReluX, but passes all values in the range [-X,X]. + kBandPass, }; // Returns a string representation of the given activation mode. @@ -510,10 +535,7 @@ string ActivationModeString(ActivationMode mode); // Describes the operation that DoElementwiseOperation should perform on its // inputs. -enum class ElementwiseOperation { - kAdd, - kMultiply -}; +enum class ElementwiseOperation { kAdd, kMultiply }; string ElementwiseOperationString(ElementwiseOperation op); @@ -541,6 +563,8 @@ class DnnSupport { // output_descriptor: dimensions of the output layer. // output_data: un-owned device memory region in which to place the // convolution result. + // scratch_allocator: un-owned, may-be-null object that may allocate scratch + // space in order to speed up the convolution operation. // // input_descriptor, filter_descriptor, convolution_descriptor and // output_descriptor together specify exactly how the convolution is aligned @@ -564,7 +588,8 @@ class DnnSupport { const DeviceMemory<float>& filter_data, const dnn::ConvolutionDescriptor& convolution_descriptor, const dnn::BatchDescriptor& output_descriptor, - DeviceMemory<float>* output_data) = 0; + DeviceMemory<float>* output_data, + ScratchAllocator* scratch_allocator) = 0; // Enqueues a double-precision convolution operation onto the stream. // See DoConvolve above for argument details. @@ -612,6 +637,8 @@ class DnnSupport { // input_descriptor: dimensions of the input layer. // backward_input_data: un-owned device memory region in which to place the // backprop of the input. + // scratch_allocator: un-owned, may-be-null object that may allocate scratch + // space in order to speed up the convolution operation. virtual bool DoConvolveBackwardData( Stream* stream, const FilterDescriptor& filter_descriptor, const DeviceMemory<float>& filter_data, @@ -619,7 +646,8 @@ class DnnSupport { DeviceMemory<float> backward_output_data, const ConvolutionDescriptor& convolution_descriptor, const BatchDescriptor& input_descriptor, - DeviceMemory<float>* backward_input_data) = 0; + DeviceMemory<float>* backward_input_data, + ScratchAllocator* scratch_allocator) = 0; // Enqueues a single-precision backward convolution (for filter) operation // onto @@ -640,6 +668,8 @@ class DnnSupport { // filter_descriptor: dimensions of the convolution filter. // backward_filter_data: un-owned device memory region in which to place the // backprop of the filter. + // scratch_allocator: un-owned, may-be-null object that may allocate scratch + // space in order to speed up the convolution operation. virtual bool DoConvolveBackwardFilter( Stream* stream, const BatchDescriptor& input_descriptor, const DeviceMemory<float>& input_data, @@ -647,7 +677,8 @@ class DnnSupport { DeviceMemory<float> backward_output_data, const ConvolutionDescriptor& convolution_descriptor, const FilterDescriptor& filter_descriptor, - DeviceMemory<float>* backward_filter_data) = 0; + DeviceMemory<float>* backward_filter_data, + ScratchAllocator* scratch_allocator) = 0; // Fully connects the "nodes" (float values) in input_data with // shape input_dimensions to output_data with output_dimensions @@ -784,8 +815,10 @@ class DnnSupport { const DeviceMemory<float>& input_diff_data, DeviceMemory<float>* output_diff_data) = 0; - // Applies local response normalization to all of the values - // held on the device in 'input_data'. + // Applies local response normalization to the values from + // input_data and writes the result to output_data. See comments on + // NormalizeDescriptor for a description of local response + // normalization. virtual bool DoNormalize(Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, const DeviceMemory<float>& input_data, @@ -850,6 +883,46 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, DeviceMemory<float>* output_data) = 0; + // Pads the input with zeros in the X and Y dimensions. The feature_map + // dimension is unchanged. + // + // Arguments (all borrowed): + // stream: borrowed pointer to the stream that the 'elementwise operation' + // should be enqueued onto. + // dimensions: The dimensions of the input. + // input_data: un-owned device memory region which contains the + // input data for the input layer. + // left_pad: Amount to pad the input on the left. + // right_pad: Amount to pad the input on the right. + // top_pad: Amount to pad the input at the top (low Y). + // bottom_pad: Amount to pad the input at the bottom (high Y). + // output_data: un-owned device memory region in which to place the + // padded result. + virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, + int64 left_pad, int64 right_pad, int64 top_pad, + int64 bottom_pad, DeviceMemory<float> *output_data) = 0; + + // Extracts a slice of the input in the X and Y dimensions. The feature_map + // dimension is unchanged. + // + // Arguments (all borrowed): + // stream: borrowed pointer to the stream that the 'elementwise operation' + // should be enqueued onto. + // dimensions: The dimensions of the input. + // input_data: un-owned device memory region which contains the + // input data for the input layer. + // left_trim: Amount to cut off the input on the left. + // right_trim: Amount to cut off the input on the right. + // top_trim: Amount to cut off the input at the top (low y). + // bottom_trim: Amount to cut off the input at the bottom (high Y). + // output_data: un-owned device memory region in which to place the + // padded result. + virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, + int64 left_trim, int64 right_trim, int64 top_trim, + int64 bottom_trim, DeviceMemory<float> *output_data) = 0; + // Enqueues an asynchronous memcpy of the *quantized* output of a layer (that // is, bytes instead of scaled floats) into 'host_dst' if they are available // for the underlying DNN implementation. If this quantized output is not @@ -862,23 +935,14 @@ class DnnSupport { // gpu_unquantized_src: the device memory that contains the unquantized data // -- this data should also have a corresponding quantized representation // on the device for this operation to succeed. + // mode: Type of quantization of the data to write into host_dst. // host_dst: un-owned host memory region that is mutated in place, // it is clobbered by the values in 'gpu_unquantized_src' when the enqueued // (asynchronous) memcpy operation is performed. - // TODO(wgulland) Merge all these versions of DoMemcpyD2HQuantized. - virtual bool DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory<float>& gpu_unquantized_src, - port::MutableArraySlice<uint8> host_dst) = 0; - - // As above, but for 16-bit values. - virtual bool DoMemcpyD2HQuantized( - Stream* stream, const DeviceMemory<float>& gpu_unquantized_src, - port::MutableArraySlice<uint16> host_dst) = 0; - - // As above, but for signed 32-bit values. + // size: size in bytes of the host_dst host memory region. virtual bool DoMemcpyD2HQuantized( Stream* stream, const DeviceMemory<float>& gpu_unquantized_src, - port::MutableArraySlice<int32> host_dst) = 0; + QuantizedActivationMode mode, void* host_dst, int64 size) = 0; // Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input // of a layer (that is, bytes instead of scaled floats) if they are supported @@ -890,13 +954,16 @@ class DnnSupport { // stream: borrowed pointer to the stream that the 'quantized memcpy' // operation should be enqueued onto. // host_src: un-owned host memory region that contains the quantized data. + // size: size in bytes of the host_src host memory region. + // mode: Type of quantization of the data to read from host_src. // gpu_unquantized_dst: the device memory that is clobbered by the values in // 'host_src' when the enqueued (asynchronous) memcpy operation is // performed. -- this data should also have a corresponding quantized // representation on the device for this operation to // succeed. virtual bool DoMemcpyH2DQuantized( - Stream* stream, port::ArraySlice<uint8> host_src, + Stream* stream, const void* host_src, int64 size, + QuantizedActivationMode mode, DeviceMemory<float>* gpu_unquantized_dst) = 0; private: diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc index c8e1d7fa48..600f083840 100644 --- a/tensorflow/stream_executor/dso_loader.cc +++ b/tensorflow/stream_executor/dso_loader.cc @@ -42,11 +42,12 @@ namespace internal { } /* static */ port::Status DsoLoader::GetCudnnDsoHandle(void** dso_handle) { - // libcudnn is versioned differently than the other libraries. See b/22397368 - // for some details about the complications surrounding this. - return GetDsoHandle(FindDsoPath("libcudnn.so.6.5", - "third_party/gpus/cuda/lib64"), - dso_handle); + // libcudnn is versioned differently than the other libraries and may have a + // different version number than other CUDA libraries. See b/22397368 for + // some details about the complications surrounding this. + return GetDsoHandle( + FindDsoPath("libcudnn.so.6.5", "third_party/gpus/cuda/lib64"), + dso_handle); } /* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) { @@ -89,16 +90,16 @@ namespace internal { string path_string = path.ToString(); *dso_handle = dlopen(path_string.c_str(), dynload_flags); if (*dso_handle == nullptr) { - LOG(INFO) << "LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH"); + LOG(INFO) << "Couldn't open CUDA library " << path + << ". LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH"); // TODO(b/22689637): Eliminate unnecessary ToString once StrCat has been // moved to the open-sourceable version. return port::Status( port::error::FAILED_PRECONDITION, port::StrCat("could not dlopen DSO: ", path, "; dlerror: ", dlerror())); } - - VLOG(2) << "loaded path \"" << path << "\" " - << (load_kind == LoadKind::kLocal ? "locally" : "globally"); + LOG(INFO) << "successfully opened CUDA library " << path + << (load_kind == LoadKind::kLocal ? " locally" : " globally"); return port::Status::OK(); } diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc index 5fdd7b9021..5ded7c590b 100644 --- a/tensorflow/stream_executor/event.cc +++ b/tensorflow/stream_executor/event.cc @@ -22,21 +22,10 @@ limitations under the License. namespace perftools { namespace gputools { -internal::EventInterface* CreateEventImplementation( - StreamExecutor* stream_exec) { - PlatformKind platform_kind = stream_exec->platform_kind(); - switch (platform_kind) { - case PlatformKind::kCuda: - return (*internal::MakeCUDAEventImplementation())(stream_exec); - default: - LOG(FATAL) << "Cannot create event implementation for platform kind: " - << PlatformKindString(platform_kind); - } -} - Event::Event(StreamExecutor* stream_exec) - : implementation_(CreateEventImplementation(stream_exec)), - stream_exec_(stream_exec) {} + : stream_exec_(stream_exec), + implementation_( + stream_exec_->implementation()->CreateEventImplementation()) {} Event::~Event() { auto status = stream_exec_->DeallocateEvent(this); diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h index 42827e96aa..4b95889547 100644 --- a/tensorflow/stream_executor/event.h +++ b/tensorflow/stream_executor/event.h @@ -18,6 +18,8 @@ limitations under the License. #include <memory> +#include "tensorflow/stream_executor/platform/port.h" + namespace perftools { namespace gputools { @@ -63,13 +65,15 @@ class Event { private: friend class Stream; + // Pointer to the StreamExecutor interface used to create this object. + // Not owned. + StreamExecutor* stream_exec_; + // Pointer to the platform-specific EventInterface implementation underlying // the object. Owned. std::unique_ptr<internal::EventInterface> implementation_; - // Pointer to the StreamExecutor interface used to create this object. - // Not owned. - StreamExecutor* stream_exec_; + SE_DISALLOW_COPY_AND_ASSIGN(Event); }; } // namespace gputools diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h index 69b49ae92e..505e0e4d13 100644 --- a/tensorflow/stream_executor/fft.h +++ b/tensorflow/stream_executor/fft.h @@ -161,7 +161,7 @@ class FftSupport { // Macro used to quickly declare overrides for abstract virtuals in the // fft::FftSupport base class. Assumes that it's emitted somewhere inside the // ::perftools::gputools namespace. -#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ +#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \ std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \ fft::Type type, bool in_place_fft) \ override; \ diff --git a/tensorflow/stream_executor/gcuda.h b/tensorflow/stream_executor/gcuda.h index 4710e30009..4ed0e3d8d0 100644 --- a/tensorflow/stream_executor/gcuda.h +++ b/tensorflow/stream_executor/gcuda.h @@ -87,6 +87,8 @@ enum SharedMemConfig { #include "tensorflow/stream_executor/kernel_cache_config.h" #include "tensorflow/stream_executor/launch_dim.h" #include "tensorflow/stream_executor/machine_manager.h" +#include "tensorflow/stream_executor/multi_platform_manager.h" +#include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/shared_memory_config.h" #include "tensorflow/stream_executor/stream.h" #include "tensorflow/stream_executor/stream_executor.h" diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc index ee0b706eef..64a4e6f49e 100644 --- a/tensorflow/stream_executor/kernel.cc +++ b/tensorflow/stream_executor/kernel.cc @@ -25,7 +25,6 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/stream_executor.h" -#include "tensorflow/stream_executor/stream_executor_internal.h" namespace perftools { namespace gputools { @@ -58,29 +57,13 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) { has_shared_memory_bytes_ = true; } -static internal::KernelInterface *KernelImplementationFromPlatformKind( - PlatformKind platform_kind) { - if (platform_kind == PlatformKind::kCuda) { - return (*internal::MakeCUDAKernelImplementation())(); - } else if (platform_kind == PlatformKind::kOpenCL || - platform_kind == PlatformKind::kOpenCLAltera) { - return (*internal::MakeOpenCLKernelImplementation())(); - } else { - LOG(FATAL) << "cannot create kernel implementation for platform kind: " - << PlatformKindString(platform_kind); - } -} - KernelBase::KernelBase(StreamExecutor *parent) - : implementation_( - KernelImplementationFromPlatformKind(parent->platform_kind())), - parent_(parent) { - DCHECK(parent_ != nullptr); -} + : parent_(parent), + implementation_(parent->implementation()->CreateKernelImplementation()) {} KernelBase::KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation) - : implementation_(implementation), parent_(parent) {} + : parent_(parent), implementation_(implementation) {} KernelBase::~KernelBase() {} diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h index 1346fd7bfa..16ffeb2a57 100644 --- a/tensorflow/stream_executor/kernel.h +++ b/tensorflow/stream_executor/kernel.h @@ -141,7 +141,6 @@ class KernelBase { explicit KernelBase(StreamExecutor *parent); // Test-only constructor that can take a mock KernelInterface implementation. - // Takes ownership of implementation, it should not be null. KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation); // Releases resources associated with the kernel instance (i.e. @@ -181,12 +180,12 @@ class KernelBase { const string &demangled_name() const { return demangled_name_; } private: - // Implementation delegated to for platform-specific functionality. - std::unique_ptr<internal::KernelInterface> implementation_; - // The StreamExecutor that loads this kernel object. StreamExecutor *parent_; + // Implementation delegated to for platform-specific functionality. + std::unique_ptr<internal::KernelInterface> implementation_; + string name_; string demangled_name_; diff --git a/tensorflow/stream_executor/lib/casts.h b/tensorflow/stream_executor/lib/casts.h index 3edfcde8bb..19087ff7dc 100644 --- a/tensorflow/stream_executor/lib/casts.h +++ b/tensorflow/stream_executor/lib/casts.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" + #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_ diff --git a/tensorflow/stream_executor/lib/error.h b/tensorflow/stream_executor/lib/error.h index 368a8af79c..5aac04f525 100644 --- a/tensorflow/stream_executor/lib/error.h +++ b/tensorflow/stream_executor/lib/error.h @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" + #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_ -#include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/error_codes.pb.h" // IWYU pragma: export namespace perftools { namespace gputools { diff --git a/tensorflow/stream_executor/lib/ptr_util.h b/tensorflow/stream_executor/lib/ptr_util.h index 578ed67ade..e42de83c16 100644 --- a/tensorflow/stream_executor/lib/ptr_util.h +++ b/tensorflow/stream_executor/lib/ptr_util.h @@ -60,4 +60,5 @@ typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) = } // namespace gputools } // namespace perftools + #endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_ diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h index 0ec243c38a..af67769253 100644 --- a/tensorflow/stream_executor/lib/status.h +++ b/tensorflow/stream_executor/lib/status.h @@ -13,18 +13,20 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" + #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_ #include "tensorflow/core/public/status.h" -#include "tensorflow/stream_executor/lib/error.h" +#include "tensorflow/stream_executor/lib/error.h" // IWYU pragma: export #include "tensorflow/stream_executor/platform/logging.h" namespace perftools { namespace gputools { namespace port { -using tensorflow::Status; +using Status = tensorflow::Status; #define SE_CHECK_OK(val) \ CHECK_EQ(::perftools::gputools::port::Status::OK(), (val)) diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h index 6c0fd0f9cd..d9b7787e30 100644 --- a/tensorflow/stream_executor/lib/statusor.h +++ b/tensorflow/stream_executor/lib/statusor.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" +// // StatusOr<T> is the union of a Status object and a T // object. StatusOr models the concept of an object that is either a // usable value, or an error Status explaining why such a value is diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h index de41b6ac70..a25eca1e9a 100644 --- a/tensorflow/stream_executor/lib/strcat.h +++ b/tensorflow/stream_executor/lib/strcat.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" + #ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ #define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_ diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc index 8ace26f578..dcdcb8f42b 100644 --- a/tensorflow/stream_executor/platform.cc +++ b/tensorflow/stream_executor/platform.cc @@ -31,8 +31,6 @@ string PlatformKindString(PlatformKind kind) { return "CUDA"; case PlatformKind::kOpenCL: return "OpenCL"; - case PlatformKind::kOpenCLAltera: - return "OpenCL+Altera"; case PlatformKind::kHost: return "Host"; case PlatformKind::kMock: diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h index 338f9940e1..3f850ba327 100644 --- a/tensorflow/stream_executor/platform.h +++ b/tensorflow/stream_executor/platform.h @@ -42,9 +42,6 @@ enum class PlatformKind { kInvalid, kCuda, kOpenCL, - kOpenCLAltera, // Altera FPGA OpenCL platform. - // See documentation: go/fpgaopencl - // (StreamExecutor integration) kHost, kMock, kSize, diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h index 33792214f5..c1b770e486 100644 --- a/tensorflow/stream_executor/platform/port.h +++ b/tensorflow/stream_executor/platform/port.h @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h" + #ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ #define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_ diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc index 83cf1cffaf..e05d42985b 100644 --- a/tensorflow/stream_executor/plugin_registry.cc +++ b/tensorflow/stream_executor/plugin_registry.cc @@ -211,7 +211,8 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id, if (plugin_id == kNullPlugin) { \ return port::Status{port::error::FAILED_PRECONDITION, \ "No suitable " PLUGIN_STRING \ - " plugin registered, default or otherwise."}; \ + " plugin registered. Have you linked in a " \ + PLUGIN_STRING "-providing plugin?"}; \ } else { \ VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \ << plugin_names_[plugin_id]; \ diff --git a/tensorflow/stream_executor/scratch_allocator.cc b/tensorflow/stream_executor/scratch_allocator.cc new file mode 100644 index 0000000000..c216de3ab9 --- /dev/null +++ b/tensorflow/stream_executor/scratch_allocator.cc @@ -0,0 +1,42 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/stream_executor/scratch_allocator.h" + +#include "tensorflow/stream_executor/lib/status_macros.h" +#include "tensorflow/stream_executor/stream.h" + +namespace perftools { +namespace gputools { + +ScratchAllocator::~ScratchAllocator() {} + +OneTimeScratchAllocator::OneTimeScratchAllocator() {} +OneTimeScratchAllocator::~OneTimeScratchAllocator() {} + +int64 OneTimeScratchAllocator::GetMemoryLimitInBytes(Stream* stream) { + return -1; +} + +port::StatusOr<DeviceMemory<uint8>> OneTimeScratchAllocator::AllocateBytes( + Stream* stream, int64 byte_size) { + CHECK(temporary_ == nullptr); + SE_ASSIGN_OR_RETURN(temporary_, + stream->AllocateTemporaryArray<uint8>(byte_size)); + return temporary_->device_memory(); +} + +} // namespace gputools +} // namespace perftools diff --git a/tensorflow/stream_executor/scratch_allocator.h b/tensorflow/stream_executor/scratch_allocator.h new file mode 100644 index 0000000000..52697d6f8e --- /dev/null +++ b/tensorflow/stream_executor/scratch_allocator.h @@ -0,0 +1,83 @@ +/* Copyright 2015 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_ +#define TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_ + +#include <memory> + +#include "tensorflow/stream_executor/device_memory.h" +#include "tensorflow/stream_executor/lib/statusor.h" +#include "tensorflow/stream_executor/platform/port.h" +#include "tensorflow/stream_executor/temporary_device_memory.h" + +namespace perftools { +namespace gputools { + +class Stream; + +// Interface that allows stream operations (e.g. +// Stream::ThenConvolveWithScratch) to optionally request scratch space be +// allocated in order to speed up the operation being enqueued. +// +// Note that the caller is responsible for deallocating the scratch space at a +// known-safe point, when all scratch-memory-consuming kernels are known for +// sure to have finished; e.g. at stream synchronization time. This is different +// from a traditional C++ object allocator, where the client is responsible for +// releasing. (Conceptually, scratch memory is a form of "temporary" device +// memory allocation.) +class ScratchAllocator { + public: + virtual ~ScratchAllocator(); + + // Returns a limit of memory this scratch allocator wants to produce, in + // bytes. This information may be used to help select an algorithm. + // + // Returns values < 0 to indicate that there is no recommended limit. + virtual int64 GetMemoryLimitInBytes(Stream* stream) = 0; + + // Returns an allocation on byte_size bytes for use in an operation on stream. + // + // This is a temporary allocation, and the caller is responsible for + // deallocating at some known-safe point. See the class comment above. + virtual port::StatusOr<DeviceMemory<uint8>> AllocateBytes( + Stream* stream, int64 byte_size) = 0; +}; + +// Allocates a single temporary memory allocation -- this memory is deallocated +// at the next stream synchronization point after this object has gone out of +// scope. This satisfies the lifetime and deallocation properties given in the +// class comment above. +// +// Thread-compatible, but not thread-safe (use in scenarios where only one +// thread will request the scratch allocation). +class OneTimeScratchAllocator : public ScratchAllocator { + public: + OneTimeScratchAllocator(); + ~OneTimeScratchAllocator() override; + int64 GetMemoryLimitInBytes(Stream* stream) override; + port::StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream, + int64 byte_size) override; + + private: + std::unique_ptr<TemporaryDeviceMemory<uint8>> temporary_; + + SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator); +}; + +} // namespace gputools +} // namespace perftools + +#endif // TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_ diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 5d6971a5a5..587896a2ab 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/stream_executor/platform/port.h" #include "tensorflow/stream_executor/blas.h" +#include "tensorflow/stream_executor/lib/stacktrace.h" #include "tensorflow/stream_executor/lib/strcat.h" #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/logging.h" @@ -29,22 +30,6 @@ namespace perftools { namespace gputools { namespace { -static internal::StreamInterface *CreateStreamImplementation( - StreamExecutor *parent) { - PlatformKind platform_kind = parent->platform_kind(); - if (platform_kind == PlatformKind::kCuda) { - return (*internal::MakeCUDAStreamImplementation())(parent); - } else if (platform_kind == PlatformKind::kOpenCL || - platform_kind == PlatformKind::kOpenCLAltera) { - return (*internal::MakeOpenCLStreamImplementation())(parent); - } else if (platform_kind == PlatformKind::kHost) { - return internal::MakeHostStreamImplementation(parent); - } else { - LOG(FATAL) << "cannot create stream implementation for platform kind: " - << PlatformKindString(platform_kind); - } -} - // Code to turn parameters to functions on stream into strings that // will be VLOG'ed. We need overloads, instead of // e.g. BatchDescriptorToVlogString(), as the code that calls these @@ -77,6 +62,10 @@ string ToVlogString(dnn::ElementwiseOperation op) { return dnn::ElementwiseOperationString(op); } +string ToVlogString(dnn::QuantizedActivationMode mode) { + return dnn::QuantizedActivationModeString(mode); +} + string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); } string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); } @@ -123,6 +112,8 @@ string ToVlogString(uint32 i) { return port::StrCat(i); } string ToVlogString(uint64 i) { return port::StrCat(i); } +string ToVlogString(int64 i) { return port::StrCat(i); } + string ToVlogString(float f) { return port::StrCat(f); } string ToVlogString(double d) { return port::StrCat(d); } @@ -181,6 +172,9 @@ string CallStr(const char *function_name, Stream *stream, separator = ", "; } port::StrAppend(&str, ") stream=", ToVlogString(stream)); + if (VLOG_IS_ON(10)) { + port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n"); + } return str; } @@ -206,8 +200,8 @@ string CallStr(const char *function_name, Stream *stream, } // namespace Stream::Stream(StreamExecutor *parent) - : implementation_(CreateStreamImplementation(parent)), - parent_(parent), + : parent_(parent), + implementation_(parent->implementation()->GetStreamImplementation()), allocated_(false), ok_(false), temporary_memory_manager_(this) { @@ -216,8 +210,8 @@ Stream::Stream(StreamExecutor *parent) Stream::Stream(StreamExecutor *parent, internal::StreamInterface *implementation) - : implementation_(implementation), - parent_(parent), + : parent_(parent), + implementation_(implementation), allocated_(false), ok_(false), temporary_memory_manager_(this) { @@ -283,15 +277,15 @@ Stream &Stream::ThenRecordEvent(Event *event) { return *this; } -Stream &Stream::ThenConvolve( - const dnn::BatchDescriptor &batch_descriptor, +Stream &Stream::ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, const dnn::ConvolutionDescriptor &convolution_descriptor, - const dnn::BatchDescriptor &output_descriptor, - DeviceMemory<float> *output) { - VLOG_CALL(PARAM(batch_descriptor), PARAM(input_data), + const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output, + ScratchAllocator *scratch_allocator) { + VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(filter_descriptor), PARAM(filter_data), PARAM(convolution_descriptor), PARAM(output_descriptor), PARAM(output)); @@ -299,8 +293,9 @@ Stream &Stream::ThenConvolve( if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoConvolve( - this, batch_descriptor, input_data, filter_descriptor, filter_data, - convolution_descriptor, output_descriptor, output)); + this, input_descriptor, input_data, filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, output, + /*scratch_allocator=*/scratch_allocator)); } else { SetError(); LOG(WARNING) @@ -311,6 +306,20 @@ Stream &Stream::ThenConvolve( return *this; } +Stream &Stream::ThenConvolve( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> *output) { + return ThenConvolveWithScratch(input_descriptor, input_data, + filter_descriptor, filter_data, + convolution_descriptor, output_descriptor, + output, /*scratch_allocator=*/nullptr); +} + Stream &Stream::ThenSeparableConvolve( const dnn::BatchDescriptor &batch_descriptor, const DeviceMemory<float> &input_data, @@ -341,14 +350,15 @@ Stream &Stream::ThenSeparableConvolve( return *this; } -Stream &Stream::ThenConvolveBackwardData( +Stream &Stream::ThenConvolveBackwardDataWithScratch( const dnn::FilterDescriptor &filter_descriptor, const DeviceMemory<float> &filter_data, const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> backward_output_data, const dnn::ConvolutionDescriptor &convolution_descriptor, const dnn::BatchDescriptor &input_descriptor, - DeviceMemory<float> *backward_input_data) { + DeviceMemory<float> *backward_input_data, + ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data), PARAM(output_descriptor), PARAM(backward_output_data), PARAM(convolution_descriptor), PARAM(input_descriptor), @@ -359,7 +369,7 @@ Stream &Stream::ThenConvolveBackwardData( CheckError(dnn->DoConvolveBackwardData( this, filter_descriptor, filter_data, output_descriptor, backward_output_data, convolution_descriptor, input_descriptor, - backward_input_data)); + backward_input_data, scratch_allocator)); } else { SetError(); LOG(WARNING) @@ -370,14 +380,29 @@ Stream &Stream::ThenConvolveBackwardData( return *this; } -Stream &Stream::ThenConvolveBackwardFilter( +Stream &Stream::ThenConvolveBackwardData( + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &input_descriptor, + DeviceMemory<float> *backward_input_data) { + return ThenConvolveBackwardDataWithScratch( + filter_descriptor, filter_data, output_descriptor, backward_output_data, + convolution_descriptor, input_descriptor, backward_input_data, + /*scratch_allocator=*/nullptr); +} + +Stream &Stream::ThenConvolveBackwardFilterWithScratch( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> backward_output_data, const dnn::ConvolutionDescriptor &convolution_descriptor, const dnn::FilterDescriptor &filter_descriptor, - DeviceMemory<float> *backward_filter_data) { + DeviceMemory<float> *backward_filter_data, + ScratchAllocator *scratch_allocator) { VLOG_CALL(PARAM(input_descriptor), PARAM(input_data), PARAM(output_descriptor), PARAM(backward_output_data), PARAM(convolution_descriptor), PARAM(filter_descriptor), @@ -388,7 +413,7 @@ Stream &Stream::ThenConvolveBackwardFilter( CheckError(dnn->DoConvolveBackwardFilter( this, input_descriptor, input_data, output_descriptor, backward_output_data, convolution_descriptor, filter_descriptor, - backward_filter_data)); + backward_filter_data, scratch_allocator)); } else { SetError(); LOG(WARNING) @@ -399,6 +424,20 @@ Stream &Stream::ThenConvolveBackwardFilter( return *this; } +Stream &Stream::ThenConvolveBackwardFilter( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + DeviceMemory<float> *backward_filter_data) { + return ThenConvolveBackwardFilterWithScratch( + input_descriptor, input_data, output_descriptor, backward_output_data, + convolution_descriptor, filter_descriptor, backward_filter_data, + /*scratch_allocator=*/nullptr); +} + Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data, const DeviceMemory<float> &weights, const dnn::BatchDescriptor &input_dimensions, @@ -589,6 +628,19 @@ Stream &Stream::ThenDepthConcatenate( DeviceMemory<float> *output_data) { VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data)); + for (size_t i = 1; i < input_dimensions.size(); ++i) { + if (input_dimensions[i].count() != input_dimensions[0].count() || + input_dimensions[i].height() != input_dimensions[0].height() || + input_dimensions[i].width() != input_dimensions[0].width()) { + SetError(); + LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n" + << "input_dimensions[0]: " << input_dimensions[0].ToString() + << "input_dimensions[" << i + << "]: " << input_dimensions[i].ToString(); + return *this; + } + } + if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data, @@ -627,15 +679,18 @@ Stream &Stream::ThenElementwiseOperate( return *this; } -Stream &Stream::ThenMemcpyD2HQuantized( - const DeviceMemory<float> &gpu_unquantized_src, - port::MutableArraySlice<uint8> host_dst) { - VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst)); +Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, int64 left_pad, + int64 right_pad, int64 top_pad, int64 bottom_pad, + DeviceMemory<float> *output_data) { + VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad), + PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad), + PARAM(output_data)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError( - dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst)); + CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad, + top_pad, bottom_pad, output_data)); } else { SetError(); LOG(WARNING) @@ -646,15 +701,20 @@ Stream &Stream::ThenMemcpyD2HQuantized( return *this; } -Stream &Stream::ThenMemcpyD2HQuantized( - const DeviceMemory<float> &gpu_unquantized_src, - port::MutableArraySlice<uint16> host_dst) { - VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst)); +Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, + int64 left_trim, int64 right_trim, int64 top_trim, + int64 bottom_trim, + DeviceMemory<float> *output_data) { + VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim), + PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim), + PARAM(output_data)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError( - dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst)); + CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim, + right_trim, top_trim, bottom_trim, + output_data)); } else { SetError(); LOG(WARNING) @@ -667,13 +727,14 @@ Stream &Stream::ThenMemcpyD2HQuantized( Stream &Stream::ThenMemcpyD2HQuantized( const DeviceMemory<float> &gpu_unquantized_src, - port::MutableArraySlice<int32> host_dst) { - VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst)); + dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) { + VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst), + PARAM(size)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError( - dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst)); + CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode, + host_dst, size)); } else { SetError(); LOG(WARNING) @@ -685,14 +746,15 @@ Stream &Stream::ThenMemcpyD2HQuantized( } Stream &Stream::ThenMemcpyH2DQuantized( - port::ArraySlice<uint8> host_src, + const void *host_src, uint64 size, dnn::QuantizedActivationMode mode, DeviceMemory<float> *gpu_unquantized_dst) { - VLOG_CALL(PARAM(host_src), PARAM(gpu_unquantized_dst)); + VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode), + PARAM(gpu_unquantized_dst)); if (ok()) { if (dnn::DnnSupport *dnn = parent_->AsDnn()) { - CheckError( - dnn->DoMemcpyH2DQuantized(this, host_src, gpu_unquantized_dst)); + CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode, + gpu_unquantized_dst)); } else { SetError(); LOG(WARNING) diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index 3b9c2af1b0..d91c62ca26 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -69,6 +69,11 @@ struct ConvolutionDescriptor; } // namespace dnn class StreamExecutor; +class ScratchAllocator; + +// Convert a type to the corresponding QuantizedActivationMode. +template <typename ElementType> +struct Quantization; // Represents a stream of dependent computations on a GPU device. // @@ -214,6 +219,15 @@ class Stream { const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output); + Stream &ThenConvolveWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> *output, ScratchAllocator *scratch_allocator); + Stream &ThenSeparableConvolve( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, @@ -233,6 +247,16 @@ class Stream { const dnn::BatchDescriptor &input_descriptor, DeviceMemory<float> *backward_input_data); + Stream &ThenConvolveBackwardDataWithScratch( + const dnn::FilterDescriptor &filter_descriptor, + const DeviceMemory<float> &filter_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::BatchDescriptor &input_descriptor, + DeviceMemory<float> *backward_input_data, + ScratchAllocator *scratch_allocator); + Stream &ThenConvolveBackwardFilter( const dnn::BatchDescriptor &input_descriptor, const DeviceMemory<float> &input_data, @@ -242,6 +266,16 @@ class Stream { const dnn::FilterDescriptor &filter_descriptor, DeviceMemory<float> *backward_filter_data); + Stream &ThenConvolveBackwardFilterWithScratch( + const dnn::BatchDescriptor &input_descriptor, + const DeviceMemory<float> &input_data, + const dnn::BatchDescriptor &output_descriptor, + DeviceMemory<float> backward_output_data, + const dnn::ConvolutionDescriptor &convolution_descriptor, + const dnn::FilterDescriptor &filter_descriptor, + DeviceMemory<float> *backward_filter_data, + ScratchAllocator *scratch_allocator); + Stream &ThenMatMul(const DeviceMemory<float> &input_data, const DeviceMemory<float> &weights, const dnn::BatchDescriptor &input_dimensions, @@ -249,18 +283,18 @@ class Stream { DeviceMemory<float> *output_data); Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, - const DeviceMemory<int8> &weights, - const DeviceMemory<float> &weight_scales, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<float> *output_data); + const DeviceMemory<int8> &weights, + const DeviceMemory<float> &weight_scales, + const dnn::BatchDescriptor &input_dimensions, + const dnn::BatchDescriptor &output_dimensions, + DeviceMemory<float> *output_data); Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data, - const DeviceMemory<int16> &weights, - const DeviceMemory<float> &weight_scales, - const dnn::BatchDescriptor &input_dimensions, - const dnn::BatchDescriptor &output_dimensions, - DeviceMemory<float> *output_data); + const DeviceMemory<int16> &weights, + const DeviceMemory<float> &weight_scales, + const dnn::BatchDescriptor &input_dimensions, + const dnn::BatchDescriptor &output_dimensions, + DeviceMemory<float> *output_data); Stream &ThenBiasAdd(const DeviceMemory<float> &input_data, const DeviceMemory<float> &biases, @@ -302,24 +336,49 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, DeviceMemory<float> *output_data); - // See DnnSupport::DoMemcpyD2HQuantized. - // TODO(wgulland) Use a template to merge the versions of - // ThenMemcpyD2HQuantized. - Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, - port::MutableArraySlice<uint8> host_dst); + Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, int64 left_pad, + int64 right_pad, int64 top_pad, int64 bottom_pad, + DeviceMemory<float> *output_data); - // See DnnSupport::DoMemcpyD2HQuantized. - Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, - port::MutableArraySlice<uint16> host_dst); + Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions, + const DeviceMemory<float> &input_data, int64 left_trim, + int64 right_trim, int64 top_trim, int64 bottom_trim, + DeviceMemory<float> *output_data); // See DnnSupport::DoMemcpyD2HQuantized. Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src, - port::MutableArraySlice<int32> host_dst); + dnn::QuantizedActivationMode mode, + void *host_dst, uint64 size); + + // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice + // and uses the Quantization trait to call the generic version of + // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode. + template <typename ElementType> + Stream &ThenMemcpyD2HQuantized( + const DeviceMemory<float> &gpu_unquantized_src, + port::MutableArraySlice<ElementType> host_dst) { + return ThenMemcpyD2HQuantized( + gpu_unquantized_src, Quantization<ElementType>::kModeId, + host_dst.data(), host_dst.size() * sizeof(ElementType)); + } // See DnnSupport::DoMemcpyH2DQuantized. - Stream &ThenMemcpyH2DQuantized(port::ArraySlice<uint8> host_src, + Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size, + dnn::QuantizedActivationMode mode, DeviceMemory<float> *gpu_unquantized_dst); + // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice + // and uses the Quantization trait to call the generic version of + // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode. + template <typename ElementType> + Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src, + DeviceMemory<float> *gpu_unquantized_dst) { + return ThenMemcpyH2DQuantized( + host_src.data(), host_src.size() * sizeof(ElementType), + Quantization<ElementType>::kModeId, gpu_unquantized_dst); + } + ///////////////// // BLAS support @@ -1143,9 +1202,11 @@ class Stream { Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern, uint64 size); - // (Synchronously) block the host code waiting for the operations entrained - // on - // the stream (enqueued to this point in program execution) to complete. + // (Synchronously) block the host code waiting for the operations + // entrained on the stream (enqueued to this point in program + // execution) to complete. + // + // Returns true if the stream is ok(). bool BlockHostUntilDone(); // Warning! This method interacts with internal threads in @@ -1195,9 +1256,9 @@ class Stream { internal::TemporaryMemoryManager *temporary_memory_manager(); private: - friend class host::HostBlas; // for parent_. - friend class host::HostFft; // for parent_. - friend class host::HostRng; // for parent_. + friend class host::HostBlas; // for parent_. + friend class host::HostFft; // for parent_. + friend class host::HostRng; // for parent_. template <typename... Args> friend struct ThenBlasImpl; // for implementing ThenBlasXXX. friend class ocl::CLBlas; // for parent_. @@ -1219,13 +1280,13 @@ class Stream { void SetError() { CheckError(false /* = operation_retcode */); } + // The StreamExecutor that supports the operation of this stream. + StreamExecutor *parent_; + // The platform-dependent implementation that the StreamExecutor interface // delegates to. std::unique_ptr<internal::StreamInterface> implementation_; - // The StreamExecutor that supports the operation of this stream. - StreamExecutor *parent_; - // mutex that guards the allocation / error state flags. // Mutable so that it can be obtained via const reader lock. mutable mutex mu_; @@ -1267,6 +1328,24 @@ inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() { return &temporary_memory_manager_; } +template <> +struct Quantization<uint8> { + static constexpr dnn::QuantizedActivationMode kModeId = + dnn::QuantizedActivationMode::k8Bit; +}; + +template <> +struct Quantization<uint16> { + static constexpr dnn::QuantizedActivationMode kModeId = + dnn::QuantizedActivationMode::k16Bit; +}; + +template <> +struct Quantization<int32> { + static constexpr dnn::QuantizedActivationMode kModeId = + dnn::QuantizedActivationMode::k32Bit; +}; + } // namespace gputools } // namespace perftools diff --git a/tensorflow/stream_executor/stream_executor.h b/tensorflow/stream_executor/stream_executor.h index b668f2ad0d..0eb113c056 100644 --- a/tensorflow/stream_executor/stream_executor.h +++ b/tensorflow/stream_executor/stream_executor.h @@ -57,9 +57,10 @@ limitations under the License. #include "tensorflow/stream_executor/kernel.h" // IWYU pragma: export #include "tensorflow/stream_executor/kernel_spec.h" // IWYU pragma: export #include "tensorflow/stream_executor/launch_dim.h" // IWYU pragma: export +#include "tensorflow/stream_executor/multi_platform_manager.h" // IWYU pragma: export #include "tensorflow/stream_executor/platform.h" // IWYU pragma: export #include "tensorflow/stream_executor/stream.h" // IWYU pragma: export #include "tensorflow/stream_executor/stream_executor_pimpl.h" // IWYU pragma: export -#include "tensorflow/stream_executor/timer.h" // IWYU pragma: export +#include "tensorflow/stream_executor/timer.h" // IWYU pragma: export #endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_ diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc index b3bb818f5f..ddc9ae0441 100644 --- a/tensorflow/stream_executor/stream_executor_internal.cc +++ b/tensorflow/stream_executor/stream_executor_internal.cc @@ -28,22 +28,6 @@ StreamExecutorFactory* MakeCUDAExecutorImplementation() { static StreamExecutorFactory instance; return &instance; } -EventFactory* MakeCUDAEventImplementation() { - static EventFactory instance; - return &instance; -} -StreamFactory* MakeCUDAStreamImplementation() { - static StreamFactory instance; - return &instance; -} -TimerFactory* MakeCUDATimerImplementation() { - static TimerFactory instance; - return &instance; -} -KernelFactory* MakeCUDAKernelImplementation() { - static KernelFactory instance; - return &instance; -} // -- OpenCL @@ -51,28 +35,10 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() { static StreamExecutorFactory instance; return &instance; } -StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation() { - static StreamExecutorFactory instance; - return &instance; -} -StreamFactory* MakeOpenCLStreamImplementation() { - static StreamFactory instance; - return &instance; -} -TimerFactory* MakeOpenCLTimerImplementation() { - static TimerFactory instance; - return &instance; -} -KernelFactory* MakeOpenCLKernelImplementation() { - static KernelFactory instance; - return &instance; -} // -- Host StreamExecutorFactory MakeHostExecutorImplementation; -StreamFactory MakeHostStreamImplementation; -TimerFactory MakeHostTimerImplementation; } // namespace internal diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index 955af9127f..dff756c8fc 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -71,6 +71,94 @@ namespace perftools { namespace gputools { namespace internal { +// Platform-dependent interface class for the generic Events interface, in +// the PIMPL style. +class EventInterface { + public: + EventInterface() {} + virtual ~EventInterface() {} + + private: + SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); +}; + +// Pointer-to-implementation object type (i.e. the KernelBase class delegates to +// this interface) with virtual destruction. This class exists for the +// platform-dependent code to hang any kernel data/resource info/functionality +// off of. +class KernelInterface { + public: + // Default constructor for the abstract interface. + KernelInterface() {} + + // Default destructor for the abstract interface. + virtual ~KernelInterface() {} + + // Returns the number of formal parameters that this kernel accepts. + virtual unsigned Arity() const = 0; + + // Sets the preferred cache configuration. + virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; + + // Gets the preferred cache configuration. + virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; + + private: + SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); +}; + +// Pointer-to-implementation object type (i.e. the Stream class delegates to +// this interface) with virtual destruction. This class exists for the +// platform-dependent code to hang any kernel data/resource info/functionality +// off of. +class StreamInterface { + public: + // Default constructor for the abstract interface. + StreamInterface() {} + + // Default destructor for the abstract interface. + virtual ~StreamInterface() {} + + // Returns the CUDA stream associated with this platform's stream + // implementation. + // + // WARNING: checks that the underlying platform is, in fact, CUDA, causing a + // fatal error if it is not. This hack is made available solely for use from + // distbelief code, which temporarily has strong ties to CUDA as a platform. + virtual void *CudaStreamHack() { return nullptr; } + + // See the above comment on CudaStreamHack -- this further breaks abstraction + // for Eigen within distbelief, which has strong ties to CUDA as a platform, + // and a historical attachment to a programming model which takes a + // stream-slot rather than a stream-value. + virtual void **CudaStreamMemberHack() { return nullptr; } + + private: + SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); +}; + +// Pointer-to-implementation object type (i.e. the Timer class delegates to +// this interface) with virtual destruction. This class exists for the +// platform-dependent code to hang any timer data/resource info/functionality +// off of. +class TimerInterface { + public: + // Default constructor for the abstract interface. + TimerInterface() {} + + // Default destructor for the abstract interface. + virtual ~TimerInterface() {} + + // Returns the number of microseconds elapsed in a completed timer. + virtual uint64 Microseconds() const = 0; + + // Returns the number of nanoseconds elapsed in a completed timer. + virtual uint64 Nanoseconds() const = 0; + + private: + SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); +}; + // Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL). // // Various platforms will provide an implementation that satisfy this interface. @@ -89,6 +177,7 @@ class StreamExecutorInterface { // See the StreamExecutor interface for comments on the same-named methods. virtual port::Status Init(int device_ordinal, DeviceOptions device_options) = 0; + virtual bool GetKernel(const MultiKernelLoaderSpec &spec, KernelBase *kernel) { return false; @@ -233,9 +322,13 @@ class StreamExecutorInterface { // initialization fails. virtual dnn::DnnSupport *CreateDnn() { return nullptr; } - // Please read the warning below. This method is only temporary. See - // http://b/15759750 - // + // Each call creates a new instance of the platform-specific implementation of + // the corresponding interface type. + virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0; + virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0; + virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0; + virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0; + // Returns the CUDA context associated with this StreamExecutor platform // implementation. // @@ -248,106 +341,6 @@ class StreamExecutorInterface { SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface); }; -// Pointer-to-implementation object type (i.e. the KernelBase class delegates to -// this interface) with virtual destruction. This class exists for the -// platform-dependent code to hang any kernel data/resource info/functionality -// off of. -class KernelInterface { - public: - // Default constructor for the abstract interface. - KernelInterface() {} - - // Default destructor for the abstract interface. - virtual ~KernelInterface() {} - - // Returns the number of formal parameters that this kernel accepts. - virtual unsigned Arity() const = 0; - - // Sets the preferred cache configuration. - virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0; - - // Gets the preferred cache configuration. - virtual KernelCacheConfig GetPreferredCacheConfig() const = 0; - - private: - SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface); -}; - -// Platform-dependent interface class for the generic Events interface, in -// the PIMPL style. -class EventInterface { - public: - EventInterface() {} - virtual ~EventInterface() {} - - private: - SE_DISALLOW_COPY_AND_ASSIGN(EventInterface); -}; - -// Pointer-to-implementation object type (i.e. the Stream class delegates to -// this interface) with virtual destruction. This class exists for the -// platform-dependent code to hang any kernel data/resource info/functionality -// off of. -class StreamInterface { - public: - // Default constructor for the abstract interface. - StreamInterface() {} - - // Default destructor for the abstract interface. - virtual ~StreamInterface() {} - - // Please read the warning below. This method is only temporary. See - // http://b/15759750 - // - // Returns the CUDA stream associated with this platform's stream - // implementation. - // - // WARNING: checks that the underlying platform is, in fact, CUDA, causing a - // fatal error if it is not. This hack is made available solely for use from - // distbelief code, which temporarily has strong ties to CUDA as a platform. - virtual void *CudaStreamHack() { return nullptr; } - - // Please read the warning above. This method is only temporary. See - // http://b/15759750 - // - // See the above comment on CudaStreamHack -- this further breaks abstraction - // for Eigen within distbelief, which has strong ties to CUDA as a platform, - // and a historical attachment to a programming model which takes a - // stream-slot rather than a stream-value. - virtual void **CudaStreamMemberHack() { return nullptr; } - - private: - SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface); -}; - -// Pointer-to-implementation object type (i.e. the Timer class delegates to -// this interface) with virtual destruction. This class exists for the -// platform-dependent code to hang any timer data/resource info/functionality -// off of. -class TimerInterface { - public: - // Default constructor for the abstract interface. - TimerInterface() {} - - // Default destructor for the abstract interface. - virtual ~TimerInterface() {} - - // Returns the number of microseconds elapsed in a completed timer. - virtual uint64 Microseconds() const = 0; - - // Returns the number of nanoseconds elapsed in a completed timer. - virtual uint64 Nanoseconds() const = 0; - - private: - SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface); -}; - -// Extern functions for constructing platform-specific instances that conform to -// the StreamExecutor interface. (Defining constructor functions extern in this -// way prevents CUDA/OpenCL headers from leaking into any shared header files.) -// -// TODO(leary) switch this all over to registries. - using StreamExecutorFactory = std::function<StreamExecutorInterface *(const PluginConfig &)>; using EventFactory = std::function<EventInterface *(StreamExecutor *)>; @@ -355,21 +348,11 @@ using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>; using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>; using KernelFactory = std::function<KernelInterface*()>; -EventFactory* MakeCUDAEventImplementation(); StreamExecutorFactory* MakeCUDAExecutorImplementation(); -StreamFactory* MakeCUDAStreamImplementation(); -TimerFactory* MakeCUDATimerImplementation(); -KernelFactory* MakeCUDAKernelImplementation(); StreamExecutorFactory* MakeOpenCLExecutorImplementation(); -StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation(); -StreamFactory* MakeOpenCLStreamImplementation(); -TimerFactory* MakeOpenCLTimerImplementation(); -KernelFactory* MakeOpenCLKernelImplementation(); extern StreamExecutorFactory MakeHostExecutorImplementation; -extern StreamFactory MakeHostStreamImplementation; -extern TimerFactory MakeHostTimerImplementation; } // namespace internal diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc index e496deaf9d..acaa0efcb2 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.cc +++ b/tensorflow/stream_executor/stream_executor_pimpl.cc @@ -26,6 +26,8 @@ limitations under the License. #include "tensorflow/stream_executor/lib/env.h" #include "tensorflow/stream_executor/lib/error.h" #include "tensorflow/stream_executor/lib/notification.h" +#include "tensorflow/stream_executor/lib/stacktrace.h" +#include "tensorflow/stream_executor/lib/str_util.h" #include "tensorflow/stream_executor/lib/stringprintf.h" #include "tensorflow/stream_executor/lib/threadpool.h" #include "tensorflow/stream_executor/platform/port.h" @@ -40,6 +42,14 @@ namespace perftools { namespace gputools { namespace { +string StackTraceIfVLOG10() { + if (VLOG_IS_ON(10)) { + return port::StrCat(" ", port::CurrentStackTrace(), "\n"); + } else { + return ""; + } +} + // Maximum stack depth to report when generating backtrace on mem allocation // (for GPU memory leak checker) static const int kMaxStackDepth = 256; @@ -66,9 +76,6 @@ internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind( case PlatformKind::kOpenCL: factory = *internal::MakeOpenCLExecutorImplementation(); break; - case PlatformKind::kOpenCLAltera: - factory = *internal::MakeOpenCLAlteraExecutorImplementation(); - break; case PlatformKind::kHost: factory = internal::MakeHostExecutorImplementation; break; @@ -148,7 +155,8 @@ MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call, StreamExecutor::StreamExecutor(PlatformKind platform_kind, const PluginConfig &plugin_config) - : implementation_(StreamExecutorImplementationFromPlatformKind( + : platform_(nullptr), + implementation_(StreamExecutorImplementationFromPlatformKind( platform_kind, plugin_config)), platform_kind_(platform_kind), device_ordinal_(-1), @@ -160,16 +168,21 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind, } StreamExecutor::StreamExecutor( - PlatformKind platform_kind, - internal::StreamExecutorInterface *implementation) - : implementation_(implementation), - platform_kind_(platform_kind), + const Platform *platform, internal::StreamExecutorInterface *implementation) + : platform_(platform), + implementation_(implementation), device_ordinal_(-1), background_threads_(new port::ThreadPool( port::Env::Default(), "stream_executor", kNumBackgroundThreads)), live_stream_count_(0), tracing_enabled_(false) { - CheckPlatformKindIsValid(platform_kind); + if (port::Lowercase(platform_->Name()) == "cuda") { + platform_kind_ = PlatformKind::kCuda; + } else if (port::Lowercase(platform_->Name()) == "opencl") { + platform_kind_ = PlatformKind::kOpenCL; + } else if (port::Lowercase(platform_->Name()) == "host") { + platform_kind_ = PlatformKind::kHost; + } } StreamExecutor::~StreamExecutor() { @@ -208,7 +221,7 @@ bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec, void StreamExecutor::Deallocate(DeviceMemoryBase *mem) { VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque() - << ") mem->size()=" << mem->size(); + << ") mem->size()=" << mem->size() << StackTraceIfVLOG10(); if (mem->opaque() != nullptr) { EraseAllocRecord(mem->opaque()); @@ -333,8 +346,8 @@ bool StreamExecutor::BlockHostUntilDone(Stream *stream) { void *StreamExecutor::Allocate(uint64 size) { void *buf = implementation_->Allocate(size); - VLOG(1) << "Called StreamExecutor::Allocate(size=" << size - << ") returns " << buf; + VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns " + << buf << StackTraceIfVLOG10(); CreateAllocRecord(buf, size); return buf; @@ -348,20 +361,20 @@ bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem, void *StreamExecutor::HostMemoryAllocate(uint64 size) { void *buffer = implementation_->HostMemoryAllocate(size); VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size - << ") returns " << buffer; + << ") returns " << buffer << StackTraceIfVLOG10(); return buffer; } void StreamExecutor::HostMemoryDeallocate(void *location) { - VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" - << location << ")"; + VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location + << ")" << StackTraceIfVLOG10(); return implementation_->HostMemoryDeallocate(location); } bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) { VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location - << ", size=" << size << ")"; + << ", size=" << size << ")" << StackTraceIfVLOG10(); if (location == nullptr || size == 0) { LOG(WARNING) << "attempting to register null or zero-sized memory: " << location << "; size " << size; @@ -371,12 +384,13 @@ bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) { bool StreamExecutor::HostMemoryUnregister(void *location) { VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location - << ")"; + << ")" << StackTraceIfVLOG10(); return implementation_->HostMemoryUnregister(location); } bool StreamExecutor::SynchronizeAllActivity() { - VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"; + VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()" + << StackTraceIfVLOG10(); bool ok = implementation_->SynchronizeAllActivity(); // This should all be quick and infallible work, so we can perform the @@ -388,16 +402,17 @@ bool StreamExecutor::SynchronizeAllActivity() { bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location, uint64 size) { - VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" - << location << ", size=" << size << ")"; + VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location + << ", size=" << size << ")" << StackTraceIfVLOG10(); return implementation_->SynchronousMemZero(location, size); } bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value, uint64 size) { - VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" - << location << ", value=" << value << ", size=" << size << ")"; + VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location + << ", value=" << value << ", size=" << size << ")" + << StackTraceIfVLOG10(); return implementation_->SynchronousMemSet(location, value, size); } @@ -406,7 +421,7 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst, const void *host_src, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(gpu_dst=" << gpu_dst->opaque() << ", host_src=" << host_src << ", size=" << size - << ") H2D"; + << ") H2D" << StackTraceIfVLOG10(); // Tracing overloaded methods is very difficult due to issues with type // inference on template args. Since use of these overloaded methods is @@ -417,9 +432,9 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst, bool StreamExecutor::SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &gpu_src, uint64 size) { - VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" - << host_dst << ", gpu_src=" << gpu_src.opaque() << ", size=" << size - << ") D2H"; + VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst + << ", gpu_src=" << gpu_src.opaque() << ", size=" << size << ") D2H" + << StackTraceIfVLOG10(); return implementation_->SynchronousMemcpy(host_dst, gpu_src, size); } @@ -428,8 +443,8 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst, const DeviceMemoryBase &gpu_src, uint64 size) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(gpu_dst=" - << gpu_dst->opaque() << ", gpu_src=" << gpu_src.opaque() << ", size=" << size - << ") D2D"; + << gpu_dst->opaque() << ", gpu_src=" << gpu_src.opaque() + << ", size=" << size << ") D2D" << StackTraceIfVLOG10(); return implementation_->SynchronousMemcpyDeviceToDevice(gpu_dst, gpu_src, size); @@ -438,7 +453,8 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst, port::Status StreamExecutor::SynchronousMemcpyD2H( const DeviceMemoryBase &gpu_src, int64 size, void *host_dst) { VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(gpu_src=" - << gpu_src.opaque() << ", size=" << size << ", host_dst=" << host_dst << ")"; + << gpu_src.opaque() << ", size=" << size << ", host_dst=" << host_dst + << ")" << StackTraceIfVLOG10(); port::Status result{port::Status::OK()}; SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H, @@ -459,8 +475,9 @@ port::Status StreamExecutor::SynchronousMemcpyD2H( port::Status StreamExecutor::SynchronousMemcpyH2D(const void *host_src, int64 size, DeviceMemoryBase *gpu_dst) { - VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" - << host_src << ", size=" << size << ", gpu_dst" << gpu_dst->opaque() << ")"; + VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src + << ", size=" << size << ", gpu_dst" << gpu_dst->opaque() << ")" + << StackTraceIfVLOG10(); port::Status result{port::Status::OK()}; SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D, diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h index 5a59a5a26d..f624e0fcdb 100644 --- a/tensorflow/stream_executor/stream_executor_pimpl.h +++ b/tensorflow/stream_executor/stream_executor_pimpl.h @@ -71,9 +71,7 @@ class StreamExecutor { public: explicit StreamExecutor(PlatformKind kind, const PluginConfig &plugin_config = PluginConfig()); - - // Primarily used for testing. - StreamExecutor(PlatformKind kind, + StreamExecutor(const Platform *platform, internal::StreamExecutorInterface *implementation); ~StreamExecutor(); @@ -81,9 +79,13 @@ class StreamExecutor { port::Status Init(); port::Status Init(int device_ordinal, DeviceOptions device_options); + // DEPRECATED: Do not use; use platform() instead. // Returns the platform that this StreamExecutor is acting upon. PlatformKind platform_kind() const { return platform_kind_; } + // Returns a reference to the platform that created this executor. + const Platform *platform() const { return platform_; } + // Retrieves (loads) a kernel for the platform this StreamExecutor is acting // upon, if one exists. // @@ -538,15 +540,18 @@ class StreamExecutor { // can acquire the lock on their first (mutating) call as well. mutable mutex mu_; - // A mapping of pointer (to GPU memory) to string representation of the stack - // (of the allocating thread) at the time at which the pointer was allocated. - std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_); + // Reference to the platform that created this executor. + const Platform *platform_; // Pointer to the platform-specific-interface implementation. This is // delegated to by the interface routines in pointer-to-implementation // fashion. std::unique_ptr<internal::StreamExecutorInterface> implementation_; + // A mapping of pointer (to GPU memory) to string representation of the stack + // (of the allocating thread) at the time at which the pointer was allocated. + std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_); + // Memoized BLAS support object -- we only want to create this once when asked // for a BLAS interface. std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_); diff --git a/tensorflow/stream_executor/timer.cc b/tensorflow/stream_executor/timer.cc index 6926028dbb..62fe1c9f64 100644 --- a/tensorflow/stream_executor/timer.cc +++ b/tensorflow/stream_executor/timer.cc @@ -20,31 +20,13 @@ limitations under the License. #include "tensorflow/stream_executor/platform.h" #include "tensorflow/stream_executor/platform/logging.h" #include "tensorflow/stream_executor/stream_executor.h" -#include "tensorflow/stream_executor/stream_executor_internal.h" namespace perftools { namespace gputools { -static internal::TimerInterface *CreateTimerImplementation( - StreamExecutor *parent) { - PlatformKind platform_kind = parent->platform_kind(); - if (platform_kind == PlatformKind::kCuda) { - return (*internal::MakeCUDATimerImplementation())(parent); - } else if (platform_kind == PlatformKind::kOpenCL || - platform_kind == PlatformKind::kOpenCLAltera) { - return (*internal::MakeOpenCLTimerImplementation())(parent); - } else if (platform_kind == PlatformKind::kHost) { - return internal::MakeHostTimerImplementation(parent); - } else if (platform_kind == PlatformKind::kMock) { - return nullptr; - } else { - LOG(FATAL) << "cannot create timer implementation for platform kind: " - << PlatformKindString(platform_kind); - } -} - Timer::Timer(StreamExecutor *parent) - : implementation_(CreateTimerImplementation(parent)), parent_(parent) {} + : parent_(parent), + implementation_(parent_->implementation()->GetTimerImplementation()) {} Timer::~Timer() { parent_->DeallocateTimer(this); } diff --git a/tensorflow/stream_executor/timer.h b/tensorflow/stream_executor/timer.h index 4da63c86b3..c39048d70b 100644 --- a/tensorflow/stream_executor/timer.h +++ b/tensorflow/stream_executor/timer.h @@ -58,14 +58,14 @@ class Timer { internal::TimerInterface *implementation() { return implementation_.get(); } private: - // Platform-dependent implementation of the timer internals for the underlying - // platform. This class just delegates to this opaque instance. - std::unique_ptr<internal::TimerInterface> implementation_; - // The StreamExecutor that manages the platform-specific internals for this // timer. StreamExecutor *parent_; + // Platform-dependent implementation of the timer internals for the underlying + // platform. This class just delegates to this opaque instance. + std::unique_ptr<internal::TimerInterface> implementation_; + SE_DISALLOW_COPY_AND_ASSIGN(Timer); }; |