diff options
author | 2017-09-12 10:16:26 -0700 | |
---|---|---|
committer | 2017-09-12 10:20:30 -0700 | |
commit | e6b011763a60d239972c8c6c0f36536ab6f885a3 (patch) | |
tree | 8930a1e6f5efa50c860683ea86807335c7470cbf /tensorflow/cc/gradients/math_grad_test.cc | |
parent | f63aa7f49f81a66112bfef6670a18658d5a479e5 (diff) |
Extend c++ gradient_checker to complex types.
PiperOrigin-RevId: 168392949
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 20 |
1 files changed, 12 insertions, 8 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 5b1558dd82..97cd86eacb 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -737,10 +737,14 @@ TEST_F(CWiseUnaryComplexGradTest, Angle) { Tensor x = test::AsTensor<complex64>( {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); - Tensor dx_expected = test::AsTensor<complex64>( - {{5.5, 5.5}, {3, 3}, - {2.1666666666666665, 2.1666666666666665}, {1.75, 1.75}, - {0.9375, 0.9375}, {0.8888888888888888, 0.8888888888888888}}, {2, 3}); + Tensor dx_expected = + test::AsTensor<complex64>({{5.5, 5.5}, + {3, 3}, + {2.1666666666666665, 2.1666666666666665}, + {1.75, 1.75}, + {0.9375, 0.9375}, + {0.8888888888888888, 0.8888888888888888}}, + {2, 3}); TestCWiseGradComplex(ANGLE, x, dy, dx_expected); } @@ -920,8 +924,8 @@ class NaryGradTest : public ::testing::Test { const OutputList& ys, const std::vector<TensorShape>& y_shapes) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, xs, x_shapes, ys, y_shapes, &max_error)); + TF_ASSERT_OK((ComputeGradientError<float, float, float>( + scope_, xs, x_shapes, ys, y_shapes, &max_error))); EXPECT_LT(max_error, 1e-3); } @@ -929,8 +933,8 @@ class NaryGradTest : public ::testing::Test { const TensorShape& y_shape) { TF_ASSERT_OK(scope_.status()); float max_error; - TF_ASSERT_OK( - ComputeGradientError(scope_, x, x_init_value, y, y_shape, &max_error)); + TF_ASSERT_OK((ComputeGradientError<float, float, float>( + scope_, x, x_init_value, y, y_shape, &max_error))); EXPECT_LT(max_error, 1e-3); } |