aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2015-12-17 16:39:02 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-17 16:39:02 -0800
commit1e218e764a342fde41253b29337bfb5efd5c59ce (patch)
tree68994e827c8294700a0f3ce00ea3ce6c8a26450e
parent9c23536c2ead0bb2e8d9904a4260fe0bf188bbb5 (diff)
Improve TensorFlow convnet performance.
+ Change the default workspace limit to 4GB. + Only change the padding of the forward convolution when they are incompatible. + Make worksapce allocation failure non-fatal. Change: 110497322
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.cc4
-rw-r--r--tensorflow/core/kernels/conv_ops.cc45
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h5
-rw-r--r--tensorflow/models/image/alexnet/alexnet_benchmark.py6
4 files changed, 33 insertions, 27 deletions
diff --git a/tensorflow/core/kernels/conv_grad_ops.cc b/tensorflow/core/kernels/conv_grad_ops.cc
index ac0ed0a404..1968aa7082 100644
--- a/tensorflow/core/kernels/conv_grad_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_ops.cc
@@ -948,7 +948,7 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
pre_transformed_in_backprop.template flat<T>().size());
static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
context);
@@ -1243,7 +1243,7 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
transformed_input.template flat<T>().size());
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
context);
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index 6af13cb6cd..d1dc8a7f97 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -261,6 +261,8 @@ struct LaunchConvOp<GPUDevice, T> {
}
return;
}
+ int padding_rows = 0;
+ int padding_cols = 0;
if (padding == Eigen::PADDING_SAME) {
const int64 out_rows = output->dim_size(1);
const int64 out_cols = output->dim_size(2);
@@ -276,23 +278,26 @@ struct LaunchConvOp<GPUDevice, T> {
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
// we pad more on the right and bottom than on the top and left.
- const int padding_rows = (out_rows - 1) * stride + patch_rows - in_rows;
- const int padding_cols = (out_cols - 1) * stride + patch_cols - in_cols;
- Tensor transformed_input;
- OP_REQUIRES_OK(
- ctx, ctx->allocate_temp(
- DataTypeToEnum<T>::value,
- TensorShape(
- {input.dim_size(0), input.dim_size(1) + padding_rows,
- input.dim_size(2) + padding_cols, input.dim_size(3)}),
- &transformed_input));
-
- functor::PadInput<GPUDevice, T, int>()(
- ctx->eigen_device<GPUDevice>(), To32Bit(input_param.tensor<T, 4>()),
- padding_rows / 2, padding_rows - padding_rows / 2, padding_cols / 2,
- padding_cols - padding_cols / 2,
- To32Bit(transformed_input.tensor<T, 4>()));
- input = transformed_input;
+ padding_rows = (out_rows - 1) * stride + patch_rows - in_rows;
+ padding_cols = (out_cols - 1) * stride + patch_cols - in_cols;
+ const bool rows_odd = (padding_rows % 2 != 0);
+ const bool cols_odd = (padding_cols % 2 != 0);
+ if (rows_odd || cols_odd) {
+ Tensor transformed_input;
+ OP_REQUIRES_OK(
+ ctx, ctx->allocate_temp(
+ DataTypeToEnum<T>::value,
+ TensorShape(
+ {input.dim_size(0), input.dim_size(1) + rows_odd,
+ input.dim_size(2) + cols_odd, input.dim_size(3)}),
+ &transformed_input));
+
+ functor::PadInput<GPUDevice, T, int>()(
+ ctx->eigen_device<GPUDevice>(),
+ To32Bit(input_param.tensor<T, 4>()), 0, rows_odd, 0, cols_odd,
+ To32Bit(transformed_input.tensor<T, 4>()));
+ input = transformed_input;
+ }
}
{
@@ -330,7 +335,9 @@ struct LaunchConvOp<GPUDevice, T> {
.set_output_feature_map_count(filter.dim_size(3));
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_filter_stride(stride)
- .set_horizontal_filter_stride(stride);
+ .set_horizontal_filter_stride(stride)
+ .set_zero_padding_height(padding_rows / 2)
+ .set_zero_padding_width(padding_cols / 2);
Tensor transformed_filter;
OP_REQUIRES_OK(ctx,
@@ -362,7 +369,7 @@ struct LaunchConvOp<GPUDevice, T> {
transformed_output.template flat<T>().size());
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
- "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
+ "TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 32 // 4GB by default
);
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
bool cudnn_launch_status =
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index bbe06cb6a1..bcdc1c3510 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -61,11 +61,10 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
if (!allocation_status.ok()) {
LOG(WARNING) << allocation_status;
- context_->SetStatus(allocation_status);
return perftools::gputools::port::StatusOr<
- perftools::gputools::DeviceMemory<uint8>>();
+ perftools::gputools::DeviceMemory<uint8>>(
+ AsDeviceMemory<uint8>(nullptr, 0));
}
-
return perftools::gputools::port::StatusOr<
perftools::gputools::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
diff --git a/tensorflow/models/image/alexnet/alexnet_benchmark.py b/tensorflow/models/image/alexnet/alexnet_benchmark.py
index 1f8c4df110..d70f213708 100644
--- a/tensorflow/models/image/alexnet/alexnet_benchmark.py
+++ b/tensorflow/models/image/alexnet/alexnet_benchmark.py
@@ -70,7 +70,7 @@ def inference(images):
with tf.name_scope('conv1') as scope:
kernel = tf.Variable(tf.truncated_normal([11, 11, 3, 64], dtype=tf.float32,
stddev=1e-1), name='weights')
- conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding='VALID')
+ conv = tf.nn.conv2d(images, kernel, [1, 4, 4, 1], padding='SAME')
biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
trainable=True, name='biases')
bias = tf.nn.bias_add(conv, biases)
@@ -200,8 +200,8 @@ def run_benchmark():
# In order to force the model to start with the same activations sizes,
# we add 3 to the image_size and employ VALID padding above.
images = tf.Variable(tf.random_normal([FLAGS.batch_size,
- image_size + 3,
- image_size + 3, 3],
+ image_size,
+ image_size, 3],
dtype=tf.float32,
stddev=1e-1))