aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/math_grad_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc58
1 files changed, 45 insertions, 13 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index a174f223ad..6313f41da5 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -64,7 +64,9 @@ class CWiseUnaryGradTest : public ::testing::Test {
IMAG,
CONJ,
COMPLEX,
- ANGLE
+ ANGLE,
+ LGAMMA,
+ ERF
};
template <typename X_T, typename Y_T>
@@ -168,6 +170,12 @@ class CWiseUnaryGradTest : public ::testing::Test {
case ANGLE:
y = Angle(scope_, x);
break;
+ case LGAMMA:
+ y = Lgamma(scope_, x);
+ break;
+ case ERF:
+ y = Erf(scope_, x);
+ break;
}
float max_error;
@@ -503,6 +511,42 @@ TEST_F(CWiseUnaryGradTest, Angle) {
TestCWiseGrad<complex64, float>(ANGLE, x_fn);
}
+TEST_F(CWiseUnaryGradTest, Lgamma) {
+ auto x_fn = [this](const int i) {
+ return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5});
+ };
+ TestCWiseGrad<float, float>(LGAMMA, x_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Lgamma_Complex) {
+ auto x_fn = [this](const int i) {
+ return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}});
+ };
+ // TODO(kbsriram)
+ // Add test when the lgamma kernel supports complex numbers
+ if (false) {
+ TestCWiseGrad<complex64, complex64>(LGAMMA, x_fn);
+ }
+}
+
+TEST_F(CWiseUnaryGradTest, Erf) {
+ auto x_fn = [this](const int i) {
+ return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3});
+ };
+ TestCWiseGrad<float, float>(ERF, x_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Erf_Complex) {
+ auto x_fn = [this](const int i) {
+ return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}});
+ };
+ // TODO(kbsriram)
+ // Add test when the erf kernel supports complex numbers
+ if (false) {
+ TestCWiseGrad<complex64, complex64>(ERF, x_fn);
+ }
+}
+
class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
@@ -821,17 +865,5 @@ TEST_F(NaryGradTest, Minimum) {
RunTest(x, x_init_value, y, shape);
}
-TEST_F(NaryGradTest, Lgamma) {
- TensorShape shape({3, 2});
- auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape));
- auto y = Lgamma(scope_, x);
- // Select values to avoid instability when computing finite differences.
- // Ref: https://en.wikipedia.org/wiki/File:Gamma_plot.svg
- Tensor x_init_value =
- test::AsTensor<float>({-3.5f, -2.5f, -1.5f, 1.0f, 2.0f, 3.5f}, {3, 2});
- RunTest(x, x_init_value, y, shape);
- // TODO(suharshs): add test case for complex values
-}
-
} // namespace
} // namespace tensorflow