aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph
diff options
context:
space:
mode:
authorGravatar Dan Moldovan <mdan@google.com>2018-09-13 07:26:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 07:30:24 -0700
commit226cc7c47e2df8682b384aef5c54836948caecb3 (patch)
tree402869c67b0daad7cd959ba8712d7f7c936f9ab5 /tensorflow/python/autograph
parent46aa7cf45c62d193f56f55d7d2ffc5baf7af3b65 (diff)
Allow unsupported comparison operators to be passed through and scale back the coverage of overloads.
It's up for discussion whether we allow overloading everything or let the users rely on the existing operator overloading mechanisms instead. The one case that we do want to support is the equality operator. PiperOrigin-RevId: 212809447
Diffstat (limited to 'tensorflow/python/autograph')
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions.py21
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions_test.py10
-rw-r--r--tensorflow/python/autograph/utils/__init__.py2
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch.py10
-rw-r--r--tensorflow/python/autograph/utils/multiple_dispatch_test.py29
5 files changed, 19 insertions, 53 deletions
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
index ac42ee2c33..8c4d53f9a8 100644
--- a/tensorflow/python/autograph/converters/logical_expressions.py
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'ag__.utils.dynamic_is',
- gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
@@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base):
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+ def _has_matching_func(self, operator):
+ op_type = type(operator)
+ return op_type in self.op_mapping
+
def _matching_func(self, operator):
op_type = type(operator)
- mapped_op = self.op_mapping.get(op_type)
- if not mapped_op:
- raise NotImplementedError('operator %s is not yet supported' % op_type)
- return mapped_op
+ return self.op_mapping[op_type]
def _as_function(self, func_name, args):
template = """
@@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base):
def visit_Compare(self, node):
node = self.generic_visit(node)
+
+ if not all(self._has_matching_func(op) for op in node.ops):
+ if len(node.ops) == 1:
+ # Basic expressions are safe to leave as they are.
+ return node
+ else:
+ raise NotImplementedError(
+ 'compound expression with at least one unsupported '
+ 'operator: {}'.format(node.ops))
+
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None
diff --git a/tensorflow/python/autograph/converters/logical_expressions_test.py b/tensorflow/python/autograph/converters/logical_expressions_test.py
index 5fb3fb992f..b78b4d3a6a 100644
--- a/tensorflow/python/autograph/converters/logical_expressions_test.py
+++ b/tensorflow/python/autograph/converters/logical_expressions_test.py
@@ -47,14 +47,12 @@ class GradientsFunctionTest(converter_testing.TestCase):
with self.cached_session() as sess:
self.assertTrue(sess.run(result.test_fn(True, False, True)))
- def test_ag_utils_lookup(self):
+ def test_unsupported_ops(self):
def test_fn(a, b):
- return a is b or a is not b
+ return a in b
- with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or
- ) as result:
- with self.cached_session() as sess:
- self.assertTrue(sess.run(result.test_fn(True, False)))
+ with self.converted(test_fn, logical_expressions, {}) as result:
+ self.assertTrue(result.test_fn('a', ('a',)))
if __name__ == '__main__':
diff --git a/tensorflow/python/autograph/utils/__init__.py b/tensorflow/python/autograph/utils/__init__.py
index e38c82a079..c781958481 100644
--- a/tensorflow/python/autograph/utils/__init__.py
+++ b/tensorflow/python/autograph/utils/__init__.py
@@ -20,8 +20,6 @@ from __future__ import print_function
from tensorflow.python.autograph.utils.context_managers import control_dependency_on_returns
from tensorflow.python.autograph.utils.misc import alias_tensors
-from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is
-from tensorflow.python.autograph.utils.multiple_dispatch import dynamic_is_not
from tensorflow.python.autograph.utils.multiple_dispatch import run_cond
from tensorflow.python.autograph.utils.py_func import wrap_py_func
from tensorflow.python.autograph.utils.tensor_list import dynamic_list_append
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch.py b/tensorflow/python/autograph/utils/multiple_dispatch.py
index 33f521db2c..107c8f7a68 100644
--- a/tensorflow/python/autograph/utils/multiple_dispatch.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch.py
@@ -22,16 +22,6 @@ from tensorflow.python.autograph.utils.type_check import is_tensor
from tensorflow.python.ops import control_flow_ops
-def dynamic_is(left, right):
- # TODO(alexbw) if we're sure we should leave 'is' in place,
- # then change the semantics in converters/logical_expressions.py
- return left is right
-
-
-def dynamic_is_not(left, right):
- return left is not right
-
-
def run_cond(condition, true_fn, false_fn):
"""Type-dependent functional conditional.
diff --git a/tensorflow/python/autograph/utils/multiple_dispatch_test.py b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
index ed20822529..2a77c895ce 100644
--- a/tensorflow/python/autograph/utils/multiple_dispatch_test.py
+++ b/tensorflow/python/autograph/utils/multiple_dispatch_test.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
from tensorflow.python.autograph.utils import multiple_dispatch
from tensorflow.python.client.session import Session
from tensorflow.python.framework.constant_op import constant
@@ -28,33 +26,6 @@ from tensorflow.python.platform import test
class MultipleDispatchTest(test.TestCase):
- def test_dynamic_is_python(self):
- a = np.eye(3)
- also_a = a
- not_actually_a = np.eye(3)
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
- def test_dynamic_is_tf(self):
- with Session().as_default():
- a = constant([2.0])
- also_a = a
- not_actually_a = constant([2.0])
- should_be_true1 = multiple_dispatch.dynamic_is(a, also_a)
- should_be_false1 = multiple_dispatch.dynamic_is_not(a, also_a)
- should_be_true2 = multiple_dispatch.dynamic_is_not(a, not_actually_a)
- should_be_false2 = multiple_dispatch.dynamic_is(a, not_actually_a)
- self.assertTrue(should_be_true1)
- self.assertTrue(should_be_true2)
- self.assertFalse(should_be_false1)
- self.assertFalse(should_be_false2)
-
def test_run_cond_python(self):
true_fn = lambda: (2,)
false_fn = lambda: (3,)