diff options
author | 2018-08-02 11:41:41 -0700 | |
---|---|---|
committer | 2018-08-02 11:45:47 -0700 | |
commit | 1a13c4f2a0b4491ae3003ff0a400d5d8cb521c4a (patch) | |
tree | 7498e9a523ddc0f013a1aa270420b17e3427460e /tensorflow/contrib/autograph | |
parent | 61763fdd8e20bf1541dc12363d44318f81e06955 (diff) |
Force conversion of directly decorated functions. Ignore the whitelist in this case.
PiperOrigin-RevId: 207137374
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r-- | tensorflow/contrib/autograph/converters/call_trees.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/impl/api.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/autograph/impl/api_test.py | 21 |
3 files changed, 17 insertions, 15 deletions
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py index a36b3d77a9..2d1bed3367 100644 --- a/tensorflow/contrib/autograph/converters/call_trees.py +++ b/tensorflow/contrib/autograph/converters/call_trees.py @@ -238,7 +238,7 @@ class CallTreeTransformer(converter.Base): # Before we could convert all the time though, we'd need a reasonable # caching mechanism. template = """ - ag__.converted_call(func, True, False, {}, args) + ag__.converted_call(func, True, False, False, {}, args) """ call_expr = templates.replace(template, func=node.func, args=node.args) new_call = call_expr[0].value diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py index 0adff76a9f..4729c735c6 100644 --- a/tensorflow/contrib/autograph/impl/api.py +++ b/tensorflow/contrib/autograph/impl/api.py @@ -68,7 +68,8 @@ def convert(recursive=False, verbose=False, arg_types=None): @wraps(f) def wrapper(*args, **kwargs): - return converted_call(f, recursive, verbose, arg_types, *args, **kwargs) + return converted_call(f, recursive, verbose, True, arg_types, *args, + **kwargs) wrapper = tf_decorator.make_decorator(f, wrapper) @@ -129,12 +130,12 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None): return decorator -def converted_call(f, recursive, verbose, arg_types, *args, **kwargs): +def converted_call(f, recursive, verbose, force_conversion, arg_types, *args, + **kwargs): """Compiles a function call inline.""" # TODO(mdan): This needs cleanup. # In particular, we may want to avoid renaming functions altogether. - - if conversion.is_whitelisted_for_graph(f): + if not force_conversion and conversion.is_whitelisted_for_graph(f): return f(*args, **kwargs) unknown_arg_value = object() # Sentinel for arguments of unknown value diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py index 754baa87b0..803fde9089 100644 --- a/tensorflow/contrib/autograph/impl/api_test.py +++ b/tensorflow/contrib/autograph/impl/api_test.py @@ -183,8 +183,8 @@ class ApiTest(test.TestCase): @api.convert(recursive=True) def test_method(self, x, s, a): while tf.reduce_sum(x) > s: - x //= api.converted_call(self.called_member, False, False, {}, self, - a) + x //= api.converted_call(self.called_member, False, False, False, {}, + self, a) return x tc = TestClass() @@ -195,7 +195,7 @@ class ApiTest(test.TestCase): self.assertListEqual([0, 1], sess.run(x).tolist()) def test_converted_call_builtin(self): - x = api.converted_call(range, False, False, {}, 3) + x = api.converted_call(range, False, False, False, {}, 3) self.assertEqual((0, 1, 2), tuple(x)) def test_converted_call_function(self): @@ -206,7 +206,7 @@ class ApiTest(test.TestCase): return x with self.test_session() as sess: - x = api.converted_call(test_fn, False, False, {}, + x = api.converted_call(test_fn, False, False, False, {}, constant_op.constant(-1)) self.assertEqual(1, sess.run(x)) @@ -224,7 +224,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc.test_method, False, False, {}, tc) + x = api.converted_call(tc.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_method_by_class(self): @@ -241,7 +241,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(TestClass.test_method, False, False, {}, tc) + x = api.converted_call(TestClass.test_method, False, False, False, {}, tc) self.assertEqual(1, sess.run(x)) def test_converted_call_callable_object(self): @@ -258,7 +258,7 @@ class ApiTest(test.TestCase): with self.test_session() as sess: tc = TestClass(constant_op.constant(-1)) - x = api.converted_call(tc, False, False, {}) + x = api.converted_call(tc, False, False, False, {}) self.assertEqual(1, sess.run(x)) def test_converted_call_constructor(self): @@ -274,7 +274,7 @@ class ApiTest(test.TestCase): return self.x with self.test_session() as sess: - tc = api.converted_call(TestClass, False, False, {}, + tc = api.converted_call(TestClass, False, False, False, {}, constant_op.constant(-1)) # tc is now a converted object. x = tc.test_method() @@ -286,11 +286,12 @@ class ApiTest(test.TestCase): return x == 0 with self.test_session() as sess: - x = api.converted_call(f, False, False, {}, constant_op.constant(0)) + x = api.converted_call(f, False, False, False, {}, + constant_op.constant(0)) self.assertTrue(sess.run(x)) converted_f = api.to_graph(f) - x = api.converted_call(converted_f, False, False, {}, + x = api.converted_call(converted_f, False, False, False, {}, constant_op.constant(0)) self.assertTrue(sess.run(x)) |