aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-08 09:58:59 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-08 09:58:59 -0800
commitddd4aaf5286de24ba70402ee0ec8b836d3aed8c7 (patch)
tree4efdf6cf4d69b45041fd2a02cd2b7327ea9f1f58 /tensorflow/stream_executor
parentcd53f3c3302c9312c1840389a9988a879b8b9dd5 (diff)
TensorFlow: upstream changes to git.
Change 109695551 Update FAQ Change 109694725 Add a gradient for resize_bilinear op. Change 109694505 Don't mention variables module in docs variables.Variable should be tf.Variable. Change 109658848 Adding an option to create a new thread-pool for each session. Change 109640570 Take the snapshot of stream-executor. + Expose an interface for scratch space allocation in the interface. Change 109638559 Let image_summary accept uint8 input This allows users to do their own normalization / scaling if the default (very weird) behavior of image_summary is undesired. This required a slight tweak to fake_input.cc to make polymorphically typed fake inputs infer if their type attr is not set but has a default. Unfortunately, adding a second valid type to image_summary *disables* automatic implicit conversion from np.float64 to tf.float32, so this change is slightly backwards incompatible. Change 109636969 Add serialization operations for SparseTensor. Change 109636644 Update generated Op docs. Change 109634899 TensorFlow: add a markdown file for producing release notes for our releases. Seed with 0.5.0 with a boring but accurate description. Change 109634502 Let histogram_summary take any realnumbertype It used to take only floats, not it understands ints. Change 109634434 TensorFlow: update locations where we mention python 3 support, update them to current truth. Change 109632108 Move HSV <> RGB conversions, grayscale conversions, and adjust_* ops back to tensorflow - make GPU-capable version of RGBToHSV and HSVToRGB, allows only float input/output - change docs to reflect new size constraints - change HSV format to be [0,1] for all components - add automatic dtype conversion for all adjust_* and grayscale conversion ops - fix up docs Change 109631077 Improve optimizer exceptions 1. grads_and_vars is now a tuple, so must be wrapped when passed to format. 2. Use '%r' instead of '%s' for dtype formatting Base CL: 109697989
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/blas.h2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc3
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.cc8
-rw-r--r--tensorflow/stream_executor/cuda/cuda_diagnostics.h1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc477
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h29
-rw-r--r--tensorflow/stream_executor/cuda/cuda_event.cc1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_fft.cc3
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc60
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_helpers.h3
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.cc8
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform_id.cc26
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform_id.h36
-rw-r--r--tensorflow/stream_executor/cuda/cuda_rng.cc3
-rw-r--r--tensorflow/stream_executor/cuda/cuda_stream.cc12
-rw-r--r--tensorflow/stream_executor/cuda/cuda_stream.h9
-rw-r--r--tensorflow/stream_executor/device_memory.h1
-rw-r--r--tensorflow/stream_executor/dnn.cc41
-rw-r--r--tensorflow/stream_executor/dnn.h139
-rw-r--r--tensorflow/stream_executor/dso_loader.cc19
-rw-r--r--tensorflow/stream_executor/event.cc17
-rw-r--r--tensorflow/stream_executor/event.h10
-rw-r--r--tensorflow/stream_executor/fft.h2
-rw-r--r--tensorflow/stream_executor/gcuda.h2
-rw-r--r--tensorflow/stream_executor/kernel.cc23
-rw-r--r--tensorflow/stream_executor/kernel.h7
-rw-r--r--tensorflow/stream_executor/lib/casts.h2
-rw-r--r--tensorflow/stream_executor/lib/error.h4
-rw-r--r--tensorflow/stream_executor/lib/ptr_util.h1
-rw-r--r--tensorflow/stream_executor/lib/status.h6
-rw-r--r--tensorflow/stream_executor/lib/statusor.h2
-rw-r--r--tensorflow/stream_executor/lib/strcat.h2
-rw-r--r--tensorflow/stream_executor/platform.cc2
-rw-r--r--tensorflow/stream_executor/platform.h3
-rw-r--r--tensorflow/stream_executor/platform/port.h2
-rw-r--r--tensorflow/stream_executor/plugin_registry.cc3
-rw-r--r--tensorflow/stream_executor/scratch_allocator.cc42
-rw-r--r--tensorflow/stream_executor/scratch_allocator.h83
-rw-r--r--tensorflow/stream_executor/stream.cc168
-rw-r--r--tensorflow/stream_executor/stream.h137
-rw-r--r--tensorflow/stream_executor/stream_executor.h3
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.cc34
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h209
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc79
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h17
-rw-r--r--tensorflow/stream_executor/timer.cc22
-rw-r--r--tensorflow/stream_executor/timer.h8
48 files changed, 1240 insertions, 541 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index be81683166..94475817e0 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -1140,7 +1140,7 @@ class BlasSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// BlasSupport base class.
-#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_BLAS_SUPPORT_OVERRIDES \
bool DoBlasAsum(Stream *stream, uint64 elem_count, \
const DeviceMemory<float> &x, int incx, \
DeviceMemory<float> *result) override; \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index d2d84f4ea0..19ad12d28b 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -22,7 +22,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
-#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/initialize.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
index a3f1820094..c2ae035c96 100644
--- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc
@@ -29,16 +29,17 @@ limitations under the License.
#include <memory>
#include <vector>
-#include "tensorflow/stream_executor/lib/error.h"
-#include "tensorflow/stream_executor/lib/inlined_vector.h"
-#include "tensorflow/stream_executor/lib/numbers.h"
#include "tensorflow/stream_executor/lib/process_state.h"
+#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/platform/logging.h"
+#include "tensorflow/stream_executor/lib/numbers.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
+#include "tensorflow/stream_executor/lib/inlined_vector.h"
namespace perftools {
namespace gputools {
@@ -113,7 +114,6 @@ void Diagnostician::LogDiagnosticInformation() {
LOG(INFO) << "retrieving CUDA diagnostic information for host: "
<< port::Hostname();
-
LogDriverVersionInformation();
}
diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.h b/tensorflow/stream_executor/cuda/cuda_diagnostics.h
index 15387c4161..42336c337f 100644
--- a/tensorflow/stream_executor/cuda/cuda_diagnostics.h
+++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_DIAGNOSTICS_H_
+#include "tensorflow/stream_executor/platform/port.h"
#include <tuple>
#include "tensorflow/stream_executor/lib/statusor.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index b73a9f9ce2..880d52b589 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -18,6 +18,12 @@ limitations under the License.
#include <dlfcn.h>
#include <functional>
+#include "tensorflow/stream_executor/cuda/cuda_activation.h"
+#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
+#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -27,13 +33,9 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/plugin_registry.h"
+#include "tensorflow/stream_executor/scratch_allocator.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
-#include "tensorflow/stream_executor/cuda/cuda_activation.h"
-#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
-#include "tensorflow/stream_executor/cuda/cuda_driver.h"
-#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
-#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "third_party/gpus/cuda/include/cudnn.h"
namespace {
@@ -62,8 +64,6 @@ namespace cuda {
PLUGIN_REGISTRY_DEFINE_PLUGIN_ID(kCuDnnPlugin);
-extern CUstream AsCUDAStreamValue(Stream* stream);
-
string ToString(cudnnStatus_t status) {
switch (status) {
case CUDNN_STATUS_SUCCESS:
@@ -117,14 +117,34 @@ static port::ThreadPool* GetCudaThreadpool() {
return cudnn_threadpool;
}
+// Retrieves the CUDNN DSO, dies on failure.
+void* GetDsoHandle() {
+ static auto result = internal::CachedDsoLoader::GetCudnnDsoHandle();
+ return result.ValueOrDie();
+}
+
+// Calls cudnnGetVersion in the loaded DSO.
+size_t cudnnGetVersion() {
+ static void* f = dlsym(GetDsoHandle(), "cudnnGetVersion");
+ if (f == nullptr) {
+ LOG(FATAL) << "could not find cudnnGetVersion in cudnn DSO; dlerror: "
+ << dlerror();
+ }
+ auto callable = reinterpret_cast<size_t (*)(void)>(f);
+ return callable();
+}
+
+// Returns whether the currently loaded cuDNN version is R2.
+bool IsCudnnR2() {
+ static auto version = cudnnGetVersion();
+ DCHECK_GE(version, 2000);
+ return version < 3000;
+}
+
#define PERFTOOLS_GPUTOOLS_CUDNN_WRAP(__name) \
struct DynLoadShim__##__name { \
static const char* kName; \
typedef std::add_pointer<decltype(::__name)>::type FuncPointerT; \
- static void* GetDsoHandle() { \
- static auto result = internal::CachedDsoLoader::GetCudnnDsoHandle(); \
- return result.ValueOrDie(); \
- } \
static FuncPointerT DynLoad() { \
static void* f = dlsym(GetDsoHandle(), kName); \
if (f == nullptr) { \
@@ -154,41 +174,53 @@ static port::ThreadPool* GetCudaThreadpool() {
} __name; \
const char* DynLoadShim__##__name::kName = #__name;
-#define CUDNN_DNN_ROUTINE_EACH(__macro) \
- __macro(cudnnSetTensor4dDescriptor) __macro( \
- cudnnGetConvolutionNdForwardOutputDim) \
- __macro(cudnnGetConvolutionForwardAlgorithm) __macro( \
- cudnnCreateTensorDescriptor) __macro(cudnnDestroyTensorDescriptor) \
- __macro(cudnnCreateFilterDescriptor) \
- __macro(cudnnSetFilter4dDescriptor) \
- __macro(cudnnSetPooling2dDescriptor) \
- __macro(cudnnDestroyFilterDescriptor) \
- __macro(cudnnCreateConvolutionDescriptor) \
- __macro(cudnnCreatePoolingDescriptor) \
- __macro(cudnnAddTensor) \
- __macro(cudnnDestroyPoolingDescriptor)
+// clang-format off
+#define CUDNN_DNN_ROUTINE_EACH(__macro) \
+ __macro(cudnnSetTensor4dDescriptor) \
+ __macro(cudnnGetConvolutionNdForwardOutputDim) \
+ __macro(cudnnGetConvolutionForwardAlgorithm) \
+ __macro(cudnnCreateTensorDescriptor) \
+ __macro(cudnnDestroyTensorDescriptor) \
+ __macro(cudnnCreateFilterDescriptor) \
+ __macro(cudnnSetFilter4dDescriptor) \
+ __macro(cudnnSetPooling2dDescriptor) \
+ __macro(cudnnDestroyFilterDescriptor) \
+ __macro(cudnnCreateConvolutionDescriptor) \
+ __macro(cudnnCreatePoolingDescriptor) \
+ __macro(cudnnAddTensor) \
+ __macro(cudnnDestroyPoolingDescriptor) \
+ __macro(cudnnSetConvolution2dDescriptor) \
+ __macro(cudnnDestroyConvolutionDescriptor) \
+ __macro(cudnnCreate) \
+ __macro(cudnnDestroy) \
+ __macro(cudnnSetStream) \
+ __macro(cudnnActivationForward) \
+ __macro(cudnnConvolutionForward) \
+ __macro(cudnnConvolutionBackwardData) \
+ __macro(cudnnConvolutionBackwardFilter) \
+ __macro(cudnnGetConvolutionForwardWorkspaceSize) \
+ __macro(cudnnTransformTensor) \
+ __macro(cudnnPoolingForward) \
+ __macro(cudnnPoolingBackward)
+// clang-format on
CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
-#undef CUDNN_DNN_ROUTINE_EACH
// clang-format off
-#define CUDNN_DNN_ROUTINE_EACH(__macro) \
- __macro(cudnnSetConvolution2dDescriptor) \
- __macro(cudnnDestroyConvolutionDescriptor) \
- __macro(cudnnCreate) \
- __macro(cudnnDestroy) \
- __macro(cudnnSetStream) \
- __macro(cudnnActivationForward) \
- __macro(cudnnConvolutionForward) \
- __macro(cudnnConvolutionBackwardData) \
- __macro(cudnnConvolutionBackwardFilter) \
- __macro(cudnnGetConvolutionForwardWorkspaceSize) \
- __macro(cudnnTransformTensor) \
- __macro(cudnnPoolingForward) \
- __macro(cudnnPoolingBackward)
+#if CUDNN_VERSION >= 3000
+#define CUDNN_DNN_ROUTINE_EACH_R3(__macro) \
+ __macro(cudnnGetConvolutionBackwardFilterWorkspaceSize) \
+ __macro(cudnnGetConvolutionBackwardDataAlgorithm) \
+ __macro(cudnnGetConvolutionBackwardFilterAlgorithm) \
+ __macro(cudnnAddTensor_v3) \
+ __macro(cudnnConvolutionBackwardData_v3) \
+ __macro(cudnnConvolutionBackwardFilter_v3) \
+ __macro(cudnnGetConvolutionBackwardDataWorkspaceSize)
// clang-format on
-CUDNN_DNN_ROUTINE_EACH(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+CUDNN_DNN_ROUTINE_EACH_R3(PERFTOOLS_GPUTOOLS_CUDNN_WRAP)
+#undef CUDNN_DNN_ROUTINE_EACH_R3
+#endif
#undef CUDNN_DNN_ROUTINE_EACH
} // namespace dynload
@@ -467,8 +499,8 @@ bool CudnnSupport::DoConvolve(
const FilterDescriptor& filter_descriptor,
const DeviceMemory<float>& filter_data,
const ConvolutionDescriptor& convolution_descriptor,
- const BatchDescriptor& output_descriptor,
- DeviceMemory<float>* output_data) {
+ const BatchDescriptor& output_descriptor, DeviceMemory<float>* output_data,
+ ScratchAllocator* scratch_allocator) {
ScopedTensorDescriptor input_4d{parent_, batch_descriptor, CUDNN_DATA_FLOAT};
ScopedTensorDescriptor output_4d{parent_, output_descriptor,
CUDNN_DATA_FLOAT};
@@ -486,23 +518,62 @@ bool CudnnSupport::DoConvolve(
// Beta is the scaling factor for output.
float beta = 0.0;
- // The NO_WORKSPACE versions are possibly slower for certain shapes, but
- // not so for the shapes currently used by Brain. Also, it seems prudent to
- // keep cuMemAlloc off the critical path.
- cudnnConvolutionFwdAlgo_t algo;
- status = dynload::cudnnGetConvolutionForwardAlgorithm(
- parent_, ToHandle(dnn_handle_), input_4d.handle(), filter.handle(),
- conv.handle(), output_4d.handle(), CUDNN_CONVOLUTION_FWD_NO_WORKSPACE, 0,
- &algo);
+ auto get_algorithm = [&](bool specify_limit)
+ SHARED_LOCKS_REQUIRED(dnn_handle_mutex_) {
+ cudnnConvolutionFwdPreference_t preference =
+ specify_limit ? CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
+
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
+
+ cudnnConvolutionFwdAlgo_t algo;
+ status = dynload::cudnnGetConvolutionForwardAlgorithm(
+ parent_, ToHandle(dnn_handle_), input_4d.handle(), filter.handle(),
+ conv.handle(), output_4d.handle(),
+ /*preference=*/preference,
+ /*memoryLimitInBytes=*/memory_limit_bytes, /*algo=*/&algo);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
+ "algorithm for doing forward "
+ "convolution";
+ return algo;
+ };
+
+ auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+
+ DeviceMemory<uint8> scratch;
+ if (scratch_allocator != nullptr) {
+ size_t size_in_bytes;
+ status = dynload::cudnnGetConvolutionForwardWorkspaceSize(
+ parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_4d.handle(),
+ /*filterDesc=*/filter.handle(), /*convDesc=*/conv.handle(),
+ /*destDesc=*/output_4d.handle(), /*algo=*/algo,
+ /*sizeInBytes=*/&size_in_bytes);
+ if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
+ scratch =
+ scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie();
+ }
+ }
- CHECK_EQ(status, CUDNN_STATUS_SUCCESS)
- << "Unable to find a suitable algorithm for doing forward convolution";
+ // If we didn't allocate any scratch space (perhaps because of failed
+ // allocation), we force a switch back to the "no workspace" algorithm.
+ if (scratch == nullptr) {
+ algo = get_algorithm(/*specify_limit=*/false);
+ }
status = dynload::cudnnConvolutionForward(
- parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(),
- input_data.opaque(), filter.handle(), filter_data.opaque(), conv.handle(),
- algo, nullptr /* workspace ptr */, 0 /* workspace size */, &beta,
- output_4d.handle(), output_data->opaque());
+ parent_, ToHandle(dnn_handle_),
+ /*alpha=*/&alpha, /*srcDesc=*/input_4d.handle(),
+ /*srcData=*/input_data.opaque(), /*filterDesc=*/filter.handle(),
+ /*filterData=*/filter_data.opaque(), /*convDesc=*/conv.handle(),
+ /*algo=*/algo, /*workSpace=*/scratch.opaque(),
+ /*workSpaceSizeInBytes=*/scratch.size(), /*beta=*/&beta,
+ /*destDesc=*/output_4d.handle(), /*destData=*/output_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to enqueue convolution on stream: "
@@ -565,7 +636,8 @@ bool CudnnSupport::DoConvolveBackwardData(
DeviceMemory<float> backward_output_data,
const ConvolutionDescriptor& convolution_descriptor,
const BatchDescriptor& input_descriptor,
- DeviceMemory<float>* backward_input_data) {
+ DeviceMemory<float>* backward_input_data,
+ ScratchAllocator* scratch_allocator) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
@@ -592,16 +664,101 @@ bool CudnnSupport::DoConvolveBackwardData(
ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT};
ScopedConvolutionDescriptor conv{parent_, convolution_descriptor};
- status = dynload::cudnnConvolutionBackwardData(
- parent_, ToHandle(dnn_handle_), &alpha, filter.handle(),
- filter_data.opaque(), out_back_4d.handle(), backward_output_data.opaque(),
- conv.handle(), &beta, in_back_4d.handle(), backward_input_data->opaque());
+#if CUDNN_VERSION >= 3000
+ if (dynload::IsCudnnR2()) {
+#endif
+ status = dynload::cudnnConvolutionBackwardData(
+ parent_, ToHandle(dnn_handle_), &alpha, filter.handle(),
+ filter_data.opaque(), out_back_4d.handle(),
+ backward_output_data.opaque(), conv.handle(), &beta,
+ in_back_4d.handle(), backward_input_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+#if CUDNN_VERSION >= 3000
+ }
+#endif
+
+#if CUDNN_VERSION >= 3000
+ auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
+ dnn_handle_mutex_) -> cudnnConvolutionBwdDataAlgo_t {
+ cudnnConvolutionBwdDataPreference_t preference =
+ specify_limit ? CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
+
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
+
+ cudnnConvolutionBwdDataAlgo_t algo;
+ cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardDataAlgorithm(
+ parent_, ToHandle(dnn_handle_),
+ /*filterDesc=*/filter.handle(),
+ /*diffDesc=*/out_back_4d.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/in_back_4d.handle(),
+ /*preference=*/preference,
+ /*memoryLimitInBytes=*/memory_limit_bytes,
+ /*algo=*/&algo);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
+ "algorithm for doing backward "
+ "filter convolution";
+ return algo;
+ };
+
+ auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+
+ DeviceMemory<uint8> scratch;
+ if (scratch_allocator != nullptr) {
+ size_t size_in_bytes;
+ status = dynload::cudnnGetConvolutionBackwardDataWorkspaceSize(
+ parent_, ToHandle(dnn_handle_),
+ /*filterDesc=*/filter.handle(),
+ /*diffDesc=*/out_back_4d.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/in_back_4d.handle(),
+ /*algo=*/algo,
+ /*sizeInBytes=*/&size_in_bytes);
+ if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
+ scratch =
+ scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie();
+ }
+ }
+
+ // If we didn't allocate any scratch space (perhaps because of failed
+ // allocation), we force a switch back to the "no workspace" algorithm.
+ if (scratch == nullptr) {
+ algo = get_algorithm(/*specify_limit=*/false);
+ }
+
+ status = dynload::cudnnConvolutionBackwardData_v3(
+ parent_, ToHandle(dnn_handle_),
+ /*alpha=*/&alpha,
+ /*filterDesc=*/filter.handle(),
+ /*filterData=*/filter_data.opaque(),
+ /*diffDesc=*/out_back_4d.handle(),
+ /*diffData=*/backward_output_data.opaque(),
+ /*convDesc=*/conv.handle(),
+ /*algo=*/algo,
+ /*workSpace=*/scratch.opaque(),
+ /*workSpaceSizeInBytes=*/scratch.size(),
+ /*beta=*/&beta,
+ /*gradDesc=*/in_back_4d.handle(),
+ /*gradData=*/backward_input_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to enqueue convolution on stream: "
<< ToString(status);
return false;
}
return true;
+#endif
}
bool CudnnSupport::DoConvolveBackwardFilter(
@@ -611,7 +768,8 @@ bool CudnnSupport::DoConvolveBackwardFilter(
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
- DeviceMemory<float>* backward_filter_data) {
+ DeviceMemory<float>* backward_filter_data,
+ ScratchAllocator* scratch_allocator) {
mutex_lock lock{dnn_handle_mutex_};
auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
AsCUDAStreamValue(stream));
@@ -637,16 +795,100 @@ bool CudnnSupport::DoConvolveBackwardFilter(
ScopedFilterDescriptor filter{parent_, filter_descriptor, CUDNN_DATA_FLOAT};
ScopedConvolutionDescriptor conv{parent_, convolution_descriptor};
- status = dynload::cudnnConvolutionBackwardFilter(
- parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(),
- input_data.opaque(), out_back_4d.handle(), backward_output_data.opaque(),
- conv.handle(), &beta, filter.handle(), backward_filter_data->opaque());
+#if CUDNN_VERSION >= 3000
+ if (dynload::IsCudnnR2()) {
+#endif
+ status = dynload::cudnnConvolutionBackwardFilter(
+ parent_, ToHandle(dnn_handle_), &alpha, input_4d.handle(),
+ input_data.opaque(), out_back_4d.handle(),
+ backward_output_data.opaque(), conv.handle(), &beta, filter.handle(),
+ backward_filter_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(FATAL) << "failed to enqueue convolution on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+#if CUDNN_VERSION >= 3000
+ }
+#endif
+
+#if CUDNN_VERSION >= 3000
+ // Lambda that retrieves the algorithm.
+ // specify_limit will occur when we have a scratch allocator and it succeeds
+ // in allocating; otherwise, we'll fall back to the "no workspace" version.
+ auto get_algorithm = [&](bool specify_limit) SHARED_LOCKS_REQUIRED(
+ dnn_handle_mutex_) {
+ cudnnConvolutionBwdFilterPreference_t preference =
+ specify_limit ? CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT
+ : CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+
+ auto memory_limit_bytes =
+ scratch_allocator == nullptr
+ ? 0
+ : scratch_allocator->GetMemoryLimitInBytes(stream);
+ if (memory_limit_bytes < 0) {
+ memory_limit_bytes = 0;
+ }
+
+ cudnnConvolutionBwdFilterAlgo_t algo;
+ cudnnStatus_t status = dynload::cudnnGetConvolutionBackwardFilterAlgorithm(
+ parent_, ToHandle(dnn_handle_),
+ /*srcDesc=*/input_4d.handle(),
+ /*diffDesc=*/out_back_4d.handle(),
+ /*convDesc=*/conv.handle(),
+ /*gradDesc=*/filter.handle(),
+ /*preference=*/preference,
+ /*memoryLimitInBytes=*/memory_limit_bytes,
+ /*algo=*/&algo);
+ CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << "Unable to find a suitable "
+ "algorithm for doing backward "
+ "filter convolution";
+ return algo;
+ };
+
+ auto algo = get_algorithm(/*specify_limit=*/scratch_allocator != nullptr);
+
+ DeviceMemory<uint8> scratch;
+ if (scratch_allocator != nullptr) {
+ size_t size_in_bytes;
+ status = dynload::cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ parent_, ToHandle(dnn_handle_), /*srcDesc=*/input_4d.handle(),
+ /*diffDesc=*/out_back_4d.handle(), /*convDesc=*/conv.handle(),
+ /*gradDesc=*/filter.handle(), /*algo=*/algo,
+ /*sizeInBytes=*/&size_in_bytes);
+ if (status == CUDNN_STATUS_SUCCESS && size_in_bytes != 0) {
+ scratch =
+ scratch_allocator->AllocateBytes(stream, size_in_bytes).ValueOrDie();
+ }
+ }
+
+ // If we didn't allocate any scratch space (perhaps because of failed
+ // allocation), we force a switch back to the "no workspace" algorithm.
+ if (scratch == nullptr) {
+ algo = get_algorithm(/*specify_limit=*/false);
+ }
+
+ status = dynload::cudnnConvolutionBackwardFilter_v3(
+ parent_, ToHandle(dnn_handle_), /*alpha=*/&alpha,
+ /*srcDesc=*/input_4d.handle(),
+ /*srcData=*/input_data.opaque(),
+ /*diffDesc=*/out_back_4d.handle(),
+ /*diffData=*/backward_output_data.opaque(),
+ /*convDesc=*/conv.handle(),
+ /*algo=*/algo,
+ /*workSpace=*/scratch.opaque(),
+ /*workSpaceSizeInBytes=*/scratch.size(),
+ /*beta=*/&beta,
+ /*gradDesc=*/filter.handle(),
+ /*gradData=*/backward_filter_data->opaque());
if (status != CUDNN_STATUS_SUCCESS) {
LOG(FATAL) << "failed to enqueue convolution on stream: "
<< ToString(status);
return false;
}
return true;
+#endif
}
bool CudnnSupport::DoMatMul(Stream* stream,
@@ -800,7 +1042,7 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
ScopedTensorDescriptor bias_descriptor{parent_, bias_dimensions,
CUDNN_DATA_FLOAT};
- // cudnnAddTensor is in-place, so we need to copy input_data to
+ // cudnnAddTensor_v3 is in-place, so we need to copy input_data to
// output_data before doing the addition, unless the input and
// output are at the same address.
if (input_data.opaque() != output_data->opaque()) {
@@ -815,13 +1057,30 @@ bool CudnnSupport::DoBiasAdd(Stream* stream,
}
mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
+ return false;
+ }
const float alpha = 1.0f;
const float beta = 1.0f;
- auto status = dynload::cudnnAddTensor(
- parent_, ToHandle(dnn_handle_), CUDNN_ADD_SAME_C, &alpha,
- bias_descriptor.handle(), biases.opaque(), &beta,
- input_descriptor.handle(), output_data->opaque());
+#if CUDNN_VERSION >= 3000
+ if (dynload::IsCudnnR2()) {
+#endif
+ status = dynload::cudnnAddTensor(
+ parent_, ToHandle(dnn_handle_), CUDNN_ADD_SAME_C, &alpha,
+ bias_descriptor.handle(), biases.opaque(), &beta,
+ input_descriptor.handle(), output_data->opaque());
+#if CUDNN_VERSION >= 3000
+ } else {
+ status = dynload::cudnnAddTensor_v3(
+ parent_, ToHandle(dnn_handle_), &alpha, bias_descriptor.handle(),
+ biases.opaque(), &beta, input_descriptor.handle(),
+ output_data->opaque());
+ }
+#endif
if (status != CUDNN_STATUS_SUCCESS) {
LOG(ERROR) << "stream " << stream << " could not enqueue bias addition.";
@@ -970,7 +1229,52 @@ bool CudnnSupport::DoDepthConcatenate(
Stream* stream, port::ArraySlice<dnn::BatchDescriptor> input_dimensions,
port::ArraySlice<const DeviceMemory<float>*> input_data,
DeviceMemory<float>* output_data) {
- LOG(FATAL) << "not yet implemented"; // TODO(leary)
+ CHECK_EQ(input_dimensions.size(), input_data.size());
+
+ for (const auto& dimensions : input_dimensions) {
+ if (dimensions.layout() != dnn::DataLayout::kBatchDepthYX) {
+ LOG(ERROR) << "CudnnSupport::DoDepthConcatenate currently only "
+ "supports the kBatchDepthYX layout.";
+ return false;
+ }
+ }
+
+ if (input_dimensions.empty()) {
+ return true; // Nothing to do.
+ }
+
+ dnn::BatchDescriptor output_dimensions =
+ dnn::BatchDescriptor::DepthConcatenateOutputDescriptor(input_dimensions);
+
+ const int64 area = output_dimensions.width() * output_dimensions.height();
+ const auto index = [area](int64 batch, int64 depth, int64 yx,
+ int64 max_depth) {
+ return (batch * max_depth + depth) * area + yx;
+ };
+
+ std::vector<float> output_host(output_dimensions.ElementCount());
+ std::vector<float> tmp;
+ int64 depth_sum = 0;
+ for (size_t i = 0; i < input_data.size(); ++i) {
+ const auto& dimensions = input_dimensions[i];
+ tmp.resize(dimensions.ElementCount());
+ stream->ThenMemcpyD2H<float>(*input_data[i], &tmp).BlockHostUntilDone();
+
+ for (int64 batch = 0; batch < output_dimensions.count(); ++batch) {
+ for (int64 yx = 0; yx < area; ++yx) {
+ for (int64 depth = 0; depth < dimensions.feature_map_count(); ++depth) {
+ LOG(INFO) << output_dimensions.ElementCount() << ' ' << batch << ' '
+ << yx << ' ' << depth;
+ output_host[index(batch, depth + depth_sum, yx,
+ output_dimensions.feature_map_count())] =
+ tmp[index(batch, depth, yx, dimensions.feature_map_count())];
+ }
+ }
+ }
+ depth_sum += dimensions.feature_map_count();
+ }
+ stream->ThenMemcpyH2D<float>(output_host, output_data);
+ return true;
}
bool CudnnSupport::DoElementwiseOperate(
@@ -982,29 +1286,30 @@ bool CudnnSupport::DoElementwiseOperate(
LOG(FATAL) << "not yet implemented"; // TODO(leary)
}
-bool CudnnSupport::DoMemcpyD2HQuantized(
- Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
- port::MutableArraySlice<uint8> host_dst) {
- LOG(ERROR) << "quantized memcpy not supported by cuDNN";
- return false;
+bool CudnnSupport::DoXYPad(
+ Stream* stream, const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data, int64 left_pad, int64 right_pad,
+ int64 top_pad, int64 bottom_pad, DeviceMemory<float> *output_data) {
+ LOG(FATAL) << "not yet implemented"; // TODO(leary)
}
-bool CudnnSupport::DoMemcpyD2HQuantized(
- Stream* stream, const DeviceMemory<float>& device_unquantized_src,
- port::MutableArraySlice<uint16> host_dst) {
- LOG(ERROR) << "quantized memcpy not supported by cuDNN";
- return false;
+bool CudnnSupport::DoXYSlice(
+ Stream* stream, const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data, int64 left_trim, int64 right_trim,
+ int64 top_trim, int64 bottom_trim, DeviceMemory<float> *output_data) {
+ LOG(FATAL) << "not yet implemented"; // TODO(leary)
}
bool CudnnSupport::DoMemcpyD2HQuantized(
- Stream* stream, const DeviceMemory<float>& device_unquantized_src,
- port::MutableArraySlice<int32> host_dst) {
+ Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
+ dnn::QuantizedActivationMode mode, void* host_dst, int64 size) {
LOG(ERROR) << "quantized memcpy not supported by cuDNN";
return false;
}
bool CudnnSupport::DoMemcpyH2DQuantized(
- Stream* stream, port::ArraySlice<uint8> host_src,
+ Stream* stream, const void* host_src, int64 size,
+ dnn::QuantizedActivationMode mode,
DeviceMemory<float>* gpu_unquantized_dst) {
LOG(ERROR) << "quantized memcpy not supported by cuDNN";
return false;
@@ -1074,7 +1379,7 @@ void initialize_cudnn() {
// Prime the cuDNN DSO. The loader will log more information.
auto statusor = gpu::internal::CachedDsoLoader::GetCudnnDsoHandle();
if (!statusor.ok()) {
- LOG(INFO) << "Unable to load cuDNN DSO.";
+ LOG(INFO) << "Unable to load cuDNN DSO";
}
gpu::PluginRegistry::Instance()->SetDefaultFactory(gpu::cuda::kCudaPlatformId,
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 4103cf8e06..37be60ec63 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -50,7 +50,8 @@ class CudnnSupport : public dnn::DnnSupport {
const DeviceMemory<float>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<float>* output_data) override;
+ DeviceMemory<float>* output_data,
+ ScratchAllocator* scratch_allocator) override;
bool DoConvolve(Stream* stream, const dnn::BatchDescriptor& batch_descriptor,
const DeviceMemory<double>& input_data,
@@ -80,7 +81,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& input_descriptor,
- DeviceMemory<float>* backward_input_data) override;
+ DeviceMemory<float>* backward_input_data,
+ ScratchAllocator* scratch_allocator) override;
bool DoConvolveBackwardFilter(
Stream* stream, const dnn::BatchDescriptor& input_descriptor,
@@ -89,7 +91,8 @@ class CudnnSupport : public dnn::DnnSupport {
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::FilterDescriptor& filter_descriptor,
- DeviceMemory<float>* backward_filter_data) override;
+ DeviceMemory<float>* backward_filter_data,
+ ScratchAllocator* scratch_allocator) override;
bool DoMatMul(Stream* stream, const DeviceMemory<float>& input_data,
const DeviceMemory<float>& weights,
@@ -160,20 +163,24 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) override;
- bool DoMemcpyD2HQuantized(Stream* stream,
- const DeviceMemory<float>& device_unquantized_src,
- port::MutableArraySlice<uint8> host_dst) override;
+ bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ int64 left_pad, int64 right_pad, int64 top_pad,
+ int64 bottom_pad, DeviceMemory<float> *output_data) override;
- bool DoMemcpyD2HQuantized(Stream* stream,
- const DeviceMemory<float>& device_unquantized_src,
- port::MutableArraySlice<uint16> host_dst) override;
+ bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ int64 left_trim, int64 right_trim, int64 top_trim,
+ int64 bottom_trim, DeviceMemory<float> *output_data) override;
bool DoMemcpyD2HQuantized(Stream* stream,
const DeviceMemory<float>& device_unquantized_src,
- port::MutableArraySlice<int32> host_dst) override;
+ dnn::QuantizedActivationMode mode, void* host_dst,
+ int64 size) override;
bool DoMemcpyH2DQuantized(
- Stream* stream, port::ArraySlice<uint8> host_src,
+ Stream* stream, const void* host_src, int64 size,
+ dnn::QuantizedActivationMode mode,
DeviceMemory<float>* device_unquantized_dst) override;
// Derives an output batch descriptor from an input batch and convolution
diff --git a/tensorflow/stream_executor/cuda/cuda_event.cc b/tensorflow/stream_executor/cuda/cuda_event.cc
index c952257aa5..ad7c0bbf3b 100644
--- a/tensorflow/stream_executor/cuda/cuda_event.cc
+++ b/tensorflow/stream_executor/cuda/cuda_event.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_event.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/lib/statusor.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.cc b/tensorflow/stream_executor/cuda/cuda_fft.cc
index 4ee549ab03..b5740a1d7b 100644
--- a/tensorflow/stream_executor/cuda/cuda_fft.cc
+++ b/tensorflow/stream_executor/cuda/cuda_fft.cc
@@ -22,7 +22,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
-#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/initialize.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 93dc90635e..2565078bb2 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -20,7 +20,7 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
#include "tensorflow/stream_executor/cuda/cuda_event.h"
-#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/cuda/cuda_timer.h"
#include "tensorflow/stream_executor/dso_loader.h"
@@ -88,20 +88,6 @@ static CUDAEvent *AsCUDAEvent(Event *event) {
return static_cast<CUDAEvent *>(event->implementation());
}
-// Given a platform-independent stream datatype, returns the internal CUDA
-// platform implementation pointer.
-static CUDAStream *AsCUDAStream(Stream *stream) {
- DCHECK(stream != nullptr);
- return static_cast<CUDAStream *>(stream->implementation());
-}
-
-// Given a platform-independent stream datatype, returns the platform
-// implementation's internal value, suitable for passing directly to libcuda
-// APIs.
-CUstream AsCUDAStreamValue(Stream *stream) {
- DCHECK(stream != nullptr);
- return AsCUDAStream(stream)->cuda_stream();
-}
// Given a platform-independent timer datatype, returns the internal CUDA
// platform implementation pointer.
@@ -861,6 +847,26 @@ bool CUDAExecutor::SupportsFft() const { return true; }
bool CUDAExecutor::SupportsRng() const { return true; }
+std::unique_ptr<internal::EventInterface>
+CUDAExecutor::CreateEventImplementation() {
+ return std::unique_ptr<internal::EventInterface>(new CUDAEvent(this));
+}
+
+std::unique_ptr<internal::KernelInterface>
+CUDAExecutor::CreateKernelImplementation() {
+ return std::unique_ptr<internal::KernelInterface>(new CUDAKernel());
+}
+
+std::unique_ptr<internal::StreamInterface>
+CUDAExecutor::GetStreamImplementation() {
+ return std::unique_ptr<internal::StreamInterface>(new CUDAStream(this));
+}
+
+std::unique_ptr<internal::TimerInterface>
+CUDAExecutor::GetTimerImplementation() {
+ return std::unique_ptr<internal::TimerInterface>(new CUDATimer(this));
+}
+
void *CUDAExecutor::CudaContextHack() { return context_; }
CUcontext CUDAExecutor::cuda_context() { return context_; }
@@ -1064,30 +1070,6 @@ void initialize_cuda_gpu_executor() {
const gpu::PluginConfig &config) {
return new gpu::cuda::CUDAExecutor{config};
};
-
- *gpu::internal::MakeCUDAKernelImplementation() = []() {
- return new gpu::cuda::CUDAKernel;
- };
-
- *gpu::internal::MakeCUDAEventImplementation() = [](
- gpu::StreamExecutor *parent) {
- gpu::cuda::CUDAExecutor *cuda_executor =
- static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
- return new gpu::cuda::CUDAEvent{cuda_executor};
- };
-
- *gpu::internal::MakeCUDAStreamImplementation() = [](
- gpu::StreamExecutor *parent) {
- gpu::cuda::CUDAExecutor *cuda_executor =
- static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
- return new gpu::cuda::CUDAStream{cuda_executor};
- };
- *gpu::internal::MakeCUDATimerImplementation() = [](
- gpu::StreamExecutor *parent) {
- gpu::cuda::CUDAExecutor *cuda_executor =
- static_cast<gpu::cuda::CUDAExecutor *>(parent->implementation());
- return new gpu::cuda::CUDATimer{cuda_executor};
- };
}
} // namespace gputools
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 01ccf82ec6..2a0c6dc456 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -203,6 +203,16 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
dnn::DnnSupport *CreateDnn() override;
+ std::unique_ptr<internal::EventInterface> CreateEventImplementation()
+ override;
+
+ std::unique_ptr<internal::KernelInterface> CreateKernelImplementation()
+ override;
+
+ std::unique_ptr<internal::StreamInterface> GetStreamImplementation() override;
+
+ std::unique_ptr<internal::TimerInterface> GetTimerImplementation() override;
+
void *CudaContextHack() override;
CUcontext cuda_context();
diff --git a/tensorflow/stream_executor/cuda/cuda_helpers.h b/tensorflow/stream_executor/cuda/cuda_helpers.h
index dad62900f1..c52516c589 100644
--- a/tensorflow/stream_executor/cuda/cuda_helpers.h
+++ b/tensorflow/stream_executor/cuda/cuda_helpers.h
@@ -30,7 +30,6 @@ limitations under the License.
namespace perftools {
namespace gputools {
-class Stream;
template <typename ElemT>
class DeviceMemory;
@@ -51,8 +50,6 @@ T *CUDAMemoryMutable(DeviceMemory<T> *mem) {
return static_cast<T *>(mem->opaque());
}
-CUstream AsCUDAStreamValue(Stream *stream);
-
static_assert(sizeof(std::complex<float>) == sizeof(cuComplex),
"std::complex<float> and cuComplex should have the same size");
static_assert(offsetof(cuComplex, x) == 0,
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc
index 172f9e2cf8..f658ce216d 100644
--- a/tensorflow/stream_executor/cuda/cuda_platform.cc
+++ b/tensorflow/stream_executor/cuda/cuda_platform.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_platform.h"
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
#include "tensorflow/stream_executor/lib/ptr_util.h"
@@ -26,8 +28,6 @@ namespace perftools {
namespace gputools {
namespace cuda {
-PLATFORM_DEFINE_ID(kCudaPlatformId);
-
CudaPlatform::CudaPlatform()
: name_("CUDA"), min_numa_node_(0), limit_numa_node_(0) {}
@@ -147,8 +147,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
- auto executor = port::MakeUnique<StreamExecutor>(PlatformKind::kCuda,
- config.plugin_config);
+ auto executor = port::MakeUnique<StreamExecutor>(
+ this, new CUDAExecutor(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/stream_executor/cuda/cuda_platform_id.cc b/tensorflow/stream_executor/cuda/cuda_platform_id.cc
new file mode 100644
index 0000000000..09ece156d2
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_platform_id.cc
@@ -0,0 +1,26 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+PLATFORM_DEFINE_ID(kCudaPlatformId);
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_platform_id.h b/tensorflow/stream_executor/cuda/cuda_platform_id.h
new file mode 100644
index 0000000000..c91ccc0e44
--- /dev/null
+++ b/tensorflow/stream_executor/cuda/cuda_platform_id.h
@@ -0,0 +1,36 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
+#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
+
+#include "tensorflow/stream_executor/platform.h"
+
+namespace perftools {
+namespace gputools {
+namespace cuda {
+
+// Opaque and unique identifier for the cuda platform.
+// This is needed so that plugins can refer to/identify this platform without
+// instantiating a CudaPlatform object.
+// This is broken out here to avoid a circular dependency between CudaPlatform
+// and CudaExecutor.
+extern const Platform::Id kCudaPlatformId;
+
+} // namespace cuda
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_PLATFORM_ID_H_
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc
index 3d244ed2e7..220ca85df9 100644
--- a/tensorflow/stream_executor/cuda/cuda_rng.cc
+++ b/tensorflow/stream_executor/cuda/cuda_rng.cc
@@ -20,7 +20,8 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/cuda/cuda_helpers.h"
-#include "tensorflow/stream_executor/cuda/cuda_platform.h"
+#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
+#include "tensorflow/stream_executor/cuda/cuda_stream.h"
#include "tensorflow/stream_executor/device_memory.h"
#include "tensorflow/stream_executor/dso_loader.h"
#include "tensorflow/stream_executor/lib/initialize.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_stream.cc b/tensorflow/stream_executor/cuda/cuda_stream.cc
index c2d8f95e9c..3bc42982e0 100644
--- a/tensorflow/stream_executor/cuda/cuda_stream.cc
+++ b/tensorflow/stream_executor/cuda/cuda_stream.cc
@@ -15,7 +15,9 @@ limitations under the License.
#include "tensorflow/stream_executor/cuda/cuda_stream.h"
+#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
#include "tensorflow/stream_executor/lib/status.h"
+#include "tensorflow/stream_executor/stream.h"
namespace perftools {
namespace gputools {
@@ -61,6 +63,16 @@ bool CUDAStream::GetOrCreateCompletedEvent(CUevent *completed_event) {
return true;
}
+CUDAStream *AsCUDAStream(Stream *stream) {
+ DCHECK(stream != nullptr);
+ return static_cast<CUDAStream *>(stream->implementation());
+}
+
+CUstream AsCUDAStreamValue(Stream *stream) {
+ DCHECK(stream != nullptr);
+ return AsCUDAStream(stream)->cuda_stream();
+}
+
} // namespace cuda
} // namespace gputools
} // namespace perftools
diff --git a/tensorflow/stream_executor/cuda/cuda_stream.h b/tensorflow/stream_executor/cuda/cuda_stream.h
index fda93a1a01..cb12a094e6 100644
--- a/tensorflow/stream_executor/cuda/cuda_stream.h
+++ b/tensorflow/stream_executor/cuda/cuda_stream.h
@@ -20,7 +20,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_STREAM_H_
#include "tensorflow/stream_executor/cuda/cuda_driver.h"
-#include "tensorflow/stream_executor/cuda/cuda_gpu_executor.h"
+#include "tensorflow/stream_executor/platform/thread_annotations.h"
#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools {
@@ -82,6 +82,13 @@ class CUDAStream : public internal::StreamInterface {
CUevent completed_event_ GUARDED_BY(mu_);
};
+// Helper functions to simplify extremely common flows.
+// Converts a Stream to the underlying CUDAStream implementation.
+CUDAStream *AsCUDAStream(Stream *stream);
+
+// Extracts a CUstream from a CUDAStream-backed Stream object.
+CUstream AsCUDAStreamValue(Stream *stream);
+
} // namespace cuda
} // namespace gputools
} // namespace perftools
diff --git a/tensorflow/stream_executor/device_memory.h b/tensorflow/stream_executor/device_memory.h
index b5125c27e1..54f348c094 100644
--- a/tensorflow/stream_executor/device_memory.h
+++ b/tensorflow/stream_executor/device_memory.h
@@ -111,6 +111,7 @@ class DeviceMemory final : public DeviceMemoryBase {
public:
// Default constructor instantiates a null-pointed, zero-sized memory region.
DeviceMemory() : DeviceMemoryBase(nullptr, 0) {}
+ DeviceMemory(std::nullptr_t) : DeviceMemory() {}
// Typed device memory regions may be constructed from untyped device memory
// regions, this effectively amounts to a cast from a void*.
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 97a15dd263..fbc9342081 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -22,6 +22,20 @@ namespace perftools {
namespace gputools {
namespace dnn {
+string QuantizedActivationModeString(QuantizedActivationMode mode) {
+ switch (mode) {
+ case dnn::QuantizedActivationMode::k8Bit:
+ return "uint8";
+ case dnn::QuantizedActivationMode::k16Bit:
+ return "uint16";
+ case dnn::QuantizedActivationMode::k32Bit:
+ return "int32";
+ default:
+ LOG(FATAL) << "Unknown quantized_activation_mode "
+ << static_cast<int32>(mode);
+ }
+}
+
string ActivationModeString(ActivationMode mode) {
switch (mode) {
case ActivationMode::kSigmoid:
@@ -78,6 +92,17 @@ string FilterLayoutString(FilterLayout layout) {
}
}
+string ShortPoolingModeString(PoolingMode mode) {
+ switch (mode) {
+ case PoolingMode::kMaximum:
+ return "Max";
+ case PoolingMode::kAverage:
+ return "Avg";
+ default:
+ LOG(FATAL) << "Unknown filter layout " << static_cast<int32>(mode);
+ }
+}
+
// -- BatchDescriptor
BatchDescriptor::BatchDescriptor()
@@ -137,7 +162,6 @@ string BatchDescriptor::ToShortString() const {
return port::StrCat(batch, depth, y, x, suffix);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout());
- return ""; // Avoid lack-of-return warning
}
}
@@ -160,6 +184,20 @@ int64 BatchDescriptor::FullyConnectedBiasCount(const BatchDescriptor& output) {
return output.NodesAcrossFeatureMaps();
}
+BatchDescriptor BatchDescriptor::DepthConcatenateOutputDescriptor(
+ port::ArraySlice<dnn::BatchDescriptor> inputs) {
+ if (inputs.empty()) {
+ return BatchDescriptor();
+ }
+ int feature_map_count = 0;
+ for (const auto& dimensions : inputs) {
+ feature_map_count += dimensions.feature_map_count();
+ }
+ BatchDescriptor output = inputs[0];
+ output.set_feature_map_count(feature_map_count);
+ return output;
+}
+
// -- FilterDescriptor
FilterDescriptor::FilterDescriptor()
@@ -205,7 +243,6 @@ string FilterDescriptor::ToShortString() const {
return port::StrCat(y, x, id, od);
default:
LOG(FATAL) << "Unknown layout " << static_cast<int32>(layout_);
- return ""; // Avoid lack-of-return warning
}
}
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 7d1dfe3d0e..237f60c6ca 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -32,6 +32,7 @@ namespace perftools {
namespace gputools {
class Stream;
+class ScratchAllocator;
namespace dnn {
@@ -55,6 +56,9 @@ enum class QuantizedActivationMode {
k32Bit = 4,
};
+// Returns a string representation of the given quantization mode.
+string QuantizedActivationModeString(QuantizedActivationMode mode);
+
// Describes the dimensions that a layer consumes/produces.
//
// This is a matrix (height, width), its "depth" (feature_map_count),
@@ -175,6 +179,13 @@ class BatchDescriptor {
// with dimensions given the 'output' descriptor.
static int64 FullyConnectedBiasCount(const BatchDescriptor& output);
+ // Return a BatchDescriptor for the output of a depth concatenation
+ // with the given input descriptors. The inputs should have the same
+ // dimensions, except possibly for feature_map_count(), though this
+ // function does not verify that.
+ static BatchDescriptor DepthConcatenateOutputDescriptor(
+ port::ArraySlice<dnn::BatchDescriptor> inputs);
+
private:
int64 count_;
int64 feature_map_count_;
@@ -280,8 +291,6 @@ class FilterDescriptor {
int64 input_filter_height_;
int64 input_filter_width_;
FilterLayout layout_;
-
- SE_DISALLOW_COPY_AND_ASSIGN(FilterDescriptor);
};
// Describes a convolution.
@@ -356,6 +365,9 @@ enum class PoolingMode : int64 {
kAverage,
};
+// Returns a short name for the pooling mode, e.g. "Avg".
+string ShortPoolingModeString(PoolingMode mode);
+
// Describes a pooling operation to be enqueued onto a stream via a platform's
// DnnSupport.
//
@@ -423,18 +435,31 @@ class PoolingDescriptor {
int64 horizontal_padding_;
int64 vertical_stride_;
int64 horizontal_stride_;
-
- SE_DISALLOW_COPY_AND_ASSIGN(PoolingDescriptor);
};
-// Describes a dist_belief local response normalization.
-// The normalization equation is:
-// y_i = x_i / (bias + alpha * (sum_j_{i - range}^{i + range} x_j^2)) ^ beta
-// where x_i is the input in feature map i, y_i is the output.
-// Each feature map is split into segment_size segments for performing the
-// sum_j_. If wrap_around is true, the sum_j_ for y_i on the left and right of
-// a segment wrap around at the edges of the segment, if wrap_around is false
-// zeros are inserted instead.
+// Describes a local response normalization (LRN). LRN is used e.g. in
+// dist_belief.
+//
+// Let V be the vector of feature maps at some (batch, y, x)
+// coordinate. LRN applies independently to each vector V in the
+// input, across all coordinates (batch, y, x), by mapping each V to
+// another vector U of the same size using the formula
+//
+// V_i = U_i / ((bias + alpha * (sum_j U_j^2)) ^ beta)
+//
+// where the sum is taken for j in the inclusive range [i - range, i + range].
+//
+// When calculating V_i the j in the sum can extend beyond the bounds
+// of U. If wrap_around is true, then U_j = U_{j mod F} where F is the
+// size of U, which is the number of feature maps. If wrap_around is
+// false, then U_j = 0 for j outside [0, F-1].
+//
+// If segment_size <= F, where F is the number of feature_maps, then
+// segment_size has no effect. Otherwise, each consecutive segment of
+// segment_size entries in V are normalized separately.
+//
+// Not all StreamExecutors allow wrap_around == true or segment_size
+// != 64. Some do not implement normalization at all.
class NormalizeDescriptor {
public:
NormalizeDescriptor();
@@ -488,8 +513,6 @@ class NormalizeDescriptor {
float beta_;
bool wrap_around_;
int32 segment_size_;
-
- SE_DISALLOW_COPY_AND_ASSIGN(NormalizeDescriptor);
};
// Describes a kind of non-linearity (threshold-like mathematical function).
@@ -503,6 +526,8 @@ enum class ActivationMode {
// BatchDescriptor::value_max().
kReluX,
kTanh,
+ // Like ReluX, but passes all values in the range [-X,X].
+ kBandPass,
};
// Returns a string representation of the given activation mode.
@@ -510,10 +535,7 @@ string ActivationModeString(ActivationMode mode);
// Describes the operation that DoElementwiseOperation should perform on its
// inputs.
-enum class ElementwiseOperation {
- kAdd,
- kMultiply
-};
+enum class ElementwiseOperation { kAdd, kMultiply };
string ElementwiseOperationString(ElementwiseOperation op);
@@ -541,6 +563,8 @@ class DnnSupport {
// output_descriptor: dimensions of the output layer.
// output_data: un-owned device memory region in which to place the
// convolution result.
+ // scratch_allocator: un-owned, may-be-null object that may allocate scratch
+ // space in order to speed up the convolution operation.
//
// input_descriptor, filter_descriptor, convolution_descriptor and
// output_descriptor together specify exactly how the convolution is aligned
@@ -564,7 +588,8 @@ class DnnSupport {
const DeviceMemory<float>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<float>* output_data) = 0;
+ DeviceMemory<float>* output_data,
+ ScratchAllocator* scratch_allocator) = 0;
// Enqueues a double-precision convolution operation onto the stream.
// See DoConvolve above for argument details.
@@ -612,6 +637,8 @@ class DnnSupport {
// input_descriptor: dimensions of the input layer.
// backward_input_data: un-owned device memory region in which to place the
// backprop of the input.
+ // scratch_allocator: un-owned, may-be-null object that may allocate scratch
+ // space in order to speed up the convolution operation.
virtual bool DoConvolveBackwardData(
Stream* stream, const FilterDescriptor& filter_descriptor,
const DeviceMemory<float>& filter_data,
@@ -619,7 +646,8 @@ class DnnSupport {
DeviceMemory<float> backward_output_data,
const ConvolutionDescriptor& convolution_descriptor,
const BatchDescriptor& input_descriptor,
- DeviceMemory<float>* backward_input_data) = 0;
+ DeviceMemory<float>* backward_input_data,
+ ScratchAllocator* scratch_allocator) = 0;
// Enqueues a single-precision backward convolution (for filter) operation
// onto
@@ -640,6 +668,8 @@ class DnnSupport {
// filter_descriptor: dimensions of the convolution filter.
// backward_filter_data: un-owned device memory region in which to place the
// backprop of the filter.
+ // scratch_allocator: un-owned, may-be-null object that may allocate scratch
+ // space in order to speed up the convolution operation.
virtual bool DoConvolveBackwardFilter(
Stream* stream, const BatchDescriptor& input_descriptor,
const DeviceMemory<float>& input_data,
@@ -647,7 +677,8 @@ class DnnSupport {
DeviceMemory<float> backward_output_data,
const ConvolutionDescriptor& convolution_descriptor,
const FilterDescriptor& filter_descriptor,
- DeviceMemory<float>* backward_filter_data) = 0;
+ DeviceMemory<float>* backward_filter_data,
+ ScratchAllocator* scratch_allocator) = 0;
// Fully connects the "nodes" (float values) in input_data with
// shape input_dimensions to output_data with output_dimensions
@@ -784,8 +815,10 @@ class DnnSupport {
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) = 0;
- // Applies local response normalization to all of the values
- // held on the device in 'input_data'.
+ // Applies local response normalization to the values from
+ // input_data and writes the result to output_data. See comments on
+ // NormalizeDescriptor for a description of local response
+ // normalization.
virtual bool DoNormalize(Stream* stream,
const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data,
@@ -850,6 +883,46 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) = 0;
+ // Pads the input with zeros in the X and Y dimensions. The feature_map
+ // dimension is unchanged.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'elementwise operation'
+ // should be enqueued onto.
+ // dimensions: The dimensions of the input.
+ // input_data: un-owned device memory region which contains the
+ // input data for the input layer.
+ // left_pad: Amount to pad the input on the left.
+ // right_pad: Amount to pad the input on the right.
+ // top_pad: Amount to pad the input at the top (low Y).
+ // bottom_pad: Amount to pad the input at the bottom (high Y).
+ // output_data: un-owned device memory region in which to place the
+ // padded result.
+ virtual bool DoXYPad(Stream* stream, const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ int64 left_pad, int64 right_pad, int64 top_pad,
+ int64 bottom_pad, DeviceMemory<float> *output_data) = 0;
+
+ // Extracts a slice of the input in the X and Y dimensions. The feature_map
+ // dimension is unchanged.
+ //
+ // Arguments (all borrowed):
+ // stream: borrowed pointer to the stream that the 'elementwise operation'
+ // should be enqueued onto.
+ // dimensions: The dimensions of the input.
+ // input_data: un-owned device memory region which contains the
+ // input data for the input layer.
+ // left_trim: Amount to cut off the input on the left.
+ // right_trim: Amount to cut off the input on the right.
+ // top_trim: Amount to cut off the input at the top (low y).
+ // bottom_trim: Amount to cut off the input at the bottom (high Y).
+ // output_data: un-owned device memory region in which to place the
+ // padded result.
+ virtual bool DoXYSlice(Stream* stream, const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ int64 left_trim, int64 right_trim, int64 top_trim,
+ int64 bottom_trim, DeviceMemory<float> *output_data) = 0;
+
// Enqueues an asynchronous memcpy of the *quantized* output of a layer (that
// is, bytes instead of scaled floats) into 'host_dst' if they are available
// for the underlying DNN implementation. If this quantized output is not
@@ -862,23 +935,14 @@ class DnnSupport {
// gpu_unquantized_src: the device memory that contains the unquantized data
// -- this data should also have a corresponding quantized representation
// on the device for this operation to succeed.
+ // mode: Type of quantization of the data to write into host_dst.
// host_dst: un-owned host memory region that is mutated in place,
// it is clobbered by the values in 'gpu_unquantized_src' when the enqueued
// (asynchronous) memcpy operation is performed.
- // TODO(wgulland) Merge all these versions of DoMemcpyD2HQuantized.
- virtual bool DoMemcpyD2HQuantized(
- Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
- port::MutableArraySlice<uint8> host_dst) = 0;
-
- // As above, but for 16-bit values.
- virtual bool DoMemcpyD2HQuantized(
- Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
- port::MutableArraySlice<uint16> host_dst) = 0;
-
- // As above, but for signed 32-bit values.
+ // size: size in bytes of the host_dst host memory region.
virtual bool DoMemcpyD2HQuantized(
Stream* stream, const DeviceMemory<float>& gpu_unquantized_src,
- port::MutableArraySlice<int32> host_dst) = 0;
+ QuantizedActivationMode mode, void* host_dst, int64 size) = 0;
// Enqueues an asynchronous memcpy of 'host_dst' into the *quantized* input
// of a layer (that is, bytes instead of scaled floats) if they are supported
@@ -890,13 +954,16 @@ class DnnSupport {
// stream: borrowed pointer to the stream that the 'quantized memcpy'
// operation should be enqueued onto.
// host_src: un-owned host memory region that contains the quantized data.
+ // size: size in bytes of the host_src host memory region.
+ // mode: Type of quantization of the data to read from host_src.
// gpu_unquantized_dst: the device memory that is clobbered by the values in
// 'host_src' when the enqueued (asynchronous) memcpy operation is
// performed. -- this data should also have a corresponding quantized
// representation on the device for this operation to
// succeed.
virtual bool DoMemcpyH2DQuantized(
- Stream* stream, port::ArraySlice<uint8> host_src,
+ Stream* stream, const void* host_src, int64 size,
+ QuantizedActivationMode mode,
DeviceMemory<float>* gpu_unquantized_dst) = 0;
private:
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index c8e1d7fa48..600f083840 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -42,11 +42,12 @@ namespace internal {
}
/* static */ port::Status DsoLoader::GetCudnnDsoHandle(void** dso_handle) {
- // libcudnn is versioned differently than the other libraries. See b/22397368
- // for some details about the complications surrounding this.
- return GetDsoHandle(FindDsoPath("libcudnn.so.6.5",
- "third_party/gpus/cuda/lib64"),
- dso_handle);
+ // libcudnn is versioned differently than the other libraries and may have a
+ // different version number than other CUDA libraries. See b/22397368 for
+ // some details about the complications surrounding this.
+ return GetDsoHandle(
+ FindDsoPath("libcudnn.so.6.5", "third_party/gpus/cuda/lib64"),
+ dso_handle);
}
/* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) {
@@ -89,16 +90,16 @@ namespace internal {
string path_string = path.ToString();
*dso_handle = dlopen(path_string.c_str(), dynload_flags);
if (*dso_handle == nullptr) {
- LOG(INFO) << "LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
+ LOG(INFO) << "Couldn't open CUDA library " << path
+ << ". LD_LIBRARY_PATH: " << getenv("LD_LIBRARY_PATH");
// TODO(b/22689637): Eliminate unnecessary ToString once StrCat has been
// moved to the open-sourceable version.
return port::Status(
port::error::FAILED_PRECONDITION,
port::StrCat("could not dlopen DSO: ", path, "; dlerror: ", dlerror()));
}
-
- VLOG(2) << "loaded path \"" << path << "\" "
- << (load_kind == LoadKind::kLocal ? "locally" : "globally");
+ LOG(INFO) << "successfully opened CUDA library " << path
+ << (load_kind == LoadKind::kLocal ? " locally" : " globally");
return port::Status::OK();
}
diff --git a/tensorflow/stream_executor/event.cc b/tensorflow/stream_executor/event.cc
index 5fdd7b9021..5ded7c590b 100644
--- a/tensorflow/stream_executor/event.cc
+++ b/tensorflow/stream_executor/event.cc
@@ -22,21 +22,10 @@ limitations under the License.
namespace perftools {
namespace gputools {
-internal::EventInterface* CreateEventImplementation(
- StreamExecutor* stream_exec) {
- PlatformKind platform_kind = stream_exec->platform_kind();
- switch (platform_kind) {
- case PlatformKind::kCuda:
- return (*internal::MakeCUDAEventImplementation())(stream_exec);
- default:
- LOG(FATAL) << "Cannot create event implementation for platform kind: "
- << PlatformKindString(platform_kind);
- }
-}
-
Event::Event(StreamExecutor* stream_exec)
- : implementation_(CreateEventImplementation(stream_exec)),
- stream_exec_(stream_exec) {}
+ : stream_exec_(stream_exec),
+ implementation_(
+ stream_exec_->implementation()->CreateEventImplementation()) {}
Event::~Event() {
auto status = stream_exec_->DeallocateEvent(this);
diff --git a/tensorflow/stream_executor/event.h b/tensorflow/stream_executor/event.h
index 42827e96aa..4b95889547 100644
--- a/tensorflow/stream_executor/event.h
+++ b/tensorflow/stream_executor/event.h
@@ -18,6 +18,8 @@ limitations under the License.
#include <memory>
+#include "tensorflow/stream_executor/platform/port.h"
+
namespace perftools {
namespace gputools {
@@ -63,13 +65,15 @@ class Event {
private:
friend class Stream;
+ // Pointer to the StreamExecutor interface used to create this object.
+ // Not owned.
+ StreamExecutor* stream_exec_;
+
// Pointer to the platform-specific EventInterface implementation underlying
// the object. Owned.
std::unique_ptr<internal::EventInterface> implementation_;
- // Pointer to the StreamExecutor interface used to create this object.
- // Not owned.
- StreamExecutor* stream_exec_;
+ SE_DISALLOW_COPY_AND_ASSIGN(Event);
};
} // namespace gputools
diff --git a/tensorflow/stream_executor/fft.h b/tensorflow/stream_executor/fft.h
index 69b49ae92e..505e0e4d13 100644
--- a/tensorflow/stream_executor/fft.h
+++ b/tensorflow/stream_executor/fft.h
@@ -161,7 +161,7 @@ class FftSupport {
// Macro used to quickly declare overrides for abstract virtuals in the
// fft::FftSupport base class. Assumes that it's emitted somewhere inside the
// ::perftools::gputools namespace.
-#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
+#define TENSORFLOW_STREAM_EXECUTOR_GPU_FFT_SUPPORT_OVERRIDES \
std::unique_ptr<fft::Plan> Create1dPlan(Stream *stream, uint64 num_x, \
fft::Type type, bool in_place_fft) \
override; \
diff --git a/tensorflow/stream_executor/gcuda.h b/tensorflow/stream_executor/gcuda.h
index 4710e30009..4ed0e3d8d0 100644
--- a/tensorflow/stream_executor/gcuda.h
+++ b/tensorflow/stream_executor/gcuda.h
@@ -87,6 +87,8 @@ enum SharedMemConfig {
#include "tensorflow/stream_executor/kernel_cache_config.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/machine_manager.h"
+#include "tensorflow/stream_executor/multi_platform_manager.h"
+#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/shared_memory_config.h"
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor.h"
diff --git a/tensorflow/stream_executor/kernel.cc b/tensorflow/stream_executor/kernel.cc
index ee0b706eef..64a4e6f49e 100644
--- a/tensorflow/stream_executor/kernel.cc
+++ b/tensorflow/stream_executor/kernel.cc
@@ -25,7 +25,6 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/stream_executor.h"
-#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools {
namespace gputools {
@@ -58,29 +57,13 @@ void KernelMetadata::set_shared_memory_bytes(int shared_memory_bytes) {
has_shared_memory_bytes_ = true;
}
-static internal::KernelInterface *KernelImplementationFromPlatformKind(
- PlatformKind platform_kind) {
- if (platform_kind == PlatformKind::kCuda) {
- return (*internal::MakeCUDAKernelImplementation())();
- } else if (platform_kind == PlatformKind::kOpenCL ||
- platform_kind == PlatformKind::kOpenCLAltera) {
- return (*internal::MakeOpenCLKernelImplementation())();
- } else {
- LOG(FATAL) << "cannot create kernel implementation for platform kind: "
- << PlatformKindString(platform_kind);
- }
-}
-
KernelBase::KernelBase(StreamExecutor *parent)
- : implementation_(
- KernelImplementationFromPlatformKind(parent->platform_kind())),
- parent_(parent) {
- DCHECK(parent_ != nullptr);
-}
+ : parent_(parent),
+ implementation_(parent->implementation()->CreateKernelImplementation()) {}
KernelBase::KernelBase(StreamExecutor *parent,
internal::KernelInterface *implementation)
- : implementation_(implementation), parent_(parent) {}
+ : parent_(parent), implementation_(implementation) {}
KernelBase::~KernelBase() {}
diff --git a/tensorflow/stream_executor/kernel.h b/tensorflow/stream_executor/kernel.h
index 1346fd7bfa..16ffeb2a57 100644
--- a/tensorflow/stream_executor/kernel.h
+++ b/tensorflow/stream_executor/kernel.h
@@ -141,7 +141,6 @@ class KernelBase {
explicit KernelBase(StreamExecutor *parent);
// Test-only constructor that can take a mock KernelInterface implementation.
- // Takes ownership of implementation, it should not be null.
KernelBase(StreamExecutor *parent, internal::KernelInterface *implementation);
// Releases resources associated with the kernel instance (i.e.
@@ -181,12 +180,12 @@ class KernelBase {
const string &demangled_name() const { return demangled_name_; }
private:
- // Implementation delegated to for platform-specific functionality.
- std::unique_ptr<internal::KernelInterface> implementation_;
-
// The StreamExecutor that loads this kernel object.
StreamExecutor *parent_;
+ // Implementation delegated to for platform-specific functionality.
+ std::unique_ptr<internal::KernelInterface> implementation_;
+
string name_;
string demangled_name_;
diff --git a/tensorflow/stream_executor/lib/casts.h b/tensorflow/stream_executor/lib/casts.h
index 3edfcde8bb..19087ff7dc 100644
--- a/tensorflow/stream_executor/lib/casts.h
+++ b/tensorflow/stream_executor/lib/casts.h
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
+
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_CASTS_H_
diff --git a/tensorflow/stream_executor/lib/error.h b/tensorflow/stream_executor/lib/error.h
index 368a8af79c..5aac04f525 100644
--- a/tensorflow/stream_executor/lib/error.h
+++ b/tensorflow/stream_executor/lib/error.h
@@ -13,10 +13,12 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
+
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ERROR_H_
-#include "tensorflow/core/lib/core/error_codes.pb.h"
+#include "tensorflow/core/lib/core/error_codes.pb.h" // IWYU pragma: export
namespace perftools {
namespace gputools {
diff --git a/tensorflow/stream_executor/lib/ptr_util.h b/tensorflow/stream_executor/lib/ptr_util.h
index 578ed67ade..e42de83c16 100644
--- a/tensorflow/stream_executor/lib/ptr_util.h
+++ b/tensorflow/stream_executor/lib/ptr_util.h
@@ -60,4 +60,5 @@ typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) =
} // namespace gputools
} // namespace perftools
+
#endif // TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
diff --git a/tensorflow/stream_executor/lib/status.h b/tensorflow/stream_executor/lib/status.h
index 0ec243c38a..af67769253 100644
--- a/tensorflow/stream_executor/lib/status.h
+++ b/tensorflow/stream_executor/lib/status.h
@@ -13,18 +13,20 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
+
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STATUS_H_
#include "tensorflow/core/public/status.h"
-#include "tensorflow/stream_executor/lib/error.h"
+#include "tensorflow/stream_executor/lib/error.h" // IWYU pragma: export
#include "tensorflow/stream_executor/platform/logging.h"
namespace perftools {
namespace gputools {
namespace port {
-using tensorflow::Status;
+using Status = tensorflow::Status;
#define SE_CHECK_OK(val) \
CHECK_EQ(::perftools::gputools::port::Status::OK(), (val))
diff --git a/tensorflow/stream_executor/lib/statusor.h b/tensorflow/stream_executor/lib/statusor.h
index 6c0fd0f9cd..d9b7787e30 100644
--- a/tensorflow/stream_executor/lib/statusor.h
+++ b/tensorflow/stream_executor/lib/statusor.h
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
+//
// StatusOr<T> is the union of a Status object and a T
// object. StatusOr models the concept of an object that is either a
// usable value, or an error Status explaining why such a value is
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
index de41b6ac70..a25eca1e9a 100644
--- a/tensorflow/stream_executor/lib/strcat.h
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
+
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
diff --git a/tensorflow/stream_executor/platform.cc b/tensorflow/stream_executor/platform.cc
index 8ace26f578..dcdcb8f42b 100644
--- a/tensorflow/stream_executor/platform.cc
+++ b/tensorflow/stream_executor/platform.cc
@@ -31,8 +31,6 @@ string PlatformKindString(PlatformKind kind) {
return "CUDA";
case PlatformKind::kOpenCL:
return "OpenCL";
- case PlatformKind::kOpenCLAltera:
- return "OpenCL+Altera";
case PlatformKind::kHost:
return "Host";
case PlatformKind::kMock:
diff --git a/tensorflow/stream_executor/platform.h b/tensorflow/stream_executor/platform.h
index 338f9940e1..3f850ba327 100644
--- a/tensorflow/stream_executor/platform.h
+++ b/tensorflow/stream_executor/platform.h
@@ -42,9 +42,6 @@ enum class PlatformKind {
kInvalid,
kCuda,
kOpenCL,
- kOpenCLAltera, // Altera FPGA OpenCL platform.
- // See documentation: go/fpgaopencl
- // (StreamExecutor integration)
kHost,
kMock,
kSize,
diff --git a/tensorflow/stream_executor/platform/port.h b/tensorflow/stream_executor/platform/port.h
index 33792214f5..c1b770e486 100644
--- a/tensorflow/stream_executor/platform/port.h
+++ b/tensorflow/stream_executor/platform/port.h
@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// IWYU pragma: private, include "perftools/gputools/executor/stream_executor.h"
+
#ifndef TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
#define TENSORFLOW_STREAM_EXECUTOR_PLATFORM_PORT_H_
diff --git a/tensorflow/stream_executor/plugin_registry.cc b/tensorflow/stream_executor/plugin_registry.cc
index 83cf1cffaf..e05d42985b 100644
--- a/tensorflow/stream_executor/plugin_registry.cc
+++ b/tensorflow/stream_executor/plugin_registry.cc
@@ -211,7 +211,8 @@ bool PluginRegistry::HasFactory(Platform::Id platform_id,
if (plugin_id == kNullPlugin) { \
return port::Status{port::error::FAILED_PRECONDITION, \
"No suitable " PLUGIN_STRING \
- " plugin registered, default or otherwise."}; \
+ " plugin registered. Have you linked in a " \
+ PLUGIN_STRING "-providing plugin?"}; \
} else { \
VLOG(2) << "Selecting default " PLUGIN_STRING " plugin, " \
<< plugin_names_[plugin_id]; \
diff --git a/tensorflow/stream_executor/scratch_allocator.cc b/tensorflow/stream_executor/scratch_allocator.cc
new file mode 100644
index 0000000000..c216de3ab9
--- /dev/null
+++ b/tensorflow/stream_executor/scratch_allocator.cc
@@ -0,0 +1,42 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/stream_executor/scratch_allocator.h"
+
+#include "tensorflow/stream_executor/lib/status_macros.h"
+#include "tensorflow/stream_executor/stream.h"
+
+namespace perftools {
+namespace gputools {
+
+ScratchAllocator::~ScratchAllocator() {}
+
+OneTimeScratchAllocator::OneTimeScratchAllocator() {}
+OneTimeScratchAllocator::~OneTimeScratchAllocator() {}
+
+int64 OneTimeScratchAllocator::GetMemoryLimitInBytes(Stream* stream) {
+ return -1;
+}
+
+port::StatusOr<DeviceMemory<uint8>> OneTimeScratchAllocator::AllocateBytes(
+ Stream* stream, int64 byte_size) {
+ CHECK(temporary_ == nullptr);
+ SE_ASSIGN_OR_RETURN(temporary_,
+ stream->AllocateTemporaryArray<uint8>(byte_size));
+ return temporary_->device_memory();
+}
+
+} // namespace gputools
+} // namespace perftools
diff --git a/tensorflow/stream_executor/scratch_allocator.h b/tensorflow/stream_executor/scratch_allocator.h
new file mode 100644
index 0000000000..52697d6f8e
--- /dev/null
+++ b/tensorflow/stream_executor/scratch_allocator.h
@@ -0,0 +1,83 @@
+/* Copyright 2015 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_
+
+#include <memory>
+
+#include "tensorflow/stream_executor/device_memory.h"
+#include "tensorflow/stream_executor/lib/statusor.h"
+#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/temporary_device_memory.h"
+
+namespace perftools {
+namespace gputools {
+
+class Stream;
+
+// Interface that allows stream operations (e.g.
+// Stream::ThenConvolveWithScratch) to optionally request scratch space be
+// allocated in order to speed up the operation being enqueued.
+//
+// Note that the caller is responsible for deallocating the scratch space at a
+// known-safe point, when all scratch-memory-consuming kernels are known for
+// sure to have finished; e.g. at stream synchronization time. This is different
+// from a traditional C++ object allocator, where the client is responsible for
+// releasing. (Conceptually, scratch memory is a form of "temporary" device
+// memory allocation.)
+class ScratchAllocator {
+ public:
+ virtual ~ScratchAllocator();
+
+ // Returns a limit of memory this scratch allocator wants to produce, in
+ // bytes. This information may be used to help select an algorithm.
+ //
+ // Returns values < 0 to indicate that there is no recommended limit.
+ virtual int64 GetMemoryLimitInBytes(Stream* stream) = 0;
+
+ // Returns an allocation on byte_size bytes for use in an operation on stream.
+ //
+ // This is a temporary allocation, and the caller is responsible for
+ // deallocating at some known-safe point. See the class comment above.
+ virtual port::StatusOr<DeviceMemory<uint8>> AllocateBytes(
+ Stream* stream, int64 byte_size) = 0;
+};
+
+// Allocates a single temporary memory allocation -- this memory is deallocated
+// at the next stream synchronization point after this object has gone out of
+// scope. This satisfies the lifetime and deallocation properties given in the
+// class comment above.
+//
+// Thread-compatible, but not thread-safe (use in scenarios where only one
+// thread will request the scratch allocation).
+class OneTimeScratchAllocator : public ScratchAllocator {
+ public:
+ OneTimeScratchAllocator();
+ ~OneTimeScratchAllocator() override;
+ int64 GetMemoryLimitInBytes(Stream* stream) override;
+ port::StatusOr<DeviceMemory<uint8>> AllocateBytes(Stream* stream,
+ int64 byte_size) override;
+
+ private:
+ std::unique_ptr<TemporaryDeviceMemory<uint8>> temporary_;
+
+ SE_DISALLOW_COPY_AND_ASSIGN(OneTimeScratchAllocator);
+};
+
+} // namespace gputools
+} // namespace perftools
+
+#endif // TENSORFLOW_STREAM_EXECUTOR_SCRATCH_ALLOCATOR_H_
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 5d6971a5a5..587896a2ab 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/logging.h"
@@ -29,22 +30,6 @@ namespace perftools {
namespace gputools {
namespace {
-static internal::StreamInterface *CreateStreamImplementation(
- StreamExecutor *parent) {
- PlatformKind platform_kind = parent->platform_kind();
- if (platform_kind == PlatformKind::kCuda) {
- return (*internal::MakeCUDAStreamImplementation())(parent);
- } else if (platform_kind == PlatformKind::kOpenCL ||
- platform_kind == PlatformKind::kOpenCLAltera) {
- return (*internal::MakeOpenCLStreamImplementation())(parent);
- } else if (platform_kind == PlatformKind::kHost) {
- return internal::MakeHostStreamImplementation(parent);
- } else {
- LOG(FATAL) << "cannot create stream implementation for platform kind: "
- << PlatformKindString(platform_kind);
- }
-}
-
// Code to turn parameters to functions on stream into strings that
// will be VLOG'ed. We need overloads, instead of
// e.g. BatchDescriptorToVlogString(), as the code that calls these
@@ -77,6 +62,10 @@ string ToVlogString(dnn::ElementwiseOperation op) {
return dnn::ElementwiseOperationString(op);
}
+string ToVlogString(dnn::QuantizedActivationMode mode) {
+ return dnn::QuantizedActivationModeString(mode);
+}
+
string ToVlogString(blas::Transpose t) { return blas::TransposeString(t); }
string ToVlogString(blas::UpperLower ul) { return blas::UpperLowerString(ul); }
@@ -123,6 +112,8 @@ string ToVlogString(uint32 i) { return port::StrCat(i); }
string ToVlogString(uint64 i) { return port::StrCat(i); }
+string ToVlogString(int64 i) { return port::StrCat(i); }
+
string ToVlogString(float f) { return port::StrCat(f); }
string ToVlogString(double d) { return port::StrCat(d); }
@@ -181,6 +172,9 @@ string CallStr(const char *function_name, Stream *stream,
separator = ", ";
}
port::StrAppend(&str, ") stream=", ToVlogString(stream));
+ if (VLOG_IS_ON(10)) {
+ port::StrAppend(&str, " ", port::CurrentStackTrace(), "\n");
+ }
return str;
}
@@ -206,8 +200,8 @@ string CallStr(const char *function_name, Stream *stream,
} // namespace
Stream::Stream(StreamExecutor *parent)
- : implementation_(CreateStreamImplementation(parent)),
- parent_(parent),
+ : parent_(parent),
+ implementation_(parent->implementation()->GetStreamImplementation()),
allocated_(false),
ok_(false),
temporary_memory_manager_(this) {
@@ -216,8 +210,8 @@ Stream::Stream(StreamExecutor *parent)
Stream::Stream(StreamExecutor *parent,
internal::StreamInterface *implementation)
- : implementation_(implementation),
- parent_(parent),
+ : parent_(parent),
+ implementation_(implementation),
allocated_(false),
ok_(false),
temporary_memory_manager_(this) {
@@ -283,15 +277,15 @@ Stream &Stream::ThenRecordEvent(Event *event) {
return *this;
}
-Stream &Stream::ThenConvolve(
- const dnn::BatchDescriptor &batch_descriptor,
+Stream &Stream::ThenConvolveWithScratch(
+ const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
- const dnn::BatchDescriptor &output_descriptor,
- DeviceMemory<float> *output) {
- VLOG_CALL(PARAM(batch_descriptor), PARAM(input_data),
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<float> *output,
+ ScratchAllocator *scratch_allocator) {
+ VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(filter_descriptor), PARAM(filter_data),
PARAM(convolution_descriptor), PARAM(output_descriptor),
PARAM(output));
@@ -299,8 +293,9 @@ Stream &Stream::ThenConvolve(
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoConvolve(
- this, batch_descriptor, input_data, filter_descriptor, filter_data,
- convolution_descriptor, output_descriptor, output));
+ this, input_descriptor, input_data, filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor, output,
+ /*scratch_allocator=*/scratch_allocator));
} else {
SetError();
LOG(WARNING)
@@ -311,6 +306,20 @@ Stream &Stream::ThenConvolve(
return *this;
}
+Stream &Stream::ThenConvolve(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output) {
+ return ThenConvolveWithScratch(input_descriptor, input_data,
+ filter_descriptor, filter_data,
+ convolution_descriptor, output_descriptor,
+ output, /*scratch_allocator=*/nullptr);
+}
+
Stream &Stream::ThenSeparableConvolve(
const dnn::BatchDescriptor &batch_descriptor,
const DeviceMemory<float> &input_data,
@@ -341,14 +350,15 @@ Stream &Stream::ThenSeparableConvolve(
return *this;
}
-Stream &Stream::ThenConvolveBackwardData(
+Stream &Stream::ThenConvolveBackwardDataWithScratch(
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::BatchDescriptor &input_descriptor,
- DeviceMemory<float> *backward_input_data) {
+ DeviceMemory<float> *backward_input_data,
+ ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(filter_descriptor), PARAM(filter_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(input_descriptor),
@@ -359,7 +369,7 @@ Stream &Stream::ThenConvolveBackwardData(
CheckError(dnn->DoConvolveBackwardData(
this, filter_descriptor, filter_data, output_descriptor,
backward_output_data, convolution_descriptor, input_descriptor,
- backward_input_data));
+ backward_input_data, scratch_allocator));
} else {
SetError();
LOG(WARNING)
@@ -370,14 +380,29 @@ Stream &Stream::ThenConvolveBackwardData(
return *this;
}
-Stream &Stream::ThenConvolveBackwardFilter(
+Stream &Stream::ThenConvolveBackwardData(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data) {
+ return ThenConvolveBackwardDataWithScratch(
+ filter_descriptor, filter_data, output_descriptor, backward_output_data,
+ convolution_descriptor, input_descriptor, backward_input_data,
+ /*scratch_allocator=*/nullptr);
+}
+
+Stream &Stream::ThenConvolveBackwardFilterWithScratch(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> backward_output_data,
const dnn::ConvolutionDescriptor &convolution_descriptor,
const dnn::FilterDescriptor &filter_descriptor,
- DeviceMemory<float> *backward_filter_data) {
+ DeviceMemory<float> *backward_filter_data,
+ ScratchAllocator *scratch_allocator) {
VLOG_CALL(PARAM(input_descriptor), PARAM(input_data),
PARAM(output_descriptor), PARAM(backward_output_data),
PARAM(convolution_descriptor), PARAM(filter_descriptor),
@@ -388,7 +413,7 @@ Stream &Stream::ThenConvolveBackwardFilter(
CheckError(dnn->DoConvolveBackwardFilter(
this, input_descriptor, input_data, output_descriptor,
backward_output_data, convolution_descriptor, filter_descriptor,
- backward_filter_data));
+ backward_filter_data, scratch_allocator));
} else {
SetError();
LOG(WARNING)
@@ -399,6 +424,20 @@ Stream &Stream::ThenConvolveBackwardFilter(
return *this;
}
+Stream &Stream::ThenConvolveBackwardFilter(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data) {
+ return ThenConvolveBackwardFilterWithScratch(
+ input_descriptor, input_data, output_descriptor, backward_output_data,
+ convolution_descriptor, filter_descriptor, backward_filter_data,
+ /*scratch_allocator=*/nullptr);
+}
+
Stream &Stream::ThenMatMul(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &weights,
const dnn::BatchDescriptor &input_dimensions,
@@ -589,6 +628,19 @@ Stream &Stream::ThenDepthConcatenate(
DeviceMemory<float> *output_data) {
VLOG_CALL(PARAM(input_dimensions), PARAM(input_data), PARAM(output_data));
+ for (size_t i = 1; i < input_dimensions.size(); ++i) {
+ if (input_dimensions[i].count() != input_dimensions[0].count() ||
+ input_dimensions[i].height() != input_dimensions[0].height() ||
+ input_dimensions[i].width() != input_dimensions[0].width()) {
+ SetError();
+ LOG(ERROR) << "Incompatible dimensions for depth concatenation.\n"
+ << "input_dimensions[0]: " << input_dimensions[0].ToString()
+ << "input_dimensions[" << i
+ << "]: " << input_dimensions[i].ToString();
+ return *this;
+ }
+ }
+
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
CheckError(dnn->DoDepthConcatenate(this, input_dimensions, input_data,
@@ -627,15 +679,18 @@ Stream &Stream::ThenElementwiseOperate(
return *this;
}
-Stream &Stream::ThenMemcpyD2HQuantized(
- const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<uint8> host_dst) {
- VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst));
+Stream &Stream::ThenXYPad(const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data, int64 left_pad,
+ int64 right_pad, int64 top_pad, int64 bottom_pad,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_pad),
+ PARAM(right_pad), PARAM(top_pad), PARAM(bottom_pad),
+ PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(
- dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst));
+ CheckError(dnn->DoXYPad(this, dimensions, input_data, left_pad, right_pad,
+ top_pad, bottom_pad, output_data));
} else {
SetError();
LOG(WARNING)
@@ -646,15 +701,20 @@ Stream &Stream::ThenMemcpyD2HQuantized(
return *this;
}
-Stream &Stream::ThenMemcpyD2HQuantized(
- const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<uint16> host_dst) {
- VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst));
+Stream &Stream::ThenXYSlice(const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data,
+ int64 left_trim, int64 right_trim, int64 top_trim,
+ int64 bottom_trim,
+ DeviceMemory<float> *output_data) {
+ VLOG_CALL(PARAM(dimensions), PARAM(input_data), PARAM(left_trim),
+ PARAM(right_trim), PARAM(top_trim), PARAM(bottom_trim),
+ PARAM(output_data));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(
- dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst));
+ CheckError(dnn->DoXYSlice(this, dimensions, input_data, left_trim,
+ right_trim, top_trim, bottom_trim,
+ output_data));
} else {
SetError();
LOG(WARNING)
@@ -667,13 +727,14 @@ Stream &Stream::ThenMemcpyD2HQuantized(
Stream &Stream::ThenMemcpyD2HQuantized(
const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<int32> host_dst) {
- VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(host_dst));
+ dnn::QuantizedActivationMode mode, void *host_dst, uint64 size) {
+ VLOG_CALL(PARAM(gpu_unquantized_src), PARAM(mode), PARAM(host_dst),
+ PARAM(size));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(
- dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, host_dst));
+ CheckError(dnn->DoMemcpyD2HQuantized(this, gpu_unquantized_src, mode,
+ host_dst, size));
} else {
SetError();
LOG(WARNING)
@@ -685,14 +746,15 @@ Stream &Stream::ThenMemcpyD2HQuantized(
}
Stream &Stream::ThenMemcpyH2DQuantized(
- port::ArraySlice<uint8> host_src,
+ const void *host_src, uint64 size, dnn::QuantizedActivationMode mode,
DeviceMemory<float> *gpu_unquantized_dst) {
- VLOG_CALL(PARAM(host_src), PARAM(gpu_unquantized_dst));
+ VLOG_CALL(PARAM(host_src), PARAM(size), PARAM(mode),
+ PARAM(gpu_unquantized_dst));
if (ok()) {
if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
- CheckError(
- dnn->DoMemcpyH2DQuantized(this, host_src, gpu_unquantized_dst));
+ CheckError(dnn->DoMemcpyH2DQuantized(this, host_src, size, mode,
+ gpu_unquantized_dst));
} else {
SetError();
LOG(WARNING)
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index 3b9c2af1b0..d91c62ca26 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -69,6 +69,11 @@ struct ConvolutionDescriptor;
} // namespace dnn
class StreamExecutor;
+class ScratchAllocator;
+
+// Convert a type to the corresponding QuantizedActivationMode.
+template <typename ElementType>
+struct Quantization;
// Represents a stream of dependent computations on a GPU device.
//
@@ -214,6 +219,15 @@ class Stream {
const dnn::BatchDescriptor &output_descriptor,
DeviceMemory<float> *output);
+ Stream &ThenConvolveWithScratch(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> *output, ScratchAllocator *scratch_allocator);
+
Stream &ThenSeparableConvolve(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
@@ -233,6 +247,16 @@ class Stream {
const dnn::BatchDescriptor &input_descriptor,
DeviceMemory<float> *backward_input_data);
+ Stream &ThenConvolveBackwardDataWithScratch(
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<float> &filter_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::BatchDescriptor &input_descriptor,
+ DeviceMemory<float> *backward_input_data,
+ ScratchAllocator *scratch_allocator);
+
Stream &ThenConvolveBackwardFilter(
const dnn::BatchDescriptor &input_descriptor,
const DeviceMemory<float> &input_data,
@@ -242,6 +266,16 @@ class Stream {
const dnn::FilterDescriptor &filter_descriptor,
DeviceMemory<float> *backward_filter_data);
+ Stream &ThenConvolveBackwardFilterWithScratch(
+ const dnn::BatchDescriptor &input_descriptor,
+ const DeviceMemory<float> &input_data,
+ const dnn::BatchDescriptor &output_descriptor,
+ DeviceMemory<float> backward_output_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const dnn::FilterDescriptor &filter_descriptor,
+ DeviceMemory<float> *backward_filter_data,
+ ScratchAllocator *scratch_allocator);
+
Stream &ThenMatMul(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &weights,
const dnn::BatchDescriptor &input_dimensions,
@@ -249,18 +283,18 @@ class Stream {
DeviceMemory<float> *output_data);
Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
- const DeviceMemory<int8> &weights,
- const DeviceMemory<float> &weight_scales,
- const dnn::BatchDescriptor &input_dimensions,
- const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<float> *output_data);
+ const DeviceMemory<int8> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
Stream &ThenMatMulQuantized(const DeviceMemory<float> &input_data,
- const DeviceMemory<int16> &weights,
- const DeviceMemory<float> &weight_scales,
- const dnn::BatchDescriptor &input_dimensions,
- const dnn::BatchDescriptor &output_dimensions,
- DeviceMemory<float> *output_data);
+ const DeviceMemory<int16> &weights,
+ const DeviceMemory<float> &weight_scales,
+ const dnn::BatchDescriptor &input_dimensions,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<float> *output_data);
Stream &ThenBiasAdd(const DeviceMemory<float> &input_data,
const DeviceMemory<float> &biases,
@@ -302,24 +336,49 @@ class Stream {
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data);
- // See DnnSupport::DoMemcpyD2HQuantized.
- // TODO(wgulland) Use a template to merge the versions of
- // ThenMemcpyD2HQuantized.
- Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<uint8> host_dst);
+ Stream &ThenXYPad(const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data, int64 left_pad,
+ int64 right_pad, int64 top_pad, int64 bottom_pad,
+ DeviceMemory<float> *output_data);
- // See DnnSupport::DoMemcpyD2HQuantized.
- Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<uint16> host_dst);
+ Stream &ThenXYSlice(const dnn::BatchDescriptor &dimensions,
+ const DeviceMemory<float> &input_data, int64 left_trim,
+ int64 right_trim, int64 top_trim, int64 bottom_trim,
+ DeviceMemory<float> *output_data);
// See DnnSupport::DoMemcpyD2HQuantized.
Stream &ThenMemcpyD2HQuantized(const DeviceMemory<float> &gpu_unquantized_src,
- port::MutableArraySlice<int32> host_dst);
+ dnn::QuantizedActivationMode mode,
+ void *host_dst, uint64 size);
+
+ // Template version of ThenMemcpyD2HQuantized that takes a MutableArraySlice
+ // and uses the Quantization trait to call the generic version of
+ // ThenMemcpyD2HQuantized with the correct QuantizedActivationMode.
+ template <typename ElementType>
+ Stream &ThenMemcpyD2HQuantized(
+ const DeviceMemory<float> &gpu_unquantized_src,
+ port::MutableArraySlice<ElementType> host_dst) {
+ return ThenMemcpyD2HQuantized(
+ gpu_unquantized_src, Quantization<ElementType>::kModeId,
+ host_dst.data(), host_dst.size() * sizeof(ElementType));
+ }
// See DnnSupport::DoMemcpyH2DQuantized.
- Stream &ThenMemcpyH2DQuantized(port::ArraySlice<uint8> host_src,
+ Stream &ThenMemcpyH2DQuantized(const void *host_src, uint64 size,
+ dnn::QuantizedActivationMode mode,
DeviceMemory<float> *gpu_unquantized_dst);
+ // Template version of ThenMemcpyH2DQuantized that takes an ArraySlice
+ // and uses the Quantization trait to call the generic version of
+ // ThenMemcpyH2DQuantized with the correct QuantizedActivationMode.
+ template <typename ElementType>
+ Stream &ThenMemcpyH2DQuantized(port::ArraySlice<ElementType> host_src,
+ DeviceMemory<float> *gpu_unquantized_dst) {
+ return ThenMemcpyH2DQuantized(
+ host_src.data(), host_src.size() * sizeof(ElementType),
+ Quantization<ElementType>::kModeId, gpu_unquantized_dst);
+ }
+
/////////////////
// BLAS support
@@ -1143,9 +1202,11 @@ class Stream {
Stream &ThenMemset32(DeviceMemoryBase *location, const uint32 &pattern,
uint64 size);
- // (Synchronously) block the host code waiting for the operations entrained
- // on
- // the stream (enqueued to this point in program execution) to complete.
+ // (Synchronously) block the host code waiting for the operations
+ // entrained on the stream (enqueued to this point in program
+ // execution) to complete.
+ //
+ // Returns true if the stream is ok().
bool BlockHostUntilDone();
// Warning! This method interacts with internal threads in
@@ -1195,9 +1256,9 @@ class Stream {
internal::TemporaryMemoryManager *temporary_memory_manager();
private:
- friend class host::HostBlas; // for parent_.
- friend class host::HostFft; // for parent_.
- friend class host::HostRng; // for parent_.
+ friend class host::HostBlas; // for parent_.
+ friend class host::HostFft; // for parent_.
+ friend class host::HostRng; // for parent_.
template <typename... Args>
friend struct ThenBlasImpl; // for implementing ThenBlasXXX.
friend class ocl::CLBlas; // for parent_.
@@ -1219,13 +1280,13 @@ class Stream {
void SetError() { CheckError(false /* = operation_retcode */); }
+ // The StreamExecutor that supports the operation of this stream.
+ StreamExecutor *parent_;
+
// The platform-dependent implementation that the StreamExecutor interface
// delegates to.
std::unique_ptr<internal::StreamInterface> implementation_;
- // The StreamExecutor that supports the operation of this stream.
- StreamExecutor *parent_;
-
// mutex that guards the allocation / error state flags.
// Mutable so that it can be obtained via const reader lock.
mutable mutex mu_;
@@ -1267,6 +1328,24 @@ inline internal::TemporaryMemoryManager *Stream::temporary_memory_manager() {
return &temporary_memory_manager_;
}
+template <>
+struct Quantization<uint8> {
+ static constexpr dnn::QuantizedActivationMode kModeId =
+ dnn::QuantizedActivationMode::k8Bit;
+};
+
+template <>
+struct Quantization<uint16> {
+ static constexpr dnn::QuantizedActivationMode kModeId =
+ dnn::QuantizedActivationMode::k16Bit;
+};
+
+template <>
+struct Quantization<int32> {
+ static constexpr dnn::QuantizedActivationMode kModeId =
+ dnn::QuantizedActivationMode::k32Bit;
+};
+
} // namespace gputools
} // namespace perftools
diff --git a/tensorflow/stream_executor/stream_executor.h b/tensorflow/stream_executor/stream_executor.h
index b668f2ad0d..0eb113c056 100644
--- a/tensorflow/stream_executor/stream_executor.h
+++ b/tensorflow/stream_executor/stream_executor.h
@@ -57,9 +57,10 @@ limitations under the License.
#include "tensorflow/stream_executor/kernel.h" // IWYU pragma: export
#include "tensorflow/stream_executor/kernel_spec.h" // IWYU pragma: export
#include "tensorflow/stream_executor/launch_dim.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/multi_platform_manager.h" // IWYU pragma: export
#include "tensorflow/stream_executor/platform.h" // IWYU pragma: export
#include "tensorflow/stream_executor/stream.h" // IWYU pragma: export
#include "tensorflow/stream_executor/stream_executor_pimpl.h" // IWYU pragma: export
-#include "tensorflow/stream_executor/timer.h" // IWYU pragma: export
+#include "tensorflow/stream_executor/timer.h" // IWYU pragma: export
#endif // TENSORFLOW_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
diff --git a/tensorflow/stream_executor/stream_executor_internal.cc b/tensorflow/stream_executor/stream_executor_internal.cc
index b3bb818f5f..ddc9ae0441 100644
--- a/tensorflow/stream_executor/stream_executor_internal.cc
+++ b/tensorflow/stream_executor/stream_executor_internal.cc
@@ -28,22 +28,6 @@ StreamExecutorFactory* MakeCUDAExecutorImplementation() {
static StreamExecutorFactory instance;
return &instance;
}
-EventFactory* MakeCUDAEventImplementation() {
- static EventFactory instance;
- return &instance;
-}
-StreamFactory* MakeCUDAStreamImplementation() {
- static StreamFactory instance;
- return &instance;
-}
-TimerFactory* MakeCUDATimerImplementation() {
- static TimerFactory instance;
- return &instance;
-}
-KernelFactory* MakeCUDAKernelImplementation() {
- static KernelFactory instance;
- return &instance;
-}
// -- OpenCL
@@ -51,28 +35,10 @@ StreamExecutorFactory* MakeOpenCLExecutorImplementation() {
static StreamExecutorFactory instance;
return &instance;
}
-StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation() {
- static StreamExecutorFactory instance;
- return &instance;
-}
-StreamFactory* MakeOpenCLStreamImplementation() {
- static StreamFactory instance;
- return &instance;
-}
-TimerFactory* MakeOpenCLTimerImplementation() {
- static TimerFactory instance;
- return &instance;
-}
-KernelFactory* MakeOpenCLKernelImplementation() {
- static KernelFactory instance;
- return &instance;
-}
// -- Host
StreamExecutorFactory MakeHostExecutorImplementation;
-StreamFactory MakeHostStreamImplementation;
-TimerFactory MakeHostTimerImplementation;
} // namespace internal
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index 955af9127f..dff756c8fc 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -71,6 +71,94 @@ namespace perftools {
namespace gputools {
namespace internal {
+// Platform-dependent interface class for the generic Events interface, in
+// the PIMPL style.
+class EventInterface {
+ public:
+ EventInterface() {}
+ virtual ~EventInterface() {}
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
+};
+
+// Pointer-to-implementation object type (i.e. the KernelBase class delegates to
+// this interface) with virtual destruction. This class exists for the
+// platform-dependent code to hang any kernel data/resource info/functionality
+// off of.
+class KernelInterface {
+ public:
+ // Default constructor for the abstract interface.
+ KernelInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~KernelInterface() {}
+
+ // Returns the number of formal parameters that this kernel accepts.
+ virtual unsigned Arity() const = 0;
+
+ // Sets the preferred cache configuration.
+ virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
+
+ // Gets the preferred cache configuration.
+ virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
+};
+
+// Pointer-to-implementation object type (i.e. the Stream class delegates to
+// this interface) with virtual destruction. This class exists for the
+// platform-dependent code to hang any kernel data/resource info/functionality
+// off of.
+class StreamInterface {
+ public:
+ // Default constructor for the abstract interface.
+ StreamInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~StreamInterface() {}
+
+ // Returns the CUDA stream associated with this platform's stream
+ // implementation.
+ //
+ // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
+ // fatal error if it is not. This hack is made available solely for use from
+ // distbelief code, which temporarily has strong ties to CUDA as a platform.
+ virtual void *CudaStreamHack() { return nullptr; }
+
+ // See the above comment on CudaStreamHack -- this further breaks abstraction
+ // for Eigen within distbelief, which has strong ties to CUDA as a platform,
+ // and a historical attachment to a programming model which takes a
+ // stream-slot rather than a stream-value.
+ virtual void **CudaStreamMemberHack() { return nullptr; }
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
+};
+
+// Pointer-to-implementation object type (i.e. the Timer class delegates to
+// this interface) with virtual destruction. This class exists for the
+// platform-dependent code to hang any timer data/resource info/functionality
+// off of.
+class TimerInterface {
+ public:
+ // Default constructor for the abstract interface.
+ TimerInterface() {}
+
+ // Default destructor for the abstract interface.
+ virtual ~TimerInterface() {}
+
+ // Returns the number of microseconds elapsed in a completed timer.
+ virtual uint64 Microseconds() const = 0;
+
+ // Returns the number of nanoseconds elapsed in a completed timer.
+ virtual uint64 Nanoseconds() const = 0;
+
+ private:
+ SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
+};
+
// Interface for the different StreamExecutor platforms (i.e. CUDA, OpenCL).
//
// Various platforms will provide an implementation that satisfy this interface.
@@ -89,6 +177,7 @@ class StreamExecutorInterface {
// See the StreamExecutor interface for comments on the same-named methods.
virtual port::Status Init(int device_ordinal,
DeviceOptions device_options) = 0;
+
virtual bool GetKernel(const MultiKernelLoaderSpec &spec,
KernelBase *kernel) {
return false;
@@ -233,9 +322,13 @@ class StreamExecutorInterface {
// initialization fails.
virtual dnn::DnnSupport *CreateDnn() { return nullptr; }
- // Please read the warning below. This method is only temporary. See
- // http://b/15759750
- //
+ // Each call creates a new instance of the platform-specific implementation of
+ // the corresponding interface type.
+ virtual std::unique_ptr<EventInterface> CreateEventImplementation() = 0;
+ virtual std::unique_ptr<KernelInterface> CreateKernelImplementation() = 0;
+ virtual std::unique_ptr<StreamInterface> GetStreamImplementation() = 0;
+ virtual std::unique_ptr<TimerInterface> GetTimerImplementation() = 0;
+
// Returns the CUDA context associated with this StreamExecutor platform
// implementation.
//
@@ -248,106 +341,6 @@ class StreamExecutorInterface {
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutorInterface);
};
-// Pointer-to-implementation object type (i.e. the KernelBase class delegates to
-// this interface) with virtual destruction. This class exists for the
-// platform-dependent code to hang any kernel data/resource info/functionality
-// off of.
-class KernelInterface {
- public:
- // Default constructor for the abstract interface.
- KernelInterface() {}
-
- // Default destructor for the abstract interface.
- virtual ~KernelInterface() {}
-
- // Returns the number of formal parameters that this kernel accepts.
- virtual unsigned Arity() const = 0;
-
- // Sets the preferred cache configuration.
- virtual void SetPreferredCacheConfig(KernelCacheConfig config) = 0;
-
- // Gets the preferred cache configuration.
- virtual KernelCacheConfig GetPreferredCacheConfig() const = 0;
-
- private:
- SE_DISALLOW_COPY_AND_ASSIGN(KernelInterface);
-};
-
-// Platform-dependent interface class for the generic Events interface, in
-// the PIMPL style.
-class EventInterface {
- public:
- EventInterface() {}
- virtual ~EventInterface() {}
-
- private:
- SE_DISALLOW_COPY_AND_ASSIGN(EventInterface);
-};
-
-// Pointer-to-implementation object type (i.e. the Stream class delegates to
-// this interface) with virtual destruction. This class exists for the
-// platform-dependent code to hang any kernel data/resource info/functionality
-// off of.
-class StreamInterface {
- public:
- // Default constructor for the abstract interface.
- StreamInterface() {}
-
- // Default destructor for the abstract interface.
- virtual ~StreamInterface() {}
-
- // Please read the warning below. This method is only temporary. See
- // http://b/15759750
- //
- // Returns the CUDA stream associated with this platform's stream
- // implementation.
- //
- // WARNING: checks that the underlying platform is, in fact, CUDA, causing a
- // fatal error if it is not. This hack is made available solely for use from
- // distbelief code, which temporarily has strong ties to CUDA as a platform.
- virtual void *CudaStreamHack() { return nullptr; }
-
- // Please read the warning above. This method is only temporary. See
- // http://b/15759750
- //
- // See the above comment on CudaStreamHack -- this further breaks abstraction
- // for Eigen within distbelief, which has strong ties to CUDA as a platform,
- // and a historical attachment to a programming model which takes a
- // stream-slot rather than a stream-value.
- virtual void **CudaStreamMemberHack() { return nullptr; }
-
- private:
- SE_DISALLOW_COPY_AND_ASSIGN(StreamInterface);
-};
-
-// Pointer-to-implementation object type (i.e. the Timer class delegates to
-// this interface) with virtual destruction. This class exists for the
-// platform-dependent code to hang any timer data/resource info/functionality
-// off of.
-class TimerInterface {
- public:
- // Default constructor for the abstract interface.
- TimerInterface() {}
-
- // Default destructor for the abstract interface.
- virtual ~TimerInterface() {}
-
- // Returns the number of microseconds elapsed in a completed timer.
- virtual uint64 Microseconds() const = 0;
-
- // Returns the number of nanoseconds elapsed in a completed timer.
- virtual uint64 Nanoseconds() const = 0;
-
- private:
- SE_DISALLOW_COPY_AND_ASSIGN(TimerInterface);
-};
-
-// Extern functions for constructing platform-specific instances that conform to
-// the StreamExecutor interface. (Defining constructor functions extern in this
-// way prevents CUDA/OpenCL headers from leaking into any shared header files.)
-//
-// TODO(leary) switch this all over to registries.
-
using StreamExecutorFactory =
std::function<StreamExecutorInterface *(const PluginConfig &)>;
using EventFactory = std::function<EventInterface *(StreamExecutor *)>;
@@ -355,21 +348,11 @@ using StreamFactory = std::function<StreamInterface *(StreamExecutor *)>;
using TimerFactory = std::function<TimerInterface *(StreamExecutor *)>;
using KernelFactory = std::function<KernelInterface*()>;
-EventFactory* MakeCUDAEventImplementation();
StreamExecutorFactory* MakeCUDAExecutorImplementation();
-StreamFactory* MakeCUDAStreamImplementation();
-TimerFactory* MakeCUDATimerImplementation();
-KernelFactory* MakeCUDAKernelImplementation();
StreamExecutorFactory* MakeOpenCLExecutorImplementation();
-StreamExecutorFactory* MakeOpenCLAlteraExecutorImplementation();
-StreamFactory* MakeOpenCLStreamImplementation();
-TimerFactory* MakeOpenCLTimerImplementation();
-KernelFactory* MakeOpenCLKernelImplementation();
extern StreamExecutorFactory MakeHostExecutorImplementation;
-extern StreamFactory MakeHostStreamImplementation;
-extern TimerFactory MakeHostTimerImplementation;
} // namespace internal
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index e496deaf9d..acaa0efcb2 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/notification.h"
+#include "tensorflow/stream_executor/lib/stacktrace.h"
+#include "tensorflow/stream_executor/lib/str_util.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -40,6 +42,14 @@ namespace perftools {
namespace gputools {
namespace {
+string StackTraceIfVLOG10() {
+ if (VLOG_IS_ON(10)) {
+ return port::StrCat(" ", port::CurrentStackTrace(), "\n");
+ } else {
+ return "";
+ }
+}
+
// Maximum stack depth to report when generating backtrace on mem allocation
// (for GPU memory leak checker)
static const int kMaxStackDepth = 256;
@@ -66,9 +76,6 @@ internal::StreamExecutorInterface *StreamExecutorImplementationFromPlatformKind(
case PlatformKind::kOpenCL:
factory = *internal::MakeOpenCLExecutorImplementation();
break;
- case PlatformKind::kOpenCLAltera:
- factory = *internal::MakeOpenCLAlteraExecutorImplementation();
- break;
case PlatformKind::kHost:
factory = internal::MakeHostExecutorImplementation;
break;
@@ -148,7 +155,8 @@ MakeScopedTracer(StreamExecutor *stream_exec, BeginCallT begin_call,
StreamExecutor::StreamExecutor(PlatformKind platform_kind,
const PluginConfig &plugin_config)
- : implementation_(StreamExecutorImplementationFromPlatformKind(
+ : platform_(nullptr),
+ implementation_(StreamExecutorImplementationFromPlatformKind(
platform_kind, plugin_config)),
platform_kind_(platform_kind),
device_ordinal_(-1),
@@ -160,16 +168,21 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
}
StreamExecutor::StreamExecutor(
- PlatformKind platform_kind,
- internal::StreamExecutorInterface *implementation)
- : implementation_(implementation),
- platform_kind_(platform_kind),
+ const Platform *platform, internal::StreamExecutorInterface *implementation)
+ : platform_(platform),
+ implementation_(implementation),
device_ordinal_(-1),
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
live_stream_count_(0),
tracing_enabled_(false) {
- CheckPlatformKindIsValid(platform_kind);
+ if (port::Lowercase(platform_->Name()) == "cuda") {
+ platform_kind_ = PlatformKind::kCuda;
+ } else if (port::Lowercase(platform_->Name()) == "opencl") {
+ platform_kind_ = PlatformKind::kOpenCL;
+ } else if (port::Lowercase(platform_->Name()) == "host") {
+ platform_kind_ = PlatformKind::kHost;
+ }
}
StreamExecutor::~StreamExecutor() {
@@ -208,7 +221,7 @@ bool StreamExecutor::GetKernel(const MultiKernelLoaderSpec &spec,
void StreamExecutor::Deallocate(DeviceMemoryBase *mem) {
VLOG(1) << "Called StreamExecutor::Deallocate(mem=" << mem->opaque()
- << ") mem->size()=" << mem->size();
+ << ") mem->size()=" << mem->size() << StackTraceIfVLOG10();
if (mem->opaque() != nullptr) {
EraseAllocRecord(mem->opaque());
@@ -333,8 +346,8 @@ bool StreamExecutor::BlockHostUntilDone(Stream *stream) {
void *StreamExecutor::Allocate(uint64 size) {
void *buf = implementation_->Allocate(size);
- VLOG(1) << "Called StreamExecutor::Allocate(size=" << size
- << ") returns " << buf;
+ VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
+ << buf << StackTraceIfVLOG10();
CreateAllocRecord(buf, size);
return buf;
@@ -348,20 +361,20 @@ bool StreamExecutor::GetSymbol(const string &symbol_name, void **mem,
void *StreamExecutor::HostMemoryAllocate(uint64 size) {
void *buffer = implementation_->HostMemoryAllocate(size);
VLOG(1) << "Called StreamExecutor::HostMemoryAllocate(size=" << size
- << ") returns " << buffer;
+ << ") returns " << buffer << StackTraceIfVLOG10();
return buffer;
}
void StreamExecutor::HostMemoryDeallocate(void *location) {
- VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location="
- << location << ")";
+ VLOG(1) << "Called StreamExecutor::HostMemoryDeallocate(location=" << location
+ << ")" << StackTraceIfVLOG10();
return implementation_->HostMemoryDeallocate(location);
}
bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
VLOG(1) << "Called StreamExecutor::HostMemoryRegister(location=" << location
- << ", size=" << size << ")";
+ << ", size=" << size << ")" << StackTraceIfVLOG10();
if (location == nullptr || size == 0) {
LOG(WARNING) << "attempting to register null or zero-sized memory: "
<< location << "; size " << size;
@@ -371,12 +384,13 @@ bool StreamExecutor::HostMemoryRegister(void *location, uint64 size) {
bool StreamExecutor::HostMemoryUnregister(void *location) {
VLOG(1) << "Called StreamExecutor::HostMemoryUnregister(location=" << location
- << ")";
+ << ")" << StackTraceIfVLOG10();
return implementation_->HostMemoryUnregister(location);
}
bool StreamExecutor::SynchronizeAllActivity() {
- VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()";
+ VLOG(1) << "Called StreamExecutor::SynchronizeAllActivity()"
+ << StackTraceIfVLOG10();
bool ok = implementation_->SynchronizeAllActivity();
// This should all be quick and infallible work, so we can perform the
@@ -388,16 +402,17 @@ bool StreamExecutor::SynchronizeAllActivity() {
bool StreamExecutor::SynchronousMemZero(DeviceMemoryBase *location,
uint64 size) {
- VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location="
- << location << ", size=" << size << ")";
+ VLOG(1) << "Called StreamExecutor::SynchronousMemZero(location=" << location
+ << ", size=" << size << ")" << StackTraceIfVLOG10();
return implementation_->SynchronousMemZero(location, size);
}
bool StreamExecutor::SynchronousMemSet(DeviceMemoryBase *location, int value,
uint64 size) {
- VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location="
- << location << ", value=" << value << ", size=" << size << ")";
+ VLOG(1) << "Called StreamExecutor::SynchronousMemSet(location=" << location
+ << ", value=" << value << ", size=" << size << ")"
+ << StackTraceIfVLOG10();
return implementation_->SynchronousMemSet(location, value, size);
}
@@ -406,7 +421,7 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
const void *host_src, uint64 size) {
VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(gpu_dst="
<< gpu_dst->opaque() << ", host_src=" << host_src << ", size=" << size
- << ") H2D";
+ << ") H2D" << StackTraceIfVLOG10();
// Tracing overloaded methods is very difficult due to issues with type
// inference on template args. Since use of these overloaded methods is
@@ -417,9 +432,9 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
bool StreamExecutor::SynchronousMemcpy(void *host_dst,
const DeviceMemoryBase &gpu_src,
uint64 size) {
- VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst="
- << host_dst << ", gpu_src=" << gpu_src.opaque() << ", size=" << size
- << ") D2H";
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(host_dst=" << host_dst
+ << ", gpu_src=" << gpu_src.opaque() << ", size=" << size << ") D2H"
+ << StackTraceIfVLOG10();
return implementation_->SynchronousMemcpy(host_dst, gpu_src, size);
}
@@ -428,8 +443,8 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
const DeviceMemoryBase &gpu_src,
uint64 size) {
VLOG(1) << "Called StreamExecutor::SynchronousMemcpy(gpu_dst="
- << gpu_dst->opaque() << ", gpu_src=" << gpu_src.opaque() << ", size=" << size
- << ") D2D";
+ << gpu_dst->opaque() << ", gpu_src=" << gpu_src.opaque()
+ << ", size=" << size << ") D2D" << StackTraceIfVLOG10();
return implementation_->SynchronousMemcpyDeviceToDevice(gpu_dst, gpu_src,
size);
@@ -438,7 +453,8 @@ bool StreamExecutor::SynchronousMemcpy(DeviceMemoryBase *gpu_dst,
port::Status StreamExecutor::SynchronousMemcpyD2H(
const DeviceMemoryBase &gpu_src, int64 size, void *host_dst) {
VLOG(1) << "Called StreamExecutor::SynchronousMemcpyD2H(gpu_src="
- << gpu_src.opaque() << ", size=" << size << ", host_dst=" << host_dst << ")";
+ << gpu_src.opaque() << ", size=" << size << ", host_dst=" << host_dst
+ << ")" << StackTraceIfVLOG10();
port::Status result{port::Status::OK()};
SCOPED_TRACE(TraceListener::SynchronousMemcpyD2H,
@@ -459,8 +475,9 @@ port::Status StreamExecutor::SynchronousMemcpyD2H(
port::Status StreamExecutor::SynchronousMemcpyH2D(const void *host_src,
int64 size,
DeviceMemoryBase *gpu_dst) {
- VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src="
- << host_src << ", size=" << size << ", gpu_dst" << gpu_dst->opaque() << ")";
+ VLOG(1) << "Called StreamExecutor::SynchronousMemcpyH2D(host_src=" << host_src
+ << ", size=" << size << ", gpu_dst" << gpu_dst->opaque() << ")"
+ << StackTraceIfVLOG10();
port::Status result{port::Status::OK()};
SCOPED_TRACE(TraceListener::SynchronousMemcpyH2D,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 5a59a5a26d..f624e0fcdb 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -71,9 +71,7 @@ class StreamExecutor {
public:
explicit StreamExecutor(PlatformKind kind,
const PluginConfig &plugin_config = PluginConfig());
-
- // Primarily used for testing.
- StreamExecutor(PlatformKind kind,
+ StreamExecutor(const Platform *platform,
internal::StreamExecutorInterface *implementation);
~StreamExecutor();
@@ -81,9 +79,13 @@ class StreamExecutor {
port::Status Init();
port::Status Init(int device_ordinal, DeviceOptions device_options);
+ // DEPRECATED: Do not use; use platform() instead.
// Returns the platform that this StreamExecutor is acting upon.
PlatformKind platform_kind() const { return platform_kind_; }
+ // Returns a reference to the platform that created this executor.
+ const Platform *platform() const { return platform_; }
+
// Retrieves (loads) a kernel for the platform this StreamExecutor is acting
// upon, if one exists.
//
@@ -538,15 +540,18 @@ class StreamExecutor {
// can acquire the lock on their first (mutating) call as well.
mutable mutex mu_;
- // A mapping of pointer (to GPU memory) to string representation of the stack
- // (of the allocating thread) at the time at which the pointer was allocated.
- std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
+ // Reference to the platform that created this executor.
+ const Platform *platform_;
// Pointer to the platform-specific-interface implementation. This is
// delegated to by the interface routines in pointer-to-implementation
// fashion.
std::unique_ptr<internal::StreamExecutorInterface> implementation_;
+ // A mapping of pointer (to GPU memory) to string representation of the stack
+ // (of the allocating thread) at the time at which the pointer was allocated.
+ std::map<void *, AllocRecord> mem_allocs_ GUARDED_BY(mu_);
+
// Memoized BLAS support object -- we only want to create this once when asked
// for a BLAS interface.
std::unique_ptr<blas::BlasSupport> blas_ GUARDED_BY(mu_);
diff --git a/tensorflow/stream_executor/timer.cc b/tensorflow/stream_executor/timer.cc
index 6926028dbb..62fe1c9f64 100644
--- a/tensorflow/stream_executor/timer.cc
+++ b/tensorflow/stream_executor/timer.cc
@@ -20,31 +20,13 @@ limitations under the License.
#include "tensorflow/stream_executor/platform.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/stream_executor.h"
-#include "tensorflow/stream_executor/stream_executor_internal.h"
namespace perftools {
namespace gputools {
-static internal::TimerInterface *CreateTimerImplementation(
- StreamExecutor *parent) {
- PlatformKind platform_kind = parent->platform_kind();
- if (platform_kind == PlatformKind::kCuda) {
- return (*internal::MakeCUDATimerImplementation())(parent);
- } else if (platform_kind == PlatformKind::kOpenCL ||
- platform_kind == PlatformKind::kOpenCLAltera) {
- return (*internal::MakeOpenCLTimerImplementation())(parent);
- } else if (platform_kind == PlatformKind::kHost) {
- return internal::MakeHostTimerImplementation(parent);
- } else if (platform_kind == PlatformKind::kMock) {
- return nullptr;
- } else {
- LOG(FATAL) << "cannot create timer implementation for platform kind: "
- << PlatformKindString(platform_kind);
- }
-}
-
Timer::Timer(StreamExecutor *parent)
- : implementation_(CreateTimerImplementation(parent)), parent_(parent) {}
+ : parent_(parent),
+ implementation_(parent_->implementation()->GetTimerImplementation()) {}
Timer::~Timer() { parent_->DeallocateTimer(this); }
diff --git a/tensorflow/stream_executor/timer.h b/tensorflow/stream_executor/timer.h
index 4da63c86b3..c39048d70b 100644
--- a/tensorflow/stream_executor/timer.h
+++ b/tensorflow/stream_executor/timer.h
@@ -58,14 +58,14 @@ class Timer {
internal::TimerInterface *implementation() { return implementation_.get(); }
private:
- // Platform-dependent implementation of the timer internals for the underlying
- // platform. This class just delegates to this opaque instance.
- std::unique_ptr<internal::TimerInterface> implementation_;
-
// The StreamExecutor that manages the platform-specific internals for this
// timer.
StreamExecutor *parent_;
+ // Platform-dependent implementation of the timer internals for the underlying
+ // platform. This class just delegates to this opaque instance.
+ std::unique_ptr<internal::TimerInterface> implementation_;
+
SE_DISALLOW_COPY_AND_ASSIGN(Timer);
};