diff options
-rw-r--r-- | tensorflow/core/kernels/bias_op.cc | 24 | ||||
-rw-r--r-- | tensorflow/core/kernels/bias_op_gpu.cu.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/kernels/cwise_ops_test.cc | 119 | ||||
-rw-r--r-- | tensorflow/core/platform/default/test_benchmark.cc | 38 | ||||
-rw-r--r-- | tensorflow/core/platform/test_benchmark.h | 8 | ||||
-rw-r--r-- | tensorflow/core/util/cuda_kernel_helper.h | 73 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/bias_op_test.py | 53 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_driver.cc | 17 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_driver.h | 7 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_gpu_executor.cc | 17 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_gpu_executor.h | 2 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream.h | 1 | ||||
-rw-r--r-- | tensorflow/stream_executor/stream_executor_internal.h | 2 |
13 files changed, 300 insertions, 99 deletions
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc index ba31221dad..c31cdeb712 100644 --- a/tensorflow/core/kernels/bias_op.cc +++ b/tensorflow/core/kernels/bias_op.cc @@ -149,6 +149,18 @@ void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format, } } +template <class T> +struct AccumulatorType { + typedef T type; +}; + +// float is faster on the CPU than half, and also more precise, +// so use float for the temporary accumulators. +template <> +struct AccumulatorType<Eigen::half> { + typedef float type; +}; + } // namespace template <typename Device, typename T> @@ -197,7 +209,11 @@ class BiasGradOp<CPUDevice, T> : public OpKernel { Eigen::array<int, 1> reduction_axis = {0}; #endif output->template flat<T>().device(context->eigen_device<CPUDevice>()) = - output_backprop.flat<T>().reshape(two_dims).sum(reduction_axis); + output_backprop.flat<T>() + .template cast<typename AccumulatorType<T>::type>() + .reshape(two_dims) + .sum(reduction_axis) + .template cast<T>(); } private: @@ -268,7 +284,7 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> { Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ BiasOp<GPUDevice, type>); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GPU_KERNEL); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL template <typename T> @@ -302,7 +318,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); perftools::gputools::DeviceMemoryBase output_ptr( output->flat<T>().data(), output->NumElements() * sizeof(T)); - stream->ThenMemset32(&output_ptr, 0, output->NumElements() * sizeof(T)); + stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T)); BiasGradGPU<T>::compute(context->template eigen_device<Device>(), output_backprop.template flat<T>().data(), output->flat<T>().data(), batch, width, height, @@ -319,7 +335,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel { Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \ BiasGradOp<GPUDevice, type>); -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GPU_KERNEL); +TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL); #undef REGISTER_GPU_KERNEL #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc index 364555f141..48e56acaab 100644 --- a/tensorflow/core/kernels/bias_op_gpu.cu.cc +++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc @@ -28,6 +28,20 @@ namespace tensorflow { typedef Eigen::GpuDevice GPUDevice; +// There are no native fp16 atomics (we simulate them using 32-bit atomics), +// so fp16 sums are done in fp32 internally. (We don't have a lot of shared +// memory traffic; BiasGradNCHW_SharedAtomics in particular works almost +// entirely on a local variable.) +template <class T> +struct AccumulatorType { + typedef T type; +}; + +template <> +struct AccumulatorType<Eigen::half> { + typedef float type; +}; + // Definition of the GPU implementations declared in bias_op.cc. template <typename T> @@ -102,21 +116,22 @@ template <typename T> __global__ void BiasGradNHWC_SharedAtomics(int32 nthreads, const T* output_backprop, T* bias_backprop, int32 bias_size) { - T* s_data = reinterpret_cast<T*>(s_buf); + typedef typename AccumulatorType<T>::type AccT; + AccT* s_data = reinterpret_cast<AccT*>(s_buf); for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) { - s_data[index] = T(0); + s_data[index] = AccT(0); } __syncthreads(); for (int32 index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int32 bias_offset = index % bias_size; - CudaAtomicAdd(s_data + bias_offset, ldg(output_backprop + index)); + CudaAtomicAdd(s_data + bias_offset, AccT(ldg(output_backprop + index))); } __syncthreads(); for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) { - CudaAtomicAdd(bias_backprop + index, s_data[index]); + CudaAtomicAdd(bias_backprop + index, T(s_data[index])); } } @@ -126,10 +141,11 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, int32 bias_size, int32 image_size, int group_size) { // Initialize the shared memory. - __shared__ T s_data[32]; + typedef typename AccumulatorType<T>::type AccT; + __shared__ AccT s_data[32]; int32 s_data_size = sizeof(s_data) / sizeof(T); for (int32 index = threadIdx.x; index < s_data_size; index += blockDim.x) { - s_data[index] = 0; + s_data[index] = AccT(0); } __syncthreads(); @@ -138,14 +154,14 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, int32 bias_index = blockIdx.x % bias_size; int32 group_index = blockIdx.x / bias_size; int32 total_count = batch * image_size; - T sum = 0; + AccT sum(0); for (int32 index = group_index * blockDim.x + threadIdx.x; index < total_count; index += blockDim.x * group_size) { int32 image_offset = index % image_size; int32 batch = index / image_size; T val = ldg(output_backprop + (batch * bias_size + bias_index) * image_size + image_offset); - sum += val; + sum += AccT(val); } // Write the accumulated sum in this thread to the shared memory. Each thread @@ -165,7 +181,7 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop, // The first thread writes out the accumulated result to the global location. if (thread_index == 0) { - CudaAtomicAdd(bias_backprop + bias_index, s_data[0]); + CudaAtomicAdd(bias_backprop + bias_index, T(s_data[0])); } } @@ -186,7 +202,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop, const int max_shared_memory_size = d.sharedMemPerBlock() / 2; int32 shared_memory_size = 0; if (data_format == FORMAT_NHWC) { - shared_memory_size = bias_size * sizeof(T); + shared_memory_size = bias_size * sizeof(typename AccumulatorType<T>::type); } // Check if we have enough shared memory. if (shared_memory_size <= max_shared_memory_size) { @@ -227,7 +243,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop, template struct BiasGPU<T>; \ template struct BiasGradGPU<T>; -TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS); +TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc index 40f0d04329..5b205175c1 100644 --- a/tensorflow/core/kernels/cwise_ops_test.cc +++ b/tensorflow/core/kernels/cwise_ops_test.cc @@ -19,6 +19,7 @@ limitations under the License. #include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test_benchmark.h" +#include "tensorflow/core/util/tensor_format.h" namespace tensorflow { @@ -82,73 +83,97 @@ BM_BINARY_SCALAR(cpu, Add); BM_BINARY_SCALAR(gpu, Add); #undef BM_BINARY_SCALAR -static Graph* BiasAdd(int rows, int cols) { +template <class T> +static Graph* BiasAdd(int rows, int cols, DataType type) { Graph* g = new Graph(OpRegistry::Global()); - Tensor lhs(DT_FLOAT, TensorShape({rows, cols})); - lhs.flat<float>().setRandom(); + Tensor lhs(type, TensorShape({rows, cols})); + lhs.template flat<T>().setRandom(); TensorShape rhs_shape; rhs_shape = TensorShape({cols}); - Tensor rhs(DT_FLOAT, rhs_shape); - rhs.flat<float>().setRandom(); + Tensor rhs(type, rhs_shape); + rhs.template flat<T>().setRandom(); test::graph::Binary(g, "BiasAdd", test::graph::Constant(g, lhs), test::graph::Constant(g, rhs)); return g; } -#define BM_BIAS_ADD(DEVICE, R, C) \ - static void BM_##DEVICE##_BiasAdd_R##R##_C##C(int iters, int arg) { \ - const int rows = RowsFromArg(arg); \ - const int cols = ColsFromArg(arg); \ - const int64 tot = static_cast<int64>(iters) * rows * cols; \ - testing::ItemsProcessed(tot); \ - testing::BytesProcessed(tot * sizeof(float)); \ - test::Benchmark(#DEVICE, BiasAdd(rows, cols)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_BiasAdd_R##R##_C##C)->Arg(RowsAndColsArg(R, C)); - -#define BM_BIAS_ADD_ALL(DEVICE) \ - BM_BIAS_ADD(DEVICE, 512, 2048); \ - BM_BIAS_ADD(DEVICE, 512, 4096); \ - BM_BIAS_ADD(DEVICE, 2048, 512); \ - BM_BIAS_ADD(DEVICE, 4096, 512); - -BM_BIAS_ADD_ALL(cpu); -BM_BIAS_ADD_ALL(gpu); +#define BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, R, C) \ + static void BM_##DEVICE##_##C_TYPE##_BiasAdd_R##R##_C##C(int iters, \ + int arg) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast<int64>(iters) * rows * cols; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(C_TYPE)); \ + test::Benchmark(#DEVICE, BiasAdd<C_TYPE>(rows, cols, TF_TYPE)).Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_##C_TYPE##_BiasAdd_R##R##_C##C) \ + ->Arg(RowsAndColsArg(R, C)); + +#define BM_BIAS_ADD_ALL(DEVICE, C_TYPE, TF_TYPE) \ + BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 512, 2048); \ + BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 512, 4096); \ + BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 2048, 512); \ + BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 4096, 512); + +using Eigen::half; +BM_BIAS_ADD_ALL(cpu, float, DT_FLOAT); +BM_BIAS_ADD_ALL(gpu, float, DT_FLOAT); +BM_BIAS_ADD_ALL(cpu, half, DT_HALF); +BM_BIAS_ADD_ALL(gpu, half, DT_HALF); #undef BM_BIAS_ADD_ALL #undef BM_BIAS_ADD -static Graph* BiasAddGrad(int rows, int cols) { +template <class T> +static Graph* BiasAddGrad(int rows, int cols, int channels, DataType type, + TensorFormat format) { Graph* g = new Graph(OpRegistry::Global()); TensorShape lhs_shape; - lhs_shape = TensorShape({rows, cols}); - Tensor lhs(DT_FLOAT, lhs_shape); - lhs.template flat<float>().setRandom(); + if (format == FORMAT_NCHW) { + lhs_shape = TensorShape({channels, rows, cols}); + } else { + lhs_shape = TensorShape({rows, cols, channels}); + } + Tensor lhs(type, lhs_shape); + lhs.template flat<T>().setRandom(); Node* n; TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAddGrad") + .Attr("data_format", ToString(format)) .Input(test::graph::Constant(g, lhs), /*index=*/0) .Finalize(g, &n)); return g; } -#define BM_BIAS_ADD_GRAD(DEVICE, R, C) \ - static void BM_##DEVICE##_BiasAddGrad_R##R##_C##C(int iters, int arg) { \ - const int rows = RowsFromArg(arg); \ - const int cols = ColsFromArg(arg); \ - const int64 tot = static_cast<int64>(iters) * rows * cols; \ - testing::ItemsProcessed(tot); \ - testing::BytesProcessed(tot * sizeof(float)); \ - test::Benchmark(#DEVICE, BiasAddGrad(rows, cols)).Run(iters); \ - } \ - BENCHMARK(BM_##DEVICE##_BiasAddGrad_R##R##_C##C)->Arg(RowsAndColsArg(R, C)); - -#define BM_BIAS_ADD_GRAD_ALL(DEVICE) \ - BM_BIAS_ADD_GRAD(DEVICE, 512, 2048); \ - BM_BIAS_ADD_GRAD(DEVICE, 512, 4096); \ - BM_BIAS_ADD_GRAD(DEVICE, 2048, 512); \ - BM_BIAS_ADD_GRAD(DEVICE, 4096, 512); - -BM_BIAS_ADD_GRAD_ALL(cpu); -BM_BIAS_ADD_GRAD_ALL(gpu); +#define BM_BIAS_ADD_GRAD(DEVICE, FMT, C_TYPE, TF_TYPE, R, C, CH) \ + static void \ + BM_##DEVICE##_##FMT##_##C_TYPE##_BiasAddGrad_R##R##_C##C##_CH##CH( \ + int iters, int arg, int channels) { \ + const int rows = RowsFromArg(arg); \ + const int cols = ColsFromArg(arg); \ + const int64 tot = static_cast<int64>(iters) * rows * cols * channels; \ + testing::ItemsProcessed(tot); \ + testing::BytesProcessed(tot * sizeof(C_TYPE)); \ + test::Benchmark(#DEVICE, BiasAddGrad<C_TYPE>(rows, cols, channels, \ + TF_TYPE, FORMAT_##FMT)) \ + .Run(iters); \ + } \ + BENCHMARK(BM_##DEVICE##_##FMT##_##C_TYPE##_BiasAddGrad_R##R##_C##C##_CH##CH) \ + ->ArgPair(RowsAndColsArg(R, C), CH); + +#define BM_BIAS_ADD_GRAD_ALL(DEVICE, FORMAT, C_TYPE, TF_TYPE) \ + BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 64, 64, 64); \ + BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 512, 512, 4); \ + BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 512, 512, 1); \ + BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 4); \ + BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 1); + +using Eigen::half; +BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, float, DT_FLOAT); +BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, half, DT_HALF); +BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, float, DT_FLOAT); +BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, float, DT_FLOAT); +BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, half, DT_HALF); +BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF); #undef BM_BIAS_ADD_GRAD_ALL #undef BM_BIAS_ADD_GRAD diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc index 15c01b5233..ffdbb0b761 100644 --- a/tensorflow/core/platform/default/test_benchmark.cc +++ b/tensorflow/core/platform/default/test_benchmark.cc @@ -18,6 +18,7 @@ limitations under the License. #include <cstdio> #include <cstdlib> +#include <algorithm> #include <vector> #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/env.h" @@ -38,7 +39,7 @@ static Env* env; Benchmark::Benchmark(const char* name, void (*fn)(int)) : name_(name), num_args_(0), fn0_(fn) { - args_.push_back(-1); + args_.push_back(std::make_pair(-1, -1)); Register(); } @@ -47,9 +48,20 @@ Benchmark::Benchmark(const char* name, void (*fn)(int, int)) Register(); } +Benchmark::Benchmark(const char* name, void (*fn)(int, int, int)) + : name_(name), num_args_(2), fn2_(fn) { + Register(); +} + Benchmark* Benchmark::Arg(int x) { CHECK_EQ(num_args_, 1); - args_.push_back(x); + args_.push_back(std::make_pair(x, -1)); + return this; +} + +Benchmark* Benchmark::ArgPair(int x, int y) { + CHECK_EQ(num_args_, 2); + args_.push_back(std::make_pair(x, y)); return this; } @@ -76,8 +88,11 @@ void Benchmark::Run(const char* pattern) { name = b->name_; for (auto arg : b->args_) { name.resize(b->name_.size()); - if (arg >= 0) { - strings::StrAppend(&name, "/", arg); + if (arg.first >= 0) { + strings::StrAppend(&name, "/", arg.first); + if (arg.second >= 0) { + strings::StrAppend(&name, "/", arg.second); + } } if (RE2::PartialMatch(name, pattern)) { width = std::max<int>(width, name.size()); @@ -91,8 +106,11 @@ void Benchmark::Run(const char* pattern) { name = b->name_; for (auto arg : b->args_) { name.resize(b->name_.size()); - if (arg >= 0) { - strings::StrAppend(&name, "/", arg); + if (arg.first >= 0) { + strings::StrAppend(&name, "/", arg.first); + if (arg.second >= 0) { + strings::StrAppend(&name, "/", arg.second); + } } if (!RE2::PartialMatch(name, pattern)) { continue; @@ -100,7 +118,7 @@ void Benchmark::Run(const char* pattern) { int iters; double seconds; - b->Run(arg, &iters, &seconds); + b->Run(arg.first, arg.second, &iters, &seconds); char buf[100]; std::string full_label = label; @@ -143,7 +161,7 @@ void Benchmark::Register() { all_benchmarks->push_back(this); } -void Benchmark::Run(int arg, int* run_count, double* run_seconds) { +void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) { env = Env::Default(); static const int64 kMinIters = 100; static const int64 kMaxIters = 1000000000; @@ -157,8 +175,10 @@ void Benchmark::Run(int arg, int* run_count, double* run_seconds) { label.clear(); if (fn0_) { (*fn0_)(iters); + } else if (fn1_) { + (*fn1_)(iters, arg1); } else { - (*fn1_)(iters, arg); + (*fn2_)(iters, arg1, arg2); } StopTiming(); const double seconds = accum_time * 1e-6; diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h index 4c90cf7f85..3536548f00 100644 --- a/tensorflow/core/platform/test_benchmark.h +++ b/tensorflow/core/platform/test_benchmark.h @@ -17,6 +17,7 @@ limitations under the License. #ifndef TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_ #define TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_ +#include <utility> #include <vector> #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/platform.h" @@ -48,20 +49,23 @@ class Benchmark { public: Benchmark(const char* name, void (*fn)(int)); Benchmark(const char* name, void (*fn)(int, int)); + Benchmark(const char* name, void (*fn)(int, int, int)); Benchmark* Arg(int x); + Benchmark* ArgPair(int x, int y); Benchmark* Range(int lo, int hi); static void Run(const char* pattern); private: string name_; int num_args_; - std::vector<int> args_; + std::vector<std::pair<int, int>> args_; void (*fn0_)(int) = nullptr; void (*fn1_)(int, int) = nullptr; + void (*fn2_)(int, int, int) = nullptr; void Register(); - void Run(int arg, int* run_count, double* run_seconds); + void Run(int arg1, int arg2, int* run_count, double* run_seconds); }; #endif diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h index a86567a7cc..86c55031c6 100644 --- a/tensorflow/core/util/cuda_kernel_helper.h +++ b/tensorflow/core/util/cuda_kernel_helper.h @@ -20,6 +20,7 @@ limitations under the License. #include <algorithm> +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/core/platform/types.h" #define CUDA_1D_KERNEL_LOOP(i, n) \ @@ -104,6 +105,78 @@ CUDA_ATOMIC_WRAPPER(Add, double) { return __longlong_as_double(old); } +// Helper functions for CudaAtomicAdd(half*, half), below. +// +// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2() +// for a more efficient implementation, assuming that adding -0.0 +// will never harm the neighboring value. In this version, we take special +// care to guarantee the bits of the untouched value are unchanged. +inline __device__ uint32 add_to_low_half(uint32 val, float x) { + Eigen::half low_half; + low_half.x = static_cast<uint16>(val & 0xffffu); + low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x); + return (val & 0xffff0000u) | low_half.x; +} + +inline __device__ uint32 add_to_high_half(uint32 val, float x) { + Eigen::half high_half; + high_half.x = static_cast<uint16>(val >> 16); + high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x); + return (val & 0xffffu) | (high_half.x << 16); +} + +// Custom implementation of atomicAdd for half. Note that we don't have +// atomicCAS() for anything less than 32 bits, so we need to include the +// other 16 bits in the operation. +// +// Unlike the other atomic adds, this version is going to be very slow +// under high concurrency, since most threads will be spinning on failing +// their compare-and-swap tests. (The fact that we get false sharing on the +// neighboring fp16 makes this even worse.) If you are doing a large reduction, +// you are much better off with doing the intermediate steps in fp32 and then +// switching to fp16 as late as you can in the calculations. +// +// Note: Assumes little endian. +CUDA_ATOMIC_WRAPPER(Add, Eigen::half) { + float val_as_float(val); + intptr_t address_int = reinterpret_cast<intptr_t>(address); + if ((address_int & 0x2) == 0) { + // The half is in the first part of the uint32 (lower 16 bits). + uint32* address_as_uint32 = reinterpret_cast<uint32*>(address); + assert(((intptr_t)address_as_uint32 & 0x3) == 0); + uint32 old = *address_as_uint32, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_uint32, assumed, + add_to_low_half(assumed, val_as_float)); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + Eigen::half ret; + ret.x = old & 0xffffu; + return ret; + } else { + // The half is in the second part of the uint32 (upper 16 bits). + uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2); + assert(((intptr_t)address_as_uint32 & 0x3) == 0); + uint32 old = *address_as_uint32, assumed; + + do { + assumed = old; + old = atomicCAS(address_as_uint32, assumed, + add_to_high_half(assumed, val_as_float)); + + // Note: uses integer comparison to avoid hang in case of NaN + } while (assumed != old); + + Eigen::half ret; + ret.x = old >> 16; + return ret; + } +} + template <typename T> __global__ void SetZero(const int nthreads, T* bottom_diff) { CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); } diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py index 1e75513d48..fadd9c5794 100644 --- a/tensorflow/python/kernel_tests/bias_op_test.py +++ b/tensorflow/python/kernel_tests/bias_op_test.py @@ -54,7 +54,7 @@ class BiasAddTest(tf.test.TestCase): np_val = self._npBias(np_inputs, np_bias) with self.test_session(use_gpu=use_gpu): tf_val = tf.nn.bias_add(np_inputs, np_bias).eval() - self.assertAllClose(np_val, tf_val) + self.assertAllCloseAccordingToType(np_val, tf_val) def _NHWCToNCHW(self, np_value): # fill the input value to at least 3-dimension @@ -79,11 +79,11 @@ class BiasAddTest(tf.test.TestCase): with self.test_session(use_gpu=use_gpu): tf_val = tf.nn.bias_add(np_inputs, np_bias, data_format="NCHW").eval() tf_val = self._NCHWToNHWC(tf_val) - self.assertAllClose(np_val, tf_val) + self.assertAllCloseAccordingToType(np_val, tf_val) def _testAll(self, np_inputs, np_bias): self._testBias(np_inputs, np_bias, use_gpu=False) - if np_inputs.dtype == np.float32 or np_inputs.dtype == np.float64: + if np_inputs.dtype in [np.float16, np.float32, np.float64]: self._testBias(np_inputs, np_bias, use_gpu=True) if tf.test.is_built_with_cuda(): self._testBiasNCHW(np_inputs, np_bias, use_gpu=True) @@ -108,7 +108,7 @@ class BiasAddTest(tf.test.TestCase): np.array([1, 2, 3]).astype(t)) def testFloatTypes(self): - for t in [np.float32, np.float64]: + for t in [np.float16, np.float32, np.float64]: self._testAll(np.random.rand(4, 3, 3).astype(t), np.random.rand(3).astype(t)) @@ -120,39 +120,44 @@ class BiasAddTest(tf.test.TestCase): bias_tensor = tf.constant(bias, shape=bias.shape, dtype=dtype) output_tensor = tf.nn.bias_add(input_tensor, bias_tensor, data_format=data_format) - err = tf.test.compute_gradient_error(input_tensor, np_input.shape, - output_tensor, np_input.shape) + tensor_jacob_t, tensor_jacob_n = tf.test.compute_gradient( + input_tensor, np_input.shape, output_tensor, np_input.shape) + bias_jacob_t, bias_jacob_n = tf.test.compute_gradient( + bias_tensor, bias.shape, output_tensor, np_input.shape) + + if dtype == np.float16: + # Compare fp16 theoretical gradients to fp32 numerical gradients, + # since fp16 numerical gradients are too imprecise unless great + # care is taken with choosing the inputs and the delta. This is + # a weaker check (in particular, it does not test the op itself, + # only its gradient), but it's much better than nothing. + input_tensor = tf.constant(np_input, shape=np_input.shape, + dtype=np.float32) + bias_tensor = tf.constant(bias, shape=bias.shape, dtype=np.float32) + output_tensor = tf.nn.bias_add(input_tensor, bias_tensor, + data_format=data_format) + _, tensor_jacob_n = tf.test.compute_gradient( + input_tensor, np_input.shape, output_tensor, np_input.shape) + _, bias_jacob_n = tf.test.compute_gradient( + bias_tensor, bias.shape, output_tensor, np_input.shape) + threshold = 2e-3 if dtype == tf.float64: threshold = 1e-10 - print("bias add tensor gradient err = ", err) - self.assertLess(err, threshold) - err = tf.test.compute_gradient_error(bias_tensor, bias.shape, - output_tensor, np_input.shape) - print("bias-add bias gradient err = ", err) - self.assertLess(err, threshold) + self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold) + self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold) def testGradientTensor(self): for (data_format, use_gpu) in GetTestConfigs(): - for dtype in (tf.float32, tf.float64): + for dtype in (tf.float16, tf.float32, tf.float64): np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], dtype=dtype.as_numpy_dtype).reshape(3, 2) bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) self._testGradient(np_input, bias, dtype, data_format, use_gpu) - def testGradientBias(self): - with self.test_session(): - t = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2], - dtype=tf.float64) - b = tf.constant([1.3, 2.4], dtype=tf.float64) - bo = tf.nn.bias_add(t, b) - err = tf.test.compute_gradient_error(b, [2], bo, [3, 2]) - print("bias add bias gradient err = ", err) - self.assertLess(err, 1e-10) - def testGradientTensor4D(self): for (data_format, use_gpu) in GetTestConfigs(): - for dtype in (tf.float32, tf.float64): + for dtype in (tf.float16, tf.float32, tf.float64): np_input = np.arange(1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape( [2, 3, 4, 2]).astype(np.float32) bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype) diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc index 66cfccdd90..7f992c6073 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.cc +++ b/tensorflow/stream_executor/cuda/cuda_driver.cc @@ -118,6 +118,7 @@ PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostUnregister); PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32_v2); PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32Async); PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8_v2); +PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8Async); PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetFunction); PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetGlobal_v2); PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleLoadDataEx); @@ -800,6 +801,22 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) { return true; } +/* static */ bool CUDADriver::AsynchronousMemsetUint8(CUcontext context, + CUdeviceptr location, + uint8 value, + size_t uint32_count, + CUstream stream) { + ScopedActivateContext activation{context}; + CUresult res = + dynload::cuMemsetD8Async(location, value, uint32_count, stream); + if (res != CUDA_SUCCESS) { + LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res); + return false; + } + VLOG(2) << "successfully enqueued async memset operation"; + return true; +} + /* static */ bool CUDADriver::AsynchronousMemsetUint32(CUcontext context, CUdeviceptr location, uint32 value, diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h index 645fb66ada..b887d048a4 100644 --- a/tensorflow/stream_executor/cuda/cuda_driver.h +++ b/tensorflow/stream_executor/cuda/cuda_driver.h @@ -233,6 +233,13 @@ class CUDADriver { uint32 value, size_t uint32_count); // Performs an asynchronous memset of the device memory segment via + // cuMemsetD8Async. + // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627 + static bool AsynchronousMemsetUint8(CUcontext context, CUdeviceptr location, + uint8 value, size_t uint32_count, + CUstream stream); + + // Performs an asynchronous memset of the device memory segment via // cuMemsetD32Async. // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5 static bool AsynchronousMemsetUint32(CUcontext context, CUdeviceptr location, diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc index f4b31ad304..9757ef640d 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc @@ -533,7 +533,22 @@ bool CUDAExecutor::SynchronousMemcpyDeviceToDevice( bool CUDAExecutor::MemZero(Stream *stream, DeviceMemoryBase *location, uint64 size) { - return Memset32(stream, location, 0x0, size); + if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 && + size % 4 == 0) { + return Memset32(stream, location, 0x0, size); + } else { + return Memset(stream, location, 0x0, size); + } +} + +bool CUDAExecutor::Memset(Stream *stream, DeviceMemoryBase *location, + uint8 pattern, uint64 size) { + VLOG(2) << "enqueueing memset8 operation onto stream " << stream + << " at location " << location << " with size " << size + << " and pattern " << std::hex << pattern; + return CUDADriver::AsynchronousMemsetUint8( + context_, AsCudaDevicePtr(location), pattern, size, + AsCUDAStreamValue(stream)); } bool CUDAExecutor::Memset32(Stream *stream, DeviceMemoryBase *location, diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h index 2a0c6dc456..ccbe6f26fd 100644 --- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h +++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h @@ -120,6 +120,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface { bool MemZero(Stream *stream, DeviceMemoryBase *location, uint64 size) override; + bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern, + uint64 size) override; bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, uint64 size) override; diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index ca52949657..b800e03ae7 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -1231,7 +1231,6 @@ class Stream { // Entrain onto the stream: a memset of zero at a GPU location of size // bytes. // The location must not be null. - // TODO(leary) Presently the size must be a 4-byte multiple. Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size); // Entrain onto the stream: a memset of a 32-bit pattern at a GPU location diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h index dff756c8fc..8ff9532a55 100644 --- a/tensorflow/stream_executor/stream_executor_internal.h +++ b/tensorflow/stream_executor/stream_executor_internal.h @@ -209,6 +209,8 @@ class StreamExecutorInterface { uint64 size) = 0; virtual bool MemZero(Stream *stream, DeviceMemoryBase *location, uint64 size) = 0; + virtual bool Memset(Stream *stream, DeviceMemoryBase *location, + uint8 pattern, uint64 size) = 0; virtual bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern, uint64 size) = 0; virtual bool Memcpy(Stream *stream, void *host_dst, |