aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Siddharth Agrawal <siddharth.950@gmail.com>2016-06-30 03:31:27 +0530
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2016-06-29 15:01:27 -0700
commite989ba261618f10b3d473ce61b9137d32b4703f6 (patch)
treea3af34a73bda98b132f03727f1354c303f419b73
parent51e803e43c298ce5a47a60ff8b6c824244a8948f (diff)
Enable tf.tanh() for SparseTensor (#2998)
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py6
-rw-r--r--tensorflow/python/ops/math_ops.py15
2 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 4a21d1acc6..7f1be574bb 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -213,6 +213,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.negative, tf.neg)
self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
def testFloatTanhEdge(self):
@@ -251,6 +252,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.negative, tf.neg)
self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(x, np.sign, tf.sign)
def testDoubleBasic(self):
@@ -288,6 +290,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.negative, tf.neg)
self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
def testHalfBasic(self):
@@ -320,6 +323,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.negative, tf.neg)
self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
+ self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
def testInt32Basic(self):
@@ -374,6 +378,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.negative, tf.neg)
self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3)
+ self._compareBothSparse(x, np.tanh, tf.tanh)
# Numpy uses an incorrect definition of sign; use the right one instead.
def complex_sign(x):
@@ -404,6 +409,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.negative, tf.neg)
self._compareBothSparse(x, np.square, tf.square)
self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3)
+ self._compareBothSparse(x, np.tanh, tf.tanh)
# Numpy uses an incorrect definition of sign; use the right one instead.
def complex_sign(x):
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index d27cefc61d..0a76450c5b 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -1548,17 +1548,20 @@ def tanh(x, name=None):
"""Computes hyperbolic tangent of `x` element-wise.
Args:
- x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`,
- or `qint32`.
+ x: A Tensor or SparseTensor with type `float`, `double`, `int32`,
+ `complex64`, `int64`, or `qint32`.
name: A name for the operation (optional).
Returns:
- A Tensor with the same type as `x` if `x.dtype != qint32` otherwise
- the return type is `quint8`.
+ A Tensor or SparseTensor respectively with the same type as `x` if
+ `x.dtype != qint32` otherwise the return type is `quint8`.
"""
with ops.op_scope([x], name, "Tanh") as name:
- x = ops.convert_to_tensor(x, name="x")
- return gen_math_ops._tanh(x, name=name)
+ if isinstance(x, ops.SparseTensor):
+ x_tanh = gen_math_ops._tanh(x.values, name=name)
+ return ops.SparseTensor(indices=x.indices, values=x_tanh, shape=x.shape)
+ else:
+ return gen_math_ops._tanh(x, name=name)
ops.RegisterShape("Abs")(common_shapes.unchanged_shape)