diff options
27 files changed, 973 insertions, 140 deletions
diff --git a/eigen.BUILD b/eigen.BUILD index 79bafe65b6..e32f3aab49 100644 --- a/eigen.BUILD +++ b/eigen.BUILD @@ -1,6 +1,6 @@ package(default_visibility = ["//visibility:public"]) -archive_dir = "eigen-eigen-d02e6a705c30" +archive_dir = "eigen-eigen-0c0b79ecd74c" cc_library( name = "eigen", diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake index db409760fa..d3075ab9d2 100644 --- a/tensorflow/contrib/cmake/external/eigen.cmake +++ b/tensorflow/contrib/cmake/external/eigen.cmake @@ -7,7 +7,7 @@ include (ExternalProject) -set(eigen_archive_hash "d02e6a705c30") +set(eigen_archive_hash "0c0b79ecd74c") set(eigen_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR} @@ -16,7 +16,7 @@ set(eigen_INCLUDE_DIRS ${tensorflow_source_dir}/third_party/eigen3 ) set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz) -set(eigen_HASH SHA256=532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf) +set(eigen_HASH SHA256=b4b5884b03bd4bae114d02b36e2435ad1504ed8e51431d16c876b6f6a365882b) set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen) set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install) diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc index 4378dd2fa4..d666546602 100644 --- a/tensorflow/core/kernels/avgpooling_op.cc +++ b/tensorflow/core/kernels/avgpooling_op.cc @@ -100,10 +100,12 @@ class AvgPoolingOp : public UnaryOp<T> { TensorFormat data_format_; }; -REGISTER_KERNEL_BUILDER(Name("AvgPool") - .Device(DEVICE_CPU) - .TypeConstraint<float>("T"), - AvgPoolingOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<float>("T"), + AvgPoolingOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"), + AvgPoolingOp<CPUDevice, Eigen::half>); #if GOOGLE_CUDA template <typename T> @@ -182,14 +184,17 @@ namespace functor { const Eigen::PaddingType& padding); \ extern template struct SpatialAvgPooling<GPUDevice, T>; +DECLARE_GPU_SPEC(Eigen::half); DECLARE_GPU_SPEC(float); #undef DECLARE_GPU_SPEC } // namespace functor -REGISTER_KERNEL_BUILDER(Name("AvgPool") - .Device(DEVICE_GPU) - .TypeConstraint<float>("T"), - AvgPoolingOp<GPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), + AvgPoolingOp<GPUDevice, Eigen::half>); +REGISTER_KERNEL_BUILDER( + Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<float>("T"), + AvgPoolingOp<GPUDevice, float>); #endif // GOOGLE_CUDA // The operation to compute AvgPool gradients. @@ -301,7 +306,7 @@ class AvgPoolingGradOp : public OpKernel { GetBroadcastSize(c, in_cols, window_cols, col_stride, pad_cols, &cindex, &csize)); - T divide_coeff = 1.0 / (rsize * csize); + T divide_coeff(1.0 / (rsize * csize)); int64 output_index = (b * out_backprop_rows + r) * out_backprop_cols + c; for (int64 r_dst = rindex; r_dst < rindex + rsize; ++r_dst) { @@ -347,6 +352,7 @@ class AvgPoolingGradOp : public OpKernel { TF_CALL_float(REGISTER_CPU_KERNEL); TF_CALL_double(REGISTER_CPU_KERNEL); +TF_CALL_half(REGISTER_CPU_KERNEL); #if GOOGLE_CUDA @@ -416,6 +422,12 @@ REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") .HostMemory("orig_input_shape") .Label("cudnn"), AvgPoolingGradOp<GPUDevice, float>); +REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") + .Device(DEVICE_GPU) + .TypeConstraint<Eigen::half>("T") + .HostMemory("orig_input_shape") + .Label("cudnn"), + AvgPoolingGradOp<GPUDevice, Eigen::half>); // A custom GPU kernel based AvgPoolingGrad implementation. It includes the // padding as the candidates for the pooling operation. @@ -532,6 +544,11 @@ REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") .TypeConstraint<float>("T") .HostMemory("orig_input_shape"), AvgPoolingGradOpCustomGPUKernel<float>); +REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad") + .Device(DEVICE_GPU) + .TypeConstraint<Eigen::half>("T") + .HostMemory("orig_input_shape"), + AvgPoolingGradOpCustomGPUKernel<Eigen::half>); #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc index 9e894b1734..a190b2168a 100644 --- a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc @@ -33,6 +33,7 @@ typedef Eigen::GpuDevice GPUDevice; #define DEFINE_GPU_KERNELS(T) \ template struct functor::SpatialAvgPooling<GPUDevice, T>; +DEFINE_GPU_KERNELS(Eigen::half) DEFINE_GPU_KERNELS(float) #undef DEFINE_GPU_KERNELS @@ -57,7 +58,7 @@ __global__ void AvePoolBackwardNHWC(const int nthreads, const int phend = min(h / stride_h + 1, pooled_height); const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; const int pwend = min(w / stride_w + 1, pooled_width); - dtype gradient = 0; + dtype gradient(0); const dtype* const top_diff_slice = top_diff + n * pooled_height * pooled_width * channels + c; for (int ph = phstart; ph < phend; ++ph) { @@ -104,6 +105,12 @@ template bool RunAvePoolBackwardNHWC( const int pooled_width, const int kernel_h, const int kernel_w, const int stride_h, const int stride_w, const int pad_t, const int pad_l, float* const bottom_diff, const GPUDevice& d); +template bool RunAvePoolBackwardNHWC( + const Eigen::half* const top_diff, const int num, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + Eigen::half* const bottom_diff, const GPUDevice& d); } // end namespace tensorflow diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h index 349cbf9d0e..aa3b274893 100644 --- a/tensorflow/core/kernels/eigen_pooling.h +++ b/tensorflow/core/kernels/eigen_pooling.h @@ -309,7 +309,7 @@ struct AvgPoolMeanReducer { EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE AvgPoolMeanReducer() : scalarCount_(0) { typedef typename packet_traits<T>::type Packet; - packetCount_ = pset1<Packet>(0.0); + packetCount_ = pset1<Packet>(T(0.0)); } EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) { diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc index 5e3f219699..f883acf3d6 100644 --- a/tensorflow/core/kernels/maxpooling_op.cc +++ b/tensorflow/core/kernels/maxpooling_op.cc @@ -160,7 +160,7 @@ static void SpatialMaxPoolWithArgMaxHelper( const int in_end = limit * in_size; EigenMatrixMap in_shard(input_backprop_flat.data() + in_start, 1, in_end - in_start); - in_shard.setConstant(0); + in_shard.setConstant(T(0)); // Backpropagate. const int out_size = out_height * out_width * depth; @@ -187,8 +187,12 @@ static void SpatialMaxPoolWithArgMaxHelper( params.tensor_in_batch, shard_cost, shard); } -REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_CPU), - MaxPoolingOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<float>("T"), + MaxPoolingOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"), + MaxPoolingOp<CPUDevice, Eigen::half>); #if GOOGLE_CUDA // Forward declarations for the functor specializations for GPU. @@ -212,6 +216,7 @@ DECLARE_GPU_SPEC(float); // kernel_label_map. REGISTER_KERNEL_BUILDER(Name("MaxPool") .Device(DEVICE_GPU) + .TypeConstraint<float>("T") .Label("eigen_tensor"), MaxPoolingOp<Eigen::GpuDevice, float>); #endif // GOOGLE_CUDA @@ -297,11 +302,16 @@ class MaxPoolingGradOp : public OpKernel { TensorFormat data_format_; }; -REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_CPU), - MaxPoolingGradOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"), + MaxPoolingGradOp<CPUDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"), + MaxPoolingGradOp<CPUDevice, Eigen::half>); #ifdef GOOGLE_CUDA +template <typename T> static void MaxPoolingBackwardCustomKernel( OpKernelContext* context, const std::vector<int32>& size, const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in, @@ -318,12 +328,12 @@ static void MaxPoolingBackwardCustomKernel( } MaxPoolBackwardNoMask( - tensor_in->flat<float>().data(), params.tensor_in_batch, + tensor_in->flat<T>().data(), params.tensor_in_batch, params.tensor_in_rows, params.tensor_in_cols, params.depth, params.out_height, params.out_width, params.window_rows, params.window_cols, params.row_stride, params.col_stride, params.pad_rows, - params.pad_cols, out_backprop.flat<float>().data(), - output->flat<float>().data(), context->eigen_device<Eigen::GpuDevice>()); + params.pad_cols, out_backprop.flat<T>().data(), + output->flat<T>().data(), context->eigen_device<Eigen::GpuDevice>()); } template <class T> @@ -378,8 +388,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel { } else { CHECK(data_format_ == FORMAT_NHWC) << "Non-Cudnn MaxPoolGrad only supports NHWC format"; - MaxPoolingBackwardCustomKernel(context, ksize_, stride_, padding_, - &tensor_in, out_backprop, output_shape); + MaxPoolingBackwardCustomKernel<T>(context, ksize_, stride_, padding_, + &tensor_in, out_backprop, output_shape); } } @@ -391,8 +401,12 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel { bool use_dnn_; }; -REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_GPU), - MaxPoolingGradOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"), + MaxPoolingGradOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), + MaxPoolingGradOp<Eigen::GpuDevice, Eigen::half>); #endif // GOOGLE_CUDA @@ -625,8 +639,12 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> { } }; -REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_GPU), - MaxPoolingNoMaskOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<float>("T"), + MaxPoolingNoMaskOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"), + MaxPoolingNoMaskOp<Eigen::GpuDevice, Eigen::half>); template <typename T> struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> { @@ -649,8 +667,14 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> { REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") .Device(DEVICE_GPU) - .TypeConstraint<int64>("Targmax"), + .TypeConstraint<int64>("Targmax") + .TypeConstraint<float>("T"), MaxPoolingWithArgmaxOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") + .Device(DEVICE_GPU) + .TypeConstraint<int64>("Targmax") + .TypeConstraint<Eigen::half>("T"), + MaxPoolingWithArgmaxOp<Eigen::GpuDevice, Eigen::half>); template <typename T> struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> { @@ -675,10 +699,18 @@ struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> { } }; -REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") - .Device(DEVICE_GPU) - .TypeConstraint<int64>("Targmax"), - MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPoolGradWithArgmax") + .Device(DEVICE_GPU) + .TypeConstraint<float>("T") + .TypeConstraint<int64>("Targmax"), + MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>); +REGISTER_KERNEL_BUILDER( + Name("MaxPoolGradWithArgmax") + .Device(DEVICE_GPU) + .TypeConstraint<Eigen::half>("T") + .TypeConstraint<int64>("Targmax"), + MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, Eigen::half>); #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc index 1bdca42f4e..91b50b1e11 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc +++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc @@ -110,7 +110,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, int wend = min(wstart + kernel_w, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - dtype maxval = -FLT_MAX; + dtype maxval = Eigen::NumTraits<dtype>::lowest(); int maxidx = -1; const dtype* bottom_data_n = bottom_data + n * height * width * channels; for (int h = hstart; h < hend; ++h) { @@ -149,7 +149,7 @@ __global__ void MaxPoolBackwardNoMaskNHWC( int wend = min(wstart + kernel_w, width); hstart = max(hstart, 0); wstart = max(wstart, 0); - dtype maxval = -FLT_MAX; + dtype maxval = Eigen::NumTraits<dtype>::lowest(); int maxidx = -1; const dtype* bottom_data_n = bottom_data + n * height * width * channels; for (int h = hstart; h < hend; ++h) { @@ -165,8 +165,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC( // Atomically accumulate the bottom diff. The index could still be // uninitialized, if all the bottom_data are NaN. if (maxidx != -1) { - atomicAdd(bottom_diff + n * height * width * channels + maxidx, - top_diff[index]); + CudaAtomicAdd(bottom_diff + n * height * width * channels + maxidx, + top_diff[index]); } } } @@ -185,8 +185,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC( // bottom_offset: the pre-computed per-image offset of the maxpool input. // This is equal to H*W*C. // bottom_diff: the gradient with respect to the input. -// This function relies on atomicAdd to avoid race conditions. Also, before the -// kernel is run, you will need to make sure that bottom_diff is filled with +// This function relies on CudaAtomicAdd to avoid race conditions. Also, before +// the kernel is run, you will need to make sure that bottom_diff is filled with // zero first. template <typename dtype> __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, @@ -194,8 +194,8 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff, const int bottom_offset, dtype* bottom_diff) { CUDA_1D_KERNEL_LOOP(index, nthreads) { int image_id = (index / top_offset); - atomicAdd(bottom_diff + image_id * bottom_offset + mask[index], - top_diff[index]); + CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index], + top_diff[index]); } } @@ -219,6 +219,23 @@ bool MaxPoolForwardWithOptionalArgmax( return d.ok(); } +bool MaxPoolForwardWithOptionalArgmax( + const Eigen::half* bottom_data, const int batch, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int output_size = batch * channels * pooled_height * pooled_width; + + MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_data, mask); + return d.ok(); +} + bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, const int height, const int width, const int channels, const int pooled_height, @@ -243,6 +260,30 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, return d.ok(); } +bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, + const Eigen::half* top_diff, Eigen::half* bottom_diff, + const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + const int bottom_size = batch * channels * height * width; + const int top_size = batch * channels * pooled_height * pooled_width; + + SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff); + + MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) / + kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + top_size, bottom_data, height, width, channels, pooled_height, + pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l, + top_diff, bottom_diff); + return d.ok(); +} + bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, const float* top_diff, const int64* mask, const int top_offset, const int bottom_offset, @@ -256,12 +297,27 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, return d.ok(); } +bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, + const Eigen::half* top_diff, const int64* mask, + const int top_offset, const int bottom_offset, + Eigen::half* bottom_diff, + const Eigen::GpuDevice& d) { + const int kThreadsPerBlock = 1024; + SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff); + MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock, + kThreadsPerBlock, 0, d.stream()>>>( + output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff); + return d.ok(); +} + typedef Eigen::GpuDevice GPUDevice; #define DEFINE_GPU_KERNELS(T) \ template struct functor::SpatialMaxPooling<GPUDevice, T>; DEFINE_GPU_KERNELS(float) +DEFINE_GPU_KERNELS(Eigen::half) #undef DEFINE_GPU_KERNELS diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h index 05e865f81c..d1c73a372e 100644 --- a/tensorflow/core/kernels/maxpooling_op_gpu.h +++ b/tensorflow/core/kernels/maxpooling_op_gpu.h @@ -37,11 +37,24 @@ bool MaxPoolForwardWithOptionalArgmax( const int stride_h, const int stride_w, const int pad_t, const int pad_l, float* top_data, int64* mask, const Eigen::GpuDevice& d); +bool MaxPoolForwardWithOptionalArgmax( + const Eigen::half* bottom_data, const int batch, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, const int pad_l, + Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d); + bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, const float* top_diff, const int64* mask, const int top_offset, const int bottom_offset, float* bottom_diff, const Eigen::GpuDevice& d); +bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size, + const Eigen::half* top_diff, const int64* mask, + const int top_offset, const int bottom_offset, + Eigen::half* bottom_diff, + const Eigen::GpuDevice& d); + bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, const int height, const int width, const int channels, const int pooled_height, @@ -51,6 +64,15 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch, const float* top_diff, float* bottom_diff, const Eigen::GpuDevice& d); +bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch, + const int height, const int width, + const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, + const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, + const Eigen::half* top_diff, Eigen::half* bottom_diff, + const Eigen::GpuDevice& d); + } // namespace tensorflow #endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_ diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index 3867cc824f..f5d7771af7 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -124,6 +124,7 @@ namespace functor { extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>; DECLARE_GPU_SPEC(float); +DECLARE_GPU_SPEC(Eigen::half); #undef DECLARE_GPU_SPEC } // namespace functor @@ -368,7 +369,9 @@ void DnnPoolingGradOp<T>::Compute( } } +template class DnnPoolingOp<Eigen::half>; template class DnnPoolingOp<float>; +template class DnnPoolingGradOp<Eigen::half>; template class DnnPoolingGradOp<float>; #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h index 138d1cb2ca..593c90b009 100644 --- a/tensorflow/core/kernels/pooling_ops_common.h +++ b/tensorflow/core/kernels/pooling_ops_common.h @@ -311,7 +311,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output, } } } - DCHECK_GT(out_count.minCoeff(), 0); + DCHECK_GT(out_count.minCoeff(), T(0)); out_mat.array().rowwise() /= out_count.transpose().array(); } diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt index ed60c227a5..3224a1c9af 100644 --- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt +++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt @@ -3012,6 +3012,63 @@ op { } } op { + name: "AvgPool" + input_arg { + name: "value" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + type: DT_DOUBLE + } + } + } +} +op { name: "AvgPool3D" input_arg { name: "input" @@ -3233,6 +3290,67 @@ op { } } op { + name: "AvgPoolGrad" + input_arg { + name: "orig_input_shape" + type: DT_INT32 + } + input_arg { + name: "grad" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + type: DT_DOUBLE + } + } + } +} +op { name: "BatchCholesky" input_arg { name: "input" @@ -11802,6 +11920,124 @@ op { } } op { + name: "MaxPool" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } +} +op { + name: "MaxPool" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } +} +op { name: "MaxPool3D" input_arg { name: "input" @@ -12015,6 +12251,73 @@ op { } } op { + name: "MaxPoolGrad" + input_arg { + name: "orig_input" + type_attr: "T" + } + input_arg { + name: "orig_output" + type_attr: "T" + } + input_arg { + name: "grad" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } +} +op { name: "MaxPoolGradWithArgmax" input_arg { name: "input" @@ -12066,6 +12369,70 @@ op { } } op { + name: "MaxPoolGradWithArgmax" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "grad" + type_attr: "T" + } + input_arg { + name: "argmax" + type_attr: "Targmax" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "Targmax" + type: "type" + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } +} +op { name: "MaxPoolWithArgmax" input_arg { name: "input" @@ -12116,6 +12483,69 @@ op { } } op { + name: "MaxPoolWithArgmax" + input_arg { + name: "input" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + output_arg { + name: "argmax" + type_attr: "Targmax" + } + attr { + name: "ksize" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "strides" + type: "list(int)" + has_minimum: true + minimum: 4 + } + attr { + name: "Targmax" + type: "type" + default_value { + type: DT_INT64 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + attr { + name: "padding" + type: "string" + allowed_values { + list { + s: "SAME" + s: "VALID" + } + } + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } +} +op { name: "Maximum" input_arg { name: "x" diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc index c1a42e74be..e3b876b240 100644 --- a/tensorflow/core/ops/nn_grad.cc +++ b/tensorflow/core/ops/nn_grad.cc @@ -154,22 +154,25 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off *g = FDH::Define( // Arg defs - {"input: float", "grad: float"}, + {"input: T", "grad: T"}, // Ret val defs - {"output: float"}, + {"output: T"}, // Attr defs - {"ksize: list(int) >= 4", + {"T: {float, half} = DT_FLOAT", + "ksize: list(int) >= 4", "strides: list(int) >= 4", GetPaddingAttrString()}, // Nodes { // Invoke MaxPool again to recompute the outputs (removed by CSE?). {{"maxpool"}, "MaxPool", {"input"}, - /*Attrs=*/{{"ksize", "$ksize"}, + /*Attrs=*/{{"T", "$T"}, + {"ksize", "$ksize"}, {"strides", "$strides"}, {"padding", "$padding"}}}, {{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"}, - /*Attrs=*/{{"ksize", "$ksize"}, + /*Attrs=*/{{"T", "$T"}, + {"ksize", "$ksize"}, {"strides", "$strides"}, {"padding", "$padding"}}} }); diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index fee145be53..b53945a4a0 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -28,7 +28,7 @@ REGISTER_OP("AvgPool") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {float, double}") + .Attr("T: {float, half, double}") .Doc(R"doc( Performs average pooling on the input. @@ -55,7 +55,7 @@ REGISTER_OP("AvgPoolGrad") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Attr("T: {float, double}") + .Attr("T: {float, half, double}") .Doc(R"doc( Computes gradients of the average pooling function. @@ -642,12 +642,13 @@ output: The gradients for LRN. // -------------------------------------------------------------------------- REGISTER_OP("MaxPool") + .Attr("T: {float, half} = DT_FLOAT") .Attr("ksize: list(int) >= 4") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Input("input: float") - .Output("output: float") + .Input("input: T") + .Output("output: T") .Doc(R"doc( Performs max pooling on the input. @@ -669,10 +670,11 @@ REGISTER_OP("MaxPoolGrad") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr(GetConvnetDataFormatAttrString()) - .Input("orig_input: float") - .Input("orig_output: float") - .Input("grad: float") - .Output("output: float") + .Input("orig_input: T") + .Input("orig_output: T") + .Input("grad: T") + .Output("output: T") + .Attr("T: {float, half} = DT_FLOAT") .Doc(R"doc( Computes gradients of the maxpooling function. @@ -696,9 +698,10 @@ REGISTER_OP("MaxPoolWithArgmax") .Attr("strides: list(int) >= 4") .Attr("Targmax: {int32, int64} = DT_INT64") .Attr(GetPaddingAttrString()) - .Input("input: float") - .Output("output: float") + .Input("input: T") + .Output("output: T") .Output("argmax: Targmax") + .Attr("T: {float, half} = DT_FLOAT") .Doc(R"doc( Performs max pooling on the input and outputs both max values and indices. @@ -720,10 +723,11 @@ REGISTER_OP("MaxPoolGradWithArgmax") .Attr("strides: list(int) >= 4") .Attr(GetPaddingAttrString()) .Attr("Targmax: {int32, int64}") - .Input("input: float") - .Input("grad: float") + .Input("input: T") + .Input("grad: T") .Input("argmax: Targmax") - .Output("output: float") + .Output("output: T") + .Attr("T: {float, half} = DT_FLOAT") .Doc(R"doc( Computes gradients of the maxpooling function. diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt index 5fb34e79d1..18624418cb 100644 --- a/tensorflow/core/ops/ops.pbtxt +++ b/tensorflow/core/ops/ops.pbtxt @@ -1251,6 +1251,7 @@ op { allowed_values { list { type: DT_FLOAT + type: DT_HALF type: DT_DOUBLE } } @@ -1447,6 +1448,7 @@ op { allowed_values { list { type: DT_FLOAT + type: DT_HALF type: DT_DOUBLE } } @@ -6614,12 +6616,25 @@ op { input_arg { name: "input" description: "4-D input to pool over." - type: DT_FLOAT + type_attr: "T" } output_arg { name: "output" description: "The max pooled output tensor." - type: DT_FLOAT + type_attr: "T" + } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } } attr { name: "ksize" @@ -6798,22 +6813,22 @@ op { input_arg { name: "orig_input" description: "The original input tensor." - type: DT_FLOAT + type_attr: "T" } input_arg { name: "orig_output" description: "The original output tensor." - type: DT_FLOAT + type_attr: "T" } input_arg { name: "grad" description: "4-D. Gradients w.r.t. the output of `max_pool`." - type: DT_FLOAT + type_attr: "T" } output_arg { name: "output" description: "Gradients w.r.t. the input to `max_pool`." - type: DT_FLOAT + type_attr: "T" } attr { name: "ksize" @@ -6854,6 +6869,19 @@ op { } } } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } summary: "Computes gradients of the maxpooling function." } op { @@ -6861,12 +6889,12 @@ op { input_arg { name: "input" description: "The original input." - type: DT_FLOAT + type_attr: "T" } input_arg { name: "grad" description: "4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the\noutput of `max_pool`." - type: DT_FLOAT + type_attr: "T" } input_arg { name: "argmax" @@ -6876,7 +6904,7 @@ op { output_arg { name: "output" description: "Gradients w.r.t. the input of `max_pool`." - type: DT_FLOAT + type_attr: "T" } attr { name: "ksize" @@ -6913,6 +6941,19 @@ op { } } } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } summary: "Computes gradients of the maxpooling function." } op { @@ -6920,12 +6961,12 @@ op { input_arg { name: "input" description: "4-D with shape `[batch, height, width, channels]`. Input to pool over." - type: DT_FLOAT + type_attr: "T" } output_arg { name: "output" description: "The max pooled output tensor." - type: DT_FLOAT + type_attr: "T" } output_arg { name: "argmax" @@ -6970,6 +7011,19 @@ op { } } } + attr { + name: "T" + type: "type" + default_value { + type: DT_FLOAT + } + allowed_values { + list { + type: DT_FLOAT + type: DT_HALF + } + } + } summary: "Performs max pooling on the input and outputs both max values and indices." description: "The indices in `argmax` are flattened, so that a maximum value at position\n`[b, y, x, c]` becomes flattened index\n`((b * height + y) * width + x) * channels + c`." } diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py index 333bfa17f9..011078036d 100644 --- a/tensorflow/python/kernel_tests/pooling_ops_test.py +++ b/tensorflow/python/kernel_tests/pooling_ops_test.py @@ -99,8 +99,8 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30): class PoolingTest(tf.test.TestCase): - def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, - data_format, expected, use_gpu): + def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding, + data_format, data_type, expected, use_gpu): """Verifies the output values of the pooling function. Args: @@ -111,6 +111,7 @@ class PoolingTest(tf.test.TestCase): strides: The stride dimensions padding: Padding type. data_format: The data format we use to run the pooling operation. + data_type: The data type to use to run the pooling operation. expected: An array containing the expected operation outputs. use_gpu: Whether we are running on GPU. """ @@ -121,7 +122,7 @@ class PoolingTest(tf.test.TestCase): # numbers from 1. x = [f * 1.0 for f in range(1, total_size + 1)] with self.test_session(use_gpu=use_gpu) as sess: - t = tf.constant(x, shape=input_sizes) + t = tf.constant(x, shape=input_sizes, dtype=data_type) if data_format == "NCHW": t = NHWCToNCHW(t) ksize = NHWCToNCHW(ksize) @@ -131,9 +132,31 @@ class PoolingTest(tf.test.TestCase): if data_format == "NCHW": t = NCHWToNHWC(t) actual = t.eval() - self.assertAllClose(expected, actual.flatten()) + self.assertAllCloseAccordingToType(expected, actual.flatten()) self.assertShapeEqual(actual, t) + def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, + data_format, expected, use_gpu): + """Verifies the output values of the pooling function. + + Args: + pool_func: Function to be called, co.MaxPool, co.AvgPool, + or the Lua version. + input_sizes: Input tensor dimensions. + ksize: The kernel size dimensions + strides: The stride dimensions + padding: Padding type. + data_format: The data format we use to run the pooling operation. + expected: An array containing the expected operation outputs. + use_gpu: Whether we are running on GPU. + """ + self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, + data_format, tf.float32, expected, use_gpu) + + if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv(): + self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding, + data_format, tf.float16, expected, use_gpu) + def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding, expected, use_gpu): """Verifies the output values of the pooling function. @@ -372,32 +395,40 @@ class PoolingTest(tf.test.TestCase): def testKernelSmallerThanStrideValid(self): for use_gpu in [True, False]: - self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 7, 7, 1], - ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1], - padding="VALID", - expected=[9, 12, 30, 33], - use_gpu=use_gpu) - - self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 7, 7, 1], - ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1], - padding="VALID", - expected=[5, 8, 26, 29], - use_gpu=use_gpu) + self._VerifyValues(tf.nn.max_pool, + input_sizes=[1, 7, 7, 1], + ksize=[1, 2, 2, 1], + strides=[1, 3, 3, 1], + padding="VALID", + expected=[9, 12, 30, 33], + use_gpu=use_gpu) + + self._VerifyValues(tf.nn.avg_pool, + input_sizes=[1, 7, 7, 1], + ksize=[1, 2, 2, 1], + strides=[1, 3, 3, 1], + padding="VALID", + expected=[5, 8, 26, 29], + use_gpu=use_gpu) def testKernelSmallerThanStrideSame(self): for use_gpu in [True, False]: - for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]: - self._VerifyValues(pool_func, input_sizes=[1, 3, 3, 1], - ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], - padding="SAME", - expected=[1, 3, 7, 9], - use_gpu=use_gpu) - - self._VerifyValues(pool_func, input_sizes=[1, 4, 4, 1], - ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1], - padding="SAME", - expected=[1, 3, 9, 11], - use_gpu=use_gpu) + for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]: + self._VerifyValues(pool_func, + input_sizes=[1, 3, 3, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[1, 3, 7, 9], + use_gpu=use_gpu) + + self._VerifyValues(pool_func, + input_sizes=[1, 4, 4, 1], + ksize=[1, 1, 1, 1], + strides=[1, 2, 2, 1], + padding="SAME", + expected=[1, 3, 9, 11], + use_gpu=use_gpu) def _testDepthwiseMaxPoolInvalidConfig(self, in_size, ksize, strides, error_msg, use_gpu=False): @@ -425,43 +456,49 @@ class PoolingTest(tf.test.TestCase): # The following are tests that verify that the CPU and GPU implementations # produce the same resuts. def _CompareMaxPoolingFwd(self, input_shape, ksize, strides, padding): - tensor_input = np.random.rand(*input_shape).astype(np.float32) - with self.test_session(use_gpu=True): - t = tf.constant(tensor_input, shape=input_shape) - out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding) - gpu_val = out_op.eval() - with self.test_session(use_gpu=False): - t = tf.constant(tensor_input, shape=input_shape) - out_op = tf.nn.max_pool(t, ksize, strides, padding) - cpu_val = out_op.eval() - self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5) + for dtype in np.float32, np.float16: + tensor_input = np.random.rand(*input_shape).astype(dtype) + with self.test_session(use_gpu=True): + t = tf.constant(tensor_input, shape=input_shape) + out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding) + gpu_val = out_op.eval() + with self.test_session(use_gpu=False): + t = tf.constant(tensor_input, shape=input_shape) + out_op = tf.nn.max_pool(t, ksize, strides, padding) + cpu_val = out_op.eval() + self.assertAllCloseAccordingToType(cpu_val, gpu_val) def _CompareMaxPoolingBk(self, input_shape, output_shape, ksize, strides, padding): - # Generate numbers in a narrow range, so that there are many duplicates - # in the input. - tensor_input = np.random.random_integers(0, 3, - input_shape).astype(np.float32) - tensor_output = np.random.rand(*output_shape).astype(np.float32) - with self.test_session(use_gpu=True): - t = tf.constant(tensor_input, shape=input_shape) - _, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding) - argmax = argmax_op.eval() - grad_in = tf.constant(tensor_output, shape=output_shape) - out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax, - ksize, strides, padding) - gpu_val = out_op.eval() - self.assertShapeEqual(gpu_val, out_op) - with self.test_session(use_gpu=False): - t = tf.constant(tensor_input, shape=input_shape) - out_op = tf.nn.max_pool(t, ksize, strides, padding) - orig_out = out_op.eval() - grad_in = tf.constant(tensor_output, shape=output_shape) - out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize, - strides, padding) - cpu_val = out_op.eval() - self.assertShapeEqual(cpu_val, out_op) - self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5) + for dtype in np.float32, np.float16: + # Generate numbers in a narrow range, so that there are many duplicates + # in the input. + tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype) + tensor_output = np.random.rand(*output_shape).astype(dtype) + with self.test_session(use_gpu=True): + t = tf.constant(tensor_input, shape=input_shape) + _, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding) + argmax = argmax_op.eval() + grad_in = tf.constant(tensor_output, shape=output_shape) + out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax, + ksize, strides, padding) + gpu_val = out_op.eval() + self.assertShapeEqual(gpu_val, out_op) + with self.test_session(use_gpu=False): + t = tf.constant(tensor_input, shape=input_shape) + out_op = tf.nn.max_pool(t, ksize, strides, padding) + orig_out = out_op.eval() + grad_in = tf.constant(tensor_output, shape=output_shape) + out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize, strides, + padding) + cpu_val = out_op.eval() + self.assertShapeEqual(cpu_val, out_op) + if dtype == np.float16: + # The CPU version accumulates its gradient on fp16, so it's less + # accurate than the GPU version that does the accumulation on fp32 + self.assertAllClose(cpu_val, gpu_val, rtol=0.01, atol=0.01) + else: + self.assertAllClose(cpu_val, gpu_val) def testMaxPoolingWithArgmax(self): # MaxPoolWithArgMax is implemented only on GPU. diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc index 23a8066e79..9d860e59a2 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.cc +++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc @@ -1876,6 +1876,40 @@ bool CudnnSupport::DoPoolForward( return true; } +bool CudnnSupport::DoPoolForward( + Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + const DeviceMemory<Eigen::half>& input_data, + const dnn::BatchDescriptor& output_dimensions, + DeviceMemory<Eigen::half>* output_data) { + mutex_lock lock{dnn_handle_mutex_}; + auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_), + AsCUDAStreamValue(stream)); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); + return false; + } + + // Alpha is the scaling factor for input. + float alpha = 1.0; + // Beta is the scaling factor for output. + float beta = 0.0; + + ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF}; + ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF}; + ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; + status = dynload::cudnnPoolingForward( + parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, + src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(), + output_data->opaque()); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "failed to enqueue forward pooling on stream: " + << ToString(status); + return false; + } + return true; +} + bool CudnnSupport::DoPoolBackward( Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, @@ -1914,6 +1948,43 @@ bool CudnnSupport::DoPoolBackward( return true; } +bool CudnnSupport::DoPoolBackward( + Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + const DeviceMemory<Eigen::half>& input_data, + const dnn::BatchDescriptor& output_dimensions, + const DeviceMemory<Eigen::half>& output_data, + const DeviceMemory<Eigen::half>& input_diff_data, + DeviceMemory<Eigen::half>* output_diff_data) { + mutex_lock lock{dnn_handle_mutex_}; + auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_), + AsCUDAStreamValue(stream)); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status); + return false; + } + + // Alpha is the scaling factor for input. + float alpha = 1.0; + // Beta is the scaling factor for output. + float beta = 0.0; + + ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF}; + ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF}; + ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions}; + status = dynload::cudnnPoolingBackward( + parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha, + dest_desc.handle(), output_data.opaque(), dest_desc.handle(), + input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta, + src_desc.handle(), output_diff_data->opaque()); + if (status != CUDNN_STATUS_SUCCESS) { + LOG(ERROR) << "failed to enqueue backward pooling on stream: " + << ToString(status); + return false; + } + return true; +} + bool CudnnSupport::DoNormalize( Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) { diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h index 523a0c6c5d..434ab730a7 100644 --- a/tensorflow/stream_executor/cuda/cuda_dnn.h +++ b/tensorflow/stream_executor/cuda/cuda_dnn.h @@ -201,6 +201,13 @@ class CudnnSupport : public dnn::DnnSupport { const dnn::BatchDescriptor& output_dimensions, DeviceMemory<float>* output_data) override; + bool DoPoolForward(Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + const DeviceMemory<Eigen::half>& input_data, + const dnn::BatchDescriptor& output_dimensions, + DeviceMemory<Eigen::half>* output_data) override; + bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, const dnn::BatchDescriptor& input_dimensions, @@ -210,6 +217,15 @@ class CudnnSupport : public dnn::DnnSupport { const DeviceMemory<float>& input_diff_data, DeviceMemory<float>* output_diff_data) override; + bool DoPoolBackward(Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + const DeviceMemory<Eigen::half>& input_data, + const dnn::BatchDescriptor& output_dimensions, + const DeviceMemory<Eigen::half>& output_data, + const DeviceMemory<Eigen::half>& input_diff_data, + DeviceMemory<Eigen::half>* output_diff_data) override; + bool DoNormalize(Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor, const DeviceMemory<float>& input_data, diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h index fbb44dc739..0ae482a73c 100644 --- a/tensorflow/stream_executor/dnn.h +++ b/tensorflow/stream_executor/dnn.h @@ -1011,6 +1011,13 @@ class DnnSupport { const dnn::BatchDescriptor& output_dimensions, DeviceMemory<float>* output_data) = 0; + virtual bool DoPoolForward(Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + const DeviceMemory<Eigen::half>& input_data, + const dnn::BatchDescriptor& output_dimensions, + DeviceMemory<Eigen::half>* output_data) = 0; + // Performs differentiation of the pooling operation. virtual bool DoPoolBackward(Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions, @@ -1021,6 +1028,15 @@ class DnnSupport { const DeviceMemory<float>& input_diff_data, DeviceMemory<float>* output_diff_data) = 0; + virtual bool DoPoolBackward(Stream* stream, + const dnn::PoolingDescriptor& pooling_dimensions, + const dnn::BatchDescriptor& input_dimensions, + const DeviceMemory<Eigen::half>& input_data, + const dnn::BatchDescriptor& output_dimensions, + const DeviceMemory<Eigen::half>& output_data, + const DeviceMemory<Eigen::half>& input_diff_data, + DeviceMemory<Eigen::half>* output_diff_data) = 0; + // Applies local response normalization to the values from // input_data and writes the result to output_data. See comments on // NormalizeDescriptor for a description of local response diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc index 446a3c9a7d..be823d9500 100644 --- a/tensorflow/stream_executor/stream.cc +++ b/tensorflow/stream_executor/stream.cc @@ -909,6 +909,30 @@ Stream &Stream::ThenPoolForward( return *this; } +Stream &Stream::ThenPoolForward( + const dnn::PoolingDescriptor &pooling_dimensions, + const dnn::BatchDescriptor &input_dimensions, + const DeviceMemory<Eigen::half> &input_data, + const dnn::BatchDescriptor &output_dimensions, + DeviceMemory<Eigen::half> *output_data) { + VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions, + input_data, output_dimensions, + output_data)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + Stream &Stream::ThenPoolBackward( const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, @@ -936,6 +960,33 @@ Stream &Stream::ThenPoolBackward( return *this; } +Stream &Stream::ThenPoolBackward( + const dnn::PoolingDescriptor &pooling_dimensions, + const dnn::BatchDescriptor &input_dimensions, + const DeviceMemory<Eigen::half> &input_data, + const dnn::BatchDescriptor &output_dimensions, + const DeviceMemory<Eigen::half> &output_data, + const DeviceMemory<Eigen::half> &input_diff_data, + DeviceMemory<Eigen::half> *output_diff_data) { + VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions), + PARAM(input_data), PARAM(output_dimensions), PARAM(output_data), + PARAM(input_diff_data), PARAM(output_diff_data)); + + if (ok()) { + if (dnn::DnnSupport *dnn = parent_->AsDnn()) { + CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions, + input_data, output_dimensions, output_data, + input_diff_data, output_diff_data)); + } else { + SetError(); + LOG(WARNING) + << "attempting to perform DNN operation using StreamExecutor " + "without DNN support"; + } + } + return *this; +} + Stream &Stream::ThenNormalize( const dnn::NormalizeDescriptor &normalize_descriptor, const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) { diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h index aac945c9e0..c131250de1 100644 --- a/tensorflow/stream_executor/stream.h +++ b/tensorflow/stream_executor/stream.h @@ -421,6 +421,12 @@ class Stream { const dnn::BatchDescriptor &output_dimensions, DeviceMemory<float> *output_data); + Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions, + const dnn::BatchDescriptor &input_dimensions, + const DeviceMemory<Eigen::half> &input_data, + const dnn::BatchDescriptor &output_dimensions, + DeviceMemory<Eigen::half> *output_data); + Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, const dnn::BatchDescriptor &input_dimensions, const DeviceMemory<float> &input_data, @@ -429,6 +435,14 @@ class Stream { const DeviceMemory<float> &input_diff_data, DeviceMemory<float> *output_diff_data); + Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions, + const dnn::BatchDescriptor &input_dimensions, + const DeviceMemory<Eigen::half> &input_data, + const dnn::BatchDescriptor &output_dimensions, + const DeviceMemory<Eigen::half> &output_data, + const DeviceMemory<Eigen::half> &input_diff_data, + DeviceMemory<Eigen::half> *output_diff_data); + Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor, const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data); diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 07f83651e0..d9cfb85fc3 100644 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -6,8 +6,8 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""): native.new_http_archive( name = "eigen_archive", - url = "https://bitbucket.org/eigen/eigen/get/d02e6a705c30.tar.gz", - sha256 = "532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf", + url = "https://bitbucket.org/eigen/eigen/get/0c0b79ecd74c.tar.gz", + sha256 = "b4b5884b03bd4bae114d02b36e2435ad1504ed8e51431d16c876b6f6a365882b", build_file = path_prefix + "eigen.BUILD", ) diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky index 56059bcc61..7415ae4d0d 100644 --- a/third_party/eigen3/Eigen/Cholesky +++ b/third_party/eigen3/Eigen/Cholesky @@ -1 +1 @@ -#include "eigen-eigen-d02e6a705c30/Eigen/Cholesky" +#include "eigen-eigen-0c0b79ecd74c/Eigen/Cholesky" diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core index c1d4a2e0f8..787e1c076e 100644 --- a/third_party/eigen3/Eigen/Core +++ b/third_party/eigen3/Eigen/Core @@ -1 +1 @@ -#include "eigen-eigen-d02e6a705c30/Eigen/Core" +#include "eigen-eigen-0c0b79ecd74c/Eigen/Core" diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues index 0a0731ba19..b6e1b81eb5 100644 --- a/third_party/eigen3/Eigen/Eigenvalues +++ b/third_party/eigen3/Eigen/Eigenvalues @@ -1 +1 @@ -#include "eigen-eigen-d02e6a705c30/Eigen/Eigenvalues" +#include "eigen-eigen-0c0b79ecd74c/Eigen/Eigenvalues" diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU index d6b39b8d23..a0782af040 100644 --- a/third_party/eigen3/Eigen/LU +++ b/third_party/eigen3/Eigen/LU @@ -1 +1 @@ -#include "eigen-eigen-d02e6a705c30/Eigen/LU" +#include "eigen-eigen-0c0b79ecd74c/Eigen/LU" diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR index a5406e93bc..0a9bee2898 100644 --- a/third_party/eigen3/Eigen/QR +++ b/third_party/eigen3/Eigen/QR @@ -1 +1 @@ -#include "eigen-eigen-d02e6a705c30/Eigen/QR" +#include "eigen-eigen-0c0b79ecd74c/Eigen/QR" diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor index 4f730236b7..5228bcda62 100644 --- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor +++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor @@ -1 +1 @@ -#include "eigen-eigen-d02e6a705c30/unsupported/Eigen/CXX11/Tensor" +#include "eigen-eigen-0c0b79ecd74c/unsupported/Eigen/CXX11/Tensor" |