aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--eigen.BUILD2
-rw-r--r--tensorflow/contrib/cmake/external/eigen.cmake4
-rw-r--r--tensorflow/core/kernels/avgpooling_op.cc35
-rw-r--r--tensorflow/core/kernels/avgpooling_op_gpu.cu.cc9
-rw-r--r--tensorflow/core/kernels/eigen_pooling.h2
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc70
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.cu.cc72
-rw-r--r--tensorflow/core/kernels/maxpooling_op_gpu.h22
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc3
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.h2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v0.pbtxt430
-rw-r--r--tensorflow/core/ops/nn_grad.cc13
-rw-r--r--tensorflow/core/ops/nn_ops.cc30
-rw-r--r--tensorflow/core/ops/ops.pbtxt76
-rw-r--r--tensorflow/python/kernel_tests/pooling_ops_test.py159
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc71
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h16
-rw-r--r--tensorflow/stream_executor/dnn.h16
-rw-r--r--tensorflow/stream_executor/stream.cc51
-rw-r--r--tensorflow/stream_executor/stream.h14
-rw-r--r--tensorflow/workspace.bzl4
-rw-r--r--third_party/eigen3/Eigen/Cholesky2
-rw-r--r--third_party/eigen3/Eigen/Core2
-rw-r--r--third_party/eigen3/Eigen/Eigenvalues2
-rw-r--r--third_party/eigen3/Eigen/LU2
-rw-r--r--third_party/eigen3/Eigen/QR2
-rw-r--r--third_party/eigen3/unsupported/Eigen/CXX11/Tensor2
27 files changed, 973 insertions, 140 deletions
diff --git a/eigen.BUILD b/eigen.BUILD
index 79bafe65b6..e32f3aab49 100644
--- a/eigen.BUILD
+++ b/eigen.BUILD
@@ -1,6 +1,6 @@
package(default_visibility = ["//visibility:public"])
-archive_dir = "eigen-eigen-d02e6a705c30"
+archive_dir = "eigen-eigen-0c0b79ecd74c"
cc_library(
name = "eigen",
diff --git a/tensorflow/contrib/cmake/external/eigen.cmake b/tensorflow/contrib/cmake/external/eigen.cmake
index db409760fa..d3075ab9d2 100644
--- a/tensorflow/contrib/cmake/external/eigen.cmake
+++ b/tensorflow/contrib/cmake/external/eigen.cmake
@@ -7,7 +7,7 @@
include (ExternalProject)
-set(eigen_archive_hash "d02e6a705c30")
+set(eigen_archive_hash "0c0b79ecd74c")
set(eigen_INCLUDE_DIRS
${CMAKE_CURRENT_BINARY_DIR}
@@ -16,7 +16,7 @@ set(eigen_INCLUDE_DIRS
${tensorflow_source_dir}/third_party/eigen3
)
set(eigen_URL https://bitbucket.org/eigen/eigen/get/${eigen_archive_hash}.tar.gz)
-set(eigen_HASH SHA256=532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf)
+set(eigen_HASH SHA256=b4b5884b03bd4bae114d02b36e2435ad1504ed8e51431d16c876b6f6a365882b)
set(eigen_BUILD ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen)
set(eigen_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/eigen/install)
diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc
index 4378dd2fa4..d666546602 100644
--- a/tensorflow/core/kernels/avgpooling_op.cc
+++ b/tensorflow/core/kernels/avgpooling_op.cc
@@ -100,10 +100,12 @@ class AvgPoolingOp : public UnaryOp<T> {
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(Name("AvgPool")
- .Device(DEVICE_CPU)
- .TypeConstraint<float>("T"),
- AvgPoolingOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ AvgPoolingOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("AvgPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
+ AvgPoolingOp<CPUDevice, Eigen::half>);
#if GOOGLE_CUDA
template <typename T>
@@ -182,14 +184,17 @@ namespace functor {
const Eigen::PaddingType& padding); \
extern template struct SpatialAvgPooling<GPUDevice, T>;
+DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
#undef DECLARE_GPU_SPEC
} // namespace functor
-REGISTER_KERNEL_BUILDER(Name("AvgPool")
- .Device(DEVICE_GPU)
- .TypeConstraint<float>("T"),
- AvgPoolingOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
+ AvgPoolingOp<GPUDevice, Eigen::half>);
+REGISTER_KERNEL_BUILDER(
+ Name("AvgPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ AvgPoolingOp<GPUDevice, float>);
#endif // GOOGLE_CUDA
// The operation to compute AvgPool gradients.
@@ -301,7 +306,7 @@ class AvgPoolingGradOp : public OpKernel {
GetBroadcastSize(c, in_cols, window_cols, col_stride,
pad_cols, &cindex, &csize));
- T divide_coeff = 1.0 / (rsize * csize);
+ T divide_coeff(1.0 / (rsize * csize));
int64 output_index =
(b * out_backprop_rows + r) * out_backprop_cols + c;
for (int64 r_dst = rindex; r_dst < rindex + rsize; ++r_dst) {
@@ -347,6 +352,7 @@ class AvgPoolingGradOp : public OpKernel {
TF_CALL_float(REGISTER_CPU_KERNEL);
TF_CALL_double(REGISTER_CPU_KERNEL);
+TF_CALL_half(REGISTER_CPU_KERNEL);
#if GOOGLE_CUDA
@@ -416,6 +422,12 @@ REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.HostMemory("orig_input_shape")
.Label("cudnn"),
AvgPoolingGradOp<GPUDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<Eigen::half>("T")
+ .HostMemory("orig_input_shape")
+ .Label("cudnn"),
+ AvgPoolingGradOp<GPUDevice, Eigen::half>);
// A custom GPU kernel based AvgPoolingGrad implementation. It includes the
// padding as the candidates for the pooling operation.
@@ -532,6 +544,11 @@ REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
.TypeConstraint<float>("T")
.HostMemory("orig_input_shape"),
AvgPoolingGradOpCustomGPUKernel<float>);
+REGISTER_KERNEL_BUILDER(Name("AvgPoolGrad")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<Eigen::half>("T")
+ .HostMemory("orig_input_shape"),
+ AvgPoolingGradOpCustomGPUKernel<Eigen::half>);
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
index 9e894b1734..a190b2168a 100644
--- a/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/avgpooling_op_gpu.cu.cc
@@ -33,6 +33,7 @@ typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::SpatialAvgPooling<GPUDevice, T>;
+DEFINE_GPU_KERNELS(Eigen::half)
DEFINE_GPU_KERNELS(float)
#undef DEFINE_GPU_KERNELS
@@ -57,7 +58,7 @@ __global__ void AvePoolBackwardNHWC(const int nthreads,
const int phend = min(h / stride_h + 1, pooled_height);
const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
const int pwend = min(w / stride_w + 1, pooled_width);
- dtype gradient = 0;
+ dtype gradient(0);
const dtype* const top_diff_slice =
top_diff + n * pooled_height * pooled_width * channels + c;
for (int ph = phstart; ph < phend; ++ph) {
@@ -104,6 +105,12 @@ template bool RunAvePoolBackwardNHWC(
const int pooled_width, const int kernel_h, const int kernel_w,
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
float* const bottom_diff, const GPUDevice& d);
+template bool RunAvePoolBackwardNHWC(
+ const Eigen::half* const top_diff, const int num, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ Eigen::half* const bottom_diff, const GPUDevice& d);
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/eigen_pooling.h b/tensorflow/core/kernels/eigen_pooling.h
index 349cbf9d0e..aa3b274893 100644
--- a/tensorflow/core/kernels/eigen_pooling.h
+++ b/tensorflow/core/kernels/eigen_pooling.h
@@ -309,7 +309,7 @@ struct AvgPoolMeanReducer {
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE AvgPoolMeanReducer() : scalarCount_(0) {
typedef typename packet_traits<T>::type Packet;
- packetCount_ = pset1<Packet>(0.0);
+ packetCount_ = pset1<Packet>(T(0.0));
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) {
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index 5e3f219699..f883acf3d6 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -160,7 +160,7 @@ static void SpatialMaxPoolWithArgMaxHelper(
const int in_end = limit * in_size;
EigenMatrixMap in_shard(input_backprop_flat.data() + in_start, 1,
in_end - in_start);
- in_shard.setConstant(0);
+ in_shard.setConstant(T(0));
// Backpropagate.
const int out_size = out_height * out_width * depth;
@@ -187,8 +187,12 @@ static void SpatialMaxPoolWithArgMaxHelper(
params.tensor_in_batch, shard_cost, shard);
}
-REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_CPU),
- MaxPoolingOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ MaxPoolingOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
+ MaxPoolingOp<CPUDevice, Eigen::half>);
#if GOOGLE_CUDA
// Forward declarations for the functor specializations for GPU.
@@ -212,6 +216,7 @@ DECLARE_GPU_SPEC(float);
// kernel_label_map.
REGISTER_KERNEL_BUILDER(Name("MaxPool")
.Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
.Label("eigen_tensor"),
MaxPoolingOp<Eigen::GpuDevice, float>);
#endif // GOOGLE_CUDA
@@ -297,11 +302,16 @@ class MaxPoolingGradOp : public OpKernel {
TensorFormat data_format_;
};
-REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_CPU),
- MaxPoolingGradOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<float>("T"),
+ MaxPoolingGradOp<CPUDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPoolGrad").Device(DEVICE_CPU).TypeConstraint<Eigen::half>("T"),
+ MaxPoolingGradOp<CPUDevice, Eigen::half>);
#ifdef GOOGLE_CUDA
+template <typename T>
static void MaxPoolingBackwardCustomKernel(
OpKernelContext* context, const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding, const Tensor* tensor_in,
@@ -318,12 +328,12 @@ static void MaxPoolingBackwardCustomKernel(
}
MaxPoolBackwardNoMask(
- tensor_in->flat<float>().data(), params.tensor_in_batch,
+ tensor_in->flat<T>().data(), params.tensor_in_batch,
params.tensor_in_rows, params.tensor_in_cols, params.depth,
params.out_height, params.out_width, params.window_rows,
params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
- params.pad_cols, out_backprop.flat<float>().data(),
- output->flat<float>().data(), context->eigen_device<Eigen::GpuDevice>());
+ params.pad_cols, out_backprop.flat<T>().data(),
+ output->flat<T>().data(), context->eigen_device<Eigen::GpuDevice>());
}
template <class T>
@@ -378,8 +388,8 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
} else {
CHECK(data_format_ == FORMAT_NHWC)
<< "Non-Cudnn MaxPoolGrad only supports NHWC format";
- MaxPoolingBackwardCustomKernel(context, ksize_, stride_, padding_,
- &tensor_in, out_backprop, output_shape);
+ MaxPoolingBackwardCustomKernel<T>(context, ksize_, stride_, padding_,
+ &tensor_in, out_backprop, output_shape);
}
}
@@ -391,8 +401,12 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
bool use_dnn_;
};
-REGISTER_KERNEL_BUILDER(Name("MaxPoolGrad").Device(DEVICE_GPU),
- MaxPoolingGradOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ MaxPoolingGradOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPoolGrad").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
+ MaxPoolingGradOp<Eigen::GpuDevice, Eigen::half>);
#endif // GOOGLE_CUDA
@@ -625,8 +639,12 @@ struct LaunchMaxPoolingNoMask<Eigen::GpuDevice, T> {
}
};
-REGISTER_KERNEL_BUILDER(Name("MaxPool").Device(DEVICE_GPU),
- MaxPoolingNoMaskOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<float>("T"),
+ MaxPoolingNoMaskOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<Eigen::half>("T"),
+ MaxPoolingNoMaskOp<Eigen::GpuDevice, Eigen::half>);
template <typename T>
struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
@@ -649,8 +667,14 @@ struct LaunchMaxPoolingWithArgmax<Eigen::GpuDevice, T> {
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
.Device(DEVICE_GPU)
- .TypeConstraint<int64>("Targmax"),
+ .TypeConstraint<int64>("Targmax")
+ .TypeConstraint<float>("T"),
MaxPoolingWithArgmaxOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<int64>("Targmax")
+ .TypeConstraint<Eigen::half>("T"),
+ MaxPoolingWithArgmaxOp<Eigen::GpuDevice, Eigen::half>);
template <typename T>
struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
@@ -675,10 +699,18 @@ struct LaunchMaxPoolingGradWithArgmax<Eigen::GpuDevice, T> {
}
};
-REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax")
- .Device(DEVICE_GPU)
- .TypeConstraint<int64>("Targmax"),
- MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPoolGradWithArgmax")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<float>("T")
+ .TypeConstraint<int64>("Targmax"),
+ MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, float>);
+REGISTER_KERNEL_BUILDER(
+ Name("MaxPoolGradWithArgmax")
+ .Device(DEVICE_GPU)
+ .TypeConstraint<Eigen::half>("T")
+ .TypeConstraint<int64>("Targmax"),
+ MaxPoolingGradWithArgmaxOp<Eigen::GpuDevice, Eigen::half>);
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
index 1bdca42f4e..91b50b1e11 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.cu.cc
@@ -110,7 +110,7 @@ __global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data,
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
- dtype maxval = -FLT_MAX;
+ dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
for (int h = hstart; h < hend; ++h) {
@@ -149,7 +149,7 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
int wend = min(wstart + kernel_w, width);
hstart = max(hstart, 0);
wstart = max(wstart, 0);
- dtype maxval = -FLT_MAX;
+ dtype maxval = Eigen::NumTraits<dtype>::lowest();
int maxidx = -1;
const dtype* bottom_data_n = bottom_data + n * height * width * channels;
for (int h = hstart; h < hend; ++h) {
@@ -165,8 +165,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// Atomically accumulate the bottom diff. The index could still be
// uninitialized, if all the bottom_data are NaN.
if (maxidx != -1) {
- atomicAdd(bottom_diff + n * height * width * channels + maxidx,
- top_diff[index]);
+ CudaAtomicAdd(bottom_diff + n * height * width * channels + maxidx,
+ top_diff[index]);
}
}
}
@@ -185,8 +185,8 @@ __global__ void MaxPoolBackwardNoMaskNHWC(
// bottom_offset: the pre-computed per-image offset of the maxpool input.
// This is equal to H*W*C.
// bottom_diff: the gradient with respect to the input.
-// This function relies on atomicAdd to avoid race conditions. Also, before the
-// kernel is run, you will need to make sure that bottom_diff is filled with
+// This function relies on CudaAtomicAdd to avoid race conditions. Also, before
+// the kernel is run, you will need to make sure that bottom_diff is filled with
// zero first.
template <typename dtype>
__global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
@@ -194,8 +194,8 @@ __global__ void MaxPoolBackward(const int nthreads, const dtype* top_diff,
const int bottom_offset, dtype* bottom_diff) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int image_id = (index / top_offset);
- atomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
- top_diff[index]);
+ CudaAtomicAdd(bottom_diff + image_id * bottom_offset + mask[index],
+ top_diff[index]);
}
}
@@ -219,6 +219,23 @@ bool MaxPoolForwardWithOptionalArgmax(
return d.ok();
}
+bool MaxPoolForwardWithOptionalArgmax(
+ const Eigen::half* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ const int output_size = batch * channels * pooled_height * pooled_width;
+
+ MaxPoolForwardNHWC<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ output_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_data, mask);
+ return d.ok();
+}
+
bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
@@ -243,6 +260,30 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
return d.ok();
}
+bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l,
+ const Eigen::half* top_diff, Eigen::half* bottom_diff,
+ const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ const int bottom_size = batch * channels * height * width;
+ const int top_size = batch * channels * pooled_height * pooled_width;
+
+ SetZero<<<(bottom_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(bottom_size, bottom_diff);
+
+ MaxPoolBackwardNoMaskNHWC<<<(top_size + kThreadsPerBlock - 1) /
+ kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ top_size, bottom_data, height, width, channels, pooled_height,
+ pooled_width, kernel_h, kernel_w, stride_h, stride_w, pad_t, pad_l,
+ top_diff, bottom_diff);
+ return d.ok();
+}
+
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const float* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
@@ -256,12 +297,27 @@ bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
return d.ok();
}
+bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
+ const Eigen::half* top_diff, const int64* mask,
+ const int top_offset, const int bottom_offset,
+ Eigen::half* bottom_diff,
+ const Eigen::GpuDevice& d) {
+ const int kThreadsPerBlock = 1024;
+ SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(input_size, bottom_diff);
+ MaxPoolBackward<<<(output_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
+ kThreadsPerBlock, 0, d.stream()>>>(
+ output_size, top_diff, mask, top_offset, bottom_offset, bottom_diff);
+ return d.ok();
+}
+
typedef Eigen::GpuDevice GPUDevice;
#define DEFINE_GPU_KERNELS(T) \
template struct functor::SpatialMaxPooling<GPUDevice, T>;
DEFINE_GPU_KERNELS(float)
+DEFINE_GPU_KERNELS(Eigen::half)
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/maxpooling_op_gpu.h b/tensorflow/core/kernels/maxpooling_op_gpu.h
index 05e865f81c..d1c73a372e 100644
--- a/tensorflow/core/kernels/maxpooling_op_gpu.h
+++ b/tensorflow/core/kernels/maxpooling_op_gpu.h
@@ -37,11 +37,24 @@ bool MaxPoolForwardWithOptionalArgmax(
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
float* top_data, int64* mask, const Eigen::GpuDevice& d);
+bool MaxPoolForwardWithOptionalArgmax(
+ const Eigen::half* bottom_data, const int batch, const int height,
+ const int width, const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h, const int kernel_w,
+ const int stride_h, const int stride_w, const int pad_t, const int pad_l,
+ Eigen::half* top_data, int64* mask, const Eigen::GpuDevice& d);
+
bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
const float* top_diff, const int64* mask,
const int top_offset, const int bottom_offset,
float* bottom_diff, const Eigen::GpuDevice& d);
+bool MaxPoolBackwardWithArgmax(const int output_size, const int input_size,
+ const Eigen::half* top_diff, const int64* mask,
+ const int top_offset, const int bottom_offset,
+ Eigen::half* bottom_diff,
+ const Eigen::GpuDevice& d);
+
bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const int height, const int width,
const int channels, const int pooled_height,
@@ -51,6 +64,15 @@ bool MaxPoolBackwardNoMask(const float* bottom_data, const int batch,
const float* top_diff, float* bottom_diff,
const Eigen::GpuDevice& d);
+bool MaxPoolBackwardNoMask(const Eigen::half* bottom_data, const int batch,
+ const int height, const int width,
+ const int channels, const int pooled_height,
+ const int pooled_width, const int kernel_h,
+ const int kernel_w, const int stride_h,
+ const int stride_w, const int pad_t, const int pad_l,
+ const Eigen::half* top_diff, Eigen::half* bottom_diff,
+ const Eigen::GpuDevice& d);
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MAXPOOLING_OP_GPU_H_
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index 3867cc824f..f5d7771af7 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -124,6 +124,7 @@ namespace functor {
extern template struct TransformDepth<GPUDevice, T, Eigen::DenseIndex>;
DECLARE_GPU_SPEC(float);
+DECLARE_GPU_SPEC(Eigen::half);
#undef DECLARE_GPU_SPEC
} // namespace functor
@@ -368,7 +369,9 @@ void DnnPoolingGradOp<T>::Compute(
}
}
+template class DnnPoolingOp<Eigen::half>;
template class DnnPoolingOp<float>;
+template class DnnPoolingGradOp<Eigen::half>;
template class DnnPoolingGradOp<float>;
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/pooling_ops_common.h b/tensorflow/core/kernels/pooling_ops_common.h
index 138d1cb2ca..593c90b009 100644
--- a/tensorflow/core/kernels/pooling_ops_common.h
+++ b/tensorflow/core/kernels/pooling_ops_common.h
@@ -311,7 +311,7 @@ void SpatialAvgPool(OpKernelContext* context, Tensor* output,
}
}
}
- DCHECK_GT(out_count.minCoeff(), 0);
+ DCHECK_GT(out_count.minCoeff(), T(0));
out_mat.array().rowwise() /= out_count.transpose().array();
}
diff --git a/tensorflow/core/ops/compat/ops_history.v0.pbtxt b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
index ed60c227a5..3224a1c9af 100644
--- a/tensorflow/core/ops/compat/ops_history.v0.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v0.pbtxt
@@ -3012,6 +3012,63 @@ op {
}
}
op {
+ name: "AvgPool"
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "AvgPool3D"
input_arg {
name: "input"
@@ -3233,6 +3290,67 @@ op {
}
}
op {
+ name: "AvgPoolGrad"
+ input_arg {
+ name: "orig_input_shape"
+ type: DT_INT32
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ type: DT_DOUBLE
+ }
+ }
+ }
+}
+op {
name: "BatchCholesky"
input_arg {
name: "input"
@@ -11802,6 +11920,124 @@ op {
}
}
op {
+ name: "MaxPool"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
+}
+op {
+ name: "MaxPool"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+}
+op {
name: "MaxPool3D"
input_arg {
name: "input"
@@ -12015,6 +12251,73 @@ op {
}
}
op {
+ name: "MaxPoolGrad"
+ input_arg {
+ name: "orig_input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "orig_output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
+}
+op {
name: "MaxPoolGradWithArgmax"
input_arg {
name: "input"
@@ -12066,6 +12369,70 @@ op {
}
}
op {
+ name: "MaxPoolGradWithArgmax"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "argmax"
+ type_attr: "Targmax"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "Targmax"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
+}
+op {
name: "MaxPoolWithArgmax"
input_arg {
name: "input"
@@ -12116,6 +12483,69 @@ op {
}
}
op {
+ name: "MaxPoolWithArgmax"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "argmax"
+ type_attr: "Targmax"
+ }
+ attr {
+ name: "ksize"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "strides"
+ type: "list(int)"
+ has_minimum: true
+ minimum: 4
+ }
+ attr {
+ name: "Targmax"
+ type: "type"
+ default_value {
+ type: DT_INT64
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "padding"
+ type: "string"
+ allowed_values {
+ list {
+ s: "SAME"
+ s: "VALID"
+ }
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
+}
+op {
name: "Maximum"
input_arg {
name: "x"
diff --git a/tensorflow/core/ops/nn_grad.cc b/tensorflow/core/ops/nn_grad.cc
index c1a42e74be..e3b876b240 100644
--- a/tensorflow/core/ops/nn_grad.cc
+++ b/tensorflow/core/ops/nn_grad.cc
@@ -154,22 +154,25 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
*g = FDH::Define(
// Arg defs
- {"input: float", "grad: float"},
+ {"input: T", "grad: T"},
// Ret val defs
- {"output: float"},
+ {"output: T"},
// Attr defs
- {"ksize: list(int) >= 4",
+ {"T: {float, half} = DT_FLOAT",
+ "ksize: list(int) >= 4",
"strides: list(int) >= 4",
GetPaddingAttrString()},
// Nodes
{
// Invoke MaxPool again to recompute the outputs (removed by CSE?).
{{"maxpool"}, "MaxPool", {"input"},
- /*Attrs=*/{{"ksize", "$ksize"},
+ /*Attrs=*/{{"T", "$T"},
+ {"ksize", "$ksize"},
{"strides", "$strides"},
{"padding", "$padding"}}},
{{"output"}, "MaxPoolGrad", {"input", "maxpool", "grad"},
- /*Attrs=*/{{"ksize", "$ksize"},
+ /*Attrs=*/{{"T", "$T"},
+ {"ksize", "$ksize"},
{"strides", "$strides"},
{"padding", "$padding"}}}
});
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index fee145be53..b53945a4a0 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -28,7 +28,7 @@ REGISTER_OP("AvgPool")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {float, double}")
+ .Attr("T: {float, half, double}")
.Doc(R"doc(
Performs average pooling on the input.
@@ -55,7 +55,7 @@ REGISTER_OP("AvgPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {float, double}")
+ .Attr("T: {float, half, double}")
.Doc(R"doc(
Computes gradients of the average pooling function.
@@ -642,12 +642,13 @@ output: The gradients for LRN.
// --------------------------------------------------------------------------
REGISTER_OP("MaxPool")
+ .Attr("T: {float, half} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Input("input: float")
- .Output("output: float")
+ .Input("input: T")
+ .Output("output: T")
.Doc(R"doc(
Performs max pooling on the input.
@@ -669,10 +670,11 @@ REGISTER_OP("MaxPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Input("orig_input: float")
- .Input("orig_output: float")
- .Input("grad: float")
- .Output("output: float")
+ .Input("orig_input: T")
+ .Input("orig_output: T")
+ .Input("grad: T")
+ .Output("output: T")
+ .Attr("T: {float, half} = DT_FLOAT")
.Doc(R"doc(
Computes gradients of the maxpooling function.
@@ -696,9 +698,10 @@ REGISTER_OP("MaxPoolWithArgmax")
.Attr("strides: list(int) >= 4")
.Attr("Targmax: {int32, int64} = DT_INT64")
.Attr(GetPaddingAttrString())
- .Input("input: float")
- .Output("output: float")
+ .Input("input: T")
+ .Output("output: T")
.Output("argmax: Targmax")
+ .Attr("T: {float, half} = DT_FLOAT")
.Doc(R"doc(
Performs max pooling on the input and outputs both max values and indices.
@@ -720,10 +723,11 @@ REGISTER_OP("MaxPoolGradWithArgmax")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr("Targmax: {int32, int64}")
- .Input("input: float")
- .Input("grad: float")
+ .Input("input: T")
+ .Input("grad: T")
.Input("argmax: Targmax")
- .Output("output: float")
+ .Output("output: T")
+ .Attr("T: {float, half} = DT_FLOAT")
.Doc(R"doc(
Computes gradients of the maxpooling function.
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 5fb34e79d1..18624418cb 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -1251,6 +1251,7 @@ op {
allowed_values {
list {
type: DT_FLOAT
+ type: DT_HALF
type: DT_DOUBLE
}
}
@@ -1447,6 +1448,7 @@ op {
allowed_values {
list {
type: DT_FLOAT
+ type: DT_HALF
type: DT_DOUBLE
}
}
@@ -6614,12 +6616,25 @@ op {
input_arg {
name: "input"
description: "4-D input to pool over."
- type: DT_FLOAT
+ type_attr: "T"
}
output_arg {
name: "output"
description: "The max pooled output tensor."
- type: DT_FLOAT
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
}
attr {
name: "ksize"
@@ -6798,22 +6813,22 @@ op {
input_arg {
name: "orig_input"
description: "The original input tensor."
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "orig_output"
description: "The original output tensor."
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "grad"
description: "4-D. Gradients w.r.t. the output of `max_pool`."
- type: DT_FLOAT
+ type_attr: "T"
}
output_arg {
name: "output"
description: "Gradients w.r.t. the input to `max_pool`."
- type: DT_FLOAT
+ type_attr: "T"
}
attr {
name: "ksize"
@@ -6854,6 +6869,19 @@ op {
}
}
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
summary: "Computes gradients of the maxpooling function."
}
op {
@@ -6861,12 +6889,12 @@ op {
input_arg {
name: "input"
description: "The original input."
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "grad"
description: "4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the\noutput of `max_pool`."
- type: DT_FLOAT
+ type_attr: "T"
}
input_arg {
name: "argmax"
@@ -6876,7 +6904,7 @@ op {
output_arg {
name: "output"
description: "Gradients w.r.t. the input of `max_pool`."
- type: DT_FLOAT
+ type_attr: "T"
}
attr {
name: "ksize"
@@ -6913,6 +6941,19 @@ op {
}
}
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
summary: "Computes gradients of the maxpooling function."
}
op {
@@ -6920,12 +6961,12 @@ op {
input_arg {
name: "input"
description: "4-D with shape `[batch, height, width, channels]`. Input to pool over."
- type: DT_FLOAT
+ type_attr: "T"
}
output_arg {
name: "output"
description: "The max pooled output tensor."
- type: DT_FLOAT
+ type_attr: "T"
}
output_arg {
name: "argmax"
@@ -6970,6 +7011,19 @@ op {
}
}
}
+ attr {
+ name: "T"
+ type: "type"
+ default_value {
+ type: DT_FLOAT
+ }
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_HALF
+ }
+ }
+ }
summary: "Performs max pooling on the input and outputs both max values and indices."
description: "The indices in `argmax` are flattened, so that a maximum value at position\n`[b, y, x, c]` becomes flattened index\n`((b * height + y) * width + x) * channels + c`."
}
diff --git a/tensorflow/python/kernel_tests/pooling_ops_test.py b/tensorflow/python/kernel_tests/pooling_ops_test.py
index 333bfa17f9..011078036d 100644
--- a/tensorflow/python/kernel_tests/pooling_ops_test.py
+++ b/tensorflow/python/kernel_tests/pooling_ops_test.py
@@ -99,8 +99,8 @@ def GetShrunkInceptionMaxPoolShapes(shrink=30):
class PoolingTest(tf.test.TestCase):
- def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
- data_format, expected, use_gpu):
+ def _VerifyOneType(self, pool_func, input_sizes, ksize, strides, padding,
+ data_format, data_type, expected, use_gpu):
"""Verifies the output values of the pooling function.
Args:
@@ -111,6 +111,7 @@ class PoolingTest(tf.test.TestCase):
strides: The stride dimensions
padding: Padding type.
data_format: The data format we use to run the pooling operation.
+ data_type: The data type to use to run the pooling operation.
expected: An array containing the expected operation outputs.
use_gpu: Whether we are running on GPU.
"""
@@ -121,7 +122,7 @@ class PoolingTest(tf.test.TestCase):
# numbers from 1.
x = [f * 1.0 for f in range(1, total_size + 1)]
with self.test_session(use_gpu=use_gpu) as sess:
- t = tf.constant(x, shape=input_sizes)
+ t = tf.constant(x, shape=input_sizes, dtype=data_type)
if data_format == "NCHW":
t = NHWCToNCHW(t)
ksize = NHWCToNCHW(ksize)
@@ -131,9 +132,31 @@ class PoolingTest(tf.test.TestCase):
if data_format == "NCHW":
t = NCHWToNHWC(t)
actual = t.eval()
- self.assertAllClose(expected, actual.flatten())
+ self.assertAllCloseAccordingToType(expected, actual.flatten())
self.assertShapeEqual(actual, t)
+ def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
+ data_format, expected, use_gpu):
+ """Verifies the output values of the pooling function.
+
+ Args:
+ pool_func: Function to be called, co.MaxPool, co.AvgPool,
+ or the Lua version.
+ input_sizes: Input tensor dimensions.
+ ksize: The kernel size dimensions
+ strides: The stride dimensions
+ padding: Padding type.
+ data_format: The data format we use to run the pooling operation.
+ expected: An array containing the expected operation outputs.
+ use_gpu: Whether we are running on GPU.
+ """
+ self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
+ data_format, tf.float32, expected, use_gpu)
+
+ if not use_gpu or test_util.CudaSupportsHalfMatMulAndConv():
+ self._VerifyOneType(pool_func, input_sizes, ksize, strides, padding,
+ data_format, tf.float16, expected, use_gpu)
+
def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
expected, use_gpu):
"""Verifies the output values of the pooling function.
@@ -372,32 +395,40 @@ class PoolingTest(tf.test.TestCase):
def testKernelSmallerThanStrideValid(self):
for use_gpu in [True, False]:
- self._VerifyValues(tf.nn.max_pool, input_sizes=[1, 7, 7, 1],
- ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1],
- padding="VALID",
- expected=[9, 12, 30, 33],
- use_gpu=use_gpu)
-
- self._VerifyValues(tf.nn.avg_pool, input_sizes=[1, 7, 7, 1],
- ksize=[1, 2, 2, 1], strides=[1, 3, 3, 1],
- padding="VALID",
- expected=[5, 8, 26, 29],
- use_gpu=use_gpu)
+ self._VerifyValues(tf.nn.max_pool,
+ input_sizes=[1, 7, 7, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 3, 3, 1],
+ padding="VALID",
+ expected=[9, 12, 30, 33],
+ use_gpu=use_gpu)
+
+ self._VerifyValues(tf.nn.avg_pool,
+ input_sizes=[1, 7, 7, 1],
+ ksize=[1, 2, 2, 1],
+ strides=[1, 3, 3, 1],
+ padding="VALID",
+ expected=[5, 8, 26, 29],
+ use_gpu=use_gpu)
def testKernelSmallerThanStrideSame(self):
for use_gpu in [True, False]:
- for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
- self._VerifyValues(pool_func, input_sizes=[1, 3, 3, 1],
- ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1],
- padding="SAME",
- expected=[1, 3, 7, 9],
- use_gpu=use_gpu)
-
- self._VerifyValues(pool_func, input_sizes=[1, 4, 4, 1],
- ksize=[1, 1, 1, 1], strides=[1, 2, 2, 1],
- padding="SAME",
- expected=[1, 3, 9, 11],
- use_gpu=use_gpu)
+ for pool_func in [tf.nn.max_pool, tf.nn.avg_pool]:
+ self._VerifyValues(pool_func,
+ input_sizes=[1, 3, 3, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=[1, 3, 7, 9],
+ use_gpu=use_gpu)
+
+ self._VerifyValues(pool_func,
+ input_sizes=[1, 4, 4, 1],
+ ksize=[1, 1, 1, 1],
+ strides=[1, 2, 2, 1],
+ padding="SAME",
+ expected=[1, 3, 9, 11],
+ use_gpu=use_gpu)
def _testDepthwiseMaxPoolInvalidConfig(self, in_size, ksize, strides,
error_msg, use_gpu=False):
@@ -425,43 +456,49 @@ class PoolingTest(tf.test.TestCase):
# The following are tests that verify that the CPU and GPU implementations
# produce the same resuts.
def _CompareMaxPoolingFwd(self, input_shape, ksize, strides, padding):
- tensor_input = np.random.rand(*input_shape).astype(np.float32)
- with self.test_session(use_gpu=True):
- t = tf.constant(tensor_input, shape=input_shape)
- out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
- gpu_val = out_op.eval()
- with self.test_session(use_gpu=False):
- t = tf.constant(tensor_input, shape=input_shape)
- out_op = tf.nn.max_pool(t, ksize, strides, padding)
- cpu_val = out_op.eval()
- self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
+ for dtype in np.float32, np.float16:
+ tensor_input = np.random.rand(*input_shape).astype(dtype)
+ with self.test_session(use_gpu=True):
+ t = tf.constant(tensor_input, shape=input_shape)
+ out_op, _ = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
+ gpu_val = out_op.eval()
+ with self.test_session(use_gpu=False):
+ t = tf.constant(tensor_input, shape=input_shape)
+ out_op = tf.nn.max_pool(t, ksize, strides, padding)
+ cpu_val = out_op.eval()
+ self.assertAllCloseAccordingToType(cpu_val, gpu_val)
def _CompareMaxPoolingBk(self, input_shape, output_shape, ksize, strides,
padding):
- # Generate numbers in a narrow range, so that there are many duplicates
- # in the input.
- tensor_input = np.random.random_integers(0, 3,
- input_shape).astype(np.float32)
- tensor_output = np.random.rand(*output_shape).astype(np.float32)
- with self.test_session(use_gpu=True):
- t = tf.constant(tensor_input, shape=input_shape)
- _, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
- argmax = argmax_op.eval()
- grad_in = tf.constant(tensor_output, shape=output_shape)
- out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax,
- ksize, strides, padding)
- gpu_val = out_op.eval()
- self.assertShapeEqual(gpu_val, out_op)
- with self.test_session(use_gpu=False):
- t = tf.constant(tensor_input, shape=input_shape)
- out_op = tf.nn.max_pool(t, ksize, strides, padding)
- orig_out = out_op.eval()
- grad_in = tf.constant(tensor_output, shape=output_shape)
- out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize,
- strides, padding)
- cpu_val = out_op.eval()
- self.assertShapeEqual(cpu_val, out_op)
- self.assertAllClose(cpu_val, gpu_val, rtol=1e-5, atol=1e-5)
+ for dtype in np.float32, np.float16:
+ # Generate numbers in a narrow range, so that there are many duplicates
+ # in the input.
+ tensor_input = np.random.random_integers(0, 3, input_shape).astype(dtype)
+ tensor_output = np.random.rand(*output_shape).astype(dtype)
+ with self.test_session(use_gpu=True):
+ t = tf.constant(tensor_input, shape=input_shape)
+ _, argmax_op = tf.nn.max_pool_with_argmax(t, ksize, strides, padding)
+ argmax = argmax_op.eval()
+ grad_in = tf.constant(tensor_output, shape=output_shape)
+ out_op = gen_nn_ops._max_pool_grad_with_argmax(t, grad_in, argmax,
+ ksize, strides, padding)
+ gpu_val = out_op.eval()
+ self.assertShapeEqual(gpu_val, out_op)
+ with self.test_session(use_gpu=False):
+ t = tf.constant(tensor_input, shape=input_shape)
+ out_op = tf.nn.max_pool(t, ksize, strides, padding)
+ orig_out = out_op.eval()
+ grad_in = tf.constant(tensor_output, shape=output_shape)
+ out_op = gen_nn_ops._max_pool_grad(t, orig_out, grad_in, ksize, strides,
+ padding)
+ cpu_val = out_op.eval()
+ self.assertShapeEqual(cpu_val, out_op)
+ if dtype == np.float16:
+ # The CPU version accumulates its gradient on fp16, so it's less
+ # accurate than the GPU version that does the accumulation on fp32
+ self.assertAllClose(cpu_val, gpu_val, rtol=0.01, atol=0.01)
+ else:
+ self.assertAllClose(cpu_val, gpu_val)
def testMaxPoolingWithArgmax(self):
# MaxPoolWithArgMax is implemented only on GPU.
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 23a8066e79..9d860e59a2 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -1876,6 +1876,40 @@ bool CudnnSupport::DoPoolForward(
return true;
}
+bool CudnnSupport::DoPoolForward(
+ Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<Eigen::half>* output_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
+ return false;
+ }
+
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
+ ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
+ ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
+ status = dynload::cudnnPoolingForward(
+ parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+ src_desc.handle(), input_data.opaque(), &beta, dest_desc.handle(),
+ output_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to enqueue forward pooling on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+}
+
bool CudnnSupport::DoPoolBackward(
Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
@@ -1914,6 +1948,43 @@ bool CudnnSupport::DoPoolBackward(
return true;
}
+bool CudnnSupport::DoPoolBackward(
+ Stream* stream, const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ const DeviceMemory<Eigen::half>& output_data,
+ const DeviceMemory<Eigen::half>& input_diff_data,
+ DeviceMemory<Eigen::half>* output_diff_data) {
+ mutex_lock lock{dnn_handle_mutex_};
+ auto status = dynload::cudnnSetStream(parent_, ToHandle(dnn_handle_),
+ AsCUDAStreamValue(stream));
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to set stream for cudnn handle: " << ToString(status);
+ return false;
+ }
+
+ // Alpha is the scaling factor for input.
+ float alpha = 1.0;
+ // Beta is the scaling factor for output.
+ float beta = 0.0;
+
+ ScopedTensorDescriptor src_desc{parent_, input_dimensions, CUDNN_DATA_HALF};
+ ScopedTensorDescriptor dest_desc{parent_, output_dimensions, CUDNN_DATA_HALF};
+ ScopedPoolingDescriptor pooling_desc{parent_, pooling_dimensions};
+ status = dynload::cudnnPoolingBackward(
+ parent_, ToHandle(dnn_handle_), pooling_desc.handle(), &alpha,
+ dest_desc.handle(), output_data.opaque(), dest_desc.handle(),
+ input_diff_data.opaque(), src_desc.handle(), input_data.opaque(), &beta,
+ src_desc.handle(), output_diff_data->opaque());
+ if (status != CUDNN_STATUS_SUCCESS) {
+ LOG(ERROR) << "failed to enqueue backward pooling on stream: "
+ << ToString(status);
+ return false;
+ }
+ return true;
+}
+
bool CudnnSupport::DoNormalize(
Stream* stream, const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data, DeviceMemory<float>* output_data) {
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 523a0c6c5d..434ab730a7 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -201,6 +201,13 @@ class CudnnSupport : public dnn::DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) override;
+ bool DoPoolForward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<Eigen::half>* output_data) override;
+
bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
const dnn::BatchDescriptor& input_dimensions,
@@ -210,6 +217,15 @@ class CudnnSupport : public dnn::DnnSupport {
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) override;
+ bool DoPoolBackward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ const DeviceMemory<Eigen::half>& output_data,
+ const DeviceMemory<Eigen::half>& input_diff_data,
+ DeviceMemory<Eigen::half>* output_diff_data) override;
+
bool DoNormalize(Stream* stream,
const dnn::NormalizeDescriptor& normalize_descriptor,
const DeviceMemory<float>& input_data,
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index fbb44dc739..0ae482a73c 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -1011,6 +1011,13 @@ class DnnSupport {
const dnn::BatchDescriptor& output_dimensions,
DeviceMemory<float>* output_data) = 0;
+ virtual bool DoPoolForward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ DeviceMemory<Eigen::half>* output_data) = 0;
+
// Performs differentiation of the pooling operation.
virtual bool DoPoolBackward(Stream* stream,
const dnn::PoolingDescriptor& pooling_dimensions,
@@ -1021,6 +1028,15 @@ class DnnSupport {
const DeviceMemory<float>& input_diff_data,
DeviceMemory<float>* output_diff_data) = 0;
+ virtual bool DoPoolBackward(Stream* stream,
+ const dnn::PoolingDescriptor& pooling_dimensions,
+ const dnn::BatchDescriptor& input_dimensions,
+ const DeviceMemory<Eigen::half>& input_data,
+ const dnn::BatchDescriptor& output_dimensions,
+ const DeviceMemory<Eigen::half>& output_data,
+ const DeviceMemory<Eigen::half>& input_diff_data,
+ DeviceMemory<Eigen::half>* output_diff_data) = 0;
+
// Applies local response normalization to the values from
// input_data and writes the result to output_data. See comments on
// NormalizeDescriptor for a description of local response
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index 446a3c9a7d..be823d9500 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -909,6 +909,30 @@ Stream &Stream::ThenPoolForward(
return *this;
}
+Stream &Stream::ThenPoolForward(
+ const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<Eigen::half> *output_data) {
+ VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoPoolForward(this, pooling_dimensions, input_dimensions,
+ input_data, output_dimensions,
+ output_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
Stream &Stream::ThenPoolBackward(
const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
@@ -936,6 +960,33 @@ Stream &Stream::ThenPoolBackward(
return *this;
}
+Stream &Stream::ThenPoolBackward(
+ const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ const DeviceMemory<Eigen::half> &output_data,
+ const DeviceMemory<Eigen::half> &input_diff_data,
+ DeviceMemory<Eigen::half> *output_diff_data) {
+ VLOG_CALL(PARAM(pooling_dimensions), PARAM(input_dimensions),
+ PARAM(input_data), PARAM(output_dimensions), PARAM(output_data),
+ PARAM(input_diff_data), PARAM(output_diff_data));
+
+ if (ok()) {
+ if (dnn::DnnSupport *dnn = parent_->AsDnn()) {
+ CheckError(dnn->DoPoolBackward(this, pooling_dimensions, input_dimensions,
+ input_data, output_dimensions, output_data,
+ input_diff_data, output_diff_data));
+ } else {
+ SetError();
+ LOG(WARNING)
+ << "attempting to perform DNN operation using StreamExecutor "
+ "without DNN support";
+ }
+ }
+ return *this;
+}
+
Stream &Stream::ThenNormalize(
const dnn::NormalizeDescriptor &normalize_descriptor,
const DeviceMemory<float> &input_data, DeviceMemory<float> *output_data) {
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index aac945c9e0..c131250de1 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -421,6 +421,12 @@ class Stream {
const dnn::BatchDescriptor &output_dimensions,
DeviceMemory<float> *output_data);
+ Stream &ThenPoolForward(const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ DeviceMemory<Eigen::half> *output_data);
+
Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
const dnn::BatchDescriptor &input_dimensions,
const DeviceMemory<float> &input_data,
@@ -429,6 +435,14 @@ class Stream {
const DeviceMemory<float> &input_diff_data,
DeviceMemory<float> *output_diff_data);
+ Stream &ThenPoolBackward(const dnn::PoolingDescriptor &pooling_dimensions,
+ const dnn::BatchDescriptor &input_dimensions,
+ const DeviceMemory<Eigen::half> &input_data,
+ const dnn::BatchDescriptor &output_dimensions,
+ const DeviceMemory<Eigen::half> &output_data,
+ const DeviceMemory<Eigen::half> &input_diff_data,
+ DeviceMemory<Eigen::half> *output_diff_data);
+
Stream &ThenNormalize(const dnn::NormalizeDescriptor &normalize_descriptor,
const DeviceMemory<float> &input_data,
DeviceMemory<float> *output_data);
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 07f83651e0..d9cfb85fc3 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -6,8 +6,8 @@
def tf_workspace(path_prefix = "", tf_repo_name = ""):
native.new_http_archive(
name = "eigen_archive",
- url = "https://bitbucket.org/eigen/eigen/get/d02e6a705c30.tar.gz",
- sha256 = "532956172daa8aba87c750791ff89a5c38cdb07e2525afe17ecb4bef812d67cf",
+ url = "https://bitbucket.org/eigen/eigen/get/0c0b79ecd74c.tar.gz",
+ sha256 = "b4b5884b03bd4bae114d02b36e2435ad1504ed8e51431d16c876b6f6a365882b",
build_file = path_prefix + "eigen.BUILD",
)
diff --git a/third_party/eigen3/Eigen/Cholesky b/third_party/eigen3/Eigen/Cholesky
index 56059bcc61..7415ae4d0d 100644
--- a/third_party/eigen3/Eigen/Cholesky
+++ b/third_party/eigen3/Eigen/Cholesky
@@ -1 +1 @@
-#include "eigen-eigen-d02e6a705c30/Eigen/Cholesky"
+#include "eigen-eigen-0c0b79ecd74c/Eigen/Cholesky"
diff --git a/third_party/eigen3/Eigen/Core b/third_party/eigen3/Eigen/Core
index c1d4a2e0f8..787e1c076e 100644
--- a/third_party/eigen3/Eigen/Core
+++ b/third_party/eigen3/Eigen/Core
@@ -1 +1 @@
-#include "eigen-eigen-d02e6a705c30/Eigen/Core"
+#include "eigen-eigen-0c0b79ecd74c/Eigen/Core"
diff --git a/third_party/eigen3/Eigen/Eigenvalues b/third_party/eigen3/Eigen/Eigenvalues
index 0a0731ba19..b6e1b81eb5 100644
--- a/third_party/eigen3/Eigen/Eigenvalues
+++ b/third_party/eigen3/Eigen/Eigenvalues
@@ -1 +1 @@
-#include "eigen-eigen-d02e6a705c30/Eigen/Eigenvalues"
+#include "eigen-eigen-0c0b79ecd74c/Eigen/Eigenvalues"
diff --git a/third_party/eigen3/Eigen/LU b/third_party/eigen3/Eigen/LU
index d6b39b8d23..a0782af040 100644
--- a/third_party/eigen3/Eigen/LU
+++ b/third_party/eigen3/Eigen/LU
@@ -1 +1 @@
-#include "eigen-eigen-d02e6a705c30/Eigen/LU"
+#include "eigen-eigen-0c0b79ecd74c/Eigen/LU"
diff --git a/third_party/eigen3/Eigen/QR b/third_party/eigen3/Eigen/QR
index a5406e93bc..0a9bee2898 100644
--- a/third_party/eigen3/Eigen/QR
+++ b/third_party/eigen3/Eigen/QR
@@ -1 +1 @@
-#include "eigen-eigen-d02e6a705c30/Eigen/QR"
+#include "eigen-eigen-0c0b79ecd74c/Eigen/QR"
diff --git a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
index 4f730236b7..5228bcda62 100644
--- a/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
+++ b/third_party/eigen3/unsupported/Eigen/CXX11/Tensor
@@ -1 +1 @@
-#include "eigen-eigen-d02e6a705c30/unsupported/Eigen/CXX11/Tensor"
+#include "eigen-eigen-0c0b79ecd74c/unsupported/Eigen/CXX11/Tensor"