aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar RJ Ryan <rjryan@google.com>2017-07-10 16:41:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-10 16:51:17 -0700
commita6773e98e97956b7adf3aa51eb3548261f51d6f7 (patch)
treea4fa423385edabe441d7644c6df7a62803e7e2a3 /tensorflow/cc
parent285f9766471e10fa9fee4299940225a33515c010 (diff)
Add a PadV2 op with support for specifying a pad value.
Added a `constant_values` keyword argument to the tf.pad Python API for compatibility with numpy.pad. For now, only scalar values are supported. To efficiently support specifying a `[D, 2]` tensor for `constant_values` to pick per-dimension pre/post constant values will require adding Eigen and XLA support first. PiperOrigin-RevId: 161460091
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc8
1 files changed, 7 insertions, 1 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index 48185db3cb..6545e4ee3e 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -269,6 +269,7 @@ Status ScatterNdNonAliasingAddGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("ScatterNdNonAliasingAdd", ScatterNdNonAliasingAddGrad);
+template <bool IsPadV2>
Status PadGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {
@@ -281,9 +282,14 @@ Status PadGrad(const Scope& scope, const Operation& op,
auto begin = Reshape(scope, pad_before, {-1});
grad_outputs->push_back(Slice(scope, grad_inputs[0], begin, Shape(scope, x)));
grad_outputs->push_back(NoGradient());
+ // PadV2 adds a "constant_values" input.
+ if (IsPadV2) {
+ grad_outputs->push_back(NoGradient());
+ }
return scope.status();
}
-REGISTER_GRADIENT_OP("Pad", PadGrad);
+REGISTER_GRADIENT_OP("Pad", PadGrad<false>);
+REGISTER_GRADIENT_OP("PadV2", PadGrad<true>);
Status SpaceToBatchGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,