diff options
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 10 |
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1c9bdff5e1..c6c9262786 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -42,6 +42,7 @@ using ops::Placeholder; using ops::Pow; using ops::Prod; using ops::RealDiv; +using ops::SegmentSum; using ops::SquaredDifference; using ops::Sub; using ops::Sum; @@ -898,5 +899,14 @@ TEST_F(NaryGradTest, Prod) { RunTest({x}, {x_shape}, {y}, {y_shape}); } +TEST_F(NaryGradTest, SegmentSum) { + TensorShape x_shape({3, 4}); + auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape)); + auto y = SegmentSum(scope_, x, {0, 0, 1}); + // the sum is always on the first dimension + TensorShape y_shape({2, 4}); + RunTest({x}, {x_shape}, {y}, {y_shape}); +} + } // namespace } // namespace tensorflow |