diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-25 07:59:14 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-25 09:04:38 -0700 |
commit | e78cbe072bb8fd50dd8be6033de9bcb5f62d59fd (patch) | |
tree | 2183e953675655acbafe9ff0a6571cb3ee67c7d7 | |
parent | a856685175f0919dd2ab03ac447d2708dc0fffe3 (diff) |
Fix gradient of pow for complex types
Change: 131294380
-rw-r--r-- | tensorflow/core/framework/tensor_testutil.cc | 6 | ||||
-rw-r--r-- | tensorflow/core/ops/math_grad.cc | 54 | ||||
-rw-r--r-- | tensorflow/core/ops/math_grad_test.cc | 19 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/cwise_ops_test.py | 23 | ||||
-rw-r--r-- | tensorflow/python/ops/math_grad.py | 8 |
5 files changed, 89 insertions, 21 deletions
diff --git a/tensorflow/core/framework/tensor_testutil.cc b/tensorflow/core/framework/tensor_testutil.cc index e307d25268..1bae7b08d5 100644 --- a/tensorflow/core/framework/tensor_testutil.cc +++ b/tensorflow/core/framework/tensor_testutil.cc @@ -50,6 +50,12 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) { case DT_DOUBLE: ExpectClose<double>(x, y, atol, rtol); break; + case DT_COMPLEX64: + ExpectClose<complex64>(x, y, atol, rtol); + break; + case DT_COMPLEX128: + ExpectClose<complex128>(x, y, atol, rtol); + break; default: LOG(FATAL) << "Unexpected type : " << DataTypeString(x.dtype()); } diff --git a/tensorflow/core/ops/math_grad.cc b/tensorflow/core/ops/math_grad.cc index 1d8f45ea7a..f74ce32cef 100644 --- a/tensorflow/core/ops/math_grad.cc +++ b/tensorflow/core/ops/math_grad.cc @@ -375,26 +375,42 @@ REGISTER_OP_GRADIENT("Div", DivGrad); Status PowGrad(const AttrSlice& attrs, FunctionDef* g) { // clang-format off - return GradForBinaryCwise(g, { - {{"z"}, "Pow", {"x", "y"}}, - // dz * y * Pow(x, y - 1) - FDH::Const("const_zero", 0.0f), - FDH::Const("const_one", 1.0f), - {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, - {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, - {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}}, - {{"t1"}, "Pow", {"x", "t0"}}, - {{"t2"}, "Mul", {"dz", "y"}}, - {{"gx"}, "Mul", {"t1", "t2"}}, - // dz * z * (x > 0 ? Log(x) : 0) + std::vector<FDH::Node> nodes = { + {{"z"}, "Pow", {"x", "y"}}, + // dz * y * Pow(x, y - 1) + FDH::Const("const_zero", 0.0f), + FDH::Const("const_one", 1.0f), + {{"zero"}, "Cast", {"const_zero"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"one"}, "Cast", {"const_one"}, {{"SrcT", DT_FLOAT}, {"DstT", "$T"}}}, + {{"t0"}, "Sub", {"y", "one"}, {}, {"dz"}}, + {{"t1"}, "Pow", {"x", "t0"}}, + {{"t2"}, "Mul", {"dz", "y"}}, + {{"gx"}, "Mul", {"t1", "t2"}}, + {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}}, + {{"zeros"}, "ZerosLike", {"x"}}}; + // clang-format on + std::vector<FDH::Node> log_x_handling; + DataType T; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "T", &T)); + if (T == DT_COMPLEX64 || T == DT_COMPLEX128) { + // dz * z * (x != 0 ? Log(x) : 0) + // clang-format off + log_x_handling = { + {{"nz_x"}, "NotEqual", {"x", "zero"}}, + {{"safe_log"}, "Select", {"nz_x", "unsafe_log", "zeros"}}}; + // clang-format on + } else { + // dz * z * (x > 0 ? Log(x) : 0) + // clang-format off + log_x_handling = { {{"pos_x"}, "Greater", {"x", "zero"}}, - {{"unsafe_log"}, "Log", {"x"}, {}, {"dz"}}, - {{"zeros"}, "ZerosLike", {"x"}}, - {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}, - {{"t4"}, "Mul", {"dz", "z"}}, - {{"gy"}, "Mul", {"safe_log", "t4"}}, - }); - // clang-format on + {{"safe_log"}, "Select", {"pos_x", "unsafe_log", "zeros"}}}; + // clang-format on + } + nodes.insert(nodes.end(), log_x_handling.begin(), log_x_handling.end()); + nodes.push_back({{"t4"}, "Mul", {"dz", "z"}}); + nodes.push_back({{"gy"}, "Mul", {"safe_log", "t4"}}); + return GradForBinaryCwise(g, nodes); } REGISTER_OP_GRADIENT("Pow", PowGrad); diff --git a/tensorflow/core/ops/math_grad_test.cc b/tensorflow/core/ops/math_grad_test.cc index 9af73b2da0..e937fc5ab1 100644 --- a/tensorflow/core/ops/math_grad_test.cc +++ b/tensorflow/core/ops/math_grad_test.cc @@ -684,6 +684,25 @@ TEST_F(MathGradTest, Pow) { } } +TEST_F(MathGradTest, ComplexPow) { + auto x = test::AsTensor<complex64>({0.f, 2.f, -2.f}, TensorShape({3})); + auto y = test::AsTensor<complex64>({2.f, 2.f, 2.f}, TensorShape({3})); + Tensor dx; + Tensor dy; + auto g = [](complex64 x, complex64 y) { return y * std::pow(x, y - 1.f); }; + auto h = [](complex64 x, complex64 y) { + return std::pow(x, y) * (x != complex64(0) ? std::log(x) : 0); + }; + SymGrad("Pow", x, y, &dx, &dy); + + test::ExpectClose( + dx, test::AsTensor<complex64>({g(0.f, 2.f), g(2.f, 2.f), g(-2.f, 2.f)}, + TensorShape({3}))); + test::ExpectClose( + dy, test::AsTensor<complex64>({h(0.f, 2.f), h(2.f, 2.f), h(-2.f, 2.f)}, + TensorShape({3}))); +} + TEST_F(MathGradTest, Maximum) { auto x = test::AsTensor<float>({-3.f, -2.f, -1.f, 1.f, 2.f, 3.f}, TensorShape({2, 3})); diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index e5406ff87a..0890ea4a2a 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -992,13 +992,34 @@ class BinaryOpTest(tf.test.TestCase): def testZeroPowGrad(self): with self.test_session(): - for dtype in np.float16, np.float32, np.float64: + for dtype in (np.float16, np.float32, np.float64, np.complex64, + np.complex128): x = tf.constant(0.0, dtype=dtype) y = tf.constant(2.0, dtype=dtype) z = tf.pow(x, y) error = tf.test.compute_gradient_error(y, [], z, []) self.assertEqual(error, 0) + def testComplexPowGradPositiveBase(self): + with self.test_session(): + for dtype in np.complex64, np.complex128: + x = tf.constant(2.0, dtype=dtype) + y = tf.constant(2.0, dtype=dtype) + z = tf.pow(x, y) + error = tf.test.compute_gradient_error(y, [], z, []) + self.assertLess(error, 1e-4) + + def testComplexPowGradNegativeBase(self): + with self.test_session() as session: + for dtype in np.complex64, np.complex128: + x = tf.constant(-2.0, dtype=dtype) + y = tf.constant(2.0, dtype=dtype) + z = tf.pow(x, y) + expected_x_grad = -4 + expected_y_grad = (-2)**2 * (np.log(2) + np.pi * 1j) + self.assertAllClose([expected_x_grad, expected_y_grad], + session.run(tf.gradients(z, [x, y]))) + class ComparisonOpTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py index 61d998ec4f..315fd4ffca 100644 --- a/tensorflow/python/ops/math_grad.py +++ b/tensorflow/python/ops/math_grad.py @@ -552,7 +552,13 @@ def _PowGrad(op, grad): gx = array_ops.reshape( math_ops.reduce_sum(grad * y * math_ops.pow(x, y - 1), rx), sx) # Avoid false singularity at x = 0 - log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x)) + if x.dtype.is_complex: + # real(x) < 0 is fine for the complex case + log_x = math_ops.select( + math_ops.not_equal(x, 0), math_ops.log(x), array_ops.zeros_like(x)) + else: + # There's no sensible real value to return if x < 0, so return 0 + log_x = math_ops.select(x > 0, math_ops.log(x), array_ops.zeros_like(x)) gy = array_ops.reshape( math_ops.reduce_sum(grad * z * log_x, ry), sy) return gx, gy |