From e989ba261618f10b3d473ce61b9137d32b4703f6 Mon Sep 17 00:00:00 2001 From: Siddharth Agrawal Date: Thu, 30 Jun 2016 03:31:27 +0530 Subject: Enable tf.tanh() for SparseTensor (#2998) --- tensorflow/python/kernel_tests/cwise_ops_test.py | 6 ++++++ tensorflow/python/ops/math_ops.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 4a21d1acc6..7f1be574bb 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -213,6 +213,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) def testFloatTanhEdge(self): @@ -251,6 +252,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(x, np.sign, tf.sign) def testDoubleBasic(self): @@ -288,6 +290,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) def testHalfBasic(self): @@ -320,6 +323,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) def testInt32Basic(self): @@ -374,6 +378,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) # Numpy uses an incorrect definition of sign; use the right one instead. def complex_sign(x): @@ -404,6 +409,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) # Numpy uses an incorrect definition of sign; use the right one instead. def complex_sign(x): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index d27cefc61d..0a76450c5b 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1548,17 +1548,20 @@ def tanh(x, name=None): """Computes hyperbolic tangent of `x` element-wise. Args: - x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`, - or `qint32`. + x: A Tensor or SparseTensor with type `float`, `double`, `int32`, + `complex64`, `int64`, or `qint32`. name: A name for the operation (optional). Returns: - A Tensor with the same type as `x` if `x.dtype != qint32` otherwise - the return type is `quint8`. + A Tensor or SparseTensor respectively with the same type as `x` if + `x.dtype != qint32` otherwise the return type is `quint8`. """ with ops.op_scope([x], name, "Tanh") as name: - x = ops.convert_to_tensor(x, name="x") - return gen_math_ops._tanh(x, name=name) + if isinstance(x, ops.SparseTensor): + x_tanh = gen_math_ops._tanh(x.values, name=name) + return ops.SparseTensor(indices=x.indices, values=x_tanh, shape=x.shape) + else: + return gen_math_ops._tanh(x, name=name) ops.RegisterShape("Abs")(common_shapes.unchanged_shape) -- cgit v1.2.3