diff options
author | 2018-07-17 18:08:14 -0700 | |
---|---|---|
committer | 2018-07-17 18:11:21 -0700 | |
commit | 8238266c4fd433107f38eb126a5c5da05a4d338b (patch) | |
tree | fc4f923e52e8df2aedde5bc180766d501b9b61bf /tensorflow/contrib/fused_conv/kernels | |
parent | 07cc6474b219ee3ad9f55860e621f61b34bb6bd1 (diff) |
Support identity activation function in Cudnn implementation of fused conv2d bias activation.
PiperOrigin-RevId: 205008958
Diffstat (limited to 'tensorflow/contrib/fused_conv/kernels')
-rw-r--r-- | tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 2458f7554a..4554a3d89a 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -135,9 +135,12 @@ class FusedConv2DBiasActivationOp : public OpKernel { context->GetAttr("activation_mode", &activation_mode_str)); OP_REQUIRES_OK(context, GetActivationModeFromString(activation_mode_str, &activation_mode_)); - OP_REQUIRES(context, activation_mode_ == ActivationMode::RELU, - errors::InvalidArgument("Current implementation only supports " - "RELU as the activation function.")); + OP_REQUIRES(context, + activation_mode_ == ActivationMode::RELU || + activation_mode_ == ActivationMode::NONE, + errors::InvalidArgument( + "Current implementation only supports RELU or NONE " + "as the activation function.")); cudnn_use_autotune_ = CudnnUseAutotune(); } @@ -538,6 +541,18 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>:: activation_mode, }; + dnn::ActivationMode dnn_activation_mode; + switch (activation_mode) { + case ActivationMode::NONE: + dnn_activation_mode = dnn::ActivationMode::kNone; + break; + case ActivationMode::RELU: + dnn_activation_mode = dnn::ActivationMode::kRelu; + break; + default: + LOG(FATAL) << "Activation mode " << activation_mode << " not supported"; + } + dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( fused_conv_parameters, &algorithm_config)) { @@ -558,10 +573,9 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>:: ->ThenFusedConvolveWithAlgorithm( conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, filter_ptr, conv_desc, side_input_ptr, - side_input_scale, bias_desc, bias_ptr, - dnn::ActivationMode::kRelu, output_desc, &output_ptr, - &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), - &profile_result) + side_input_scale, bias_desc, bias_ptr, dnn_activation_mode, + output_desc, &output_ptr, &scratch_allocator, + dnn::AlgorithmConfig(profile_algorithm), &profile_result) .ok(); if (cudnn_launch_status) { if (profile_result.is_valid()) { @@ -597,7 +611,7 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>:: ->ThenFusedConvolveWithAlgorithm( conv_input_desc, conv_input_ptr, conv_input_scale, filter_desc, filter_ptr, conv_desc, side_input_ptr, side_input_scale, - bias_desc, bias_ptr, dnn::ActivationMode::kRelu, output_desc, + bias_desc, bias_ptr, dnn_activation_mode, output_desc, &output_ptr, &scratch_allocator, algorithm_config, /*output_profile_result=*/nullptr) .ok(); |