aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 17:42:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 17:46:19 -0700
commit7b88cabfec45c9e04ab3d9cf1c2411c6dce4c694 (patch)
tree9bdc598fa33808d8689299438a50ad7445ebdec5 /tensorflow/python/ops
parentbfda65cc70526c919c57ef8321dd282e463ed8a3 (diff)
Add xlogy and xdivy op.
PiperOrigin-RevId: 214700693
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/math_grad.py34
-rw-r--r--tensorflow/python/ops/math_grad_test.py88
-rw-r--r--tensorflow/python/ops/math_ops_test.py71
3 files changed, 193 insertions, 0 deletions
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8e11c4bce1..35278d9680 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -516,6 +516,40 @@ def _Log1pGrad(op, grad):
return grad * math_ops.reciprocal(1 + x)
+@ops.RegisterGradient("Xlogy")
+def _XLogyGrad(op, grad):
+ """Returns gradient of xlogy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xlogy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(x, y)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
+@ops.RegisterGradient("Xdivy")
+def _XDivyGrad(op, grad):
+ """Returns gradient of xdivy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xdivy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
@ops.RegisterGradient("Sinh")
def _SinhGrad(op, grad):
"""Returns grad * cosh(x)."""
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 7110e0958c..9cfb050942 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -256,5 +256,93 @@ class DivNoNanGradientTest(test.TestCase):
self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
+class XlogyTest(test.TestCase):
+
+ def _xlogy_gradients(self, x, y):
+ xlogy_xgrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), x)[0])
+ xlogy_ygrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), y)[0])
+ return xlogy_xgrad, xlogy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ xlogy_expected_xgrad = self.evaluate(math_ops.log(y))
+ xlogy_expected_ygrad = self.evaluate(x / y)
+ self.assertAllClose(xlogy_expected_xgrad, xlogy_xgrad)
+ self.assertAllClose(xlogy_expected_ygrad, xlogy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ self.assertAllClose(-np.inf, xlogy_xgrad)
+ self.assertAllClose(np.inf, xlogy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+
+class XdivyTest(test.TestCase):
+
+ def _xdivy_gradients(self, x, y):
+ xdivy_xgrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), x)[0])
+ xdivy_ygrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), y)[0])
+ return xdivy_xgrad, xdivy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ xdivy_expected_xgrad = self.evaluate(1 / y)
+ xdivy_expected_ygrad = self.evaluate(-x / y**2)
+ self.assertAllClose(xdivy_expected_xgrad, xdivy_xgrad)
+ self.assertAllClose(xdivy_expected_ygrad, xdivy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ self.assertAllClose(np.inf, xdivy_xgrad)
+ self.assertAllClose(-np.inf, xdivy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 1b01d1d37f..f051850d92 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -488,5 +489,75 @@ class DivNoNanTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, np_result)
+class XlogyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy = self.evaluate(math_ops.xlogy(x, y))
+ xtimeslogy = self.evaluate(x * math_ops.log(y))
+ self.assertAllClose(xlogy, xtimeslogy)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xlogy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ xtimes_logy = self.evaluate(math_ops.log(y[1]))
+ self.assertAllClose(zeros_np, xlogy_tf_np[0])
+ self.assertAllClose(xtimes_logy, xlogy_tf_np[1])
+
+
+class XdivyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy = self.evaluate(math_ops.xdivy(x, y))
+ x_over_y = self.evaluate(x / y)
+ self.assertAllClose(xdivy, x_over_y)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xdivy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ x_over_y = self.evaluate(1 / y[1])
+ self.assertAllClose(zeros_np, xdivy_tf_np[0])
+ self.assertAllClose(x_over_y, xdivy_tf_np[1])
+
+
if __name__ == "__main__":
googletest.main()