diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emission_utils.h')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emission_utils.h | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 744346abf3..a64a616ab1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -44,9 +44,11 @@ namespace gpu { // "connectivity" (i.e. which elements of the input affect which elements of // the output) are concerned. enum class CudnnConvKind { - kForward, // input + filter => output - kBackwardInput, // filter + output => input - kBackwardFilter, // input + output => filter + kForward, // input + filter => output + kBackwardInput, // filter + output => input + kBackwardFilter, // input + output => filter + kForwardActivation, // activation(conv(input, filter) + broadcast(bias) + + // (optionally) side_input) => output }; StatusOr<CudnnConvKind> GetCudnnConvKind(const HloCustomCallInstruction* instr); @@ -119,6 +121,7 @@ bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo); extern const char* const kCudnnConvForwardCallTarget; extern const char* const kCudnnConvBackwardInputCallTarget; extern const char* const kCudnnConvBackwardFilterCallTarget; +extern const char* const kCudnnConvBiasActivationForwardCallTarget; // Returns true if `hlo` will be implemented as a call to a cuDNN convolution // routine. |