aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/autograph/converters/call_trees_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/autograph/converters/call_trees_test.py')
-rw-r--r--tensorflow/contrib/autograph/converters/call_trees_test.py76
1 files changed, 31 insertions, 45 deletions
diff --git a/tensorflow/contrib/autograph/converters/call_trees_test.py b/tensorflow/contrib/autograph/converters/call_trees_test.py
index 27d8281b85..8cdba659ee 100644
--- a/tensorflow/contrib/autograph/converters/call_trees_test.py
+++ b/tensorflow/contrib/autograph/converters/call_trees_test.py
@@ -36,37 +36,34 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn_1(_):
raise ValueError('This should not be called in the compiled version.')
- def renamed_test_fn_1(a):
+ def other_test_fn_1(a):
return a + 1
def test_fn_2(a):
return test_fn_1(a) + 1
- node = self.parse_and_analyze(test_fn_2, {'test_fn_1': test_fn_1})
- node = call_trees.transform(node, self.ctx)
+ ns = {'test_fn_1': test_fn_1}
+ node, ctx = self.prepare(test_fn_2, ns)
+ node = call_trees.transform(node, ctx)
- with self.compiled(node) as result:
- # Only test_fn_2 is transformed, so we'll insert renamed_test_fn_1
- # manually.
- result.renamed_test_fn_1 = renamed_test_fn_1
- self.assertEquals(3, result.test_fn_2(1))
+ with self.compiled(node, ns) as result:
+ new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',))
+ setattr(result, new_name, other_test_fn_1)
+ self.assertEquals(result.test_fn_2(1), 3)
def test_dynamic_function(self):
def test_fn_1():
- raise ValueError('This should be masked by the mock.')
+ raise ValueError('This should be masked by the mock in self.compiled.')
def test_fn_2(f):
return f() + 3
- node = self.parse_and_analyze(test_fn_2, {})
- node = call_trees.transform(node, self.ctx)
-
- with self.compiled(node) as result:
+ with self.converted(test_fn_2, call_trees, {}) as result:
# 10 = 7 (from the mock) + 3 (from test_fn_2)
self.assertEquals(10, result.test_fn_2(test_fn_1))
- def test_simple_methods(self):
+ def test_basic_method(self):
class TestClass(object):
@@ -76,49 +73,43 @@ class CallTreesTest(converter_testing.TestCase):
def test_fn_2(self, a):
return self.test_fn_1(a) + 1
- node = self.parse_and_analyze(
- TestClass.test_fn_2, {'TestClass': TestClass},
+ ns = {'TestClass': TestClass}
+ node, ctx = self.prepare(
+ TestClass.test_fn_2,
+ ns,
namer=converter_testing.FakeNoRenameNamer(),
arg_types={'self': (TestClass.__name__, TestClass)})
- node = call_trees.transform(node, self.ctx)
+ node = call_trees.transform(node, ctx)
- with self.compiled(node) as result:
+ with self.compiled(node, ns) as result:
tc = TestClass()
self.assertEquals(3, result.test_fn_2(tc, 1))
- def test_py_func_wrap_no_retval(self):
+ def test_py_func_no_retval(self):
def test_fn(a):
setattr(a, 'foo', 'bar')
- node = self.parse_and_analyze(test_fn, {'setattr': setattr})
- node = call_trees.transform(node, self.ctx)
-
- with self.compiled(node) as result:
+ with self.converted(test_fn, call_trees, {'setattr': setattr}) as result:
with self.test_session() as sess:
- # The function has no return value, so we do some tricks to grab the
- # generated py_func node and ensure its effect only happens at graph
- # execution.
class Dummy(object):
pass
a = Dummy()
result.test_fn(a)
+ py_func_op, = sess.graph.get_operations()
self.assertFalse(hasattr(a, 'foo'))
- sess.run(sess.graph.get_operations()[0])
+ sess.run(py_func_op)
self.assertEquals('bar', a.foo)
- def test_py_func_wrap_known_function(self):
+ def test_py_func_known_function(self):
def test_fn():
return np.random.binomial(2, 0.5)
- node = self.parse_and_analyze(test_fn, {'np': np})
- node = call_trees.transform(node, self.ctx)
-
- with self.compiled(node, dtypes.int64) as result:
- result.np = np
+ with self.converted(test_fn, call_trees, {'np': np},
+ dtypes.int64) as result:
with self.test_session() as sess:
self.assertTrue(isinstance(result.test_fn(), ops.Tensor))
self.assertIn(sess.run(result.test_fn()), (0, 1, 2))
@@ -130,22 +121,17 @@ class CallTreesTest(converter_testing.TestCase):
a = math_ops.add(a, constant_op.constant(1))
return a
- node = self.parse_and_analyze(
- test_fn, {
- 'math_ops': math_ops,
- 'constant_op': constant_op
- },
+ ns = {'math_ops': math_ops, 'constant_op': constant_op}
+ node, ctx = self.prepare(
+ test_fn,
+ ns,
arg_types=set(((math_ops.__name__,), (constant_op.__name__,))))
- node = call_trees.transform(node, self.ctx)
+ node = call_trees.transform(node, ctx)
- with self.compiled(node) as result:
- result.math_ops = math_ops
- result.constant_op = constant_op
+ with self.compiled(node, ns) as result:
with self.test_session() as sess:
- # Not renamed, because the converter doesn't rename the definition
- # itself (the caller is responsible for that).
result_tensor = result.test_fn(constant_op.constant(1))
- self.assertEquals(3, sess.run(result_tensor))
+ self.assertEquals(sess.run(result_tensor), 3)
if __name__ == '__main__':