aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-25 07:59:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-25 09:04:38 -0700
commite78cbe072bb8fd50dd8be6033de9bcb5f62d59fd (patch)
tree2183e953675655acbafe9ff0a6571cb3ee67c7d7
parenta856685175f0919dd2ab03ac447d2708dc0fffe3 (diff)
Fix gradient of pow for complex types
Change: 131294380
-rw-r--r--tensorflow/core/framework/tensor_testutil.cc6
-rw-r--r--tensorflow/core/ops/math_grad.cc54
-rw-r--r--tensorflow/core/ops/math_grad_test.cc19
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py23
-rw-r--r--tensorflow/python/ops/math_grad.py8
5 files changed, 89 insertions, 21 deletions
diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc
index e307d25268..1bae7b08d5 100644
--- a/tensorflow/core/framework/tensor_testutil.cc
+++ b/tensorflow/core/framework/tensor_testutil.cc
@@ -50,6 +50,12 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
case DT_DOUBLE:
ExpectClose<double>(x, y, atol, rtol);
break;
+ case DT_COMPLEX64:
+ ExpectClose<complex64>(x, y, atol, rtol);
+ break;
+ case DT_COMPLEX128:
+ ExpectClose<complex128>(x, y, atol, rtol);
+ break;
default:
LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype());
}
diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc
index 1d8f45ea7a..f74ce32cef 100644
--- a/tensorflow/core/ops/math_grad.cc
+++ b/tensorflow/core/ops/math_grad.cc
@@ -375,26 +375,42 @@ REGISTER_OP_GRADIENT("Div", DivGrad);
Status PowGrad(const AttrSlice& attrs, FunctionDef* g) {
// clang-format off
- return GradForBinaryCwise(g, {
- {{"z"}, "Pow", {"x", "y"}},
- // dz * y * Pow(x, y - 1)
- FDH::Const("const_zero", 0.0f),
- FDH::Const("const_one", 1.0f),
- {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
- {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
- {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
- {{"t1"}, "Pow", {"x", "t0"}},
- {{"t2"}, "Mul", {"dz", "y"}},
- {{"gx"}, "Mul", {"t1", "t2"}},
- // dz * z * (x > 0 ? Log(x) : 0)
+ std::vector<FDH::Node> nodes = {
+ {{"z"}, "Pow", {"x", "y"}},
+ // dz * y * Pow(x, y - 1)
+ FDH::Const("const_zero", 0.0f),
+ FDH::Const("const_one", 1.0f),
+ {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}},
+ {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}},
+ {{"t1"}, "Pow", {"x", "t0"}},
+ {{"t2"}, "Mul", {"dz", "y"}},
+ {{"gx"}, "Mul", {"t1", "t2"}},
+ {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
+ {{"zeros"}, "ZerosLike", {"x"}}};
+ // clang-format on
+ std::vector<FDH::Node> log_x_handling;
+ DataType T;
+ TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T));
+ if (T == DT_COMPLEX64 || T == DT_COMPLEX128) {
+ // dz * z * (x != 0 ? Log(x) : 0)
+ // clang-format off
+ log_x_handling = {
+ {{"nz_x"}, "NotEqual", {"x", "zero"}},
+ {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}};
+ // clang-format on
+ } else {
+ // dz * z * (x > 0 ? Log(x) : 0)
+ // clang-format off
+ log_x_handling = {
{{"pos_x"}, "Greater", {"x", "zero"}},
- {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}},
- {{"zeros"}, "ZerosLike", {"x"}},
- {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}},
- {{"t4"}, "Mul", {"dz", "z"}},
- {{"gy"}, "Mul", {"safe_log", "t4"}},
- });
- // clang-format on
+ {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}};
+ // clang-format on
+ }
+ nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end());
+ nodes.push_back({{"t4"}, "Mul", {"dz", "z"}});
+ nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}});
+ return GradForBinaryCwise(g, nodes);
}
REGISTER_OP_GRADIENT("Pow", PowGrad);
diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc
index 9af73b2da0..e937fc5ab1 100644
--- a/tensorflow/core/ops/math_grad_test.cc
+++ b/tensorflow/core/ops/math_grad_test.cc
@@ -684,6 +684,25 @@ TEST_F(MathGradTest, Pow) {
}
}
+TEST_F(MathGradTest, ComplexPow) {
+ auto x = test::AsTensor<complex64>({0.f, 2.f, -2.f}, TensorShape({3}));
+ auto y = test::AsTensor<complex64>({2.f, 2.f, 2.f}, TensorShape({3}));
+ Tensor dx;
+ Tensor dy;
+ auto g = [](complex64 x, complex64 y) { return y * std::pow(x, y - 1.f); };
+ auto h = [](complex64 x, complex64 y) {
+ return std::pow(x, y) * (x != complex64(0) ? std::log(x) : 0);
+ };
+ SymGrad("Pow", x, y, &dx, &dy);
+
+ test::ExpectClose(
+ dx, test::AsTensor<complex64>({g(0.f, 2.f), g(2.f, 2.f), g(-2.f, 2.f)},
+ TensorShape({3})));
+ test::ExpectClose(
+ dy, test::AsTensor<complex64>({h(0.f, 2.f), h(2.f, 2.f), h(-2.f, 2.f)},
+ TensorShape({3})));
+}
+
TEST_F(MathGradTest, Maximum) {
auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f},
TensorShape({2, 3}));
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index e5406ff87a..0890ea4a2a 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -992,13 +992,34 @@ class BinaryOpTest(tf.test.TestCase):
def testZeroPowGrad(self):
with self.test_session():
- for dtype in np.float16, np.float32, np.float64:
+ for dtype in (np.float16, np.float32, np.float64, np.complex64,
+ np.complex128):
x = tf.constant(0.0, dtype=dtype)
y = tf.constant(2.0, dtype=dtype)
z = tf.pow(x, y)
error = tf.test.compute_gradient_error(y, [], z, [])
self.assertEqual(error, 0)
+ def testComplexPowGradPositiveBase(self):
+ with self.test_session():
+ for dtype in np.complex64, np.complex128:
+ x = tf.constant(2.0, dtype=dtype)
+ y = tf.constant(2.0, dtype=dtype)
+ z = tf.pow(x, y)
+ error = tf.test.compute_gradient_error(y, [], z, [])
+ self.assertLess(error, 1e-4)
+
+ def testComplexPowGradNegativeBase(self):
+ with self.test_session() as session:
+ for dtype in np.complex64, np.complex128:
+ x = tf.constant(-2.0, dtype=dtype)
+ y = tf.constant(2.0, dtype=dtype)
+ z = tf.pow(x, y)
+ expected_x_grad = -4
+ expected_y_grad = (-2)**2 * (np.log(2) + np.pi * 1j)
+ self.assertAllClose([expected_x_grad, expected_y_grad],
+ session.run(tf.gradients(z, [x, y])))
+
class ComparisonOpTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 61d998ec4f..315fd4ffca 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -552,7 +552,13 @@ def _PowGrad(op, grad):
gx = array_ops.reshape(
math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx)
# Avoid false singularity at x = 0
- log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x))
+ if x.dtype.is_complex:
+ # real(x) < 0 is fine for the complex case
+ log_x = math_ops.select(
+ math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x))
+ else:
+ # There's no sensible real value to return if x < 0, so return 0
+ log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x))
gy = array_ops.reshape(
math_ops.reduce_sum(grad * z * log_x, ry), sy)
return gx, gy