diff options
Diffstat (limited to 'tensorflow/python/autograph')
5 files changed, 19 insertions, 53 deletions
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py index ac42ee2c33..8c4d53f9a8 100644 --- a/tensorflow/python/autograph/converters/logical_expressions.py +++ b/tensorflow/python/autograph/converters/logical_expressions.py @@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base): gast.NotEq: 'tf.not_equal', gast.Or: 'tf.logical_or', gast.USub: 'tf.negative', - gast.Is: 'ag__.utils.dynamic_is', - gast.IsNot: 'ag__.utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): @@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base): '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"') + def _has_matching_func(self, operator): + op_type = type(operator) + return op_type in self.op_mapping + def _matching_func(self, operator): op_type = type(operator) - mapped_op = self.op_mapping.get(op_type) - if not mapped_op: - raise NotImplementedError('operator %s is not yet supported' % op_type) - return mapped_op + return self.op_mapping[op_type] def _as_function(self, func_name, args): template = """ @@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base): def visit_Compare(self, node): node = self.generic_visit(node) + + if not all(self._has_matching_func(op) for op in node.ops): + if len(node.ops) == 1: + # Basic expressions are safe to leave as they are. + return node + else: + raise NotImplementedError( + 'compound expression with at least one unsupported ' + 'operator: {}'.format(node.ops)) + ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py index 5fb3fb992f..b78b4d3a6a 100644 --- a/tensorflow/python/autograph/converters/logical_expressions_test.py +++ b/tensorflow/python/autograph/converters/logical_expressions_test.py @@ -47,14 +47,12 @@ class GradientsFunctionTest(converter_testing.TestCase): with self.cached_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) - def test_ag_utils_lookup(self): + def test_unsupported_ops(self): def test_fn(a, b): - return a is b or a is not b + return a in b - with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or - ) as result: - with self.cached_session() as sess: - self.assertTrue(sess.run(result.test_fn(True, False))) + with self.converted(test_fn, logical_expressions, {}) as result: + self.assertTrue(result.test_fn('a', ('a',))) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py index e38c82a079..c781958481 100644 --- a/tensorflow/python/autograph/utils/__init__.py +++ b/tensorflow/python/autograph/utils/__init__.py @@ -20,8 +20,6 @@ from __future__ import print_function from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns from tensorflow.python.autograph.utils.misc import alias_tensors -from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is -from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is_not from tensorflow.python.autograph.utils.multiple_dispatch import run_cond from tensorflow.python.autograph.utils.py_func import wrap_py_func from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append diff --git a/tensorflow/python/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py index 33f521db2c..107c8f7a68 100644 --- a/tensorflow/python/autograph/utils/multiple_dispatch.py +++ b/tensorflow/python/autograph/utils/multiple_dispatch.py @@ -22,16 +22,6 @@ from tensorflow.python.autograph.utils.type_check import is_tensor from tensorflow.python.ops import control_flow_ops -def dynamic_is(left, right): - # TODO(alexbw) if we're sure we should leave 'is' in place, - # then change the semantics in converters/logical_expressions.py - return left is right - - -def dynamic_is_not(left, right): - return left is not right - - def run_cond(condition, true_fn, false_fn): """Type-dependent functional conditional. diff --git a/tensorflow/python/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py index ed20822529..2a77c895ce 100644 --- a/tensorflow/python/autograph/utils/multiple_dispatch_test.py +++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py @@ -18,8 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.python.autograph.utils import multiple_dispatch from tensorflow.python.client.session import Session from tensorflow.python.framework.constant_op import constant @@ -28,33 +26,6 @@ from tensorflow.python.platform import test class MultipleDispatchTest(test.TestCase): - def test_dynamic_is_python(self): - a = np.eye(3) - also_a = a - not_actually_a = np.eye(3) - should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) - should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) - should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) - should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) - self.assertTrue(should_be_true1) - self.assertTrue(should_be_true2) - self.assertFalse(should_be_false1) - self.assertFalse(should_be_false2) - - def test_dynamic_is_tf(self): - with Session().as_default(): - a = constant([2.0]) - also_a = a - not_actually_a = constant([2.0]) - should_be_true1 = multiple_dispatch.dynamic_is(a, also_a) - should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a) - should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a) - should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a) - self.assertTrue(should_be_true1) - self.assertTrue(should_be_true2) - self.assertFalse(should_be_false1) - self.assertFalse(should_be_false2) - def test_run_cond_python(self): true_fn = lambda: (2,) false_fn = lambda: (3,) |