diff options
-rw-r--r-- | tensorflow/cc/gradients/math_grad.cc | 243 | ||||
-rw-r--r-- | tensorflow/cc/gradients/math_grad_test.cc | 321 |
2 files changed, 563 insertions, 1 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index 369bdf570f..d841899bbd 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -21,6 +21,248 @@ namespace tensorflow { namespace ops { namespace { +// TODO(andydavis) Add control dependencies to gradient functions (as needed). + +Status AbsGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // dx = dy * sign(x) + grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0)))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Abs", AbsGrad); + +Status NegGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // dx = -dy; + grad_outputs->push_back(Neg(scope, grad_inputs[0])); + return scope.status(); +} +REGISTER_GRADIENT_OP("Neg", NegGrad); + +Status InvGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // dx = dy * (-1 * (y * y)) + grad_outputs->push_back( + Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0))))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Inv", InvGrad); + +Status SquareGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // dx = dy * (2 * x) + auto two = Cast(scope, Const(scope, 2), op.input(0).type()); + grad_outputs->push_back( + Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0)))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Square", SquareGrad); + +Status SqrtGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = sqrt(x) + // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y) + // dx = dy * (0.5 * (1 / y)) + auto y_inv = Inv(scope, op.output(0)); + auto half = Cast(scope, Const(scope, 0.5), op.input(0).type()); + auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Sqrt", SqrtGrad); + +Status RsqrtGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = 1/x^1/2 = x^-1/2 + // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1 + // dx = dy * (-1/2 * y * x^-1) + auto x_inv = Inv(scope, op.input(0)); + auto y = op.output(0); + auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type()); + auto a = Mul(scope, neghalf, x_inv); + auto b = Mul(scope, a, y); + auto dx = Mul(scope, grad_inputs[0], b); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad); + +Status ExpGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = exp(x) + // dy/dx = exp(x) + // dx = dy * y + grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Exp", ExpGrad); + +Status LogGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // f(x) = log(x) = y + // df/dx = 1 / x + // dx = dy * (1 / x) + grad_outputs->push_back(Mul(scope, grad_inputs[0], Inv(scope, op.input(0)))); + return scope.status(); +} +REGISTER_GRADIENT_OP("Log", LogGrad); + +Status TanhGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = tanh(x) + // dy/dx = 1 - (tanh(x))^2 = 1 - y^2 + // dx = dy * (1 - y^2) + auto y2 = Square(scope, op.output(0)); + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2)); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Tanh", TanhGrad); + +Status SigmoidGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = 1 / (1 + exp(-x)) + // dy/dx = y * (1 - y) + // dx = dy * y * (1 - y) + auto y = op.output(0); + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y))); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); + +Status SignGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto shape = Shape(scope, op.input(0)); + auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type()); + auto dx = Fill(scope, shape, zero); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Sign", SignGrad); + +Status SinGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = sin(x) + // dy/dx = cos(x) + // dx = dy * cos(x) + auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0))); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Sin", SinGrad); + +Status CosGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = cos(x) + // dy/dx = -sin(x) + // dx = dy * -sin(x) + auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0)))); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Cos", CosGrad); + +Status AsinGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = asin(x) + // dy/dx = 1 / (1 - x * x)^1/2 + // dx = dy * (1 / (1 - x * x)^1/2) + auto x2 = Square(scope, op.input(0)); + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Inv(scope, Sqrt(scope, Sub(scope, one, x2))); + auto dx = Mul(scope, grad_inputs[0], dydx); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Asin", AsinGrad); + +Status AcosGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = acos(x) + // dy/dx = - 1 / (1 - x * x)^1/2 + // dx = dy * (- 1 / (1 - x * x)^1/2) + auto x2 = Square(scope, op.input(0)); + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Neg(scope, Inv(scope, Sqrt(scope, Sub(scope, one, x2)))); + auto dx = Mul(scope, grad_inputs[0], dydx); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Acos", AcosGrad); + +Status TanGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = tan(x) + // dy/dx = sec(x)^2 = 1 / cos(x)^2 + // dx = dy * (1 / cos(x)^2) + auto dydx = Square(scope, Inv(scope, Cos(scope, op.input(0)))); + auto dx = Mul(scope, grad_inputs[0], dydx); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Tan", TanGrad); + +Status AtanGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + // y = arctan(x) + // dy/dx = 1 / (1 + x^2) + // dx = dy * (1 / (1 + x^2) + auto one = Cast(scope, Const(scope, 1.0), op.input(0).type()); + auto dydx = Inv(scope, Add(scope, one, Square(scope, op.input(0)))); + auto dx = Mul(scope, grad_inputs[0], dydx); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Atan", AtanGrad); + +Status RealGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); + auto dx = Complex(scope, grad_inputs[0], zero); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Real", RealGrad); + +Status ImagGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type()); + auto dx = Complex(scope, zero, grad_inputs[0]); + grad_outputs->push_back(dx); + return scope.status(); +} +REGISTER_GRADIENT_OP("Imag", ImagGrad); + +Status ConjGrad(const Scope& scope, const Operation& op, + const std::vector<Output>& grad_inputs, + std::vector<Output>* grad_outputs) { + grad_outputs->push_back(Conj(scope, grad_inputs[0])); + return scope.status(); +} +REGISTER_GRADIENT_OP("Conj", ConjGrad); + // MatMulGrad helper function used to compute two MatMul operations // based on input matrix transposition combinations. Status MatMulGradHelper(const Scope& scope, const bool is_batch, @@ -91,7 +333,6 @@ Status MatMulGrad(const Scope& scope, const Operation& op, return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a", "transpose_b", grad_outputs); } - REGISTER_GRADIENT_OP("MatMul", MatMulGrad); Status BatchMatMulGrad(const Scope& scope, const Operation& op, diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index 1248c0aa32..456a036111 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -29,6 +29,327 @@ namespace { // TODO(andydavis) Test gradient function against numeric gradients output. // TODO(andydavis) As more gradients are added move common test functions // to a testutil library. + +class CWiseUnaryGradTest : public ::testing::Test { + protected: + CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} + + enum UnaryOpType { + ABS, + NEG, + INV, + SQUARE, + SQRT, + RSQRT, + EXP, + LOG, + TANH, + SIGMOID, + SIGN, + SIN, + COS, + ASIN, + ACOS, + TAN, + ATAN + }; + + void TestCWiseGrad(UnaryOpType op_type, std::function<float(int)> x_fn, + std::function<float(float)> dy_fn, + std::function<float(float, float)> dx_fn) { + Tensor x(DT_FLOAT, {2, 3, 2}); + auto x_flat = x.flat<float>(); + for (int i = 0; i < x_flat.size(); ++i) { + x_flat(i) = x_fn(i); + } + + Tensor dy(DT_FLOAT, {2, 3, 2}); + auto dy_flat = dy.flat<float>(); + for (int i = 0; i < dy_flat.size(); ++i) { + dy_flat(i) = dy_fn(x_flat(i)); + } + + Tensor dx(DT_FLOAT, {2, 3, 2}); + auto dx_flat = dx.flat<float>(); + for (int i = 0; i < dx_flat.size(); ++i) { + dx_flat(i) = dx_fn(x_flat(i), dy_flat(i)); + } + + Output y; + switch (op_type) { + case ABS: + y = Abs(scope_, x); + break; + case NEG: + y = Neg(scope_, x); + break; + case INV: + y = Inv(scope_, x); + break; + case SQUARE: + y = Square(scope_, x); + break; + case SQRT: + y = Sqrt(scope_, x); + break; + case RSQRT: + y = Rsqrt(scope_, x); + break; + case EXP: + y = Exp(scope_, x); + break; + case LOG: + y = Log(scope_, x); + break; + case TANH: + y = Tanh(scope_, x); + break; + case SIGMOID: + y = Sigmoid(scope_, x); + break; + case SIGN: + y = Sign(scope_, x); + break; + case SIN: + y = Sin(scope_, x); + break; + case COS: + y = Cos(scope_, x); + break; + case ASIN: + y = Asin(scope_, x); + break; + case ACOS: + y = Acos(scope_, x); + break; + case TAN: + y = Tan(scope_, x); + break; + case ATAN: + y = Atan(scope_, x); + break; + } + + std::vector<Output> grad_outputs; + TF_ASSERT_OK(test::CallGradFunction( + scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); + Tensor output; + test::GetTensor(scope_, grad_outputs[0], &output); + test::ExpectClose(output, dx); + } + + float RV(std::vector<float> v) { return v[random::New64() % v.size()]; } + + Scope scope_; +}; + +TEST_F(CWiseUnaryGradTest, Abs) { + auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { return x * dy; }; + TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Neg) { + auto x_fn = [this](const int i) { return RV({-1, 0, 1}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { return -dy; }; + TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Inv) { + auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; + auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return -(1 / (x * x)) * dy; + }; + TestCWiseGrad(INV, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Square) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); }; + auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; }; + TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sqrt) { + auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); }; + auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * 0.5 * (1.0 / std::sqrt(x)); + }; + TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Rsqrt) { + auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); }; + auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x); + }; + TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Exp) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * std::exp(x); + }; + TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Log) { + auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); }; + TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tanh) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + const float y = std::tanh(x); + return dy * (1.0 - y * y); + }; + TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sigmoid) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + const float y = 1.0 / (1.0 + std::exp(-x)); + return dy * y * (1.0 - y); + }; + TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sign) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { return 0.0; }; + TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Sin) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * std::cos(x); + }; + TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Cos) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * -1.0 * std::sin(x); + }; + TestCWiseGrad(COS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Asin) { + auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * (1.0 / std::sqrt(1.0 - x * x)); + }; + TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Acos) { + auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * (-1.0 / std::sqrt(1.0 - x * x)); + }; + TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Tan) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + const float cosx = std::cos(x); + return dy * (1 / (cosx * cosx)); + }; + TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn); +} + +TEST_F(CWiseUnaryGradTest, Atan) { + auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); }; + auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); }; + auto dx_fn = [this](const float x, const float dy) { + return dy * (1 / (1 + x * x)); + }; + TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn); +} + +class CWiseUnaryComplexGradTest : public ::testing::Test { + protected: + CWiseUnaryComplexGradTest() + : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {} + + enum UnaryOpType { REAL, IMAG, CONJ }; + + void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x, + const Tensor& dy, const Tensor& dx_expected) { + Output y; + switch (op_type) { + case REAL: + y = Real(scope_, x); + break; + case IMAG: + y = Imag(scope_, x); + break; + case CONJ: + y = Conj(scope_, x); + break; + } + + std::vector<Output> grad_outputs; + TF_ASSERT_OK(test::CallGradFunction( + scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs)); + Tensor dx; + test::GetTensor(scope_, grad_outputs[0], &dx); + test::ExpectClose(dx, dx_expected); + } + + Scope scope_; +}; + +TEST_F(CWiseUnaryComplexGradTest, Real) { + Tensor x = test::AsTensor<complex64>( + {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); + Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); + Tensor dx_expected = test::AsTensor<complex64>( + {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3}); + TestCWiseGradComplex(REAL, x, dy, dx_expected); +} + +TEST_F(CWiseUnaryComplexGradTest, Imag) { + Tensor x = test::AsTensor<complex64>( + {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); + Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3}); + Tensor dx_expected = test::AsTensor<complex64>( + {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3}); + TestCWiseGradComplex(IMAG, x, dy, dx_expected); +} + +TEST_F(CWiseUnaryComplexGradTest, Conj) { + Tensor x = test::AsTensor<complex64>( + {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); + Tensor dy = test::AsTensor<complex64>( + {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3}); + Tensor dx_expected = test::AsTensor<complex64>( + {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3}); + TestCWiseGradComplex(CONJ, x, dy, dx_expected); +} + class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} |