diff options
3 files changed, 11 insertions, 5 deletions
diff --git a/tensorflow/contrib/autograph/converters/logical_expressions.py b/tensorflow/contrib/autograph/converters/logical_expressions.py index 16eb1f0e3f..41c3424fa3 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions.py @@ -57,8 +57,8 @@ class LogicalExpressionTransformer(converter.Base): gast.NotEq: 'tf.not_equal', gast.Or: 'tf.logical_or', gast.USub: 'tf.negative', - gast.Is: 'autograph_utils.dynamic_is', - gast.IsNot: 'autograph_utils.dynamic_is_not' + gast.Is: 'ag__.utils.dynamic_is', + gast.IsNot: 'ag__.utils.dynamic_is_not' } def _expect_simple_symbol(self, operand): diff --git a/tensorflow/contrib/autograph/converters/logical_expressions_test.py b/tensorflow/contrib/autograph/converters/logical_expressions_test.py index 8f9eee7081..409a73afba 100644 --- a/tensorflow/contrib/autograph/converters/logical_expressions_test.py +++ b/tensorflow/contrib/autograph/converters/logical_expressions_test.py @@ -47,6 +47,15 @@ class GradientsFunctionTest(converter_testing.TestCase): with self.cached_session() as sess: self.assertTrue(sess.run(result.test_fn(True, False, True))) + def test_ag_utils_lookup(self): + def test_fn(a, b): + return a is b or a is not b + + with self.converted(test_fn, logical_expressions, {}, math_ops.logical_or + ) as result: + with self.cached_session() as sess: + self.assertTrue(sess.run(result.test_fn(True, False))) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 803fde9089..a4c6fed265 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -38,9 +38,6 @@ class ApiTest(test.TestCase): def setUp(self): config.COMPILED_IMPORT_STATEMENTS = ( 'from __future__ import print_function', - 'from tensorflow.contrib.autograph import utils' - ' as autograph_utils', - 'tf = autograph_utils.fake_tf()', ) def test_decorator_recurses(self): |