diff options
Diffstat (limited to 'tensorflow/python/autograph/converters/logical_expressions.py')
-rw-r--r-- | tensorflow/python/autograph/converters/logical_expressions.py | 21 |
1 files changed, 15 insertions, 6 deletions
diff --git a/tensorflow/python/autograph/converters/logical_expressions.py b/tensorflow/python/autograph/converters/logical_expressions.py index ac42ee2c33..8c4d53f9a8 100644 --- a/tensorflow/python/autograph/converters/logical_expressions.py +++ b/tensorflow/python/autograph/converters/logical_expressions.py @@ -57,8 +57,6 @@ class LogicalExpressionTransformer(converter.Base): gast.NotEq: 'tf.not_equal', gast.Or: 'tf.logical_or', gast.USub: 'tf.negative', - gast.Is: 'ag__.utils.dynamic_is', - gast.IsNot: 'ag__.utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): @@ -72,12 +70,13 @@ class LogicalExpressionTransformer(converter.Base): '"a.x or b"; for a workaround, assign the expression to a local ' 'variable and use that instead, for example "tmp = a.x", "tmp or b"') + def _has_matching_func(self, operator): + op_type = type(operator) + return op_type in self.op_mapping + def _matching_func(self, operator): op_type = type(operator) - mapped_op = self.op_mapping.get(op_type) - if not mapped_op: - raise NotImplementedError('operator %s is not yet supported' % op_type) - return mapped_op + return self.op_mapping[op_type] def _as_function(self, func_name, args): template = """ @@ -90,6 +89,16 @@ class LogicalExpressionTransformer(converter.Base): def visit_Compare(self, node): node = self.generic_visit(node) + + if not all(self._has_matching_func(op) for op in node.ops): + if len(node.ops) == 1: + # Basic expressions are safe to leave as they are. + return node + else: + raise NotImplementedError( + 'compound expression with at least one unsupported ' + 'operator: {}'.format(node.ops)) + ops_and_comps = list(zip(node.ops, node.comparators)) left = node.left op_tree = None |