diff options
author | 2016-06-30 03:31:27 +0530 | |
---|---|---|
committer | 2016-06-29 15:01:27 -0700 | |
commit | e989ba261618f10b3d473ce61b9137d32b4703f6 (patch) | |
tree | a3af34a73bda98b132f03727f1354c303f419b73 | |
parent | 51e803e43c298ce5a47a60ff8b6c824244a8948f (diff) |
Enable tf.tanh() for SparseTensor (#2998)
-rw-r--r-- | tensorflow/python/kernel_tests/cwise_ops_test.py | 6 | ||||
-rw-r--r-- | tensorflow/python/ops/math_ops.py | 15 |
2 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py index 4a21d1acc6..7f1be574bb 100644 --- a/tensorflow/python/kernel_tests/cwise_ops_test.py +++ b/tensorflow/python/kernel_tests/cwise_ops_test.py @@ -213,6 +213,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) def testFloatTanhEdge(self): @@ -251,6 +252,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(x, np.sign, tf.sign) def testDoubleBasic(self): @@ -288,6 +290,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) def testHalfBasic(self): @@ -320,6 +323,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) self._compareBothSparse(y, np.sign, tf.sign) def testInt32Basic(self): @@ -374,6 +378,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) # Numpy uses an incorrect definition of sign; use the right one instead. def complex_sign(x): @@ -404,6 +409,7 @@ class UnaryOpTest(tf.test.TestCase): self._compareBothSparse(x, np.negative, tf.neg) self._compareBothSparse(x, np.square, tf.square) self._compareBothSparse(x, np.sqrt, tf.sqrt, 1e-3) + self._compareBothSparse(x, np.tanh, tf.tanh) # Numpy uses an incorrect definition of sign; use the right one instead. def complex_sign(x): diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py index d27cefc61d..0a76450c5b 100644 --- a/tensorflow/python/ops/math_ops.py +++ b/tensorflow/python/ops/math_ops.py @@ -1548,17 +1548,20 @@ def tanh(x, name=None): """Computes hyperbolic tangent of `x` element-wise. Args: - x: A Tensor with type `float32`, `float64`, `int32`, `complex64`, `int64`, - or `qint32`. + x: A Tensor or SparseTensor with type `float`, `double`, `int32`, + `complex64`, `int64`, or `qint32`. name: A name for the operation (optional). Returns: - A Tensor with the same type as `x` if `x.dtype != qint32` otherwise - the return type is `quint8`. + A Tensor or SparseTensor respectively with the same type as `x` if + `x.dtype != qint32` otherwise the return type is `quint8`. """ with ops.op_scope([x], name, "Tanh") as name: - x = ops.convert_to_tensor(x, name="x") - return gen_math_ops._tanh(x, name=name) + if isinstance(x, ops.SparseTensor): + x_tanh = gen_math_ops._tanh(x.values, name=name) + return ops.SparseTensor(indices=x.indices, values=x_tanh, shape=x.shape) + else: + return gen_math_ops._tanh(x, name=name) ops.RegisterShape("Abs")(common_shapes.unchanged_shape) |