diff options
author | 2018-08-10 11:43:15 -0700 | |
---|---|---|
committer | 2018-08-10 11:47:16 -0700 | |
commit | 2db78d20af06f256b86889c3f7d202ae88d6a896 (patch) | |
tree | acfb2388df350de933cd700664ac057de2bd2c1e /tensorflow/contrib/autograph | |
parent | 8314a59b275c828d969a33952c7e611d58beac1d (diff) |
Add support for builtin abs() to Autograph
PiperOrigin-RevId: 208243676
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/utils/builtins.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/utils/builtins_test.py | 17 |
2 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/autograph/utils/builtins.py b/tensorflow/contrib/autograph/utils/builtins.py index ccbe5fc954..4dd440ef19 100644 --- a/tensorflow/contrib/autograph/utils/builtins.py +++ b/tensorflow/contrib/autograph/utils/builtins.py @@ -44,6 +44,8 @@ def dynamic_builtin(f, *args, **kwargs): return dynamic_int(*args, **kwargs) if f is float: return dynamic_float(*args, **kwargs) + if f is abs: + return dynamic_abs(*args, **kwargs) raise NotImplementedError( 'The "%s" builtin is not yet supported.' % f.__name__) @@ -81,6 +83,13 @@ def dynamic_float(num_or_tensor, **kwargs): return float(num_or_tensor) +def dynamic_abs(num_or_tensor, **kwargs): + if tensor_util.is_tensor(num_or_tensor): + return math_ops.abs(num_or_tensor, **kwargs) + else: + return abs(num_or_tensor, **kwargs) + + def dynamic_range(start_or_stop, stop=None, step=None): """Implementation of range using dynamic dispatch.""" if type_check.is_tensor(start_or_stop, stop, step): diff --git a/tensorflow/contrib/autograph/utils/builtins_test.py b/tensorflow/contrib/autograph/utils/builtins_test.py index b4821f36fc..b1cd5253bc 100644 --- a/tensorflow/contrib/autograph/utils/builtins_test.py +++ b/tensorflow/contrib/autograph/utils/builtins_test.py @@ -44,6 +44,23 @@ class BuiltinsTest(test.TestCase): with self.test_session() as sess: self.assertEqual(3, sess.run(builtins.dynamic_builtin(len, a))) + def test_dynamic_abs_tf_scalar(self): + a = constant_op.constant(-1) + + with self.test_session() as sess: + self.assertEqual(1, sess.run(builtins.dynamic_builtin(abs, a))) + + def test_dynamic_abs_tf_array(self): + a = constant_op.constant([-1, 2, -3]) + + with self.test_session() as sess: + self.assertListEqual([1, 2, 3], + list(sess.run(builtins.dynamic_builtin(abs, a)))) + + def test_dynamic_abs_py_scalar(self): + a = -1 + self.assertEqual(1, builtins.dynamic_builtin(abs, a)) + def test_dynamic_len_tf_matrix(self): a = constant_op.constant([[1, 2], [3, 4]]) |