diff options
author | 2017-09-13 17:14:48 -0700 | |
---|---|---|
committer | 2017-09-13 17:21:45 -0700 | |
commit | 8570d4948a2c44cd6646cdc38f4fc17a915f6081 (patch) | |
tree | 7e24648fda97e227343463dae53453e5c8860573 /tensorflow/cc/gradients/math_grad_test.cc | |
parent | 75b061cc6d5d4c05c2e0be5a169f1dbb89459482 (diff) |
Use the numeric gradient checker for unary math gradient tests.
- Port CWiseUnaryGradTest to use ComputeGradientError.
- Move test values away from any poles or discontinuities in the
gradient to reduce numeric estimation errors.
- Consolidate CWiseUnaryComplexGradTest and CWiseUnaryGradTest.
PiperOrigin-RevId: 168619288
Diffstat (limited to 'tensorflow/cc/gradients/math_grad_test.cc')
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 523 |
1 files changed, 134 insertions, 389 deletions
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 97cd86eacb..047243aa6a 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -59,30 +59,24 @@ class CWiseUnaryGradTest : public ::testing::Test { ASIN, ACOS, TAN, - ATAN + ATAN, + REAL, + IMAG, + CONJ, + COMPLEX, + ANGLE }; - template <typename T> - void TestCWiseGrad(UnaryOpType op_type, const std::function<T(int)>& x_fn, - const std::function<T(const T&)>& dy_fn, - const std::function<T(const T&, const T&)>& dx_fn) { - DataType dtype = DataTypeToEnum<T>::v(); - Tensor x(dtype, {2, 3, 2}); - auto x_flat = x.flat<T>(); - for (int i = 0; i < x_flat.size(); ++i) { - x_flat(i) = x_fn(i); - } - - Tensor dy(dtype, {2, 3, 2}); - auto dy_flat = dy.flat<T>(); - for (int i = 0; i < dy_flat.size(); ++i) { - dy_flat(i) = dy_fn(x_flat(i)); - } - - Tensor dx(dtype, {2, 3, 2}); - auto dx_flat = dx.flat<T>(); - for (int i = 0; i < dx_flat.size(); ++i) { - dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); + template <typename X_T, typename Y_T> + void TestCWiseGrad(UnaryOpType op_type, const std::function<X_T(int)>& x_fn) { + TF_ASSERT_OK(scope_.status()); + DataType x_type = DataTypeToEnum<X_T>::v(); + TensorShape shape({2, 3, 2}); + auto x = Placeholder(scope_, x_type, Placeholder::Shape(shape)); + Tensor x_data(x_type, shape); + auto x_data_flat = x_data.flat<X_T>(); + for (int i = 0; i < x_data_flat.size(); ++i) { + x_data_flat(i) = x_fn(i); } Output y; @@ -159,14 +153,27 @@ class CWiseUnaryGradTest : public ::testing::Test { case ATAN: y = Atan(scope_, x); break; + case REAL: + y = Real(scope_, x); + break; + case IMAG: + y = Imag(scope_, x); + break; + case CONJ: + y = Conj(scope_, x); + break; + case COMPLEX: + y = Complex(scope_, x, x); + break; + case ANGLE: + y = Angle(scope_, x); + break; } - std::vector<Output> grad_outputs; - TF_ASSERT_OK(test::CallGradFunction( - scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); - Tensor output; - test::GetTensor(scope_, grad_outputs[0], &output); - test::ExpectClose(output, dx); + float max_error; + TF_ASSERT_OK((ComputeGradientError<X_T, Y_T, float>(scope_, x, x_data, y, + shape, &max_error))); + EXPECT_LT(max_error, 1e-3f); } float RV(const std::vector<float>& v) { @@ -181,581 +188,319 @@ class CWiseUnaryGradTest : public ::testing::Test { return complex64(val.real(), -val.imag()); } - const complex64 one_{1.0, 0}; - Scope scope_; }; TEST_F(CWiseUnaryGradTest, Abs) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { return x * dy; }; - TestCWiseGrad<float>(ABS, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(ABS, x_fn); } TEST_F(CWiseUnaryGradTest, Neg) { auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { return -dy; }; - TestCWiseGrad<float>(NEG, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(NEG, x_fn); } TEST_F(CWiseUnaryGradTest, Reciprocal) { auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; - auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return -(1 / (x * x)) * dy; - }; - TestCWiseGrad<float>(INV, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(INV, x_fn); } TEST_F(CWiseUnaryGradTest, Reciprocal_Complex) { auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64 x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64 x, const complex64 dy) { - return -conjugate(one_ / (x * x)) * dy; - }; - TestCWiseGrad<complex64>(INV, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(INV, x_fn); } TEST_F(CWiseUnaryGradTest, Square) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; - auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; - TestCWiseGrad<float>(SQUARE, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(SQUARE, x_fn); } TEST_F(CWiseUnaryGradTest, Square_Complex) { auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return conjugate(complex64(2, 0) * x) * dy; - }; - TestCWiseGrad<complex64>(SQUARE, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(SQUARE, x_fn); } TEST_F(CWiseUnaryGradTest, Sqrt) { - auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; - auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * 0.5 * (1.0 / std::sqrt(x)); - }; - TestCWiseGrad<float>(SQRT, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4, 5, 6, 7}); }; + TestCWiseGrad<float, float>(SQRT, x_fn); } TEST_F(CWiseUnaryGradTest, Sqrt_Complex) { - auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return conjugate(complex64(0.5, 0) / std::sqrt(x)) * dy; + auto x_fn = [this](const int i) { + return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); }; - TestCWiseGrad<complex64>(SQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(SQRT, x_fn); } TEST_F(CWiseUnaryGradTest, Rsqrt) { auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; - auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); - }; - TestCWiseGrad<float>(RSQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(RSQRT, x_fn); } TEST_F(CWiseUnaryGradTest, Rsqrt_Complex) { - auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return conjugate(complex64(-0.5, 0) / std::sqrt(x) / x) * dy; + auto x_fn = [this](const int i) { + return CRV({{-1.0f, 0.5f}, {1.0f, 0.5f}, {2, -1}}); }; - TestCWiseGrad<complex64>(RSQRT, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(RSQRT, x_fn); } TEST_F(CWiseUnaryGradTest, Exp) { - auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * std::exp(x); + auto x_fn = [this](const int i) { + return RV({0, -1, 1, -1.5f, 1.5f, -2, 2}); }; - TestCWiseGrad<float>(EXP, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(EXP, x_fn); } TEST_F(CWiseUnaryGradTest, Exp_Complex) { auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(std::exp(x)); - }; - TestCWiseGrad<complex64>(EXP, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(EXP, x_fn); } TEST_F(CWiseUnaryGradTest, Expm1) { - auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -2, 3, 100}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * std::exp(x); - }; - TestCWiseGrad<float>(EXPM1, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0, -1, 1e-6, 1, -1.5, 1.5}); }; + TestCWiseGrad<float, float>(EXPM1, x_fn); } TEST_F(CWiseUnaryGradTest, Expm1_Complex) { - auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(std::exp(x)); + auto x_fn = [this](const int i) { + return CRV({{-1, 0}, {1, 0}, {1.5, -1.5}}); }; - TestCWiseGrad<complex64>(EXPM1, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(EXPM1, x_fn); } TEST_F(CWiseUnaryGradTest, Log) { - auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; - TestCWiseGrad<float>(LOG, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0.5, 1, 2, 3, 4}); }; + TestCWiseGrad<float, float>(LOG, x_fn); } TEST_F(CWiseUnaryGradTest, Log_Complex) { - auto x_fn = [this](const int i) { return CRV({{-1, 0}, {1, 0}, {2, -1}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(one_ / x); + auto x_fn = [this](const int i) { + return CRV({{-1, 0.5f}, {1, 0.5f}, {2, -1}}); }; - TestCWiseGrad<complex64>(LOG, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(LOG, x_fn); } TEST_F(CWiseUnaryGradTest, Log1p) { auto x_fn = [this](const int i) { return RV({0, 1e-6, 1, 2, 3, 4, 100}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * (1.0 / (1.0 + x)); - }; - TestCWiseGrad<float>(LOG1P, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(LOG1P, x_fn); } TEST_F(CWiseUnaryGradTest, Log1p_Complex) { auto x_fn = [this](const int i) { return CRV({{0, 0}, {1e-6, 0}, {2, -1}, {1, 2}, {3, 4}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy / (one_ + conjugate(x)); - }; - TestCWiseGrad<complex64>(LOG1P, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(LOG1P, x_fn); } TEST_F(CWiseUnaryGradTest, Sinh) { - auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * std::cosh(x); - }; - TestCWiseGrad<float>(SINH, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0.5, -0.5, 1, -1, 1.5, -1.5}); }; + TestCWiseGrad<float, float>(SINH, x_fn); } TEST_F(CWiseUnaryGradTest, Sinh_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); - }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(std::cosh(x)); + return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); }; - TestCWiseGrad<complex64>(SINH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(SINH, x_fn); } TEST_F(CWiseUnaryGradTest, Cosh) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * std::sinh(x); - }; - TestCWiseGrad<float>(COSH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(COSH, x_fn); } TEST_F(CWiseUnaryGradTest, Cosh_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); - }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(std::sinh(x)); + return CRV({{0.5, 0.25}, {0.25, 0.5}, {1.5, -1}, {1, 1.5}}); }; - TestCWiseGrad<complex64>(COSH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(COSH, x_fn); } TEST_F(CWiseUnaryGradTest, Tanh) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - const float y = std::tanh(x); - return dy * (1.0 - y * y); - }; - TestCWiseGrad<float>(TANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(TANH, x_fn); } TEST_F(CWiseUnaryGradTest, Tanh_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - const complex64 y = std::tanh(x); - return dy * conjugate((one_ - y * y)); - }; - TestCWiseGrad<complex64>(TANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(TANH, x_fn); } TEST_F(CWiseUnaryGradTest, Asinh) { - auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - auto y = std::asinh(x); - return dy / std::cosh(y); - }; - TestCWiseGrad<float>(ASINH, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0.5, 1, -1, -1.5, 1.5}); }; + TestCWiseGrad<float, float>(ASINH, x_fn); } TEST_F(CWiseUnaryGradTest, Asinh_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); - }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - auto y = std::asinh(x); - return dy / conjugate(std::cosh(y)); + return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); }; - TestCWiseGrad<complex64>(ASINH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(ASINH, x_fn); } TEST_F(CWiseUnaryGradTest, Acosh) { - auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7}); }; - auto dy_fn = [this](const float x) { - return x + RV({8, 9, 10, 11, 12, 13, 14}); - }; - auto dx_fn = [this](const float x, const float dy) { - auto y = std::acosh(x); - return dy / std::sinh(y); - }; - TestCWiseGrad<float>(ACOSH, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({1.5, 2, 2.5}); }; + TestCWiseGrad<float, float>(ACOSH, x_fn); } TEST_F(CWiseUnaryGradTest, Acosh_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 1}, {2, 1}, {1, 4}, {1, 2}, {3, 4}}); - }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{2, 2}, {3, 3}, {1, 4}}); + return CRV({{1, 0.5}, {0.5, 1}, {0.5, -1}, {1, 1.5}}); }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - auto y = std::acosh(x); - return dy / conjugate(std::sinh(y)); - }; - TestCWiseGrad<complex64>(ACOSH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(ACOSH, x_fn); } TEST_F(CWiseUnaryGradTest, Atanh) { auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.1, 0.1}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * (1. / (1. - x * x)); - }; - TestCWiseGrad<float>(ATANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(ATANH, x_fn); } TEST_F(CWiseUnaryGradTest, Atanh_Complex) { auto x_fn = [this](const int i) { return CRV({{0.1, 0}, {0, 0.1}, {0.2, -0.1}, {0.1, 0.2}, {0.3, 0.4}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy / conjugate(one_ - x * x); - }; - TestCWiseGrad<complex64>(ATANH, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(ATANH, x_fn); } TEST_F(CWiseUnaryGradTest, Sigmoid) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - const float y = 1.0 / (1.0 + std::exp(-x)); - return dy * y * (1.0 - y); - }; - TestCWiseGrad<float>(SIGMOID, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(SIGMOID, x_fn); } TEST_F(CWiseUnaryGradTest, Sigmoid_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 0}, {2, -1}, {1, 2}, {3, 4}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - const complex64 y = one_ / (one_ + std::exp(-x)); - return dy * conjugate(y * (one_ - y)); - }; - TestCWiseGrad<complex64>(SIGMOID, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(SIGMOID, x_fn); } TEST_F(CWiseUnaryGradTest, Sign) { - auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { return 0.0; }; - TestCWiseGrad<float>(SIGN, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3}); }; + TestCWiseGrad<float, float>(SIGN, x_fn); } TEST_F(CWiseUnaryGradTest, Sin) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * std::cos(x); - }; - TestCWiseGrad<float>(SIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(SIN, x_fn); } TEST_F(CWiseUnaryGradTest, Sin_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(std::cos(x)); - }; - TestCWiseGrad<complex64>(SIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(SIN, x_fn); } TEST_F(CWiseUnaryGradTest, Cos) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * -1.0 * std::sin(x); - }; - TestCWiseGrad<float>(COS, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(COS, x_fn); } TEST_F(CWiseUnaryGradTest, Cos_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); + return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy * conjugate(-std::sin(x)); - }; - TestCWiseGrad<complex64>(COS, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(COS, x_fn); } TEST_F(CWiseUnaryGradTest, Asin) { - auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * (1.0 / std::sqrt(1.0 - x * x)); - }; - TestCWiseGrad<float>(ASIN, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0, 0.25, -0.25, -0.5, 0.5}); }; + TestCWiseGrad<float, float>(ASIN, x_fn); } TEST_F(CWiseUnaryGradTest, Asin_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); - }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy / conjugate(std::sqrt(one_ - x * x)); + return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); }; // TODO(kbsriram) // Enable test when the asin kernel supports complex numbers if (false) { - TestCWiseGrad<complex64>(ASIN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(ASIN, x_fn); } } TEST_F(CWiseUnaryGradTest, Acos) { - auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * (-1.0 / std::sqrt(1.0 - x * x)); - }; - TestCWiseGrad<float>(ACOS, x_fn, dy_fn, dx_fn); + auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -0.75, 0.75}); }; + TestCWiseGrad<float, float>(ACOS, x_fn); } TEST_F(CWiseUnaryGradTest, Acos_Complex) { auto x_fn = [this](const int i) { - return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); - }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy / -conjugate(std::sqrt(one_ - x * x)); + return CRV({{0.5, 0}, {0, 0.5}, {0.25, -0.75}, {0.5, 0.25}}); }; // TODO(kbsriram) // Add test when the acos kernel supports complex numbers if (false) { - TestCWiseGrad<complex64>(ACOS, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(ACOS, x_fn); } } TEST_F(CWiseUnaryGradTest, Tan) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - const float cosx = std::cos(x); - return dy * (1 / (cosx * cosx)); - }; - TestCWiseGrad<float>(TAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(TAN, x_fn); } TEST_F(CWiseUnaryGradTest, Tan_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - const complex64 cosx = std::cos(x); - return dy / conjugate(cosx * cosx); - }; // TODO(kbsriram) // Enable when tan kernel supports complex inputs if (false) { - TestCWiseGrad<complex64>(TAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(TAN, x_fn); } } TEST_F(CWiseUnaryGradTest, Atan) { auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; - auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; - auto dx_fn = [this](const float x, const float dy) { - return dy * (1 / (1 + x * x)); - }; - TestCWiseGrad<float>(ATAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<float, float>(ATAN, x_fn); } TEST_F(CWiseUnaryGradTest, Atan_Complex) { auto x_fn = [this](const int i) { return CRV({{1, 0}, {0, 1}, {2, -1}, {1, 2}, {3, 4}}); }; - auto dy_fn = [this](const complex64& x) { - return x + CRV({{-2, 2}, {-3, 3}, {1, -4}}); - }; - auto dx_fn = [this](const complex64& x, const complex64& dy) { - return dy / (one_ + x * x); - }; // TODO(kbsriram) // Add test when the atan kernel supports complex numbers if (false) { - TestCWiseGrad<complex64>(ATAN, x_fn, dy_fn, dx_fn); + TestCWiseGrad<complex64, complex64>(ATAN, x_fn); } } -class CWiseUnaryComplexGradTest : public ::testing::Test { - protected: - CWiseUnaryComplexGradTest() - : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} - - enum UnaryOpType { REAL, IMAG, ANGLE, CONJ }; +TEST_F(CWiseUnaryGradTest, Real) { + auto x_fn = [this](const int i) { + return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); + }; + TestCWiseGrad<complex64, float>(REAL, x_fn); +} - void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, - const Tensor& dy, const Tensor& dx_expected) { - Output y; - switch (op_type) { - case REAL: - y = Real(scope_, x); - break; - case IMAG: - y = Imag(scope_, x); - break; - case ANGLE: - y = Angle(scope_, x); - break; - case CONJ: - y = Conj(scope_, x); - break; - } +TEST_F(CWiseUnaryGradTest, Imag) { + auto x_fn = [this](const int i) { + return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); + }; + TestCWiseGrad<complex64, float>(IMAG, x_fn); +} - std::vector<Output> grad_outputs; - TF_ASSERT_OK(test::CallGradFunction( - scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); - Tensor dx; - test::GetTensor(scope_, grad_outputs[0], &dx); - test::ExpectClose(dx, dx_expected); - } +TEST_F(CWiseUnaryGradTest, Conj) { + auto x_fn = [this](const int i) { + return CRV({{1, -1}, {-2, 2}, {2, 3}, {-2, -3}}); + }; + TestCWiseGrad<complex64, complex64>(CONJ, x_fn); +} - Scope scope_; -}; +TEST_F(CWiseUnaryGradTest, Complex) { + auto x_fn = [this](const int i) { return RV({1, -1, 2, -2, 3, -3}); }; + TestCWiseGrad<float, complex64>(COMPLEX, x_fn); +} -TEST_F(CWiseUnaryComplexGradTest, Real) { - Tensor x = test::AsTensor<complex64>( - {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); - Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); - Tensor dx_expected = test::AsTensor<complex64>( - {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); - TestCWiseGradComplex(REAL, x, dy, dx_expected); -} - -TEST_F(CWiseUnaryComplexGradTest, Imag) { - Tensor x = test::AsTensor<complex64>( - {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); - Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); - Tensor dx_expected = test::AsTensor<complex64>( - {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); - TestCWiseGradComplex(IMAG, x, dy, dx_expected); -} - -TEST_F(CWiseUnaryComplexGradTest, Angle) { - Tensor x = test::AsTensor<complex64>( - {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); - Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); - Tensor dx_expected = - test::AsTensor<complex64>({{5.5, 5.5}, - {3, 3}, - {2.1666666666666665, 2.1666666666666665}, - {1.75, 1.75}, - {0.9375, 0.9375}, - {0.8888888888888888, 0.8888888888888888}}, - {2, 3}); - TestCWiseGradComplex(ANGLE, x, dy, dx_expected); -} - -TEST_F(CWiseUnaryComplexGradTest, Conj) { - Tensor x = test::AsTensor<complex64>( - {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); - Tensor dy = test::AsTensor<complex64>( - {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); - Tensor dx_expected = test::AsTensor<complex64>( - {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); - TestCWiseGradComplex(CONJ, x, dy, dx_expected); +TEST_F(CWiseUnaryGradTest, Angle) { + auto x_fn = [this](const int i) { + return CRV({{1.5, 1.5}, {1.5, -1.5}, {-1.5, 1.5}, {-1.5, -1.5}}); + }; + TestCWiseGrad<complex64, float>(ANGLE, x_fn); } class MathGradTest : public ::testing::Test { |