aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/stream_executor
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/stream_executor')
-rw-r--r--tensorflow/stream_executor/blas.h1
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc151
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h16
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc18
-rw-r--r--tensorflow/stream_executor/device_description.h6
-rw-r--r--tensorflow/stream_executor/dnn.h4
-rw-r--r--tensorflow/stream_executor/lib/array_slice.h8
-rw-r--r--tensorflow/stream_executor/lib/inlined_vector.h4
-rw-r--r--tensorflow/stream_executor/lib/strcat.h6
-rw-r--r--tensorflow/stream_executor/lib/stringpiece.h5
-rw-r--r--tensorflow/stream_executor/plugin_registry.h2
-rw-r--r--tensorflow/stream_executor/stream.cc38
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc24
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h18
14 files changed, 234 insertions, 67 deletions
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 7f851e3646..f25ed700d6 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -41,6 +41,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
#include <complex>
+#include <vector>
#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 55408ab9ab..ca90c383f9 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/env.h"
#include "tensorflow/stream_executor/lib/error.h"
#include "tensorflow/stream_executor/lib/initialize.h"
+#include "tensorflow/stream_executor/lib/mathutil.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/lib/threadpool.h"
@@ -132,23 +133,42 @@ string ToString(cudnnStatus_t status) {
}
template <typename T>
-cudnnDataType_t GetCudnnDataType();
+cudnnDataType_t GetCudnnDataType(
+ dnn::DataLayout = dnn::DataLayout::kBatchDepthYX);
template <>
-cudnnDataType_t GetCudnnDataType<double>() {
+cudnnDataType_t GetCudnnDataType<double>(dnn::DataLayout) {
return CUDNN_DATA_DOUBLE;
}
template <>
-cudnnDataType_t GetCudnnDataType<float>() {
+cudnnDataType_t GetCudnnDataType<float>(dnn::DataLayout) {
return CUDNN_DATA_FLOAT;
}
template <>
-cudnnDataType_t GetCudnnDataType<Eigen::half>() {
+cudnnDataType_t GetCudnnDataType<Eigen::half>(dnn::DataLayout) {
return CUDNN_DATA_HALF;
}
+template <>
+cudnnDataType_t GetCudnnDataType<int8>(dnn::DataLayout layout) {
+ switch (layout) {
+ case dnn::DataLayout::kYXDepthBatch:
+ case dnn::DataLayout::kYXBatchDepth:
+ case dnn::DataLayout::kBatchYXDepth:
+ case dnn::DataLayout::kBatchDepthYX:
+ return CUDNN_DATA_INT8;
+ case dnn::DataLayout::kBatchDepthYX4:
+ return CUDNN_DATA_INT8x4;
+ }
+}
+
+template <>
+cudnnDataType_t GetCudnnDataType<int32>(dnn::DataLayout) {
+ return CUDNN_DATA_INT32;
+}
+
// RAII wrapper for all calls to cuDNN with a cuDNN handle argument.
//
// See CudnnAccess::GetHandle() for details.
@@ -2387,6 +2407,33 @@ cudnnDataType_t GetRnnComputeType(dnn::DataType data_type) {
}
}
+// Determines whether we can safely perform a winograd non-fused convolution for
+// the given input and output shapes. This works around b/68264959, an integer
+// overflow in cuDNNv5 and cuDNNv6.
+#if CUDNN_VERSION >= 7000
+bool ShouldIncludeWinogradNonfusedAlgo(const dnn::BatchDescriptor&,
+ const dnn::BatchDescriptor&) {
+ return true;
+}
+#else
+bool ShouldIncludeWinogradNonfusedAlgo(
+ const dnn::BatchDescriptor& input_desc,
+ const dnn::BatchDescriptor& output_desc) {
+ int64 batch = input_desc.count();
+ int64 in_depths = input_desc.feature_map_count();
+ int64 in_rows = input_desc.height();
+ int64 in_cols = input_desc.ndims() == 1 ? 1 : input_desc.width();
+ int64 out_depths = output_desc.feature_map_count();
+
+ int64 total_size = port::MathUtil::CeilOfRatio(batch, int64{16}) *
+ std::max(in_depths, out_depths) * in_cols * in_rows *
+ sizeof(float);
+
+ const int64 threshold = 1L << 31;
+ return total_size < threshold;
+}
+#endif
+
} // namespace
template <class T>
@@ -2465,6 +2512,13 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionForward(
cudnn.handle(),
/*alpha=*/alpha, /*srcDesc=*/input_nd.handle(),
@@ -2486,19 +2540,19 @@ port::Status CudnnSupport::DoConvolveImpl(
return port::Status::OK();
}
-template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status CudnnSupport::DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
- const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale, const dnn::FilterDescriptor& filter_descriptor,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data, ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
if (activation_mode != dnn::ActivationMode::kRelu &&
@@ -2509,14 +2563,17 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
}
CudnnTensorDescriptor conv_input_nd(
- conv_input_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
+ conv_input_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
CudnnTensorDescriptor output_nd(
- output_descriptor, static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnFilterDescriptor filter(filter_descriptor,
- static_cast<cudnnDataType_t>(cudnn_data_type));
- CudnnTensorDescriptor bias_nd(bias_descriptor, CUDNN_DATA_FLOAT);
- CudnnConvolutionDescriptor conv(
- convolution_descriptor, static_cast<cudnnDataType_t>(cudnn_compute_type));
+ output_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnFilterDescriptor filter(
+ filter_descriptor,
+ GetCudnnDataType<ElementType>(conv_input_descriptor.layout()));
+ CudnnTensorDescriptor bias_nd(bias_descriptor, GetCudnnDataType<BiasType>());
+ CudnnConvolutionDescriptor conv(convolution_descriptor,
+ GetCudnnDataType<AccumulatorType>());
auto cudnn = cudnn_->GetHandle(parent_, stream);
@@ -2566,6 +2623,14 @@ port::Status CudnnSupport::DoFusedConvolveImpl(
<< "\noutput_nd.handle() = " << output_nd.handle()
<< "\noutput_data->opaque() = " << output_data->opaque();
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(conv_input_descriptor,
+ output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See around b/68264959.");
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBiasActivationForward(
cudnn.handle(),
/*alpha1=*/&conv_input_scale,
@@ -2933,8 +2998,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<double, double, double, CUDNN_DATA_DOUBLE,
- CUDNN_DATA_DOUBLE>(
+ DoFusedConvolveImpl<double>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2957,8 +3021,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<float, float, float, CUDNN_DATA_FLOAT,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -2982,8 +3045,7 @@ bool CudnnSupport::DoFusedConvolve(
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result) {
return IsStatusOk(
- DoFusedConvolveImpl<Eigen::half, Eigen::half, float, CUDNN_DATA_HALF,
- CUDNN_DATA_FLOAT>(
+ DoFusedConvolveImpl<float>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3014,8 +3076,7 @@ bool CudnnSupport::DoFusedConvolve(
return false;
}
return IsStatusOk(
- DoFusedConvolveImpl<int8, float, float, CUDNN_DATA_INT8x4,
- CUDNN_DATA_INT32>(
+ DoFusedConvolveImpl<int32>(
stream, conv_input_descriptor, conv_input_data, conv_input_scale,
filter_descriptor, filter_data, convolution_descriptor,
side_input_data, side_input_scale, bias_descriptor, biases,
@@ -3096,6 +3157,13 @@ port::Status CudnnSupport::DoConvolveBackwardDataImpl(
}
}
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
// Cudnn 7.1.4 has a bug if the workspace of the following convolution is not
// zero-initialized, nvbugs/2254619.
if (CUDNN_VERSION >= 7000 &&
@@ -3275,6 +3343,33 @@ port::Status CudnnSupport::DoConvolveBackwardFilterImpl(
"This configuration potentially produces incorrect results.");
}());
+ if (algo_desc.algo_id() == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED &&
+ !ShouldIncludeWinogradNonfusedAlgo(input_descriptor, output_descriptor)) {
+ return port::Status(port::error::FAILED_PRECONDITION,
+ "This configuration has potential integer overflow in "
+ "cuDNNv5 and cuDNNv6. See b/68264959.");
+ }
+
+ // Zero out the result buffer for strided conv backward filter for NHWC
+ // layouts. cuDNN 7.1.4 and 7.2 has non-determinisic bug if the buffer is not
+ // zeroed.
+ //
+ // This wrong result caused by the bug is very flaky. It needs to be run for
+ // up to 20 times to produce a mismatch.
+ //
+ // TODO(timshen): add a nvbugs link.
+ if (CUDNN_VERSION >= 7100 &&
+ algorithm_config.algorithm().algo_id() ==
+ CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1 &&
+ cudnn_type == CUDNN_DATA_HALF &&
+ input_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ filter_descriptor.layout() == dnn::FilterLayout::kOutputYXInput &&
+ output_descriptor.layout() == dnn::DataLayout::kBatchYXDepth &&
+ (convolution_descriptor.vertical_filter_stride() > 1 ||
+ convolution_descriptor.horizontal_filter_stride() > 1)) {
+ stream->ThenMemZero(backward_filter_data, backward_filter_data->size());
+ }
+
RETURN_IF_CUDNN_ERROR(cudnnConvolutionBackwardFilter(
cudnn.handle(),
/*alpha=*/alpha,
@@ -3894,7 +3989,7 @@ bool CudnnSupport::DoDepthConcatenate(
for (size_t i = 0; i < input_data.size(); ++i) {
const auto& dimensions = input_dimensions[i];
tmp.resize(dimensions.ElementCount());
- stream->ThenMemcpyD2H<float>(*input_data[i], &tmp);
+ stream->ThenMemcpyD2H<float>(*input_data[i], absl::MakeSpan(tmp));
port::Status block_status = stream->BlockHostUntilDone();
if (!block_status.ok()) {
LOG(ERROR) << "BlockHostUntilDone failed: " << block_status;
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 9d88f971bb..74f6f935b8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -674,19 +674,21 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result);
- template <typename Type, typename BiasType, typename ScaleType,
- int cudnn_data_type, int cudnn_compute_type>
+ template <typename AccumulatorType, typename ElementType, typename BiasType,
+ typename ScaleType>
port::Status DoFusedConvolveImpl(
Stream* stream, const dnn::BatchDescriptor& conv_input_descriptor,
- const DeviceMemory<Type>& conv_input_data, ScaleType conv_input_scale,
+ const DeviceMemory<ElementType>& conv_input_data,
+ ScaleType conv_input_scale,
const dnn::FilterDescriptor& filter_descriptor,
- const DeviceMemory<Type>& filter_data,
+ const DeviceMemory<ElementType>& filter_data,
const dnn::ConvolutionDescriptor& convolution_descriptor,
- const DeviceMemory<Type>& side_input_data, ScaleType side_input_scale,
- const dnn::BatchDescriptor& bias_descriptor,
+ const DeviceMemory<ElementType>& side_input_data,
+ ScaleType side_input_scale, const dnn::BatchDescriptor& bias_descriptor,
const DeviceMemory<BiasType>& biases, dnn::ActivationMode activation_mode,
const dnn::BatchDescriptor& output_descriptor,
- DeviceMemory<Type>* output_data, ScratchAllocator* scratch_allocator,
+ DeviceMemory<ElementType>* output_data,
+ ScratchAllocator* scratch_allocator,
const dnn::AlgorithmConfig& algorithm_config,
dnn::ProfileResult* output_profile_result);
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index 9d5bcc7f77..5cceb8983c 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -467,8 +467,6 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel,
return;
}
- int block_size = thread_dims.x * thread_dims.y * thread_dims.z;
-
const DeviceDescription &device_description =
kernel.parent()->GetDeviceDescription();
@@ -485,7 +483,7 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel,
if (suggested_threads != 0) {
VLOG(2) << "The cuda occupancy calculator recommends using "
<< suggested_threads
- << " threads per block to acheive an occupancy of " << blocks_per_sm
+ << " threads per block to achieve an occupancy of " << blocks_per_sm
<< " blocks per SM.";
}
}
@@ -499,14 +497,14 @@ int CUDAExecutor::CalculateOccupancy(
CUfunction func) {
int suggested_blocks = 0;
int suggested_threads = 0;
- CUresult err =
- cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads,
- func, NULL, shared_memory_per_block, 0);
+ CUresult err = cuOccupancyMaxPotentialBlockSize(
+ &suggested_blocks, &suggested_threads, func, nullptr,
+ shared_memory_per_block, 0);
CHECK_EQ(err, CUDA_SUCCESS);
return suggested_blocks;
}
-// Compute and return the suggested thread count to acheive ideal occupancy.
+// Compute and return the suggested thread count to achieve ideal occupancy.
// If the provided thread dimensions match this number, zero is returned.
int CUDAExecutor::CompareOccupancy(int *initial_blocks,
const DeviceDescription &device_description,
@@ -516,9 +514,9 @@ int CUDAExecutor::CompareOccupancy(int *initial_blocks,
CUfunction func) {
int suggested_blocks = 0;
int suggested_threads = 0;
- CUresult err =
- cuOccupancyMaxPotentialBlockSize(&suggested_blocks, &suggested_threads,
- func, NULL, shared_memory_per_block, 0);
+ CUresult err = cuOccupancyMaxPotentialBlockSize(
+ &suggested_blocks, &suggested_threads, func, nullptr,
+ shared_memory_per_block, 0);
CHECK_EQ(err, CUDA_SUCCESS);
if (suggested_blocks > *initial_blocks) {
*initial_blocks = suggested_blocks;
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index b15ce31216..8ddf18629d 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -22,8 +22,7 @@ limitations under the License.
#include <map>
#include <memory>
-#include "tensorflow/stream_executor/platform/port.h"
-
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -310,9 +309,8 @@ class DeviceDescriptionBuilder {
bool ThreadDimOk(const DeviceDescription &device_description,
const ThreadDim &thread_dim);
-// [deprecated] Use MathUtil::CeilOfRatio directly instead.
-//
// Equivalent to ceil(double(element_count) / threads_per_block).
+ABSL_DEPRECATED("Use MathUtil::CeilOfRatio directly instead.")
uint64 DivideCeil(uint64 x, uint64 y);
// Calculate the number of threads/blocks required to process element_count
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 9abfa1db6a..621b155240 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -873,7 +873,7 @@ class NormalizeDescriptor {
// Describes a kind of non-linearity (threshold-like mathematical function).
enum class ActivationMode {
- kNone,
+ kNone = 0,
kSigmoid,
// Rectified linear activation: f(x) = x < 0 ? 0 : x
kRelu,
@@ -885,6 +885,8 @@ enum class ActivationMode {
kTanh,
// Like ReluX, but passes all values in the range [-X,X].
kBandPass,
+
+ kNumActivationModes, // Always in the end.
};
// Returns a string representation of the given activation mode.
diff --git a/tensorflow/stream_executor/lib/array_slice.h b/tensorflow/stream_executor/lib/array_slice.h
index 8e3c4ca047..5f4e586762 100644
--- a/tensorflow/stream_executor/lib/array_slice.h
+++ b/tensorflow/stream_executor/lib/array_slice.h
@@ -16,13 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_ARRAY_SLICE_H_
-#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "absl/types/span.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::ArraySlice;
-using tensorflow::gtl::MutableArraySlice;
+template <typename T>
+using ArraySlice = absl::Span<const T>;
+template <typename T>
+using MutableArraySlice = absl::Span<T>;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/inlined_vector.h b/tensorflow/stream_executor/lib/inlined_vector.h
index 40bdddb180..0198947e5b 100644
--- a/tensorflow/stream_executor/lib/inlined_vector.h
+++ b/tensorflow/stream_executor/lib/inlined_vector.h
@@ -16,12 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_INLINED_VECTOR_H_
-#include "tensorflow/core/lib/gtl/inlined_vector.h"
+#include "absl/container/inlined_vector.h"
namespace stream_executor {
namespace port {
-using tensorflow::gtl::InlinedVector;
+using absl::InlinedVector;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/strcat.h b/tensorflow/stream_executor/lib/strcat.h
index c959e4df5b..3688d7b4eb 100644
--- a/tensorflow/stream_executor/lib/strcat.h
+++ b/tensorflow/stream_executor/lib/strcat.h
@@ -18,13 +18,13 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRCAT_H_
-#include "tensorflow/core/lib/strings/strcat.h"
+#include "absl/strings/str_cat.h"
namespace stream_executor {
namespace port {
-using tensorflow::strings::StrCat;
-using tensorflow::strings::StrAppend;
+using absl::StrAppend;
+using absl::StrCat;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/lib/stringpiece.h b/tensorflow/stream_executor/lib/stringpiece.h
index b80de5df30..7624910129 100644
--- a/tensorflow/stream_executor/lib/stringpiece.h
+++ b/tensorflow/stream_executor/lib/stringpiece.h
@@ -16,13 +16,12 @@ limitations under the License.
#ifndef TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
#define TENSORFLOW_STREAM_EXECUTOR_LIB_STRINGPIECE_H_
-#include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/stream_executor/platform/port.h"
+#include "absl/strings/string_view.h"
namespace stream_executor {
namespace port {
-using tensorflow::StringPiece;
+using StringPiece = absl::string_view;
} // namespace port
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/plugin_registry.h b/tensorflow/stream_executor/plugin_registry.h
index 49628ecd24..3065b5cb77 100644
--- a/tensorflow/stream_executor/plugin_registry.h
+++ b/tensorflow/stream_executor/plugin_registry.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <map>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/fft.h"
@@ -97,6 +98,7 @@ class PluginRegistry {
// TODO(b/22689637): Deprecated/temporary. Will be deleted once all users are
// on MultiPlatformManager / PlatformId.
template <typename FactoryT>
+ ABSL_DEPRECATED("Use MultiPlatformManager / PlatformId instead.")
port::StatusOr<FactoryT> GetFactory(PlatformKind platform_kind,
PluginId plugin_id);
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 19d3b2389a..69558fd14b 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -587,6 +587,44 @@ Stream &Stream::ThenConvolveWithScratch(
Stream &Stream::ThenFusedConvolveWithAlgorithm(
const dnn::BatchDescriptor &conv_input_descriptor,
+ const DeviceMemory<double> &conv_input_data, double conv_input_scale,
+ const dnn::FilterDescriptor &filter_descriptor,
+ const DeviceMemory<double> &filter_data,
+ const dnn::ConvolutionDescriptor &convolution_descriptor,
+ const DeviceMemory<double> &side_input_data, double side_input_scale,
+ const dnn::BatchDescriptor &bias_descriptor,
+ const DeviceMemory<double> &biases, dnn::ActivationMode activation_mode,
+ const dnn::BatchDescriptor &output_descriptor, DeviceMemory<double> *output,
+ ScratchAllocator *scratch_allocator,
+ const dnn::AlgorithmConfig &algorithm_config,
+ dnn::ProfileResult *output_profile_result) {
+ VLOG_CALL(PARAM(conv_input_descriptor), PARAM(conv_input_data),
+ PARAM(conv_input_scale), PARAM(filter_descriptor),
+ PARAM(filter_data), PARAM(convolution_descriptor), PARAM(biases),
+ PARAM(side_input_data), PARAM(side_input_scale),
+ PARAM(activation_mode), PARAM(output_descriptor), PARAM(output),
+ PARAM(algorithm_config));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ auto status = dnn->DoFusedConvolve(
+ this, conv_input_descriptor, conv_input_data, conv_input_scale,
+ filter_descriptor, filter_data, convolution_descriptor,
+ side_input_data, side_input_scale, bias_descriptor, biases,
+ activation_mode, output_descriptor, output, scratch_allocator,
+ algorithm_config, output_profile_result);
+ if (!status && !output_profile_result) {
+ SetError();
+ }
+ } else {
+ SetErrorAndLogNoDnnSupport();
+ }
+ }
+ return *this;
+}
+
+Stream &Stream::ThenFusedConvolveWithAlgorithm(
+ const dnn::BatchDescriptor &conv_input_descriptor,
const DeviceMemory<float> &conv_input_data, float conv_input_scale,
const dnn::FilterDescriptor &filter_descriptor,
const DeviceMemory<float> &filter_data,
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 9515d8e62a..10bf006787 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <atomic>
#include <utility>
+#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/lib/env.h"
@@ -163,6 +164,15 @@ StreamExecutor::StreamExecutor(PlatformKind platform_kind,
CheckPlatformKindIsValid(platform_kind);
}
+// Get per-device memory limit in bytes. Returns 0 if
+// TF_PER_DEVICE_MEMORY_LIMIT_MB environment variable is not set.
+static int64 GetMemoryLimitBytes() {
+ int64 value;
+ SE_CHECK_OK(tensorflow::ReadInt64FromEnvVar("TF_PER_DEVICE_MEMORY_LIMIT_MB",
+ 0, &value));
+ return value * (1ll << 20);
+}
+
StreamExecutor::StreamExecutor(
const Platform *platform,
std::unique_ptr<internal::StreamExecutorInterface> implementation)
@@ -172,7 +182,9 @@ StreamExecutor::StreamExecutor(
background_threads_(new port::ThreadPool(
port::Env::Default(), "stream_executor", kNumBackgroundThreads)),
live_stream_count_(0),
- tracing_enabled_(false) {
+ tracing_enabled_(false),
+ mem_alloc_bytes_(0),
+ memory_limit_bytes_(GetMemoryLimitBytes()) {
if (port::Lowercase(platform_->Name()) == "cuda") {
platform_kind_ = PlatformKind::kCuda;
} else if (port::Lowercase(platform_->Name()) == "opencl") {
@@ -460,6 +472,14 @@ port::Status StreamExecutor::BlockHostUntilDone(Stream *stream) {
}
void *StreamExecutor::Allocate(uint64 size) {
+ if (memory_limit_bytes_ > 0 &&
+ mem_alloc_bytes_ + size > memory_limit_bytes_) {
+ LOG(WARNING) << "Not enough memory to allocate " << size << " on device "
+ << device_ordinal_
+ << " within provided limit. [used=" << mem_alloc_bytes_
+ << ", limit=" << memory_limit_bytes_ << "]";
+ return nullptr;
+ }
void *buf = implementation_->Allocate(size);
VLOG(1) << "Called StreamExecutor::Allocate(size=" << size << ") returns "
<< buf << StackTraceIfVLOG10();
@@ -779,6 +799,7 @@ void StreamExecutor::CreateAllocRecord(void *opaque, uint64 bytes) {
mutex_lock lock(mu_);
mem_allocs_[opaque] = AllocRecord{
bytes, ""};
+ mem_alloc_bytes_ += bytes;
}
}
@@ -789,6 +810,7 @@ void StreamExecutor::EraseAllocRecord(void *opaque) {
LOG(ERROR) << "Deallocating unknown pointer: "
<< port::Printf("0x%p", opaque);
} else {
+ mem_alloc_bytes_ -= mem_allocs_[opaque].bytes;
mem_allocs_.erase(opaque);
}
}
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 437f298616..4a8a270afa 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <tuple>
#include <vector>
+#include "absl/base/macros.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/strcat.h"
@@ -81,8 +82,8 @@ class StreamExecutor {
port::Status Init();
port::Status Init(int device_ordinal, DeviceOptions device_options);
- // DEPRECATED: Do not use; use platform() instead.
// Returns the platform that this StreamExecutor is acting upon.
+ ABSL_DEPRECATED("Use platform() instead.")
PlatformKind platform_kind() const { return platform_kind_; }
// Returns a reference to the platform that created this executor.
@@ -255,15 +256,15 @@ class StreamExecutor {
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the host source to the device destination.
- //
- // Deprecation: prefer explicit H2D below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyH2D, to avoid error-prone API usage.")
bool SynchronousMemcpy(DeviceMemoryBase *device_dst, const void *host_src,
uint64 size) SE_MUST_USE_RESULT;
// [deprecated] Blocks the caller while a data segment of the given size is
// copied from the device source to the host destination.
- //
- // Deprecation: prefer explicit D2H below, to avoid error-prone API usage.
+ ABSL_DEPRECATED(
+ "Prefer SynchronousMemcpyD2H, to avoid error-prone API usage.")
bool SynchronousMemcpy(void *host_dst, const DeviceMemoryBase &device_src,
uint64 size) SE_MUST_USE_RESULT;
@@ -699,6 +700,13 @@ class StreamExecutor {
// The set of TraceListeners registered for this StreamExecutor.
std::set<TraceListener*> listeners_ GUARDED_BY(mu_);
+ // Allocated memory in bytes.
+ int64 mem_alloc_bytes_;
+
+ // Memory limit in bytes. Value less or equal to 0 indicates there is no
+ // limit.
+ int64 memory_limit_bytes_;
+
SE_DISALLOW_COPY_AND_ASSIGN(StreamExecutor);
};