diff options
Diffstat (limited to 'tensorflow/cc/gradients/nn_grad.cc')
-rw-r--r-- | tensorflow/cc/gradients/nn_grad.cc | 15 |
1 files changed, 15 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/nn_grad.cc b/tensorflow/cc/gradients/nn_grad.cc index f9d69ff896..6fc73c3fa1 100644 --- a/tensorflow/cc/gradients/nn_grad.cc +++ b/tensorflow/cc/gradients/nn_grad.cc @@ -95,6 +95,21 @@ Status SeluGradHelper(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Selu", SeluGradHelper); +Status BiasAddGradHelper(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + string data_format; + BiasAddGrad::Attrs input_attrs; + TF_RETURN_IF_ERROR( + GetNodeAttr(op.output(0).node()->attrs(), "data_format", &data_format)); + input_attrs.DataFormat(data_format); + auto dx_1 = BiasAddGrad(scope, grad_inputs[0], input_attrs); + grad_outputs->push_back(Identity(scope, grad_inputs[0])); + grad_outputs->push_back(dx_1); + return scope.status(); +} +REGISTER_GRADIENT_OP("BiasAdd", BiasAddGradHelper); + } // anonymous namespace } // namespace ops } // namespace tensorflow |