aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions.py21
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions_test.py10
-rw-r--r--tensorflow/python/autograph/utils/__init__.py2
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch.py10
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch_test.py29
5 files changed, 19 insertions, 53 deletions
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
index ac42ee2c33..8c4d53f9a8 100644
--- a/tensorflow/python/autograph/converters/logical_expressions.py
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'ag__.utils.dynamic_is',
- gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
@@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base):
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+ def _has_matching_func(self, operator):
+ op_type = type(operator)
+ return op_type in self.op_mapping
+
def _matching_func(self, operator):
op_type = type(operator)
- mapped_op = self.op_mapping.get(op_type)
- if not mapped_op:
- raise NotImplementedError('operator %s is not yet supported' % op_type)
- return mapped_op
+ return self.op_mapping[op_type]
def _as_function(self, func_name, args):
template = """
@@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base):
def visit_Compare(self, node):
node = self.generic_visit(node)
+
+ if not all(self._has_matching_func(op) for op in node.ops):
+ if len(node.ops) == 1:
+ # Basic expressions are safe to leave as they are.
+ return node
+ else:
+ raise NotImplementedError(
+ 'compound expression with at least one unsupported '
+ 'operator: {}'.format(node.ops))
+
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None
diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py
index 5fb3fb992f..b78b4d3a6a 100644
--- a/tensorflow/python/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/python/autograph/converters/logical_expressions_test.py
@@ -47,14 +47,12 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
- def test_ag_utils_lookup(self):
+ def test_unsupported_ops(self):
def test_fn(a, b):
- return a is b or a is not b
+ return a in b
- with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
- ) as result:
- with self.cached_session() as sess:
- self.assertTrue(sess.run(result.test_fn(True, False)))
+ with self.converted(test_fn, logical_expressions, {}) as result:
+ self.assertTrue(result.test_fn('a', ('a',)))
if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py
index e38c82a079..c781958481 100644
--- a/tensorflow/python/autograph/utils/__init__.py
+++ b/tensorflow/python/autograph/utils/__init__.py
@@ -20,8 +20,6 @@ from __future__ import print_function
from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.python.autograph.utils.misc import alias_tensors
-from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is
-from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is_not
from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
from tensorflow.python.autograph.utils.py_func import wrap_py_func
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py
index 33f521db2c..107c8f7a68 100644
--- a/tensorflow/python/autograph/utils/multiple_dispatch.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch.py
@@ -22,16 +22,6 @@ from tensorflow.python.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
-def dynamic_is(left, right):
- # TODO(alexbw) if we're sure we should leave 'is' in place,
- # then change the semantics in converters/logical_expressions.py
- return left is right
-
-
-def dynamic_is_not(left, right):
- return left is not right
-
-
def run_cond(condition, true_fn, false_fn):
"""Type-dependent functional conditional.
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
index ed20822529..2a77c895ce 100644
--- a/tensorflow/python/autograph/utils/multiple_dispatch_test.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
from tensorflow.python.autograph.utils import multiple_dispatch
from tensorflow.python.client.session import Session
from tensorflow.python.framework.constant_op import constant
@@ -28,33 +26,6 @@ from tensorflow.python.platform import test
class MultipleDispatchTest(test.TestCase):
- def test_dynamic_is_python(self):
- a = np.eye(3)
- also_a = a
- not_actually_a = np.eye(3)
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
- def test_dynamic_is_tf(self):
- with Session().as_default():
- a = constant([2.0])
- also_a = a
- not_actually_a = constant([2.0])
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
def test_run_cond_python(self):
true_fn = lambda: (2,)
false_fn = lambda: (3,)