aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/kernels/bias_op.cc24
-rw-r--r--tensorflow/core/kernels/bias_op_gpu.cu.cc38
-rw-r--r--tensorflow/core/kernels/cwise_ops_test.cc119
-rw-r--r--tensorflow/core/platform/default/test_benchmark.cc38
-rw-r--r--tensorflow/core/platform/test_benchmark.h8
-rw-r--r--tensorflow/core/util/cuda_kernel_helper.h73
-rw-r--r--tensorflow/python/kernel_tests/bias_op_test.py53
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.cc17
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.h7
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc17
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h2
-rw-r--r--tensorflow/stream_executor/stream.h1
-rw-r--r--tensorflow/stream_executor/stream_executor_internal.h2
13 files changed, 300 insertions, 99 deletions
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index ba31221dad..c31cdeb712 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -149,6 +149,18 @@ void GetBiasValueDims(const Tensor& value_tensor, TensorFormat data_format,
}
}
+template <class T>
+struct AccumulatorType {
+ typedef T type;
+};
+
+// float is faster on the CPU than half, and also more precise,
+// so use float for the temporary accumulators.
+template <>
+struct AccumulatorType<Eigen::half> {
+ typedef float type;
+};
+
} // namespace
template <typename Device, typename T>
@@ -197,7 +209,11 @@ class BiasGradOp<CPUDevice, T> : public OpKernel {
Eigen::array<int, 1> reduction_axis = {0};
#endif
output->template flat<T>().device(context->eigen_device<CPUDevice>()) =
- output_backprop.flat<T>().reshape(two_dims).sum(reduction_axis);
+ output_backprop.flat<T>()
+ .template cast<typename AccumulatorType<T>::type>()
+ .reshape(two_dims)
+ .sum(reduction_axis)
+ .template cast<T>();
}
private:
@@ -268,7 +284,7 @@ class BiasOp<GPUDevice, T> : public BinaryOp<T> {
Name("BiasAddV1").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
BiasOp<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GPU_KERNEL);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
template <typename T>
@@ -302,7 +318,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
perftools::gputools::DeviceMemoryBase output_ptr(
output->flat<T>().data(), output->NumElements() * sizeof(T));
- stream->ThenMemset32(&output_ptr, 0, output->NumElements() * sizeof(T));
+ stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
output_backprop.template flat<T>().data(),
output->flat<T>().data(), batch, width, height,
@@ -319,7 +335,7 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
Name("BiasAddGrad").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
BiasGradOp<GPUDevice, type>);
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_GPU_KERNEL);
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNEL);
#undef REGISTER_GPU_KERNEL
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/bias_op_gpu.cu.cc b/tensorflow/core/kernels/bias_op_gpu.cu.cc
index 364555f141..48e56acaab 100644
--- a/tensorflow/core/kernels/bias_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/bias_op_gpu.cu.cc
@@ -28,6 +28,20 @@ namespace tensorflow {
typedef Eigen::GpuDevice GPUDevice;
+// There are no native fp16 atomics (we simulate them using 32-bit atomics),
+// so fp16 sums are done in fp32 internally. (We don't have a lot of shared
+// memory traffic; BiasGradNCHW_SharedAtomics in particular works almost
+// entirely on a local variable.)
+template <class T>
+struct AccumulatorType {
+ typedef T type;
+};
+
+template <>
+struct AccumulatorType<Eigen::half> {
+ typedef float type;
+};
+
// Definition of the GPU implementations declared in bias_op.cc.
template <typename T>
@@ -102,21 +116,22 @@ template <typename T>
__global__ void BiasGradNHWC_SharedAtomics(int32 nthreads,
const T* output_backprop,
T* bias_backprop, int32 bias_size) {
- T* s_data = reinterpret_cast<T*>(s_buf);
+ typedef typename AccumulatorType<T>::type AccT;
+ AccT* s_data = reinterpret_cast<AccT*>(s_buf);
for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
- s_data[index] = T(0);
+ s_data[index] = AccT(0);
}
__syncthreads();
for (int32 index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
index += blockDim.x * gridDim.x) {
int32 bias_offset = index % bias_size;
- CudaAtomicAdd(s_data + bias_offset, ldg(output_backprop + index));
+ CudaAtomicAdd(s_data + bias_offset, AccT(ldg(output_backprop + index)));
}
__syncthreads();
for (int32 index = threadIdx.x; index < bias_size; index += blockDim.x) {
- CudaAtomicAdd(bias_backprop + index, s_data[index]);
+ CudaAtomicAdd(bias_backprop + index, T(s_data[index]));
}
}
@@ -126,10 +141,11 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
int32 bias_size, int32 image_size,
int group_size) {
// Initialize the shared memory.
- __shared__ T s_data[32];
+ typedef typename AccumulatorType<T>::type AccT;
+ __shared__ AccT s_data[32];
int32 s_data_size = sizeof(s_data) / sizeof(T);
for (int32 index = threadIdx.x; index < s_data_size; index += blockDim.x) {
- s_data[index] = 0;
+ s_data[index] = AccT(0);
}
__syncthreads();
@@ -138,14 +154,14 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
int32 bias_index = blockIdx.x % bias_size;
int32 group_index = blockIdx.x / bias_size;
int32 total_count = batch * image_size;
- T sum = 0;
+ AccT sum(0);
for (int32 index = group_index * blockDim.x + threadIdx.x;
index < total_count; index += blockDim.x * group_size) {
int32 image_offset = index % image_size;
int32 batch = index / image_size;
T val = ldg(output_backprop +
(batch * bias_size + bias_index) * image_size + image_offset);
- sum += val;
+ sum += AccT(val);
}
// Write the accumulated sum in this thread to the shared memory. Each thread
@@ -165,7 +181,7 @@ __global__ void BiasGradNCHW_SharedAtomics(const T* output_backprop,
// The first thread writes out the accumulated result to the global location.
if (thread_index == 0) {
- CudaAtomicAdd(bias_backprop + bias_index, s_data[0]);
+ CudaAtomicAdd(bias_backprop + bias_index, T(s_data[0]));
}
}
@@ -186,7 +202,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
const int max_shared_memory_size = d.sharedMemPerBlock() / 2;
int32 shared_memory_size = 0;
if (data_format == FORMAT_NHWC) {
- shared_memory_size = bias_size * sizeof(T);
+ shared_memory_size = bias_size * sizeof(typename AccumulatorType<T>::type);
}
// Check if we have enough shared memory.
if (shared_memory_size <= max_shared_memory_size) {
@@ -227,7 +243,7 @@ void BiasGradGPU<T>::compute(const GPUDevice& d, const T* output_backprop,
template struct BiasGPU<T>; \
template struct BiasGradGPU<T>;
-TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DEFINE_GPU_SPECS);
+TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/cwise_ops_test.cc b/tensorflow/core/kernels/cwise_ops_test.cc
index 40f0d04329..5b205175c1 100644
--- a/tensorflow/core/kernels/cwise_ops_test.cc
+++ b/tensorflow/core/kernels/cwise_ops_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
+#include "tensorflow/core/util/tensor_format.h"
namespace tensorflow {
@@ -82,73 +83,97 @@ BM_BINARY_SCALAR(cpu, Add);
BM_BINARY_SCALAR(gpu, Add);
#undef BM_BINARY_SCALAR
-static Graph* BiasAdd(int rows, int cols) {
+template <class T>
+static Graph* BiasAdd(int rows, int cols, DataType type) {
Graph* g = new Graph(OpRegistry::Global());
- Tensor lhs(DT_FLOAT, TensorShape({rows, cols}));
- lhs.flat<float>().setRandom();
+ Tensor lhs(type, TensorShape({rows, cols}));
+ lhs.template flat<T>().setRandom();
TensorShape rhs_shape;
rhs_shape = TensorShape({cols});
- Tensor rhs(DT_FLOAT, rhs_shape);
- rhs.flat<float>().setRandom();
+ Tensor rhs(type, rhs_shape);
+ rhs.template flat<T>().setRandom();
test::graph::Binary(g, "BiasAdd", test::graph::Constant(g, lhs),
test::graph::Constant(g, rhs));
return g;
}
-#define BM_BIAS_ADD(DEVICE, R, C) \
- static void BM_##DEVICE##_BiasAdd_R##R##_C##C(int iters, int arg) { \
- const int rows = RowsFromArg(arg); \
- const int cols = ColsFromArg(arg); \
- const int64 tot = static_cast<int64>(iters) * rows * cols; \
- testing::ItemsProcessed(tot); \
- testing::BytesProcessed(tot * sizeof(float)); \
- test::Benchmark(#DEVICE, BiasAdd(rows, cols)).Run(iters); \
- } \
- BENCHMARK(BM_##DEVICE##_BiasAdd_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
-
-#define BM_BIAS_ADD_ALL(DEVICE) \
- BM_BIAS_ADD(DEVICE, 512, 2048); \
- BM_BIAS_ADD(DEVICE, 512, 4096); \
- BM_BIAS_ADD(DEVICE, 2048, 512); \
- BM_BIAS_ADD(DEVICE, 4096, 512);
-
-BM_BIAS_ADD_ALL(cpu);
-BM_BIAS_ADD_ALL(gpu);
+#define BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, R, C) \
+ static void BM_##DEVICE##_##C_TYPE##_BiasAdd_R##R##_C##C(int iters, \
+ int arg) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(C_TYPE)); \
+ test::Benchmark(#DEVICE, BiasAdd<C_TYPE>(rows, cols, TF_TYPE)).Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##C_TYPE##_BiasAdd_R##R##_C##C) \
+ ->Arg(RowsAndColsArg(R, C));
+
+#define BM_BIAS_ADD_ALL(DEVICE, C_TYPE, TF_TYPE) \
+ BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 512, 2048); \
+ BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 512, 4096); \
+ BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 2048, 512); \
+ BM_BIAS_ADD(DEVICE, C_TYPE, TF_TYPE, 4096, 512);
+
+using Eigen::half;
+BM_BIAS_ADD_ALL(cpu, float, DT_FLOAT);
+BM_BIAS_ADD_ALL(gpu, float, DT_FLOAT);
+BM_BIAS_ADD_ALL(cpu, half, DT_HALF);
+BM_BIAS_ADD_ALL(gpu, half, DT_HALF);
#undef BM_BIAS_ADD_ALL
#undef BM_BIAS_ADD
-static Graph* BiasAddGrad(int rows, int cols) {
+template <class T>
+static Graph* BiasAddGrad(int rows, int cols, int channels, DataType type,
+ TensorFormat format) {
Graph* g = new Graph(OpRegistry::Global());
TensorShape lhs_shape;
- lhs_shape = TensorShape({rows, cols});
- Tensor lhs(DT_FLOAT, lhs_shape);
- lhs.template flat<float>().setRandom();
+ if (format == FORMAT_NCHW) {
+ lhs_shape = TensorShape({channels, rows, cols});
+ } else {
+ lhs_shape = TensorShape({rows, cols, channels});
+ }
+ Tensor lhs(type, lhs_shape);
+ lhs.template flat<T>().setRandom();
Node* n;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "BiasAddGrad")
+ .Attr("data_format", ToString(format))
.Input(test::graph::Constant(g, lhs), /*index=*/0)
.Finalize(g, &n));
return g;
}
-#define BM_BIAS_ADD_GRAD(DEVICE, R, C) \
- static void BM_##DEVICE##_BiasAddGrad_R##R##_C##C(int iters, int arg) { \
- const int rows = RowsFromArg(arg); \
- const int cols = ColsFromArg(arg); \
- const int64 tot = static_cast<int64>(iters) * rows * cols; \
- testing::ItemsProcessed(tot); \
- testing::BytesProcessed(tot * sizeof(float)); \
- test::Benchmark(#DEVICE, BiasAddGrad(rows, cols)).Run(iters); \
- } \
- BENCHMARK(BM_##DEVICE##_BiasAddGrad_R##R##_C##C)->Arg(RowsAndColsArg(R, C));
-
-#define BM_BIAS_ADD_GRAD_ALL(DEVICE) \
- BM_BIAS_ADD_GRAD(DEVICE, 512, 2048); \
- BM_BIAS_ADD_GRAD(DEVICE, 512, 4096); \
- BM_BIAS_ADD_GRAD(DEVICE, 2048, 512); \
- BM_BIAS_ADD_GRAD(DEVICE, 4096, 512);
-
-BM_BIAS_ADD_GRAD_ALL(cpu);
-BM_BIAS_ADD_GRAD_ALL(gpu);
+#define BM_BIAS_ADD_GRAD(DEVICE, FMT, C_TYPE, TF_TYPE, R, C, CH) \
+ static void \
+ BM_##DEVICE##_##FMT##_##C_TYPE##_BiasAddGrad_R##R##_C##C##_CH##CH( \
+ int iters, int arg, int channels) { \
+ const int rows = RowsFromArg(arg); \
+ const int cols = ColsFromArg(arg); \
+ const int64 tot = static_cast<int64>(iters) * rows * cols * channels; \
+ testing::ItemsProcessed(tot); \
+ testing::BytesProcessed(tot * sizeof(C_TYPE)); \
+ test::Benchmark(#DEVICE, BiasAddGrad<C_TYPE>(rows, cols, channels, \
+ TF_TYPE, FORMAT_##FMT)) \
+ .Run(iters); \
+ } \
+ BENCHMARK(BM_##DEVICE##_##FMT##_##C_TYPE##_BiasAddGrad_R##R##_C##C##_CH##CH) \
+ ->ArgPair(RowsAndColsArg(R, C), CH);
+
+#define BM_BIAS_ADD_GRAD_ALL(DEVICE, FORMAT, C_TYPE, TF_TYPE) \
+ BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 64, 64, 64); \
+ BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 512, 512, 4); \
+ BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 512, 512, 1); \
+ BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 4); \
+ BM_BIAS_ADD_GRAD(DEVICE, FORMAT, C_TYPE, TF_TYPE, 4096, 4096, 1);
+
+using Eigen::half;
+BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, float, DT_FLOAT);
+BM_BIAS_ADD_GRAD_ALL(gpu, NCHW, half, DT_HALF);
+BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, float, DT_FLOAT);
+BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, float, DT_FLOAT);
+BM_BIAS_ADD_GRAD_ALL(cpu, NHWC, half, DT_HALF);
+BM_BIAS_ADD_GRAD_ALL(gpu, NHWC, half, DT_HALF);
#undef BM_BIAS_ADD_GRAD_ALL
#undef BM_BIAS_ADD_GRAD
diff --git a/tensorflow/core/platform/default/test_benchmark.cc b/tensorflow/core/platform/default/test_benchmark.cc
index 15c01b5233..ffdbb0b761 100644
--- a/tensorflow/core/platform/default/test_benchmark.cc
+++ b/tensorflow/core/platform/default/test_benchmark.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <cstdio>
#include <cstdlib>
+#include <algorithm>
#include <vector>
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
@@ -38,7 +39,7 @@ static Env* env;
Benchmark::Benchmark(const char* name, void (*fn)(int))
: name_(name), num_args_(0), fn0_(fn) {
- args_.push_back(-1);
+ args_.push_back(std::make_pair(-1, -1));
Register();
}
@@ -47,9 +48,20 @@ Benchmark::Benchmark(const char* name, void (*fn)(int, int))
Register();
}
+Benchmark::Benchmark(const char* name, void (*fn)(int, int, int))
+ : name_(name), num_args_(2), fn2_(fn) {
+ Register();
+}
+
Benchmark* Benchmark::Arg(int x) {
CHECK_EQ(num_args_, 1);
- args_.push_back(x);
+ args_.push_back(std::make_pair(x, -1));
+ return this;
+}
+
+Benchmark* Benchmark::ArgPair(int x, int y) {
+ CHECK_EQ(num_args_, 2);
+ args_.push_back(std::make_pair(x, y));
return this;
}
@@ -76,8 +88,11 @@ void Benchmark::Run(const char* pattern) {
name = b->name_;
for (auto arg : b->args_) {
name.resize(b->name_.size());
- if (arg >= 0) {
- strings::StrAppend(&name, "/", arg);
+ if (arg.first >= 0) {
+ strings::StrAppend(&name, "/", arg.first);
+ if (arg.second >= 0) {
+ strings::StrAppend(&name, "/", arg.second);
+ }
}
if (RE2::PartialMatch(name, pattern)) {
width = std::max<int>(width, name.size());
@@ -91,8 +106,11 @@ void Benchmark::Run(const char* pattern) {
name = b->name_;
for (auto arg : b->args_) {
name.resize(b->name_.size());
- if (arg >= 0) {
- strings::StrAppend(&name, "/", arg);
+ if (arg.first >= 0) {
+ strings::StrAppend(&name, "/", arg.first);
+ if (arg.second >= 0) {
+ strings::StrAppend(&name, "/", arg.second);
+ }
}
if (!RE2::PartialMatch(name, pattern)) {
continue;
@@ -100,7 +118,7 @@ void Benchmark::Run(const char* pattern) {
int iters;
double seconds;
- b->Run(arg, &iters, &seconds);
+ b->Run(arg.first, arg.second, &iters, &seconds);
char buf[100];
std::string full_label = label;
@@ -143,7 +161,7 @@ void Benchmark::Register() {
all_benchmarks->push_back(this);
}
-void Benchmark::Run(int arg, int* run_count, double* run_seconds) {
+void Benchmark::Run(int arg1, int arg2, int* run_count, double* run_seconds) {
env = Env::Default();
static const int64 kMinIters = 100;
static const int64 kMaxIters = 1000000000;
@@ -157,8 +175,10 @@ void Benchmark::Run(int arg, int* run_count, double* run_seconds) {
label.clear();
if (fn0_) {
(*fn0_)(iters);
+ } else if (fn1_) {
+ (*fn1_)(iters, arg1);
} else {
- (*fn1_)(iters, arg);
+ (*fn2_)(iters, arg1, arg2);
}
StopTiming();
const double seconds = accum_time * 1e-6;
diff --git a/tensorflow/core/platform/test_benchmark.h b/tensorflow/core/platform/test_benchmark.h
index 4c90cf7f85..3536548f00 100644
--- a/tensorflow/core/platform/test_benchmark.h
+++ b/tensorflow/core/platform/test_benchmark.h
@@ -17,6 +17,7 @@ limitations under the License.
#ifndef TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
#define TENSORFLOW_PLATFORM_TEST_BENCHMARK_H_
+#include <utility>
#include <vector>
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/platform.h"
@@ -48,20 +49,23 @@ class Benchmark {
public:
Benchmark(const char* name, void (*fn)(int));
Benchmark(const char* name, void (*fn)(int, int));
+ Benchmark(const char* name, void (*fn)(int, int, int));
Benchmark* Arg(int x);
+ Benchmark* ArgPair(int x, int y);
Benchmark* Range(int lo, int hi);
static void Run(const char* pattern);
private:
string name_;
int num_args_;
- std::vector<int> args_;
+ std::vector<std::pair<int, int>> args_;
void (*fn0_)(int) = nullptr;
void (*fn1_)(int, int) = nullptr;
+ void (*fn2_)(int, int, int) = nullptr;
void Register();
- void Run(int arg, int* run_count, double* run_seconds);
+ void Run(int arg1, int arg2, int* run_count, double* run_seconds);
};
#endif
diff --git a/tensorflow/core/util/cuda_kernel_helper.h b/tensorflow/core/util/cuda_kernel_helper.h
index a86567a7cc..86c55031c6 100644
--- a/tensorflow/core/util/cuda_kernel_helper.h
+++ b/tensorflow/core/util/cuda_kernel_helper.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <algorithm>
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/platform/types.h"
#define CUDA_1D_KERNEL_LOOP(i, n) \
@@ -104,6 +105,78 @@ CUDA_ATOMIC_WRAPPER(Add, double) {
return __longlong_as_double(old);
}
+// Helper functions for CudaAtomicAdd(half*, half), below.
+//
+// Note that if __CUDA_ARCH__ >= 530, we could probably use __hadd2()
+// for a more efficient implementation, assuming that adding -0.0
+// will never harm the neighboring value. In this version, we take special
+// care to guarantee the bits of the untouched value are unchanged.
+inline __device__ uint32 add_to_low_half(uint32 val, float x) {
+ Eigen::half low_half;
+ low_half.x = static_cast<uint16>(val & 0xffffu);
+ low_half = static_cast<Eigen::half>(static_cast<float>(low_half) + x);
+ return (val & 0xffff0000u) | low_half.x;
+}
+
+inline __device__ uint32 add_to_high_half(uint32 val, float x) {
+ Eigen::half high_half;
+ high_half.x = static_cast<uint16>(val >> 16);
+ high_half = static_cast<Eigen::half>(static_cast<float>(high_half) + x);
+ return (val & 0xffffu) | (high_half.x << 16);
+}
+
+// Custom implementation of atomicAdd for half. Note that we don't have
+// atomicCAS() for anything less than 32 bits, so we need to include the
+// other 16 bits in the operation.
+//
+// Unlike the other atomic adds, this version is going to be very slow
+// under high concurrency, since most threads will be spinning on failing
+// their compare-and-swap tests. (The fact that we get false sharing on the
+// neighboring fp16 makes this even worse.) If you are doing a large reduction,
+// you are much better off with doing the intermediate steps in fp32 and then
+// switching to fp16 as late as you can in the calculations.
+//
+// Note: Assumes little endian.
+CUDA_ATOMIC_WRAPPER(Add, Eigen::half) {
+ float val_as_float(val);
+ intptr_t address_int = reinterpret_cast<intptr_t>(address);
+ if ((address_int & 0x2) == 0) {
+ // The half is in the first part of the uint32 (lower 16 bits).
+ uint32* address_as_uint32 = reinterpret_cast<uint32*>(address);
+ assert(((intptr_t)address_as_uint32 & 0x3) == 0);
+ uint32 old = *address_as_uint32, assumed;
+
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_uint32, assumed,
+ add_to_low_half(assumed, val_as_float));
+
+ // Note: uses integer comparison to avoid hang in case of NaN
+ } while (assumed != old);
+
+ Eigen::half ret;
+ ret.x = old & 0xffffu;
+ return ret;
+ } else {
+ // The half is in the second part of the uint32 (upper 16 bits).
+ uint32* address_as_uint32 = reinterpret_cast<uint32*>(address_int - 2);
+ assert(((intptr_t)address_as_uint32 & 0x3) == 0);
+ uint32 old = *address_as_uint32, assumed;
+
+ do {
+ assumed = old;
+ old = atomicCAS(address_as_uint32, assumed,
+ add_to_high_half(assumed, val_as_float));
+
+ // Note: uses integer comparison to avoid hang in case of NaN
+ } while (assumed != old);
+
+ Eigen::half ret;
+ ret.x = old >> 16;
+ return ret;
+ }
+}
+
template <typename T>
__global__ void SetZero(const int nthreads, T* bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) { *(bottom_diff + index) = T(0); }
diff --git a/tensorflow/python/kernel_tests/bias_op_test.py b/tensorflow/python/kernel_tests/bias_op_test.py
index 1e75513d48..fadd9c5794 100644
--- a/tensorflow/python/kernel_tests/bias_op_test.py
+++ b/tensorflow/python/kernel_tests/bias_op_test.py
@@ -54,7 +54,7 @@ class BiasAddTest(tf.test.TestCase):
np_val = self._npBias(np_inputs, np_bias)
with self.test_session(use_gpu=use_gpu):
tf_val = tf.nn.bias_add(np_inputs, np_bias).eval()
- self.assertAllClose(np_val, tf_val)
+ self.assertAllCloseAccordingToType(np_val, tf_val)
def _NHWCToNCHW(self, np_value):
# fill the input value to at least 3-dimension
@@ -79,11 +79,11 @@ class BiasAddTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu):
tf_val = tf.nn.bias_add(np_inputs, np_bias, data_format="NCHW").eval()
tf_val = self._NCHWToNHWC(tf_val)
- self.assertAllClose(np_val, tf_val)
+ self.assertAllCloseAccordingToType(np_val, tf_val)
def _testAll(self, np_inputs, np_bias):
self._testBias(np_inputs, np_bias, use_gpu=False)
- if np_inputs.dtype == np.float32 or np_inputs.dtype == np.float64:
+ if np_inputs.dtype in [np.float16, np.float32, np.float64]:
self._testBias(np_inputs, np_bias, use_gpu=True)
if tf.test.is_built_with_cuda():
self._testBiasNCHW(np_inputs, np_bias, use_gpu=True)
@@ -108,7 +108,7 @@ class BiasAddTest(tf.test.TestCase):
np.array([1, 2, 3]).astype(t))
def testFloatTypes(self):
- for t in [np.float32, np.float64]:
+ for t in [np.float16, np.float32, np.float64]:
self._testAll(np.random.rand(4, 3, 3).astype(t),
np.random.rand(3).astype(t))
@@ -120,39 +120,44 @@ class BiasAddTest(tf.test.TestCase):
bias_tensor = tf.constant(bias, shape=bias.shape, dtype=dtype)
output_tensor = tf.nn.bias_add(input_tensor, bias_tensor,
data_format=data_format)
- err = tf.test.compute_gradient_error(input_tensor, np_input.shape,
- output_tensor, np_input.shape)
+ tensor_jacob_t, tensor_jacob_n = tf.test.compute_gradient(
+ input_tensor, np_input.shape, output_tensor, np_input.shape)
+ bias_jacob_t, bias_jacob_n = tf.test.compute_gradient(
+ bias_tensor, bias.shape, output_tensor, np_input.shape)
+
+ if dtype == np.float16:
+ # Compare fp16 theoretical gradients to fp32 numerical gradients,
+ # since fp16 numerical gradients are too imprecise unless great
+ # care is taken with choosing the inputs and the delta. This is
+ # a weaker check (in particular, it does not test the op itself,
+ # only its gradient), but it's much better than nothing.
+ input_tensor = tf.constant(np_input, shape=np_input.shape,
+ dtype=np.float32)
+ bias_tensor = tf.constant(bias, shape=bias.shape, dtype=np.float32)
+ output_tensor = tf.nn.bias_add(input_tensor, bias_tensor,
+ data_format=data_format)
+ _, tensor_jacob_n = tf.test.compute_gradient(
+ input_tensor, np_input.shape, output_tensor, np_input.shape)
+ _, bias_jacob_n = tf.test.compute_gradient(
+ bias_tensor, bias.shape, output_tensor, np_input.shape)
+
threshold = 2e-3
if dtype == tf.float64:
threshold = 1e-10
- print("bias add tensor gradient err = ", err)
- self.assertLess(err, threshold)
- err = tf.test.compute_gradient_error(bias_tensor, bias.shape,
- output_tensor, np_input.shape)
- print("bias-add bias gradient err = ", err)
- self.assertLess(err, threshold)
+ self.assertAllClose(tensor_jacob_t, tensor_jacob_n, threshold, threshold)
+ self.assertAllClose(bias_jacob_t, bias_jacob_n, threshold, threshold)
def testGradientTensor(self):
for (data_format, use_gpu) in GetTestConfigs():
- for dtype in (tf.float32, tf.float64):
+ for dtype in (tf.float16, tf.float32, tf.float64):
np_input = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
dtype=dtype.as_numpy_dtype).reshape(3, 2)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
self._testGradient(np_input, bias, dtype, data_format, use_gpu)
- def testGradientBias(self):
- with self.test_session():
- t = tf.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[3, 2],
- dtype=tf.float64)
- b = tf.constant([1.3, 2.4], dtype=tf.float64)
- bo = tf.nn.bias_add(t, b)
- err = tf.test.compute_gradient_error(b, [2], bo, [3, 2])
- print("bias add bias gradient err = ", err)
- self.assertLess(err, 1e-10)
-
def testGradientTensor4D(self):
for (data_format, use_gpu) in GetTestConfigs():
- for dtype in (tf.float32, tf.float64):
+ for dtype in (tf.float16, tf.float32, tf.float64):
np_input = np.arange(1.0, 49.0, dtype=dtype.as_numpy_dtype).reshape(
[2, 3, 4, 2]).astype(np.float32)
bias = np.array([1.3, 2.4], dtype=dtype.as_numpy_dtype)
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.cc b/tensorflow/stream_executor/cuda/cuda_driver.cc
index 66cfccdd90..7f992c6073 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.cc
+++ b/tensorflow/stream_executor/cuda/cuda_driver.cc
@@ -118,6 +118,7 @@ PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemHostUnregister);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD32Async);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8_v2);
+PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuMemsetD8Async);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetFunction);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleGetGlobal_v2);
PERFTOOLS_GPUTOOLS_LIBCUDA_WRAP(cuModuleLoadDataEx);
@@ -800,6 +801,22 @@ CUDADriver::ContextGetSharedMemConfig(CUcontext context) {
return true;
}
+/* static */ bool CUDADriver::AsynchronousMemsetUint8(CUcontext context,
+ CUdeviceptr location,
+ uint8 value,
+ size_t uint32_count,
+ CUstream stream) {
+ ScopedActivateContext activation{context};
+ CUresult res =
+ dynload::cuMemsetD8Async(location, value, uint32_count, stream);
+ if (res != CUDA_SUCCESS) {
+ LOG(ERROR) << "failed to enqueue async memset operation: " << ToString(res);
+ return false;
+ }
+ VLOG(2) << "successfully enqueued async memset operation";
+ return true;
+}
+
/* static */ bool CUDADriver::AsynchronousMemsetUint32(CUcontext context,
CUdeviceptr location,
uint32 value,
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h
index 645fb66ada..b887d048a4 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.h
+++ b/tensorflow/stream_executor/cuda/cuda_driver.h
@@ -233,6 +233,13 @@ class CUDADriver {
uint32 value, size_t uint32_count);
// Performs an asynchronous memset of the device memory segment via
+ // cuMemsetD8Async.
+ // http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gaef08a7ccd61112f94e82f2b30d43627
+ static bool AsynchronousMemsetUint8(CUcontext context, CUdeviceptr location,
+ uint8 value, size_t uint32_count,
+ CUstream stream);
+
+ // Performs an asynchronous memset of the device memory segment via
// cuMemsetD32Async.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g58229da5d30f1c0cdf667b320ec2c0f5
static bool AsynchronousMemsetUint32(CUcontext context, CUdeviceptr location,
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index f4b31ad304..9757ef640d 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -533,7 +533,22 @@ bool CUDAExecutor::SynchronousMemcpyDeviceToDevice(
bool CUDAExecutor::MemZero(Stream *stream, DeviceMemoryBase *location,
uint64 size) {
- return Memset32(stream, location, 0x0, size);
+ if (reinterpret_cast<uintptr_t>(location->opaque()) % 4 == 0 &&
+ size % 4 == 0) {
+ return Memset32(stream, location, 0x0, size);
+ } else {
+ return Memset(stream, location, 0x0, size);
+ }
+}
+
+bool CUDAExecutor::Memset(Stream *stream, DeviceMemoryBase *location,
+ uint8 pattern, uint64 size) {
+ VLOG(2) << "enqueueing memset8 operation onto stream " << stream
+ << " at location " << location << " with size " << size
+ << " and pattern " << std::hex << pattern;
+ return CUDADriver::AsynchronousMemsetUint8(
+ context_, AsCudaDevicePtr(location), pattern, size,
+ AsCUDAStreamValue(stream));
}
bool CUDAExecutor::Memset32(Stream *stream, DeviceMemoryBase *location,
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 2a0c6dc456..ccbe6f26fd 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -120,6 +120,8 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
bool MemZero(Stream *stream, DeviceMemoryBase *location,
uint64 size) override;
+ bool Memset(Stream *stream, DeviceMemoryBase *location, uint8 pattern,
+ uint64 size) override;
bool Memset32(Stream *stream, DeviceMemoryBase *location, uint32 pattern,
uint64 size) override;
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index ca52949657..b800e03ae7 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -1231,7 +1231,6 @@ class Stream {
// Entrain onto the stream: a memset of zero at a GPU location of size
// bytes.
// The location must not be null.
- // TODO(leary) Presently the size must be a 4-byte multiple.
Stream &ThenMemZero(DeviceMemoryBase *location, uint64 size);
// Entrain onto the stream: a memset of a 32-bit pattern at a GPU location
diff --git a/tensorflow/stream_executor/stream_executor_internal.h b/tensorflow/stream_executor/stream_executor_internal.h
index dff756c8fc..8ff9532a55 100644
--- a/tensorflow/stream_executor/stream_executor_internal.h
+++ b/tensorflow/stream_executor/stream_executor_internal.h
@@ -209,6 +209,8 @@ class StreamExecutorInterface {
uint64 size) = 0;
virtual bool MemZero(Stream *stream, DeviceMemoryBase *location,
uint64 size) = 0;
+ virtual bool Memset(Stream *stream, DeviceMemoryBase *location,
+ uint8 pattern, uint64 size) = 0;
virtual bool Memset32(Stream *stream, DeviceMemoryBase *location,
uint32 pattern, uint64 size) = 0;
virtual bool Memcpy(Stream *stream, void *host_dst,