diff options
author | 2018-09-13 07:26:18 -0700 | |
---|---|---|
committer | 2018-09-13 07:30:24 -0700 | |
commit | 226cc7c47e2df8682b384aef5c54836948caecb3 (patch) | |
tree | 402869c67b0daad7cd959ba8712d7f7c936f9ab5 /tensorflow/python/autograph | |
parent | 46aa7cf45c62d193f56f55d7d2ffc5baf7af3b65 (diff) |
Allow unsupported comparison operators to be passed through and scale back the coverage of overloads.
It's up for discussion whether we allow overloading everything or let the users rely on the existing operator overloading mechanisms instead. The one case that we do want to support is the equality operator.
PiperOrigin-RevId: 212809447
Diffstat (limited to 'tensorflow/python/autograph')
5 files changed, 19 insertions, 53 deletions
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py index ac42ee2c33..8c4d53f9a8 100644 --- a/tensorflow/python/autograph/converters/logical_expressions.py +++ b/tensorflow/python/autograph/converters/logical_expressions.py @@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base): gast.NotEq: 'tf.not_equal', gast.Or: 'tf.logical_or', gast.USub: 'tf.negative', - gast.Is: 'ag__.utils.dynamic_is', - gast.IsNot: 'ag__.utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): @@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base): '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"') + def _has_matching_func(self, operator): + op_type = type(operator) + return op_type in self.op_mapping + def _matching_func(self, operator): op_type = type(operator) - mapped_op = self.op_mapping.get(op_type) - if not mapped_op: - raise NotImplementedError('operator %s is not yet supported' % op_type) - return mapped_op + return self.op_mapping[op_type] def _as_function(self, func_name, args): template = """ @@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base): def visit_Compare(self, node): node = self.generic_visit(node) + + if not all(self._has_matching_func(op) for op in node.ops): + if len(node.ops) == 1: + # Basic expressions are safe to leave as they are. + return node + else: + raise NotImplementedError( + 'compound expression with at least one unsupported ' + 'operator: {}'.format(node.ops)) + ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py index 5fb3fb992f..b78b4d3a6a 100644 --- a/tensorflow/python/autograph/converters/logical_expressions_test.py +++ b/tensorflow/python/autograph/converters/logical_expressions_test.py @@ -47,14 +47,12 @@ class GradientsFunctionTest(converter_testing.TestCase): with self.cached_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) - def test_ag_utils_lookup(self): + def test_unsupported_ops(self): def test_fn(a, b): - return a is b or a is not b + return a in b - with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or - ) as result: - with self.cached_session() as sess: - self.assertTrue(sess.run(result.test_fn(True, False))) + with self.converted(test_fn, logical_expressions, {}) as result: + self.assertTrue(result.test_fn('a', ('a',))) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py index e38c82a079..c781958481 100644 --- a/tensorflow/python/autograph/utils/__init__.py +++ b/tensorflow/python/autograph/utils/__init__.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns from tensorflow.python.autograph.utils.misc import alias_tensors -from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is -from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is_not from tensorflow.python.autograph.utils.multiple_dispatch import run_cond from tensorflow.python.autograph.utils.py_func import wrap_py_func from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append diff --git a/tensorflow/python/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py index 33f521db2c..107c8f7a68 100644 --- a/tensorflow/python/autograph/utils/multiple_dispatch.py +++ b/tensorflow/python/autograph/utils/multiple_dispatch.py @@ -22,16 +22,6 @@ from tensorflow.python.autograph.utils.type_check import is_tensor from tensorflow.python.ops import control_flow_ops -def dynamic_is(left, right): - # TODO(alexbw) if we're sure we should leave 'is' in place, - # then change the semantics in converters/logical_expressions.py - return left is right - - -def dynamic_is_not(left, right): - return left is not right - - def run_cond(condition, true_fn, false_fn): """Type-dependent functional conditional. diff --git a/tensorflow/python/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py index ed20822529..2a77c895ce 100644 --- a/tensorflow/python/autograph/utils/multiple_dispatch_test.py +++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.python.autograph.utils import multiple_dispatch from tensorflow.python.client.session import Session from tensorflow.python.framework.constant_op import constant @@ -28,33 +26,6 @@ from tensorflow.python.platform import test class MultipleDispatchTest(test.TestCase): - def test_dynamic_is_python(self): - a = np.eye(3) - also_a = a - not_actually_a = np.eye(3) - should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) - should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) - should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) - should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) - self.assertTrue(should_be_true1) - self.assertTrue(should_be_true2) - self.assertFalse(should_be_false1) - self.assertFalse(should_be_false2) - - def test_dynamic_is_tf(self): - with Session().as_default(): - a = constant([2.0]) - also_a = a - not_actually_a = constant([2.0]) - should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) - should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) - should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) - should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) - self.assertTrue(should_be_true1) - self.assertTrue(should_be_true2) - self.assertFalse(should_be_false1) - self.assertFalse(should_be_false2) - def test_run_cond_python(self): true_fn = lambda: (2,) false_fn = lambda: (3,) |