aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/array_grad.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/gradients/array_grad.cc')
-rw-r--r--tensorflow/cc/gradients/array_grad.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/array_grad.cc b/tensorflow/cc/gradients/array_grad.cc
index b353accddc..e9173227aa 100644
--- a/tensorflow/cc/gradients/array_grad.cc
+++ b/tensorflow/cc/gradients/array_grad.cc
@@ -120,6 +120,24 @@ Status SplitGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Split", SplitGrad);
+Status FillGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = fill(fill_shape, x)
+ // No gradient returned for the fill_shape argument.
+ grad_outputs->push_back(NoGradient());
+ // The gradient for x (which must be a scalar) is just the sum of
+ // all the gradients from the shape it fills.
+ // We use ReduceSum to implement this, which needs an argument providing
+ // the indices of all the dimensions of the incoming gradient.
+ // grad(x) = reduce_sum(grad(y), [0..rank(grad(y))])
+ auto all_dims = Range(scope, Const(scope, 0), Rank(scope, grad_inputs[0]),
+ Const(scope, 1));
+ grad_outputs->push_back(ReduceSum(scope, grad_inputs[0], all_dims));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Fill", FillGrad);
+
Status DiagGrad(const Scope& scope, const Operation& op,
const std::vector<Output>& grad_inputs,
std::vector<Output>* grad_outputs) {