aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-02 11:41:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-02 11:45:47 -0700
commit1a13c4f2a0b4491ae3003ff0a400d5d8cb521c4a (patch)
tree7498e9a523ddc0f013a1aa270420b17e3427460e /tensorflow/contrib/autograph
parent61763fdd8e20bf1541dc12363d44318f81e06955 (diff)
Force conversion of directly decorated functions. Ignore the whitelist in this case.
PiperOrigin-RevId: 207137374
Diffstat (limited to 'tensorflow/contrib/autograph')
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees.py2
-rw-r--r--tensorflow/contrib/autograph/impl/api.py9
-rw-r--r--tensorflow/contrib/autograph/impl/api_test.py21
3 files changed, 17 insertions, 15 deletions
diff --git a/tensorflow/contrib/autograph/converters/call_trees.py b/tensorflow/contrib/autograph/converters/call_trees.py
index a36b3d77a9..2d1bed3367 100644
--- a/tensorflow/contrib/autograph/converters/call_trees.py
+++ b/tensorflow/contrib/autograph/converters/call_trees.py
@@ -238,7 +238,7 @@ class CallTreeTransformer(converter.Base):
# Before we could convert all the time though, we'd need a reasonable
# caching mechanism.
template = """
- ag__.converted_call(func, True, False, {}, args)
+ ag__.converted_call(func, True, False, False, {}, args)
"""
call_expr = templates.replace(template, func=node.func, args=node.args)
new_call = call_expr[0].value
diff --git a/tensorflow/contrib/autograph/impl/api.py b/tensorflow/contrib/autograph/impl/api.py
index 0adff76a9f..4729c735c6 100644
--- a/tensorflow/contrib/autograph/impl/api.py
+++ b/tensorflow/contrib/autograph/impl/api.py
@@ -68,7 +68,8 @@ def convert(recursive=False, verbose=False, arg_types=None):
@wraps(f)
def wrapper(*args, **kwargs):
- return converted_call(f, recursive, verbose, arg_types, *args, **kwargs)
+ return converted_call(f, recursive, verbose, True, arg_types, *args,
+ **kwargs)
wrapper = tf_decorator.make_decorator(f, wrapper)
@@ -129,12 +130,12 @@ def do_not_convert(run_as=RunMode.GRAPH, return_dtypes=None):
return decorator
-def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
+def converted_call(f, recursive, verbose, force_conversion, arg_types, *args,
+ **kwargs):
"""Compiles a function call inline."""
# TODO(mdan): This needs cleanup.
# In particular, we may want to avoid renaming functions altogether.
-
- if conversion.is_whitelisted_for_graph(f):
+ if not force_conversion and conversion.is_whitelisted_for_graph(f):
return f(*args, **kwargs)
unknown_arg_value = object() # Sentinel for arguments of unknown value
diff --git a/tensorflow/contrib/autograph/impl/api_test.py b/tensorflow/contrib/autograph/impl/api_test.py
index 754baa87b0..803fde9089 100644
--- a/tensorflow/contrib/autograph/impl/api_test.py
+++ b/tensorflow/contrib/autograph/impl/api_test.py
@@ -183,8 +183,8 @@ class ApiTest(test.TestCase):
@api.convert(recursive=True)
def test_method(self, x, s, a):
while tf.reduce_sum(x) > s:
- x //= api.converted_call(self.called_member, False, False, {}, self,
- a)
+ x //= api.converted_call(self.called_member, False, False, False, {},
+ self, a)
return x
tc = TestClass()
@@ -195,7 +195,7 @@ class ApiTest(test.TestCase):
self.assertListEqual([0, 1], sess.run(x).tolist())
def test_converted_call_builtin(self):
- x = api.converted_call(range, False, False, {}, 3)
+ x = api.converted_call(range, False, False, False, {}, 3)
self.assertEqual((0, 1, 2), tuple(x))
def test_converted_call_function(self):
@@ -206,7 +206,7 @@ class ApiTest(test.TestCase):
return x
with self.test_session() as sess:
- x = api.converted_call(test_fn, False, False, {},
+ x = api.converted_call(test_fn, False, False, False, {},
constant_op.constant(-1))
self.assertEqual(1, sess.run(x))
@@ -224,7 +224,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc.test_method, False, False, {}, tc)
+ x = api.converted_call(tc.test_method, False, False, False, {}, tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_method_by_class(self):
@@ -241,7 +241,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(TestClass.test_method, False, False, {}, tc)
+ x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
self.assertEqual(1, sess.run(x))
def test_converted_call_callable_object(self):
@@ -258,7 +258,7 @@ class ApiTest(test.TestCase):
with self.test_session() as sess:
tc = TestClass(constant_op.constant(-1))
- x = api.converted_call(tc, False, False, {})
+ x = api.converted_call(tc, False, False, False, {})
self.assertEqual(1, sess.run(x))
def test_converted_call_constructor(self):
@@ -274,7 +274,7 @@ class ApiTest(test.TestCase):
return self.x
with self.test_session() as sess:
- tc = api.converted_call(TestClass, False, False, {},
+ tc = api.converted_call(TestClass, False, False, False, {},
constant_op.constant(-1))
# tc is now a converted object.
x = tc.test_method()
@@ -286,11 +286,12 @@ class ApiTest(test.TestCase):
return x == 0
with self.test_session() as sess:
- x = api.converted_call(f, False, False, {}, constant_op.constant(0))
+ x = api.converted_call(f, False, False, False, {},
+ constant_op.constant(0))
self.assertTrue(sess.run(x))
converted_f = api.to_graph(f)
- x = api.converted_call(converted_f, False, False, {},
+ x = api.converted_call(converted_f, False, False, False, {},
constant_op.constant(0))
self.assertTrue(sess.run(x))