aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc')
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc33
1 files changed, 25 insertions, 8 deletions
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 2458f7554a..0ccb4583ab 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -135,9 +135,12 @@ class FusedConv2DBiasActivationOp : public OpKernel {
context->GetAttr("activation_mode", &activation_mode_str));
OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str,
&activation_mode_));
- OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
- errors::InvalidArgument("Current implementation only supports "
- "RELU as the activation function."));
+ OP_REQUIRES(context,
+ activation_mode_ == ActivationMode::RELU ||
+ activation_mode_ == ActivationMode::NONE,
+ errors::InvalidArgument(
+ "Current implementation only supports RELU or NONE "
+ "as the activation function."));
cudnn_use_autotune_ = CudnnUseAutotune();
}
@@ -440,6 +443,8 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
: dnn::DataLayout::kBatchDepthYX;
constexpr auto filter_layout = is_int8x4 ? dnn::FilterLayout::kOutputInputYX4
: dnn::FilterLayout::kOutputInputYX;
+ constexpr auto compute_data_format =
+ is_int8x4 ? FORMAT_NCHW_VECT_C : FORMAT_NCHW;
dnn::BatchDescriptor conv_input_desc;
conv_input_desc.set_count(batch_size)
@@ -526,6 +531,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
batch_size,
conv_input_depth,
{{conv_input_rows, conv_input_cols}},
+ compute_data_format,
output_depth,
{{filter_rows, filter_cols}},
// TODO(yangzihao): Add support for arbitrary dilations for fused conv.
@@ -538,6 +544,18 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
activation_mode,
};
+ dnn::ActivationMode dnn_activation_mode;
+ switch (activation_mode) {
+ case ActivationMode::NONE:
+ dnn_activation_mode = dnn::ActivationMode::kNone;
+ break;
+ case ActivationMode::RELU:
+ dnn_activation_mode = dnn::ActivationMode::kRelu;
+ break;
+ default:
+ LOG(FATAL) << "Activation mode " << activation_mode << " not supported";
+ }
+
dnn::AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
fused_conv_parameters, &algorithm_config)) {
@@ -558,10 +576,9 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
->ThenFusedConvolveWithAlgorithm(
conv_input_desc, conv_input_ptr, conv_input_scale,
filter_desc, filter_ptr, conv_desc, side_input_ptr,
- side_input_scale, bias_desc, bias_ptr,
- dnn::ActivationMode::kRelu, output_desc, &output_ptr,
- &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
- &profile_result)
+ side_input_scale, bias_desc, bias_ptr, dnn_activation_mode,
+ output_desc, &output_ptr, &scratch_allocator,
+ dnn::AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
@@ -597,7 +614,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
->ThenFusedConvolveWithAlgorithm(
conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
filter_ptr, conv_desc, side_input_ptr, side_input_scale,
- bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc,
+ bias_desc, bias_ptr, dnn_activation_mode, output_desc,
&output_ptr, &scratch_allocator, algorithm_config,
/*output_profile_result=*/nullptr)
.ok();