aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-06-23 14:37:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-23 14:41:27 -0700
commitdee19ca4dd0510499b7da9ebb97c92910638b4f2 (patch)
tree52658121169633d687df4d3a64e3db333e04a829
parent0ecae4fb42b4333b98f67461bb15e825816f32ad (diff)
Sinh, ArcSinh, Cosh, LogCosh functions added to distributions/python/ops/trig.
Care is taken to ensure a fair bit of stability. PiperOrigin-RevId: 159995514
-rw-r--r--tensorflow/contrib/distributions/BUILD17
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/trig_test.py256
-rw-r--r--tensorflow/contrib/distributions/python/ops/trig.py173
-rw-r--r--tensorflow/python/kernel_tests/distributions/util_test.py12
4 files changed, 458 insertions, 0 deletions
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index ca2b256f6a..218308f20a 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -625,6 +625,23 @@ cuda_py_test(
],
)
+cuda_py_test(
+ name = "trig_test",
+ size = "small",
+ srcs = ["python/kernel_tests/trig_test.py"],
+ additional_deps = [
+ ":distributions_py",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
filegroup(
name = "all_files",
srcs = glob(
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/trig_test.py b/tensorflow/contrib/distributions/python/kernel_tests/trig_test.py
new file mode 100644
index 0000000000..0d40f477ca
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/kernel_tests/trig_test.py
@@ -0,0 +1,256 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Trigonometric functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.distributions.python.ops import trig
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.platform import test
+
+
+class SinhTest(test.TestCase):
+
+ def test_versus_numpy_scalar_positive_x(self):
+ x = 1.23
+ with self.test_session():
+ self.assertAllClose(np.sinh(x), trig.sinh(x).eval())
+
+ def test_versus_numpy_multidim_x(self):
+ x = [[0., -1.], [0.5, -2.]]
+ with self.test_session():
+ self.assertAllClose(np.sinh(x), trig.sinh(x).eval())
+
+
+class ArcSinhTest(test.TestCase):
+ """Test arcsinh.
+
+ Note that many tests below were run over a range of values using np.logspace.
+ The accuracy was highly dependent on the exact values used within that range,
+ and whether our approximation matched with numpy. For that reason, we used
+ 1000 values over every range, which ensures we hit most of the "tricky" values
+ """
+
+ def _assert_all_finite(self, values):
+ self.assertAllEqual(np.ones_like(values).astype(np.bool),
+ np.isfinite(values))
+
+ def test_versus_numpy_scalar_positive_x(self):
+ x = 1.23
+ with self.test_session():
+ self.assertAllClose(np.arcsinh(x), trig.arcsinh(x).eval())
+
+ def test_versus_numpy_at_zero(self):
+ # Zero is especially difficult.
+ with self.test_session() as sess:
+ x = constant_op.constant(0.)
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(np.arcsinh(x_), y_)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_versus_numpy_multidim_x(self):
+ with self.test_session() as sess:
+ x = constant_op.constant([[0., -1.], [0.5, -2.]])
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(np.arcsinh(x_), y_)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_versus_numpy_positive_values_32_bit(self):
+ with self.test_session() as sess:
+ # Larger than 38 is Inf in float32.
+ x = constant_op.constant(np.logspace(0, 38, num=2000).astype(np.float32))
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(
+ np.arcsinh(x_), # numpy does this in 64bit
+ y_,
+ rtol=1e-6)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_versus_numpy_moderate_negative_values_32_bit(self):
+ # The moderate negative values were the most difficult to get close to
+ # numpy.
+ with self.test_session() as sess:
+ x = constant_op.constant(-np.logspace(0, 10, num=1000).astype(np.float32))
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(
+ np.arcsinh(x_), # numpy does this in 64bit
+ y_,
+ rtol=1e-4)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_versus_numpy_extreme_negative_values_32_bit(self):
+ # For these extreme values arcsinh uses the approximation 1 / (2 * x), and
+ # 1 / 10^38 = 0 in 32bit...so stop at 10^37.
+ with self.test_session() as sess:
+ x = constant_op.constant(
+ -np.logspace(10, 37, num=1000).astype(np.float32))
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(
+ np.arcsinh(x_), # numpy does this in 64bit
+ y_,
+ rtol=1e-6)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_versus_numpy_positive_values_64_bit(self):
+ with self.test_session() as sess:
+ x = constant_op.constant(np.logspace(0, 200, num=1000))
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(
+ np.arcsinh(x_),
+ y_,
+ rtol=1e-6)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_versus_numpy_negative_values_64_bit(self):
+ with self.test_session() as sess:
+ x = constant_op.constant(-np.logspace(0, 200, num=1000))
+ y = trig.arcsinh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ x_, y_, grad_ = sess.run([x, y, grad])
+ self.assertAllClose(
+ np.arcsinh(x_),
+ y_,
+ rtol=1e-5)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+ def test_arcsinh_is_inverse_to_sinh_near_zero(self):
+ sinh = trig.sinh
+ arcsinh = trig.arcsinh
+ with self.test_session() as sess:
+ x = np.linspace(-1.1, 1.1, num=1000).astype(np.float32)
+ arcsinh_x = arcsinh(x)
+ sinh_arcsinh_x = sinh(arcsinh_x)
+ arcsinh_sinh_arcsinh_x = arcsinh(sinh_arcsinh_x)
+
+ arcsinh_x_, sinh_arcsinh_x_, arcsinh_sinh_arcsinh_x_ = sess.run(
+ [arcsinh_x, sinh_arcsinh_x, arcsinh_sinh_arcsinh_x])
+
+ self.assertAllClose(x, sinh_arcsinh_x_)
+ self.assertAllClose(arcsinh_x_, arcsinh_sinh_arcsinh_x_)
+
+ def test_arcsinh_is_inverse_to_sinh_where_x_is_very_small(self):
+ sinh = trig.sinh
+ arcsinh = trig.arcsinh
+ # Exact same cutoff as is in the code.
+ very_small_cutoff = -0.01 / np.sqrt(np.finfo(np.float32).eps)
+ with self.test_session() as sess:
+ x = -np.logspace(
+ np.log(-very_small_cutoff),
+ np.log(-1000 * very_small_cutoff),
+ num=1000).astype(np.float32)
+ arcsinh_x = arcsinh(x)
+ sinh_arcsinh_x = sinh(arcsinh_x)
+ arcsinh_sinh_arcsinh_x = arcsinh(sinh_arcsinh_x)
+
+ arcsinh_x_, sinh_arcsinh_x_, arcsinh_sinh_arcsinh_x_ = sess.run(
+ [arcsinh_x, sinh_arcsinh_x, arcsinh_sinh_arcsinh_x])
+
+ self.assertAllClose(x, sinh_arcsinh_x_, rtol=1e-5)
+ self.assertAllClose(arcsinh_x_, arcsinh_sinh_arcsinh_x_)
+
+ def test_arcsinh_is_inverse_to_sinh_where_x_is_moderate_or_big(self):
+ sinh = trig.sinh
+ arcsinh = trig.arcsinh
+ very_big_cutoff = np.sqrt(np.finfo(np.float32).max)
+ with self.test_session() as sess:
+ x = np.linspace(1., very_big_cutoff, num=1000).astype(np.float32)
+ arcsinh_x = arcsinh(x)
+ sinh_arcsinh_x = sinh(arcsinh_x)
+ arcsinh_sinh_arcsinh_x = arcsinh(sinh_arcsinh_x)
+
+ arcsinh_x_, sinh_arcsinh_x_, arcsinh_sinh_arcsinh_x_ = sess.run(
+ [arcsinh_x, sinh_arcsinh_x, arcsinh_sinh_arcsinh_x])
+
+ self.assertAllClose(x, sinh_arcsinh_x_, rtol=1e-5)
+ self.assertAllClose(arcsinh_x_, arcsinh_sinh_arcsinh_x_)
+
+ def test_arcsinh_is_inverse_to_sinh_where_x_is_very_big(self):
+ sinh = trig.sinh
+ arcsinh = trig.arcsinh
+ very_big_cutoff = np.sqrt(np.finfo(np.float32).max)
+ with self.test_session() as sess:
+ x = np.logspace(
+ np.log(very_big_cutoff),
+ 5 + np.log(very_big_cutoff), num=1000).astype(np.float32)
+ arcsinh_x = arcsinh(x)
+ sinh_arcsinh_x = sinh(arcsinh_x)
+ arcsinh_sinh_arcsinh_x = arcsinh(sinh_arcsinh_x)
+
+ arcsinh_x_, sinh_arcsinh_x_, arcsinh_sinh_arcsinh_x_ = sess.run(
+ [arcsinh_x, sinh_arcsinh_x, arcsinh_sinh_arcsinh_x])
+
+ self.assertAllClose(x, sinh_arcsinh_x_, rtol=1e-5)
+ self.assertAllClose(arcsinh_x_, arcsinh_sinh_arcsinh_x_)
+
+
+class CoshTest(test.TestCase):
+
+ def test_versus_numpy_scalar_positive_x(self):
+ x = 1.23
+ with self.test_session():
+ self.assertAllClose(np.cosh(x), trig.cosh(x).eval())
+
+ def test_versus_numpy_multidim_x(self):
+ x = [[0., -1.], [0.5, -2.]]
+ with self.test_session():
+ self.assertAllClose(np.cosh(x), trig.cosh(x).eval())
+
+
+class LogCoshTest(test.TestCase):
+
+ def _assert_all_finite(self, values):
+ self.assertAllEqual(np.ones_like(values).astype(np.bool),
+ np.isfinite(values))
+
+ def test_versus_numpy_scalar_32bit(self):
+ with self.test_session() as sess:
+ x_64 = np.linspace(-200, 200, 2000)
+ x = constant_op.constant(x_64.astype(np.float32))
+
+ y = trig.log_cosh(x)
+ grad = gradients_impl.gradients(y, x)[0]
+ y_, grad_ = sess.run([y, grad])
+
+ self.assertAllClose(np.log(np.cosh(x_64)), y_)
+ self._assert_all_finite(y_)
+ self._assert_all_finite(grad_)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/trig.py b/tensorflow/contrib/distributions/python/ops/trig.py
new file mode 100644
index 0000000000..e49288e5b1
--- /dev/null
+++ b/tensorflow/contrib/distributions/python/ops/trig.py
@@ -0,0 +1,173 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trigonometric functions."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import math_ops
+
+__all__ = [
+ "sinh",
+ "arcsinh",
+ "cosh",
+ "log_cosh",
+]
+
+
+def sinh(x, name="sinh"):
+ """Hyperbolic sin: `sinh(x) = (e**x - e**-x) / 2`.
+
+ For `x in (-inf, inf)`, `arcsinh(sinh(x)) = sinh(arcsinh(x)) = x.`
+
+ Args:
+ x: Numeric `Tensor`.
+ name: A string name to prepend to created Ops.
+
+ Returns:
+ Numeric `Tensor` of same `shape` and `dtype` as `x`.
+ """
+ with ops.name_scope(name):
+ # For x near zero, this could be replaced with the identity x --> x as a
+ # Taylor approximation. The Taylor approximation would be more accurate,
+ # and the current implementation will equal zero "too often", specifically
+ # for |x| small enough so that exp(x) = 1 (up to floating point precision).
+ # Although less accurate, the present implementation is still safe for
+ # antipated use. Why? No one would use this function and expect that zero
+ # was not a possible output value.
+ x = ops.convert_to_tensor(x, name="x")
+ return 0.5 * (math_ops.exp(x) - math_ops.exp(-x))
+
+
+def arcsinh(x, name="arcsinh"):
+ """Inverse hyperbolic sin: `arcsinh(x) = log(x + sqrt(x**2 + 1))`.
+
+ For `x in (-inf, inf)`, `arcsinh(sinh(x)) = sinh(arcsinh(x)) = x.`
+
+ Args:
+ x: Numeric `Tensor`.
+ name: A string name to prepend to created Ops.
+
+ Returns:
+ Numeric `Tensor` of same `shape` and `dtype` as `x`.
+ """
+ with ops.name_scope(name):
+ x = ops.convert_to_tensor(x, name="x")
+ finfo = np.finfo(x.dtype.as_numpy_dtype)
+
+ # To compute stable arcsinh(x), we will compute various approximations of
+ # z := x + sqrt(x**2 + 1), and then arcsinh(x) = log(z).
+ # Different approximations are used over different values of x, then the
+ # result is pieced together using tf.where. Since NaN propagate through the
+ # unselected branch during grdients of where, care is taken to ensure that
+ # every approximation is finite for all input values.
+
+ # For x near zero, the straightforward formula for z is fine.
+ # This formula will have trouble once x < 0 and x**2 + 1 = x**2, since then
+ # z = 0 (numerically), and then we have log(z) = log(0) = -inf
+ # This formula also has trouble once x > sqrt(finfo.max), since then
+ # x**2 = inf, and thus z = inf. Therefore we clip.
+ x_near_zero = clip_ops.clip_by_value(x, -1., 1.)
+ x_is_near_zero = math_ops.abs(x) < 1.
+ z_for_x_near_zero = x + math_ops.sqrt(x_near_zero**2 + 1)
+
+ # Some cutoffs.
+ # Important! Keep these cutoffs in sync with the tests, which use cutoffs
+ # of the exact same name.
+ # very_big_cutoff**2 = finfo.max, the maximum representable value.
+ # very_small_cutoff could have been defined as 1 / sqrt(eps), so that for
+ # x < very_small_cutoff, 1 / x**2 + 1 = 1, which causes trouble.
+ # The 0.01 was added in order to make this match numpy in 32bit
+ # as much as possible. Anything < 1 should be stable.
+ very_small_cutoff = -0.01 / np.sqrt(finfo.eps)
+ very_big_cutoff = np.sqrt(finfo.max)
+
+ # For very_small_cutoff < x < -1, and 1 < x < very_big_cutoff, and
+ # x != 0, we can use
+ # z = sqrt(x**2 + 1) = |x| * sqrt(1 + 1 / x**2).
+ # This formula has trouble if x < -sqrt(eps) since then 1 + 1 / x**2 = 1,
+ # and then we get z = x + |x| = 0, thus returning log(0) = -inf.
+ # This formula also has trouble if x**2 = Inf. Therefore we clip.
+ # This formula also has trouble if x = 0, since then we have 1 / 0**2 = inf.
+ x_not_near_zero = array_ops.where(
+ x >= 0.,
+ math_ops.maximum(x, 1.),
+ math_ops.minimum(x, -1.))
+ x_clipped_moderate_or_big = clip_ops.clip_by_value(
+ x_not_near_zero, very_small_cutoff, very_big_cutoff)
+ z_for_moderate_or_big_x = x + math_ops.abs(x) * math_ops.sqrt(
+ 1. + 1. / x_clipped_moderate_or_big**2)
+
+ # For x < very_small_cutoff, we use the first order Taylor series,
+ # sqrt(1 + 1 / x**2) approx 1 + 1 / (2 * x**2)
+ # This formula has trouble for x = 0.
+ x_is_very_small = x < very_small_cutoff
+ z_for_very_small_x = 1 / (2. * math_ops.abs(x_not_near_zero))
+
+ z = array_ops.where(
+ x_is_near_zero,
+ z_for_x_near_zero,
+ array_ops.where(
+ x_is_very_small,
+ z_for_very_small_x,
+ z_for_moderate_or_big_x))
+
+ return math_ops.log(z)
+
+
+def cosh(x, name="cosh"):
+ """Hyperbolic cosine: `cosh(x) = (e**x + e**-x) / 2`.
+
+ For `x in (-inf, inf)`, `arccosh(cosh(x)) = cosh(arccosh(x)) = x.`
+
+ Args:
+ x: Numeric `Tensor`.
+ name: A string name to prepend to created Ops.
+
+ Returns:
+ Numeric `Tensor` of same `shape` and `dtype` as `x`.
+ """
+ with ops.name_scope(name):
+ x = ops.convert_to_tensor(x, name="x")
+ return 0.5 * (math_ops.exp(x) + math_ops.exp(-x))
+
+
+def log_cosh(x, name="log_cosh"):
+ """Logarithm of hyperbolic cosine: `log_cosh(x) = Log[(e**x + e**-x) / 2]`.
+
+ Args:
+ x: Numeric `Tensor`.
+ name: A string name to prepend to created Ops.
+
+ Returns:
+ Numeric `Tensor` of same `shape` and `dtype` as `x`.
+ """
+ # For large |x| >> 1, e**x will become Inf. So we need to approximate
+ # Log[e**x + e**-x] approx |x|.
+ # We also need to ensure that large |x| is never fed to the exponential func.
+ with ops.name_scope(name):
+ x = ops.convert_to_tensor(x, name="x")
+ large_x_value = 0.9 * np.log(np.finfo(x.dtype.as_numpy_dtype).max)
+ x_capped = clip_ops.clip_by_value(x, -large_x_value, large_x_value)
+ return array_ops.where(
+ math_ops.abs(x) > large_x_value,
+ math_ops.abs(x) - np.log(2).astype(x.dtype.as_numpy_dtype),
+ math_ops.log(cosh(x_capped)))
diff --git a/tensorflow/python/kernel_tests/distributions/util_test.py b/tensorflow/python/kernel_tests/distributions/util_test.py
index fb1960f672..07b1ba71bb 100644
--- a/tensorflow/python/kernel_tests/distributions/util_test.py
+++ b/tensorflow/python/kernel_tests/distributions/util_test.py
@@ -703,5 +703,17 @@ class SoftplusTest(test.TestCase):
# Equivalent to `assertAllFalse` (if it existed).
self.assertAllEqual(np.zeros_like(grads).astype(np.bool), np.isnan(grads))
+ def testInverseSoftplusGradientFinite(self):
+ with self.test_session():
+ # This range of x is all finite, and so is 1 / x. So the
+ # gradient and its approximations should be finite as well.
+ x = constant_op.constant(np.logspace(-4.8, 4.5).astype(np.float16))
+ y = distribution_util.softplus_inverse(x)
+ grads = gradients_impl.gradients(y, x)[0].eval()
+ # Equivalent to `assertAllTrue` (if it existed).
+ self.assertAllEqual(
+ np.ones_like(grads).astype(np.bool), np.isfinite(grads))
+
+
if __name__ == "__main__":
test.main()