aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-17 14:22:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-17 15:34:07 -0700
commit204eac513d875a4fb437d4985b2ad520f79c262c (patch)
treefadcebb02c4c78363449f28ea56fe67aa2bcd63b
parent699154cedccef7a8a44de4be2d11dc737513f941 (diff)
Relax static shape information requirements for depthwise and separable conv2d.
The current implementation requires that if rank of a tensor is statically known, certain dimensions needs to be known statically as well. Change: 130570253
-rw-r--r--tensorflow/core/kernels/BUILD2
-rw-r--r--tensorflow/core/kernels/conv_ops.cc519
-rw-r--r--tensorflow/core/kernels/conv_ops.h58
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h5
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.cc56
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py10
-rw-r--r--tensorflow/python/ops/nn.py61
7 files changed, 389 insertions, 322 deletions
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 04316fb085..fa494f117e 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1328,6 +1328,7 @@ tf_kernel_library(
name = "depthwise_conv_op",
prefix = "depthwise_conv_op",
deps = [
+ ":conv_ops",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
@@ -1936,6 +1937,7 @@ filegroup(
"batch_norm_op.h",
"control_flow_ops.h",
"conv_2d.h",
+ "conv_ops.h",
"image_resizer_state.h",
"maxpooling_op.h",
"pad_op.h",
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index d49d859e30..e2182d0ec0 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -18,6 +18,7 @@ limitations under the License.
#define USE_EIGEN_TENSOR
#define EIGEN_USE_THREADS
+#include "tensorflow/core/kernels/conv_ops.h"
#include <string.h>
#include <map>
#include <vector>
@@ -50,6 +51,7 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+namespace {
template <typename Device, typename T>
struct LaunchGeneric {
static void launch(OpKernelContext* ctx, const Tensor& input,
@@ -87,12 +89,10 @@ struct LaunchGeneric {
}
}
};
-
-template <typename Device, typename T>
-class LaunchConvOp;
+} // namespace
template <typename T>
-class LaunchConvOp<CPUDevice, T> {
+class LaunchConv2DOp<CPUDevice, T> {
public:
void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& input, const Tensor& filter, int row_stride,
@@ -231,7 +231,7 @@ class Conv2DOp : public BinaryOp<T> {
bool use_cudnn_;
Padding padding_;
TensorFormat data_format_;
- LaunchConvOp<Device, T> launcher_;
+ LaunchConv2DOp<Device, T> launcher_;
bool cudnn_use_autotune_;
TF_DISALLOW_COPY_AND_ASSIGN(Conv2DOp);
@@ -244,8 +244,11 @@ class Conv2DOp : public BinaryOp<T> {
TF_CALL_half(REGISTER_CPU);
TF_CALL_float(REGISTER_CPU);
-#if GOOGLE_CUDA
+// To be used inside depthwise_conv_op.cc.
+template class LaunchConv2DOp<CPUDevice, float>;
+
+#if GOOGLE_CUDA
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
int64 default_value_in_bytes) {
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
@@ -264,276 +267,265 @@ int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
}
template <typename T>
-class LaunchConvOp<GPUDevice, T> {
- public:
- void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input_param, const Tensor& filter, int row_stride,
- int col_stride, const Eigen::PaddingType& padding, Tensor* output,
- TensorFormat data_format) {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmType;
- using perftools::gputools::dnn::ProfileResult;
- using perftools::gputools::dnn::kDefaultAlgorithm;
- auto* stream = ctx->op_device_context()->stream();
- OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
-
- if (!use_cudnn) {
- ctx->SetStatus(
- errors::Unimplemented("Conv2D for GPU is not currently supported "
- "without cudnn"));
- return;
- }
-
- Tensor input = input_param;
- if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
- col_stride == 1 && data_format == FORMAT_NHWC) {
- // 1x1 filter, so call cublas directly.
- const uint64 m =
- input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
- const uint64 k = filter.dim_size(2);
- const uint64 n = filter.dim_size(3);
-
- auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
- input.template flat<T>().size());
- auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
- filter.template flat<T>().size());
- auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
- output->template flat<T>().size());
-
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
- bool blas_launch_status =
- stream
- ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
- n, a_ptr, k, 0.0f, &c_ptr, n)
- .ok();
- if (!blas_launch_status) {
- ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
- ", n=", n, ", k=", k));
- }
+void LaunchConv2DOp<GPUDevice, T>::launch(
+ OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input_param, const Tensor& filter, int row_stride,
+ int col_stride, const Eigen::PaddingType& padding, Tensor* output,
+ TensorFormat data_format) {
+ using perftools::gputools::dnn::AlgorithmConfig;
+ using perftools::gputools::dnn::AlgorithmType;
+ using perftools::gputools::dnn::ProfileResult;
+ using perftools::gputools::dnn::kDefaultAlgorithm;
+ auto* stream = ctx->op_device_context()->stream();
+ OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
+
+ if (!use_cudnn) {
+ ctx->SetStatus(
+ errors::Unimplemented("Conv2D for GPU is not currently supported "
+ "without cudnn"));
+ return;
+ }
- return;
+ Tensor input = input_param;
+ if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
+ col_stride == 1 && data_format == FORMAT_NHWC) {
+ // 1x1 filter, so call cublas directly.
+ const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
+ const uint64 k = filter.dim_size(2);
+ const uint64 n = filter.dim_size(3);
+
+ auto a_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto b_ptr = AsDeviceMemory(filter.template flat<T>().data(),
+ filter.template flat<T>().size());
+ auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
+ output->template flat<T>().size());
+
+ auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ bool blas_launch_status =
+ stream
+ ->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
+ a_ptr, k, 0.0f, &c_ptr, n)
+ .ok();
+ if (!blas_launch_status) {
+ ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
+ ", n=", n, ", k=", k));
}
- int padding_rows = 0;
- int padding_cols = 0;
- const int64 in_batch = GetTensorDim(input, data_format, 'N');
- int64 in_rows = GetTensorDim(input, data_format, 'H');
- int64 in_cols = GetTensorDim(input, data_format, 'W');
- const int64 in_depths = GetTensorDim(input, data_format, 'C');
- const int64 out_batch = GetTensorDim(*output, data_format, 'N');
- const int64 out_rows = GetTensorDim(*output, data_format, 'H');
- const int64 out_cols = GetTensorDim(*output, data_format, 'W');
- const int64 out_depths = GetTensorDim(*output, data_format, 'C');
- const int64 patch_rows = filter.dim_size(0);
- const int64 patch_cols = filter.dim_size(1);
- if (padding == Eigen::PADDING_SAME) {
- // Total padding on rows and cols is
- // Pr = (R' - 1) * S + Kr - R
- // Pc = (C' - 1) * S + Kc - C
- // where (R', C') are output dimensions, (R, C) are input dimensions, S
- // is stride, (Kr, Kc) are filter dimensions.
- // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
- // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
- // we pad more on the right and bottom than on the top and left.
- padding_rows =
- std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
- padding_cols =
- std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
- const bool rows_odd = (padding_rows % 2 != 0);
- const bool cols_odd = (padding_cols % 2 != 0);
- if (rows_odd || cols_odd) {
- Tensor transformed_input;
- int64 new_in_rows = in_rows + rows_odd;
- int64 new_in_cols = in_cols + cols_odd;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- ShapeFromFormat(data_format, in_batch, new_in_rows,
- new_in_cols, in_depths),
- &transformed_input));
-
- functor::PadInput<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
- {{0, 0}}, {{rows_odd, cols_odd}},
- To32Bit(transformed_input.tensor<T, 4>()), data_format);
-
- input = transformed_input;
- in_rows = new_in_rows;
- in_cols = new_in_cols;
- }
+
+ return;
+ }
+ int padding_rows = 0;
+ int padding_cols = 0;
+ const int64 in_batch = GetTensorDim(input, data_format, 'N');
+ int64 in_rows = GetTensorDim(input, data_format, 'H');
+ int64 in_cols = GetTensorDim(input, data_format, 'W');
+ const int64 in_depths = GetTensorDim(input, data_format, 'C');
+ const int64 out_batch = GetTensorDim(*output, data_format, 'N');
+ const int64 out_rows = GetTensorDim(*output, data_format, 'H');
+ const int64 out_cols = GetTensorDim(*output, data_format, 'W');
+ const int64 out_depths = GetTensorDim(*output, data_format, 'C');
+ const int64 patch_rows = filter.dim_size(0);
+ const int64 patch_cols = filter.dim_size(1);
+ if (padding == Eigen::PADDING_SAME) {
+ // Total padding on rows and cols is
+ // Pr = (R' - 1) * S + Kr - R
+ // Pc = (C' - 1) * S + Kc - C
+ // where (R', C') are output dimensions, (R, C) are input dimensions, S
+ // is stride, (Kr, Kc) are filter dimensions.
+ // We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
+ // and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
+ // we pad more on the right and bottom than on the top and left.
+ padding_rows =
+ std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+ padding_cols =
+ std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
+ if (rows_odd || cols_odd) {
+ Tensor transformed_input;
+ int64 new_in_rows = in_rows + rows_odd;
+ int64 new_in_cols = in_cols + cols_odd;
+ OP_REQUIRES_OK(
+ ctx,
+ ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(data_format, in_batch, new_in_rows,
+ new_in_cols, in_depths),
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
+ {{0, 0}}, {{rows_odd, cols_odd}},
+ To32Bit(transformed_input.tensor<T, 4>()), data_format);
+
+ input = transformed_input;
+ in_rows = new_in_rows;
+ in_cols = new_in_cols;
}
+ }
- if (data_format == FORMAT_NHWC) {
- // Convert the input tensor from NHWC to NCHW.
- TensorShape nchw_shape =
- ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
- if (in_depths > 1) {
- Tensor transformed_input;
- OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- nchw_shape, &transformed_input));
- functor::NHWCToNCHW<GPUDevice, T, 4>()(
- ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor&>(input).tensor<T, 4>(),
- transformed_input.tensor<T, 4>());
- input = transformed_input;
- } else {
- // If depth <= 1, then just reshape.
- CHECK(input.CopyFrom(input, nchw_shape));
- }
+ if (data_format == FORMAT_NHWC) {
+ // Convert the input tensor from NHWC to NCHW.
+ TensorShape nchw_shape =
+ ShapeFromFormat(FORMAT_NCHW, in_batch, in_rows, in_cols, in_depths);
+ if (in_depths > 1) {
+ Tensor transformed_input;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ nchw_shape, &transformed_input));
+ functor::NHWCToNCHW<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(input).tensor<T, 4>(),
+ transformed_input.tensor<T, 4>());
+ input = transformed_input;
+ } else {
+ // If depth <= 1, then just reshape.
+ CHECK(input.CopyFrom(input, nchw_shape));
}
+ }
- CHECK(padding_rows >= 0 && padding_cols >= 0)
- << "Negative row or col paddings: (" << padding_rows << ", "
- << padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
- input_desc.set_count(in_batch)
- .set_feature_map_count(in_depths)
- .set_height(in_rows)
- .set_width(in_cols)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
- output_desc.set_count(out_batch)
- .set_height(out_rows)
- .set_width(out_cols)
- .set_feature_map_count(out_depths)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
- filter_desc.set_input_filter_height(filter.dim_size(0))
- .set_input_filter_width(filter.dim_size(1))
- .set_input_feature_map_count(filter.dim_size(2))
- .set_output_feature_map_count(filter.dim_size(3));
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(row_stride)
- .set_horizontal_filter_stride(col_stride)
- .set_zero_padding_height(padding_rows / 2)
- .set_zero_padding_width(padding_cols / 2);
-
- Tensor transformed_filter;
- OP_REQUIRES_OK(ctx,
- ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape({filter.dim_size(3), filter.dim_size(2),
- filter.dim_size(0), filter.dim_size(1)}),
- &transformed_filter));
-
- functor::TransformFilter<GPUDevice, T, int, 4>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
- To32Bit(transformed_filter.tensor<T, 4>()));
-
- Tensor transformed_output;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
- ShapeFromFormat(FORMAT_NCHW, out_batch,
- out_rows, out_cols, out_depths),
- &transformed_output));
-
- auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
- input.template flat<T>().size());
- auto filter_ptr =
- AsDeviceMemory(transformed_filter.template flat<T>().data(),
- transformed_filter.template flat<T>().size());
- auto output_ptr =
- AsDeviceMemory(transformed_output.template flat<T>().data(),
- transformed_output.template flat<T>().size());
-
- static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
- );
-
- int device_id = stream->parent()->device_ordinal();
- ConvParameters conv_parameters = {
- in_batch, // batch
- in_depths, // in_depths
- in_rows, // in_rows
- in_cols, // in_cols
- out_depths, // out_depths
- patch_rows, // filter_rows
- patch_cols, // filter_cols
- row_stride, // stride_rows
- col_stride, // stride_cols
- padding_rows, // padding_rows
- padding_cols, // padding_cols
- device_id, // device_id
- };
- AlgorithmConfig algorithm_config;
- if (cudnn_use_autotune &&
- !autotune_results_.Find(conv_parameters, &algorithm_config)) {
- std::vector<AlgorithmType> algorithms;
- CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms));
- ProfileResult best_result;
- ProfileResult best_result_no_scratch;
- for (auto profile_algorithm : algorithms) {
- // TODO(zhengxq): profile each algorithm multiple times to better
- // accuracy.
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- ProfileResult profile_result;
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithAlgorithm(
- input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
- output_desc, &output_ptr, &scratch_allocator,
- AlgorithmConfig(profile_algorithm), &profile_result)
- .ok();
- if (cudnn_launch_status) {
- if (profile_result.is_valid()) {
- if (profile_result.elapsed_time_in_ms() <
- best_result.elapsed_time_in_ms()) {
- best_result = profile_result;
- }
- if (scratch_allocator.TotalByteSize() == 0 &&
- profile_result.elapsed_time_in_ms() <
- best_result_no_scratch.elapsed_time_in_ms()) {
- best_result_no_scratch = profile_result;
- }
+ CHECK(padding_rows >= 0 && padding_cols >= 0)
+ << "Negative row or col paddings: (" << padding_rows << ", "
+ << padding_cols << ")";
+ perftools::gputools::dnn::BatchDescriptor input_desc;
+ input_desc.set_count(in_batch)
+ .set_feature_map_count(in_depths)
+ .set_height(in_rows)
+ .set_width(in_cols)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::BatchDescriptor output_desc;
+ output_desc.set_count(out_batch)
+ .set_height(out_rows)
+ .set_width(out_cols)
+ .set_feature_map_count(out_depths)
+ .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ perftools::gputools::dnn::FilterDescriptor filter_desc;
+ filter_desc.set_input_filter_height(filter.dim_size(0))
+ .set_input_filter_width(filter.dim_size(1))
+ .set_input_feature_map_count(filter.dim_size(2))
+ .set_output_feature_map_count(filter.dim_size(3));
+ perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ conv_desc.set_vertical_filter_stride(row_stride)
+ .set_horizontal_filter_stride(col_stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
+
+ Tensor transformed_filter;
+ OP_REQUIRES_OK(ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape({filter.dim_size(3), filter.dim_size(2),
+ filter.dim_size(0), filter.dim_size(1)}),
+ &transformed_filter));
+
+ functor::TransformFilter<GPUDevice, T, int, 4>()(
+ ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
+ To32Bit(transformed_filter.tensor<T, 4>()));
+
+ Tensor transformed_output;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(DataTypeToEnum<T>::value,
+ ShapeFromFormat(FORMAT_NCHW, out_batch, out_rows,
+ out_cols, out_depths),
+ &transformed_output));
+
+ auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
+ input.template flat<T>().size());
+ auto filter_ptr =
+ AsDeviceMemory(transformed_filter.template flat<T>().data(),
+ transformed_filter.template flat<T>().size());
+ auto output_ptr =
+ AsDeviceMemory(transformed_output.template flat<T>().data(),
+ transformed_output.template flat<T>().size());
+
+ static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
+ );
+
+ int device_id = stream->parent()->device_ordinal();
+ ConvParameters conv_parameters = {
+ in_batch, // batch
+ in_depths, // in_depths
+ in_rows, // in_rows
+ in_cols, // in_cols
+ out_depths, // out_depths
+ patch_rows, // filter_rows
+ patch_cols, // filter_cols
+ row_stride, // stride_rows
+ col_stride, // stride_cols
+ padding_rows, // padding_rows
+ padding_cols, // padding_cols
+ device_id, // device_id
+ };
+ AlgorithmConfig algorithm_config;
+ if (cudnn_use_autotune &&
+ !autotune_results_.Find(conv_parameters, &algorithm_config)) {
+ std::vector<AlgorithmType> algorithms;
+ CHECK(stream->parent()->GetConvolveAlgorithms(&algorithms));
+ ProfileResult best_result;
+ ProfileResult best_result_no_scratch;
+ for (auto profile_algorithm : algorithms) {
+ // TODO(zhengxq): profile each algorithm multiple times to better
+ // accuracy.
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ ProfileResult profile_result;
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(
+ input_desc, input_ptr, filter_desc, filter_ptr, conv_desc,
+ output_desc, &output_ptr, &scratch_allocator,
+ AlgorithmConfig(profile_algorithm), &profile_result)
+ .ok();
+ if (cudnn_launch_status) {
+ if (profile_result.is_valid()) {
+ if (profile_result.elapsed_time_in_ms() <
+ best_result.elapsed_time_in_ms()) {
+ best_result = profile_result;
+ }
+ if (scratch_allocator.TotalByteSize() == 0 &&
+ profile_result.elapsed_time_in_ms() <
+ best_result_no_scratch.elapsed_time_in_ms()) {
+ best_result_no_scratch = profile_result;
}
}
}
- OP_REQUIRES(ctx, best_result.is_valid() &&
- best_result.algorithm() != kDefaultAlgorithm,
- errors::NotFound("No algorithm worked!"));
- OP_REQUIRES(ctx,
- best_result_no_scratch.is_valid() &&
- best_result_no_scratch.algorithm() != kDefaultAlgorithm,
- errors::NotFound("No algorithm without scratch worked!"));
- algorithm_config.set_algorithm(best_result.algorithm());
- algorithm_config.set_algorithm_no_scratch(
- best_result_no_scratch.algorithm());
- autotune_results_.Insert(conv_parameters, algorithm_config);
- }
-
- CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
- bool cudnn_launch_status =
- stream
- ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
- filter_ptr, conv_desc, output_desc,
- &output_ptr, &scratch_allocator,
- algorithm_config, nullptr)
- .ok();
-
- if (!cudnn_launch_status) {
- ctx->SetStatus(errors::Internal(
- "cuDNN launch failure : input shape(", input.shape().DebugString(),
- ") filter shape(", filter.shape().DebugString(), ")"));
- }
-
- // Convert the output tensor back from NHWC to NCHW.
- if (data_format == FORMAT_NHWC) {
- functor::NCHWToNHWC<GPUDevice, T, 4>()(
- ctx->eigen_device<GPUDevice>(),
- const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
- output->tensor<T, 4>());
- } else {
- *output = transformed_output;
}
+ OP_REQUIRES(ctx, best_result.is_valid() &&
+ best_result.algorithm() != kDefaultAlgorithm,
+ errors::NotFound("No algorithm worked!"));
+ OP_REQUIRES(ctx,
+ best_result_no_scratch.is_valid() &&
+ best_result_no_scratch.algorithm() != kDefaultAlgorithm,
+ errors::NotFound("No algorithm without scratch worked!"));
+ algorithm_config.set_algorithm(best_result.algorithm());
+ algorithm_config.set_algorithm_no_scratch(
+ best_result_no_scratch.algorithm());
+ autotune_results_.Insert(conv_parameters, algorithm_config);
}
- private:
- AutoTuneMap<ConvParameters, perftools::gputools::dnn::AlgorithmConfig>
- autotune_results_;
-};
+ CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
+ bool cudnn_launch_status =
+ stream
+ ->ThenConvolveWithAlgorithm(input_desc, input_ptr, filter_desc,
+ filter_ptr, conv_desc, output_desc,
+ &output_ptr, &scratch_allocator,
+ algorithm_config, nullptr)
+ .ok();
+
+ if (!cudnn_launch_status) {
+ ctx->SetStatus(errors::Internal(
+ "cuDNN launch failure : input shape(", input.shape().DebugString(),
+ ") filter shape(", filter.shape().DebugString(), ")"));
+ }
-#endif // GOOGLE_CUDA
+ // Convert the output tensor back from NHWC to NCHW.
+ if (data_format == FORMAT_NHWC) {
+ functor::NCHWToNHWC<GPUDevice, T, 4>()(
+ ctx->eigen_device<GPUDevice>(),
+ const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
+ output->tensor<T, 4>());
+ } else {
+ *output = transformed_output;
+ }
+}
-#if GOOGLE_CUDA
// Forward declarations of the functor specializations for GPU.
namespace functor {
#define DECLARE_GPU_SPEC(T) \
@@ -577,6 +569,9 @@ REGISTER_KERNEL_BUILDER(
Name("Conv2D").Device(DEVICE_GPU).TypeConstraint<float>("T"),
Conv2DOp<GPUDevice, float>);
+// To be used inside depthwise_conv_op.cc.
+template class LaunchConv2DOp<GPUDevice, float>;
+
#endif // GOOGLE_CUDA
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
new file mode 100644
index 0000000000..d09db3dc15
--- /dev/null
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -0,0 +1,58 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_KERNELS_CONV_OPS_H_
+#define TENSORFLOW_KERNELS_CONV_OPS_H_
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/util/tensor_format.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/core/kernels/conv_ops_gpu.h"
+#include "tensorflow/core/platform/stream_executor.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+
+// Forward declaration.
+class OpKernelContext;
+
+template <typename Device, typename T>
+class LaunchConv2DOp {
+ public:
+ void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Eigen::PaddingType& padding, Tensor* output,
+ TensorFormat data_format);
+};
+
+#ifdef GOOGLE_CUDA
+template <typename T>
+class LaunchConv2DOp<Eigen::GpuDevice, T> {
+ public:
+ void launch(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
+ const Tensor& input, const Tensor& filter, int row_stride,
+ int col_stride, const Eigen::PaddingType& padding, Tensor* output,
+ TensorFormat data_format);
+
+ private:
+ AutoTuneMap<ConvParameters, perftools::gputools::dnn::AlgorithmConfig>
+ autotune_results_;
+};
+#endif // GOOGLE_CUDA
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_KERNELS_CONV_OPS_H
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index 6cfa541d06..26d03908d3 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -19,6 +19,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include <tuple>
+#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/stream_executor.h"
namespace tensorflow {
@@ -26,8 +27,8 @@ namespace tensorflow {
// TODO(zhengxq): move this to gpu_util.h. The use of such wrappers is wide
// spread.
template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
- uint64 size) {
+inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
+ uint64 size) {
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
size * sizeof(T));
perftools::gputools::DeviceMemory<T> typed(wrapped);
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index 88d2651938..172bc1bcba 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <cmath>
+#include <type_traits>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -25,12 +26,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/conv_ops.h"
#include "tensorflow/core/kernels/depthwise_conv_op.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/padding.h"
+#include "tensorflow/core/util/use_cudnn.h"
#include "tensorflow/core/util/work_sharder.h"
#if GOOGLE_CUDA
@@ -225,6 +228,9 @@ struct LaunchDepthwiseConvOp<CPUDevice, T> {
}
};
+// Extern template instantiated in conv_ops.cc.
+extern template class LaunchConv2DOp<CPUDevice, float>;
+
#if GOOGLE_CUDA
template <typename T>
@@ -247,6 +253,9 @@ struct LaunchDepthwiseConvOp<GPUDevice, T> {
}
};
+// Extern template instantiated in conv_ops.cc.
+extern template class LaunchConv2DOp<GPUDevice, float>;
+
#endif
template <typename Device, typename T>
@@ -267,18 +276,20 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+
+ // For special case when in_depth == 1.
+ use_cudnn_ = CanUseCudnn();
+ cudnn_use_autotune_ = CudnnUseAutotune();
}
void Compute(OpKernelContext* context) override {
// Input tensor is of the following dimensions:
// [ batch, in_rows, in_cols, in_depth ]
const Tensor& input = context->input(0);
- auto input_ptr = input.template flat<T>().data();
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, in_depth, depth_multiplier]
const Tensor& filter = context->input(1);
- auto filter_ptr = filter.template flat<T>().data();
// For 2D convolution, there should be 4 dimensions.
OP_REQUIRES(context, input.dims() == 4,
@@ -338,7 +349,26 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
// [ in_batch, out_rows, out_cols, out_depth ]
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
- auto output_ptr = output->template flat<T>().data();
+
+ VLOG(2) << "DepthwiseConv2dNative: "
+ << " Input: [" << batch << ", " << input_rows << ", " << input_cols
+ << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
+ << filter_cols << ", " << in_depth << ", " << depth_multiplier
+ << "]; stride = " << stride << ", pad_rows = " << pad_rows
+ << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
+ << out_rows << ", " << out_cols << ", " << out_depth << "]";
+
+ // If there is nothing to compute, return.
+ if (out_shape.num_elements() == 0) {
+ return;
+ }
+
+ if (std::is_same<T, float>::value && in_depth == 1) {
+ launcher_.launch(context, use_cudnn_, cudnn_use_autotune_, input, filter,
+ stride, stride, BrainPadding2EigenPadding(padding_),
+ output, FORMAT_NHWC);
+ return;
+ }
DepthwiseArgs args;
args.batch = batch;
@@ -355,18 +385,9 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
args.out_cols = out_cols;
args.out_depth = out_depth;
- VLOG(2) << "DepthwiseConv2dNative: "
- << " Input: [" << batch << ", " << input_rows << ", " << input_cols
- << ", " << in_depth << "]; Filter: [" << filter_rows << ", "
- << filter_cols << ", " << in_depth << ", " << depth_multiplier
- << "]; stride = " << stride << ", pad_rows = " << pad_rows
- << ", pad_cols = " << pad_cols << ", output: [" << batch << ", "
- << out_rows << ", " << out_cols << ", " << out_depth << "]";
-
- // If there is nothing to compute, return.
- if (out_shape.num_elements() == 0) {
- return;
- }
+ auto input_ptr = input.template flat<T>().data();
+ auto filter_ptr = filter.template flat<T>().data();
+ auto output_ptr = output->template flat<T>().data();
LaunchDepthwiseConvOp<Device, T>::launch(context, args, input_ptr,
filter_ptr, output_ptr);
}
@@ -375,6 +396,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
std::vector<int32> strides_;
Padding padding_;
+ // For the case in_depth == 1.
+ LaunchConv2DOp<Device, T> launcher_;
+ bool use_cudnn_;
+ bool cudnn_use_autotune_;
+
TF_DISALLOW_COPY_AND_ASSIGN(DepthwiseConv2dNativeOp);
};
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 96c00ae01b..7a1545daca 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -29,17 +29,17 @@ def ConfigsToTest():
convolution parameters.
"""
input_sizes = [[4, 5, 5, 48], [4, 8, 8, 84], [4, 17, 17, 48], [4, 35, 35, 2],
- [4, 147, 147, 2], [3, 299, 299, 3]]
+ [4, 147, 147, 2], [3, 299, 299, 3], [5, 183, 183, 1]]
filter_sizes = [[1, 1, 48, 2], [1, 3, 84, 1], [3, 1, 48, 4], [5, 5, 2, 1],
- [3, 3, 2, 8], [2, 2, 3, 8]]
+ [3, 3, 2, 8], [2, 2, 3, 8], [5, 5, 1, 2]]
out_sizes = [[4, 5, 5, 96], [4, 8, 8, 84], [4, 17, 17, 192], [4, 35, 35, 2],
- [4, 49, 49, 16], [3, 150, 150, 24]]
- strides = [1, 1, 1, 1, 3, 2]
+ [4, 49, 49, 16], [3, 150, 150, 24], [5, 92, 92, 2]]
+ strides = [1, 1, 1, 1, 3, 2, 2]
# pylint: disable=invalid-name
VALID = "VALID"
SAME = "SAME"
# pylint: enable=invalid-name
- paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME]
+ paddings = [SAME, SAME, SAME, SAME, VALID, SAME, SAME, SAME]
for i, f, o, s, p in zip(input_sizes, filter_sizes, out_sizes, strides,
paddings):
yield i, f, o, s, p
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 303932f9dd..fa3fba3375 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -586,7 +586,7 @@ def zero_fraction(value, name=None):
math_ops.cast(math_ops.equal(value, zero), dtypes.float32))
-# pylint: disable=redefined-builtin,line-too-long
+# pylint: disable=redefined-builtin
def depthwise_conv2d(input, filter, strides, padding, name=None):
"""Depthwise 2-D convolution.
@@ -625,28 +625,10 @@ def depthwise_conv2d(input, filter, strides, padding, name=None):
with ops.name_scope(name, "depthwise", [input, filter]) as name:
input = ops.convert_to_tensor(input, name="tensor_in")
filter = ops.convert_to_tensor(filter, name="filter_in")
- # A shape is required to statically compute the number of separable filters.
- if filter.get_shape().ndims is not None:
- assert len(filter.get_shape()) == 4
- in_channels = filter.get_shape()[2]
- # Sanity checks, if shape information is available for the inputs.
- if input.get_shape().ndims is not None:
- assert len(input.get_shape()) == 4
- assert input.get_shape()[3] == in_channels, (
- "Mismatched input depth %d and number of depthwise filters %d." %
- (input.get_shape()[3].value, in_channels))
- else:
- assert input.get_shape().ndims is not None, (
- "Either tensor must provide static shape information.")
- assert input.get_shape().ndims == 4
- in_channels = input.get_shape()[3]
-
- if in_channels == 1:
- return nn_ops.conv2d(input, filter, strides, padding, name=name)
- else:
- return nn_ops.depthwise_conv2d_native(
- input, filter, strides, padding, name=name)
-# pylint: enable=redefined-builtin,line-too-long
+
+ return nn_ops.depthwise_conv2d_native(
+ input, filter, strides, padding, name=name)
+# pylint: enable=redefined-builtin
# pylint: disable=redefined-builtin,line-too-long
@@ -702,21 +684,24 @@ def separable_conv2d(input, depthwise_filter, pointwise_filter, strides,
pointwise_filter = ops.convert_to_tensor(
pointwise_filter, name="pointwise_filter")
- if pointwise_filter.get_shape().ndims is not None:
- assert len(pointwise_filter.get_shape()) == 4
- assert pointwise_filter.get_shape()[0] == 1
- assert pointwise_filter.get_shape()[1] == 1
- if depthwise_filter.get_shape().ndims and input.get_shape().ndims:
- channel_multiplier = depthwise_filter.get_shape()[3]
- in_channels = input.get_shape()[3]
- out_channels = pointwise_filter.get_shape()[3]
- if channel_multiplier * in_channels > out_channels:
- raise ValueError(
- ("Refusing to perform an overparameterized separable "
- "convolution: channel_multiplier * in_channels = "
- "%d * %d = %d > %d = out_channels" %
- (channel_multiplier, in_channels,
- channel_multiplier * in_channels, out_channels)))
+ pointwise_filter_shape = pointwise_filter.get_shape().with_rank(4)
+ pointwise_filter_shape[0].assert_is_compatible_with(1)
+ pointwise_filter_shape[1].assert_is_compatible_with(1)
+
+ channel_multiplier = depthwise_filter.get_shape().with_rank(4)[3]
+ in_channels = input.get_shape().with_rank(4)[3]
+ out_channels = pointwise_filter_shape[3]
+
+ # If any of channel numbers is unknown, then the comparison below returns
+ # None. See TensorShape.__gt__().
+ if channel_multiplier * in_channels > out_channels:
+ raise ValueError(
+ "Refusing to perform an overparameterized separable "
+ "convolution: channel_multiplier * in_channels = "
+ "%d * %d = %d > %d = out_channels" %
+ (channel_multiplier, in_channels,
+ channel_multiplier * in_channels, out_channels))
+
# The layout of the ops in the graph are expected to be as follows:
# depthwise_conv2d // Conv2D op corresponding to native deptwise conv.
# separable_conv2d // Conv2D op corresponding to the pointwise conv.