diff options
Diffstat (limited to 'tensorflow/cc/gradients/array_grad.cc')
-rw-r--r-- | tensorflow/cc/gradients/array_grad.cc | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc index 48185db3cb..6545e4ee3e 100644 --- a/tensorflow/cc/gradients/array_grad.cc +++ b/tensorflow/cc/gradients/array_grad.cc @@ -269,6 +269,7 @@ Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad); +template <bool IsPadV2> Status PadGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, std::vector<Output>* grad_outputs) { @@ -281,9 +282,14 @@ Status PadGrad(const Scope& scope, const Operation& op, auto begin = Reshape(scope, pad_before, {-1}); grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x))); grad_outputs->push_back(NoGradient()); + // PadV2 adds a "constant_values" input. + if (IsPadV2) { + grad_outputs->push_back(NoGradient()); + } return scope.status(); } -REGISTER_GRADIENT_OP("Pad", PadGrad); +REGISTER_GRADIENT_OP("Pad", PadGrad<false>); +REGISTER_GRADIENT_OP("PadV2", PadGrad<true>); Status SpaceToBatchGrad(const Scope& scope, const Operation& op, const std::vector<Output>& grad_inputs, |