aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Siddharth Agrawal <siddharth.950@gmail.com>2016-07-01 01:59:14 +0530
committerGravatar Martin Wicke <martin.wicke@gmail.com>2016-06-30 13:29:14 -0700
commitaa2cacd6627ffb296bedc910c957a0fd4a2f957f (patch)
treed9965911579741dc5d2513447088aba20c9bb92a /tensorflow/python
parentac90ecb08d80fc92147fdb3ed852fbb1ddbecf5f (diff)
Enable tf.erf() for SparseTensor (#3122)
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/kernel_tests/cwise_ops_test.py4
-rw-r--r--tensorflow/python/ops/math_ops.py19
2 files changed, 23 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/cwise_ops_test.py b/tensorflow/python/kernel_tests/cwise_ops_test.py
index 7f1be574bb..093da97469 100644
--- a/tensorflow/python/kernel_tests/cwise_ops_test.py
+++ b/tensorflow/python/kernel_tests/cwise_ops_test.py
@@ -215,6 +215,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
def testFloatTanhEdge(self):
x = np.arange(40, 40 + 6).reshape(6).astype(np.float32)
@@ -254,6 +255,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(x, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(x, np.sign, tf.sign)
+ self._compareBothSparse(x, np.sign, tf.erf)
def testDoubleBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float64)
@@ -292,6 +294,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), tf.erf)
def testHalfBasic(self):
x = np.arange(-3, 3).reshape(1, 3, 2).astype(np.float16)
@@ -325,6 +328,7 @@ class UnaryOpTest(tf.test.TestCase):
self._compareBothSparse(z, np.sqrt, tf.sqrt, tol=1e-3)
self._compareBothSparse(x, np.tanh, tf.tanh)
self._compareBothSparse(y, np.sign, tf.sign)
+ self._compareBothSparse(x, np.vectorize(math.erf), tf.erf, tol=1e-3)
def testInt32Basic(self):
x = np.arange(-6, 6, 2).reshape(1, 3, 2).astype(np.int32)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 0a76450c5b..38c7e51594 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -348,6 +348,25 @@ def sqrt(x, name=None):
return gen_math_ops.sqrt(x, name=name)
+def erf(x, name=None):
+ """Computes the Gauss error function of `x` element-wise.
+
+ Args:
+ x: A `Tensor` of `SparseTensor`. Must be one of the following types: `half`,
+ `float32`, `float64`.
+ name: A name for the operation (optional).
+
+ Returns:
+ A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
+ """
+ with ops.op_scope([x], name, "Erf") as name:
+ if isinstance(x, ops.SparseTensor):
+ x_erf = gen_math_ops.erf(x.values, name=name)
+ return ops.SparseTensor(indices=x.indices, values=x_erf, shape=x.shape)
+ else:
+ return gen_math_ops.erf(x, name=name)
+
+
def complex_abs(x, name=None):
r"""Computes the complex absolute value of a tensor.