aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/math_grad_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc10
1 files changed, 10 insertions, 0 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 1c9bdff5e1..c6c9262786 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -42,6 +42,7 @@ using ops::Placeholder;
using ops::Pow;
using ops::Prod;
using ops::RealDiv;
+using ops::SegmentSum;
using ops::SquaredDifference;
using ops::Sub;
using ops::Sum;
@@ -898,5 +899,14 @@ TEST_F(NaryGradTest, Prod) {
RunTest({x}, {x_shape}, {y}, {y_shape});
}
+TEST_F(NaryGradTest, SegmentSum) {
+ TensorShape x_shape({3, 4});
+ auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(x_shape));
+ auto y = SegmentSum(scope_, x, {0, 0, 1});
+ // the sum is always on the first dimension
+ TensorShape y_shape({2, 4});
+ RunTest({x}, {x_shape}, {y}, {y_shape});
+}
+
} // namespace
} // namespace tensorflow