aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/cc/gradients/math_grad_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-12 10:16:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-12 10:20:30 -0700
commite6b011763a60d239972c8c6c0f36536ab6f885a3 (patch)
tree8930a1e6f5efa50c860683ea86807335c7470cbf /tensorflow/cc/gradients/math_grad_test.cc
parentf63aa7f49f81a66112bfef6670a18658d5a479e5 (diff)
Extend c++ gradient_checker to complex types.
PiperOrigin-RevId: 168392949
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc20
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 5b1558dd82..97cd86eacb 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -737,10 +737,14 @@ TEST_F(CWiseUnaryComplexGradTest, Angle) {
Tensor x = test::AsTensor<complex64>(
{{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
- Tensor dx_expected = test::AsTensor<complex64>(
- {{5.5, 5.5}, {3, 3},
- {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75},
- {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3});
+ Tensor dx_expected =
+ test::AsTensor<complex64>({{5.5, 5.5},
+ {3, 3},
+ {2.1666666666666665, 2.1666666666666665},
+ {1.75, 1.75},
+ {0.9375, 0.9375},
+ {0.8888888888888888, 0.8888888888888888}},
+ {2, 3});
TestCWiseGradComplex(ANGLE, x, dy, dx_expected);
}
@@ -920,8 +924,8 @@ class NaryGradTest : public ::testing::Test {
const OutputList& ys, const std::vector<TensorShape>& y_shapes) {
TF_ASSERT_OK(scope_.status());
float max_error;
- TF_ASSERT_OK(
- ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<float, float, float>(
+ scope_, xs, x_shapes, ys, y_shapes, &max_error)));
EXPECT_LT(max_error, 1e-3);
}
@@ -929,8 +933,8 @@ class NaryGradTest : public ::testing::Test {
const TensorShape& y_shape) {
TF_ASSERT_OK(scope_.status());
float max_error;
- TF_ASSERT_OK(
- ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error));
+ TF_ASSERT_OK((ComputeGradientError<float, float, float>(
+ scope_, x, x_init_value, y, y_shape, &max_error)));
EXPECT_LT(max_error, 1e-3);
}