diff options
author | Florian Courtial <floriancourtial@gmail.com> | 2018-05-30 23:36:52 +0200 |
---|---|---|
committer | Florian Courtial <floriancourtial@gmail.com> | 2018-05-30 23:36:52 +0200 |
commit | 2701ab910894da95c25bcf6f2e30f0a6c2c20552 (patch) | |
tree | bc665c08d1fb2cfb364de03b6ceb89424923cef1 /tensorflow/cc | |
parent | 2716bfff551591297a4ba6e61299e8147ac27c05 (diff) |
Add C++ SegmentSum gradient operation.
Diffstat (limited to 'tensorflow/cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 20 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 10 |
2 files changed, 30 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 52c177212a..62404fff09 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -1006,6 +1006,26 @@ Status ProdGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Prod", ProdGrad); +Status SegmentSumGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // The SegmentSum operation sums segments of the Tensor that have the same + // index in the segment_ids parameter. + // i.e z = [2, 3, 4, 5], segment_ids [0, 0, 0, 1] + // will produce [2 + 3 + 4, 5] = [9, 5] + // The gradient that will flow back to the gather operation will look like + // [x1, x2], it will have the same shape as the output of the SegmentSum + // operation. The differentiation step of the SegmentSum operation just + // broadcast the gradient in order to retrieve the z's shape. + // dy/dz = [x1, x1, x1, x2] + grad_outputs->push_back(Gather(scope, grad_inputs[0], op.input(1))); + + // stop propagation along segment_ids + grad_outputs->push_back(NoGradient()); + return scope.status(); +} +REGISTER_GRADIENT_OP("SegmentSum", SegmentSumGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index fd7b6fe662..acc100d144 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -41,6 +41,7 @@ using ops::Mul; using ops::Placeholder; using ops::Pow; using ops::Prod; +using ops::SegmentSum; using ops::RealDiv; using ops::SquaredDifference; using ops::Sub; @@ -902,5 +903,14 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, SegmentSum) { + TensorShape x_shape({3, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = SegmentSum(scope_, x, {0, 0, 1}); + // the sum is always on the first dimension + TensorShape y_shape({2, 4}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow |