diff options
author | 2017-07-11 22:48:31 -0700 | |
---|---|---|
committer | 2017-07-11 22:52:35 -0700 | |
commit | a05c3ce52ffd4909cb9bf7c155805406b5b1fc06 (patch) | |
tree | 3193719801c2aee4bb2e967d59f582ce40dabf2b /tensorflow/core/kernels/fft_ops.cc | |
parent | 54137ffd39bfb86bc6f259ebaa3c85fecfbe7e93 (diff) |
Add scratch allocator option for 1D, 2D, 3D, and batched cufft plan creation.
PiperOrigin-RevId: 161627209
Diffstat (limited to 'tensorflow/core/kernels/fft_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/fft_ops.cc | 108 |
1 files changed, 93 insertions, 15 deletions
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc index 35a3f7b189..6924de284b 100644 --- a/tensorflow/core/kernels/fft_ops.cc +++ b/tensorflow/core/kernels/fft_ops.cc @@ -17,6 +17,7 @@ limitations under the License. // See docs in ../ops/spectral_ops.cc. +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor.h" @@ -24,8 +25,8 @@ limitations under the License. #include "tensorflow/core/framework/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/work_sharder.h" -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #if GOOGLE_CUDA #include "tensorflow/core/platform/stream_executor.h" @@ -276,22 +277,93 @@ REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL), #undef FFT_LABEL #if GOOGLE_CUDA +namespace gpu = ::perftools::gputools; namespace { -// TODO(vrv/zhifengc): Refactor AsDeviceMemory() into GPUUtil. template <typename T> -perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { - perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); - perftools::gputools::DeviceMemory<T> typed(wrapped); +gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { + gpu::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); + gpu::DeviceMemory<T> typed(wrapped); + return typed; +} + +template <typename T> +gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) { + gpu::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T)); + gpu::DeviceMemory<T> typed(wrapped); return typed; } + +// A class to provide scratch-space allocator for Stream-Executor Cufft +// callback. Tensorflow is responsible for releasing the temporary buffers after +// the kernel finishes. +// TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator +// into base class. +class CufftScratchAllocator : public gpu::ScratchAllocator { + public: + virtual ~CufftScratchAllocator() {} + CufftScratchAllocator(int64 memory_limit, OpKernelContext* context) + : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} + int64 GetMemoryLimitInBytes(gpu::Stream* stream) override { + return memory_limit_; + } + gpu::port::StatusOr<gpu::DeviceMemory<uint8>> AllocateBytes( + gpu::Stream* stream, int64 byte_size) override { + Tensor temporary_memory; + if (byte_size > memory_limit_) { + return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>(); + } + AllocationAttributes allocation_attr; + allocation_attr.no_retry_on_failure = true; + Status allocation_status(context_->allocate_temp( + DT_UINT8, TensorShape({byte_size}), &temporary_memory, + AllocatorAttributes(), allocation_attr)); + if (!allocation_status.ok()) { + return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>(); + } + // Hold the reference of the allocated tensors until the end of the + // allocator. + allocated_tensors_.push_back(temporary_memory); + total_byte_size_ += byte_size; + return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>( + AsDeviceMemory(temporary_memory.flat<uint8>().data(), + temporary_memory.flat<uint8>().size())); + } + int64 TotalByteSize() { return total_byte_size_; } + + private: + int64 memory_limit_; + int64 total_byte_size_; + OpKernelContext* context_; + std::vector<Tensor> allocated_tensors_; +}; + } // end namespace +int64 GetCufftWorkspaceLimit(const string& envvar_in_mb, + int64 default_value_in_bytes) { + const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str()); + if (workspace_limit_in_mb_str != nullptr && + strcmp(workspace_limit_in_mb_str, "") != 0) { + int64 scratch_limit_in_mb = -1; + Status status = ReadInt64FromEnvVar(envvar_in_mb, default_value_in_bytes, + &scratch_limit_in_mb); + if (!status.ok()) { + LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": " + << workspace_limit_in_mb_str; + } else { + return scratch_limit_in_mb * (1 << 20); + } + } + return default_value_in_bytes; +} + class FFTGPUBase : public FFTBase { public: using FFTBase::FFTBase; protected: + static int64 CufftScratchSize; void DoFFT(OpKernelContext* ctx, const Tensor& in, uint64* fft_shape, Tensor* out) override { auto* stream = ctx->op_device_context()->stream(); @@ -306,10 +378,10 @@ class FFTGPUBase : public FFTBase { batch_size *= input_shape.dim_size(i); } uint64 input_embed[3]; - uint64 input_stride = 1; + const uint64 input_stride = 1; uint64 input_distance = 1; uint64 output_embed[3]; - uint64 output_stride = 1; + const uint64 output_stride = 1; uint64 output_distance = 1; for (int i = 0; i < fft_rank; ++i) { @@ -322,15 +394,16 @@ class FFTGPUBase : public FFTBase { constexpr bool kInPlaceFft = false; const auto kFftType = - IsReal() ? (IsForward() ? perftools::gputools::fft::Type::kR2C - : perftools::gputools::fft::Type::kC2R) - : (IsForward() ? perftools::gputools::fft::Type::kC2CForward - : perftools::gputools::fft::Type::kC2CInverse); + IsReal() ? (IsForward() ? gpu::fft::Type::kR2C : gpu::fft::Type::kC2R) + : (IsForward() ? gpu::fft::Type::kC2CForward + : gpu::fft::Type::kC2CInverse); - auto plan = stream->parent()->AsFft()->CreateBatchedPlan( - stream, fft_rank, fft_shape, input_embed, input_stride, input_distance, - output_embed, output_stride, output_distance, kFftType, kInPlaceFft, - batch_size); + CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx); + auto plan = + stream->parent()->AsFft()->CreateBatchedPlanWithScratchAllocator( + stream, fft_rank, fft_shape, input_embed, input_stride, + input_distance, output_embed, output_stride, output_distance, + kFftType, kInPlaceFft, batch_size, &scratch_allocator); if (IsReal()) { if (IsForward()) { @@ -375,6 +448,11 @@ class FFTGPUBase : public FFTBase { } }; +int64 FFTGPUBase::CufftScratchSize = GetCufftWorkspaceLimit( + // default value is in bytes despite the name of the environment variable + "TF_CUFFT_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB +); + template <bool Forward, bool _Real, int FFTRank> class FFTGPU : public FFTGPUBase { public: |