aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/converters/logical_expressions.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/converters/logical_expressions.py')
-rw-r--r--tensorflow/python/autograph/converters/logical_expressions.py21
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py
index ac42ee2c33..8c4d53f9a8 100644
--- a/tensorflow/python/autograph/converters/logical_expressions.py
+++ b/tensorflow/python/autograph/converters/logical_expressions.py
@@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base):
gast.NotEq: 'tf.not_equal',
gast.Or: 'tf.logical_or',
gast.USub: 'tf.negative',
- gast.Is: 'ag__.utils.dynamic_is',
- gast.IsNot: 'ag__.utils.dynamic_is_not'
}
def _expect_simple_symbol(self, operand):
@@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base):
'"a.x or b"; for a workaround, assign the expression to a local '
'variable and use that instead, for example "tmp = a.x", "tmp or b"')
+ def _has_matching_func(self, operator):
+ op_type = type(operator)
+ return op_type in self.op_mapping
+
def _matching_func(self, operator):
op_type = type(operator)
- mapped_op = self.op_mapping.get(op_type)
- if not mapped_op:
- raise NotImplementedError('operator %s is not yet supported' % op_type)
- return mapped_op
+ return self.op_mapping[op_type]
def _as_function(self, func_name, args):
template = """
@@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base):
def visit_Compare(self, node):
node = self.generic_visit(node)
+
+ if not all(self._has_matching_func(op) for op in node.ops):
+ if len(node.ops) == 1:
+ # Basic expressions are safe to leave as they are.
+ return node
+ else:
+ raise NotImplementedError(
+ 'compound expression with at least one unsupported '
+ 'operator: {}'.format(node.ops))
+
ops_and_comps = list(zip(node.ops, node.comparators))
left = node.left
op_tree = None