aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/cc/gradients/math_grad.cc243
-rw-r--r--tensorflow/cc/gradients/math_grad_test.cc321
2 files changed, 563 insertions, 1 deletions
diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc
index 369bdf570f..d841899bbd 100644
--- a/tensorflow/cc/gradients/math_grad.cc
+++ b/tensorflow/cc/gradients/math_grad.cc
@@ -21,6 +21,248 @@ namespace tensorflow {
namespace ops {
namespace {
+// TODO(andydavis) Add control dependencies to gradient functions (as needed).
+
+Status AbsGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // dx = dy * sign(x)
+ grad_outputs->push_back(Mul(scope, grad_inputs[0], Sign(scope, op.input(0))));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Abs", AbsGrad);
+
+Status NegGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // dx = -dy;
+ grad_outputs->push_back(Neg(scope, grad_inputs[0]));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Neg", NegGrad);
+
+Status InvGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // dx = dy * (-1 * (y * y))
+ grad_outputs->push_back(
+ Mul(scope, grad_inputs[0], Neg(scope, Square(scope, op.output(0)))));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Inv", InvGrad);
+
+Status SquareGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // dx = dy * (2 * x)
+ auto two = Cast(scope, Const(scope, 2), op.input(0).type());
+ grad_outputs->push_back(
+ Mul(scope, grad_inputs[0], Mul(scope, two, op.input(0))));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Square", SquareGrad);
+
+Status SqrtGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = sqrt(x)
+ // dy/dx = 0.5 * (1 / sqrt(x)) = 0.5 * (1 / y)
+ // dx = dy * (0.5 * (1 / y))
+ auto y_inv = Inv(scope, op.output(0));
+ auto half = Cast(scope, Const(scope, 0.5), op.input(0).type());
+ auto dx = Mul(scope, grad_inputs[0], Mul(scope, half, y_inv));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Sqrt", SqrtGrad);
+
+Status RsqrtGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = 1/x^1/2 = x^-1/2
+ // dy/dx = -1/2 * x^-3/2 = -1/2 * x^-1/2 * x^-1 = -1/2 * y * x^-1
+ // dx = dy * (-1/2 * y * x^-1)
+ auto x_inv = Inv(scope, op.input(0));
+ auto y = op.output(0);
+ auto neghalf = Cast(scope, Const(scope, -0.5), op.input(0).type());
+ auto a = Mul(scope, neghalf, x_inv);
+ auto b = Mul(scope, a, y);
+ auto dx = Mul(scope, grad_inputs[0], b);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Rsqrt", RsqrtGrad);
+
+Status ExpGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = exp(x)
+ // dy/dx = exp(x)
+ // dx = dy * y
+ grad_outputs->push_back(Mul(scope, grad_inputs[0], op.output(0)));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Exp", ExpGrad);
+
+Status LogGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // f(x) = log(x) = y
+ // df/dx = 1 / x
+ // dx = dy * (1 / x)
+ grad_outputs->push_back(Mul(scope, grad_inputs[0], Inv(scope, op.input(0))));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Log", LogGrad);
+
+Status TanhGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = tanh(x)
+ // dy/dx = 1 - (tanh(x))^2 = 1 - y^2
+ // dx = dy * (1 - y^2)
+ auto y2 = Square(scope, op.output(0));
+ auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
+ auto dx = Mul(scope, grad_inputs[0], Sub(scope, one, y2));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Tanh", TanhGrad);
+
+Status SigmoidGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = 1 / (1 + exp(-x))
+ // dy/dx = y * (1 - y)
+ // dx = dy * y * (1 - y)
+ auto y = op.output(0);
+ auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
+ auto dx = Mul(scope, grad_inputs[0], Mul(scope, y, Sub(scope, one, y)));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad);
+
+Status SignGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto shape = Shape(scope, op.input(0));
+ auto zero = Cast(scope, Const(scope, 0.0), op.input(0).type());
+ auto dx = Fill(scope, shape, zero);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Sign", SignGrad);
+
+Status SinGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = sin(x)
+ // dy/dx = cos(x)
+ // dx = dy * cos(x)
+ auto dx = Mul(scope, grad_inputs[0], Cos(scope, op.input(0)));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Sin", SinGrad);
+
+Status CosGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = cos(x)
+ // dy/dx = -sin(x)
+ // dx = dy * -sin(x)
+ auto dx = Mul(scope, grad_inputs[0], Neg(scope, Sin(scope, op.input(0))));
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Cos", CosGrad);
+
+Status AsinGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = asin(x)
+ // dy/dx = 1 / (1 - x * x)^1/2
+ // dx = dy * (1 / (1 - x * x)^1/2)
+ auto x2 = Square(scope, op.input(0));
+ auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
+ auto dydx = Inv(scope, Sqrt(scope, Sub(scope, one, x2)));
+ auto dx = Mul(scope, grad_inputs[0], dydx);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Asin", AsinGrad);
+
+Status AcosGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = acos(x)
+ // dy/dx = - 1 / (1 - x * x)^1/2
+ // dx = dy * (- 1 / (1 - x * x)^1/2)
+ auto x2 = Square(scope, op.input(0));
+ auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
+ auto dydx = Neg(scope, Inv(scope, Sqrt(scope, Sub(scope, one, x2))));
+ auto dx = Mul(scope, grad_inputs[0], dydx);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Acos", AcosGrad);
+
+Status TanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = tan(x)
+ // dy/dx = sec(x)^2 = 1 / cos(x)^2
+ // dx = dy * (1 / cos(x)^2)
+ auto dydx = Square(scope, Inv(scope, Cos(scope, op.input(0))));
+ auto dx = Mul(scope, grad_inputs[0], dydx);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Tan", TanGrad);
+
+Status AtanGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ // y = arctan(x)
+ // dy/dx = 1 / (1 + x^2)
+ // dx = dy * (1 / (1 + x^2)
+ auto one = Cast(scope, Const(scope, 1.0), op.input(0).type());
+ auto dydx = Inv(scope, Add(scope, one, Square(scope, op.input(0))));
+ auto dx = Mul(scope, grad_inputs[0], dydx);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Atan", AtanGrad);
+
+Status RealGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
+ auto dx = Complex(scope, grad_inputs[0], zero);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Real", RealGrad);
+
+Status ImagGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ auto zero = Cast(scope, Const(scope, 0.0), op.output(0).type());
+ auto dx = Complex(scope, zero, grad_inputs[0]);
+ grad_outputs->push_back(dx);
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Imag", ImagGrad);
+
+Status ConjGrad(const Scope& scope, const Operation& op,
+ const std::vector<Output>& grad_inputs,
+ std::vector<Output>* grad_outputs) {
+ grad_outputs->push_back(Conj(scope, grad_inputs[0]));
+ return scope.status();
+}
+REGISTER_GRADIENT_OP("Conj", ConjGrad);
+
// MatMulGrad helper function used to compute two MatMul operations
// based on input matrix transposition combinations.
Status MatMulGradHelper(const Scope& scope, const bool is_batch,
@@ -91,7 +333,6 @@ Status MatMulGrad(const Scope& scope, const Operation& op,
return MatMulGradCommon(scope, op, false, grad_inputs, "transpose_a",
"transpose_b", grad_outputs);
}
-
REGISTER_GRADIENT_OP("MatMul", MatMulGrad);
Status BatchMatMulGrad(const Scope& scope, const Operation& op,
diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc
index 1248c0aa32..456a036111 100644
--- a/tensorflow/cc/gradients/math_grad_test.cc
+++ b/tensorflow/cc/gradients/math_grad_test.cc
@@ -29,6 +29,327 @@ namespace {
// TODO(andydavis) Test gradient function against numeric gradients output.
// TODO(andydavis) As more gradients are added move common test functions
// to a testutil library.
+
+class CWiseUnaryGradTest : public ::testing::Test {
+ protected:
+ CWiseUnaryGradTest() : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
+
+ enum UnaryOpType {
+ ABS,
+ NEG,
+ INV,
+ SQUARE,
+ SQRT,
+ RSQRT,
+ EXP,
+ LOG,
+ TANH,
+ SIGMOID,
+ SIGN,
+ SIN,
+ COS,
+ ASIN,
+ ACOS,
+ TAN,
+ ATAN
+ };
+
+ void TestCWiseGrad(UnaryOpType op_type, std::function<float(int)> x_fn,
+ std::function<float(float)> dy_fn,
+ std::function<float(float, float)> dx_fn) {
+ Tensor x(DT_FLOAT, {2, 3, 2});
+ auto x_flat = x.flat<float>();
+ for (int i = 0; i < x_flat.size(); ++i) {
+ x_flat(i) = x_fn(i);
+ }
+
+ Tensor dy(DT_FLOAT, {2, 3, 2});
+ auto dy_flat = dy.flat<float>();
+ for (int i = 0; i < dy_flat.size(); ++i) {
+ dy_flat(i) = dy_fn(x_flat(i));
+ }
+
+ Tensor dx(DT_FLOAT, {2, 3, 2});
+ auto dx_flat = dx.flat<float>();
+ for (int i = 0; i < dx_flat.size(); ++i) {
+ dx_flat(i) = dx_fn(x_flat(i), dy_flat(i));
+ }
+
+ Output y;
+ switch (op_type) {
+ case ABS:
+ y = Abs(scope_, x);
+ break;
+ case NEG:
+ y = Neg(scope_, x);
+ break;
+ case INV:
+ y = Inv(scope_, x);
+ break;
+ case SQUARE:
+ y = Square(scope_, x);
+ break;
+ case SQRT:
+ y = Sqrt(scope_, x);
+ break;
+ case RSQRT:
+ y = Rsqrt(scope_, x);
+ break;
+ case EXP:
+ y = Exp(scope_, x);
+ break;
+ case LOG:
+ y = Log(scope_, x);
+ break;
+ case TANH:
+ y = Tanh(scope_, x);
+ break;
+ case SIGMOID:
+ y = Sigmoid(scope_, x);
+ break;
+ case SIGN:
+ y = Sign(scope_, x);
+ break;
+ case SIN:
+ y = Sin(scope_, x);
+ break;
+ case COS:
+ y = Cos(scope_, x);
+ break;
+ case ASIN:
+ y = Asin(scope_, x);
+ break;
+ case ACOS:
+ y = Acos(scope_, x);
+ break;
+ case TAN:
+ y = Tan(scope_, x);
+ break;
+ case ATAN:
+ y = Atan(scope_, x);
+ break;
+ }
+
+ std::vector<Output> grad_outputs;
+ TF_ASSERT_OK(test::CallGradFunction(
+ scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
+ Tensor output;
+ test::GetTensor(scope_, grad_outputs[0], &output);
+ test::ExpectClose(output, dx);
+ }
+
+ float RV(std::vector<float> v) { return v[random::New64() % v.size()]; }
+
+ Scope scope_;
+};
+
+TEST_F(CWiseUnaryGradTest, Abs) {
+ auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) { return x * dy; };
+ TestCWiseGrad(ABS, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Neg) {
+ auto x_fn = [this](const int i) { return RV({-1, 0, 1}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) { return -dy; };
+ TestCWiseGrad(NEG, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Inv) {
+ auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
+ auto dy_fn = [this](const float x) { return RV({0, -2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return -(1 / (x * x)) * dy;
+ };
+ TestCWiseGrad(INV, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Square) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return RV({0, -7, 7, -8, 8, -9, 9}); };
+ auto dx_fn = [this](const float x, const float dy) { return 2 * x * dy; };
+ TestCWiseGrad(SQUARE, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Sqrt) {
+ auto x_fn = [this](const int i) { return RV({0, 1, 2, 3, 4, 5, 6, 7}); };
+ auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * 0.5 * (1.0 / std::sqrt(x));
+ };
+ TestCWiseGrad(SQRT, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Rsqrt) {
+ auto x_fn = [this](const int i) { return RV({1, 2, 3, 4, 5, 6, 7, 8}); };
+ auto dy_fn = [this](const float x) { return x + RV({8, 9, 10, 11, 12, 13}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * -0.5 * (1 / std::sqrt(x)) * (1 / x);
+ };
+ TestCWiseGrad(RSQRT, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Exp) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * std::exp(x);
+ };
+ TestCWiseGrad(EXP, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Log) {
+ auto x_fn = [this](const int i) { return RV({-1, 1, -2, 2, -3, 3, -4, 4}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) { return dy * (1.0 / x); };
+ TestCWiseGrad(LOG, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Tanh) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ const float y = std::tanh(x);
+ return dy * (1.0 - y * y);
+ };
+ TestCWiseGrad(TANH, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Sigmoid) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ const float y = 1.0 / (1.0 + std::exp(-x));
+ return dy * y * (1.0 - y);
+ };
+ TestCWiseGrad(SIGMOID, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Sign) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) { return 0.0; };
+ TestCWiseGrad(SIGN, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Sin) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * std::cos(x);
+ };
+ TestCWiseGrad(SIN, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Cos) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * -1.0 * std::sin(x);
+ };
+ TestCWiseGrad(COS, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Asin) {
+ auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * (1.0 / std::sqrt(1.0 - x * x));
+ };
+ TestCWiseGrad(ASIN, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Acos) {
+ auto x_fn = [this](const int i) { return RV({0, -0.5, 0.5, -1, 1}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * (-1.0 / std::sqrt(1.0 - x * x));
+ };
+ TestCWiseGrad(ACOS, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Tan) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ const float cosx = std::cos(x);
+ return dy * (1 / (cosx * cosx));
+ };
+ TestCWiseGrad(TAN, x_fn, dy_fn, dx_fn);
+}
+
+TEST_F(CWiseUnaryGradTest, Atan) {
+ auto x_fn = [this](const int i) { return RV({0, -1, 1, -2, 2, -3, 3}); };
+ auto dy_fn = [this](const float x) { return x + RV({-2, 2, -3, 3, -4, 4}); };
+ auto dx_fn = [this](const float x, const float dy) {
+ return dy * (1 / (1 + x * x));
+ };
+ TestCWiseGrad(ATAN, x_fn, dy_fn, dx_fn);
+}
+
+class CWiseUnaryComplexGradTest : public ::testing::Test {
+ protected:
+ CWiseUnaryComplexGradTest()
+ : scope_(Scope::NewRootScope().WithDevice("/cpu:0")) {}
+
+ enum UnaryOpType { REAL, IMAG, CONJ };
+
+ void TestCWiseGradComplex(UnaryOpType op_type, const Tensor& x,
+ const Tensor& dy, const Tensor& dx_expected) {
+ Output y;
+ switch (op_type) {
+ case REAL:
+ y = Real(scope_, x);
+ break;
+ case IMAG:
+ y = Imag(scope_, x);
+ break;
+ case CONJ:
+ y = Conj(scope_, x);
+ break;
+ }
+
+ std::vector<Output> grad_outputs;
+ TF_ASSERT_OK(test::CallGradFunction(
+ scope_, Operation(y.node()), {ops::Const(scope_, dy)}, &grad_outputs));
+ Tensor dx;
+ test::GetTensor(scope_, grad_outputs[0], &dx);
+ test::ExpectClose(dx, dx_expected);
+ }
+
+ Scope scope_;
+};
+
+TEST_F(CWiseUnaryComplexGradTest, Real) {
+ Tensor x = test::AsTensor<complex64>(
+ {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
+ Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
+ Tensor dx_expected = test::AsTensor<complex64>(
+ {{11, 0}, {-12, 0}, {13, 0}, {-14, 0}, {15, 0}, {-16, 0}}, {2, 3});
+ TestCWiseGradComplex(REAL, x, dy, dx_expected);
+}
+
+TEST_F(CWiseUnaryComplexGradTest, Imag) {
+ Tensor x = test::AsTensor<complex64>(
+ {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
+ Tensor dy = test::AsTensor<float>({11, -12, 13, -14, 15, -16}, {2, 3});
+ Tensor dx_expected = test::AsTensor<complex64>(
+ {{0, 11}, {0, -12}, {0, 13}, {0, -14}, {0, 15}, {0, -16}}, {2, 3});
+ TestCWiseGradComplex(IMAG, x, dy, dx_expected);
+}
+
+TEST_F(CWiseUnaryComplexGradTest, Conj) {
+ Tensor x = test::AsTensor<complex64>(
+ {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
+ Tensor dy = test::AsTensor<complex64>(
+ {{1, -1}, {-2, 2}, {3, -3}, {-4, 4}, {8, -8}, {-9, 9}}, {2, 3});
+ Tensor dx_expected = test::AsTensor<complex64>(
+ {{1, 1}, {-2, -2}, {3, 3}, {-4, -4}, {8, 8}, {-9, -9}}, {2, 3});
+ TestCWiseGradComplex(CONJ, x, dy, dx_expected);
+}
+
class MathGradTest : public ::testing::Test {
protected:
MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {}