aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 14:19:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 14:19:31 -0700
commitd85b570c677a429370221a183ccb4f0e0ab86943 (patch)
treee59c2f77e1b1c9a5147e984d22d76013331050b7 /tensorflow/cc
parentf7cb9d504f1a675375171fa19cee70fa64c28c64 (diff)
parent2701ab910894da95c25bcf6f2e30f0a6c2c20552 (diff)
Merge pull request #19653 from theflofly:segment-sum-prod
PiperOrigin-RevId: 208266242
Diffstat (limited to 'tensorflow/cc')
-rw-r--r--tensorflow/cc/gradients/math_grad.cc20
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc10
2 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 35a01e0341..d95dd879b4 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -1007,6 +1007,26 @@ Status ProdGrad(const Scope& scope, const Operation& op,
}
REGISTER_GRADIENT_OP("Prod", ProdGrad);
+Status SegmentSumGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // The SegmentSum operation sums segments of the Tensor that have the same
+ // index in the segment_ids parameter.
+ // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1]
+ // will produce [2 + 3 + 4, 5] = [9, 5]
+ // The gradient that will flow back to the gather operation will look like
+ // [x1, x2], it will have the same shape as the output of the SegmentSum
+ // operation. The differentiation step of the SegmentSum operation just
+ // broadcast the gradient in order to retrieve the z's shape.
+ // dy/dz = [x1, x1, x1, x2]
+ grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1)));
+
+ // stop propagation along segment_ids
+ grad_outputs->push_back(NoGradient());
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad);
+
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 1c9bdff5e1..c6c9262786 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -42,6 +42,7 @@ using ops::Placeholder;
using ops::Pow;
using ops::Prod;
using ops::RealDiv;
+using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
@@ -898,5 +899,14 @@ TEST_F(NaryGradTest, Prod) {
RunTest({x}, {x_shape}, {y}, {y_shape});
}
+TEST_F(NaryGradTest, SegmentSum) {
+ TensorShape x_shape({3, 4});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = SegmentSum(scope_, x, {0, 0, 1});
+ // the sum is always on the first dimension
+ TensorShape y_shape({2, 4});
+ RunTest({x}, {x_shape}, {y}, {y_shape});
+}
+
} // namespace
} // namespace tensorflow