aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-10 11:43:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-10 11:47:16 -0700
commit2db78d20af06f256b86889c3f7d202ae88d6a896 (patch)
treeacfb2388df350de933cd700664ac057de2bd2c1e /tensorflow/contrib/autograph
parent8314a59b275c828d969a33952c7e611d58beac1d (diff)
Add support for builtin abs() to Autograph
PiperOrigin-RevId: 208243676
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/utils/builtins.py9
-rw-r--r--tensorflow/contrib/autograph/utils/builtins_test.py17
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py
index ccbe5fc954..4dd440ef19 100644
--- a/tensorflow/contrib/autograph/utils/builtins.py
+++ b/tensorflow/contrib/autograph/utils/builtins.py
@@ -44,6 +44,8 @@ def dynamic_builtin(f, *args, **kwargs):
return dynamic_int(*args, **kwargs)
if f is float:
return dynamic_float(*args, **kwargs)
+ if f is abs:
+ return dynamic_abs(*args, **kwargs)
raise NotImplementedError(
'The "%s" builtin is not yet supported.' % f.__name__)
@@ -81,6 +83,13 @@ def dynamic_float(num_or_tensor, **kwargs):
return float(num_or_tensor)
+def dynamic_abs(num_or_tensor, **kwargs):
+ if tensor_util.is_tensor(num_or_tensor):
+ return math_ops.abs(num_or_tensor, **kwargs)
+ else:
+ return abs(num_or_tensor, **kwargs)
+
+
def dynamic_range(start_or_stop, stop=None, step=None):
"""Implementation of range using dynamic dispatch."""
if type_check.is_tensor(start_or_stop, stop, step):
diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py
index b4821f36fc..b1cd5253bc 100644
--- a/tensorflow/contrib/autograph/utils/builtins_test.py
+++ b/tensorflow/contrib/autograph/utils/builtins_test.py
@@ -44,6 +44,23 @@ class BuiltinsTest(test.TestCase):
with self.test_session() as sess:
self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a)))
+ def test_dynamic_abs_tf_scalar(self):
+ a = constant_op.constant(-1)
+
+ with self.test_session() as sess:
+ self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a)))
+
+ def test_dynamic_abs_tf_array(self):
+ a = constant_op.constant([-1, 2, -3])
+
+ with self.test_session() as sess:
+ self.assertListEqual([1, 2, 3],
+ list(sess.run(builtins.dynamic_builtin(abs, a))))
+
+ def test_dynamic_abs_py_scalar(self):
+ a = -1
+ self.assertEqual(1, builtins.dynamic_builtin(abs, a))
+
def test_dynamic_len_tf_matrix(self):
a = constant_op.constant([[1, 2], [3, 4]])