aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/fused_conv/kernels
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-17 18:08:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 18:11:21 -0700
commit8238266c4fd433107f38eb126a5c5da05a4d338b (patch)
treefc4f923e52e8df2aedde5bc180766d501b9b61bf /tensorflow/contrib/fused_conv/kernels
parent07cc6474b219ee3ad9f55860e621f61b34bb6bd1 (diff)
Support identity activation function in Cudnn implementation of fused conv2d bias activation.
PiperOrigin-RevId: 205008958
Diffstat (limited to 'tensorflow/contrib/fused_conv/kernels')
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc30
1 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 2458f7554a..4554a3d89a 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -135,9 +135,12 @@ class FusedConv2DBiasActivationOp : public OpKernel {
context->GetAttr("activation_mode", &activation_mode_str));
OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str,
&activation_mode_));
- OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU,
- errors::InvalidArgument("Current implementation only supports "
- "RELU as the activation function."));
+ OP_REQUIRES(context,
+ activation_mode_ == ActivationMode::RELU ||
+ activation_mode_ == ActivationMode::NONE,
+ errors::InvalidArgument(
+ "Current implementation only supports RELU or NONE "
+ "as the activation function."));
cudnn_use_autotune_ = CudnnUseAutotune();
}
@@ -538,6 +541,18 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
activation_mode,
};
+ dnn::ActivationMode dnn_activation_mode;
+ switch (activation_mode) {
+ case ActivationMode::NONE:
+ dnn_activation_mode = dnn::ActivationMode::kNone;
+ break;
+ case ActivationMode::RELU:
+ dnn_activation_mode = dnn::ActivationMode::kRelu;
+ break;
+ default:
+ LOG(FATAL) << "Activation mode " << activation_mode << " not supported";
+ }
+
dnn::AlgorithmConfig algorithm_config;
if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find(
fused_conv_parameters, &algorithm_config)) {
@@ -558,10 +573,9 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
->ThenFusedConvolveWithAlgorithm(
conv_input_desc, conv_input_ptr, conv_input_scale,
filter_desc, filter_ptr, conv_desc, side_input_ptr,
- side_input_scale, bias_desc, bias_ptr,
- dnn::ActivationMode::kRelu, output_desc, &output_ptr,
- &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm),
- &profile_result)
+ side_input_scale, bias_desc, bias_ptr, dnn_activation_mode,
+ output_desc, &output_ptr, &scratch_allocator,
+ dnn::AlgorithmConfig(profile_algorithm), &profile_result)
.ok();
if (cudnn_launch_status) {
if (profile_result.is_valid()) {
@@ -597,7 +611,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
->ThenFusedConvolveWithAlgorithm(
conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc,
filter_ptr, conv_desc, side_input_ptr, side_input_scale,
- bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc,
+ bias_desc, bias_ptr, dnn_activation_mode, output_desc,
&output_ptr, &scratch_allocator, algorithm_config,
/*output_profile_result=*/nullptr)
.ok();