diff options
Diffstat (limited to 'tensorflow/python/autograph/converters/call_trees_test.py')
-rw-r--r-- | tensorflow/python/autograph/converters/call_trees_test.py | 138 |
1 files changed, 138 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/converters/call_trees_test.py b/tensorflow/python/autograph/converters/call_trees_test.py new file mode 100644 index 0000000000..0e50f42c6a --- /dev/null +++ b/tensorflow/python/autograph/converters/call_trees_test.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for call_trees module.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.autograph.converters import call_trees +from tensorflow.python.autograph.core import converter_testing +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class CallTreesTest(converter_testing.TestCase): + + def test_basic(self): + + def test_fn_1(_): + raise ValueError('This should not be called in the compiled version.') + + def other_test_fn_1(a): + return a + 1 + + def test_fn_2(a): + return test_fn_1(a) + 1 + + ns = {'test_fn_1': test_fn_1} + node, ctx = self.prepare(test_fn_2, ns) + node = call_trees.transform(node, ctx) + + with self.compiled(node, ns) as result: + new_name, _ = ctx.namer.compiled_function_name(('test_fn_1',)) + setattr(result, new_name, other_test_fn_1) + self.assertEquals(result.test_fn_2(1), 3) + + def test_dynamic_function(self): + + def test_fn_1(): + raise ValueError('This should be masked by the mock in self.compiled.') + + def test_fn_2(f): + return f() + 3 + + with self.converted(test_fn_2, call_trees, {}) as result: + # 10 = 7 (from the mock) + 3 (from test_fn_2) + self.assertEquals(10, result.test_fn_2(test_fn_1)) + + def test_basic_method(self): + + class TestClass(object): + + def test_fn_1(self, a): + return a + 1 + + def test_fn_2(self, a): + return self.test_fn_1(a) + 1 + + ns = {'TestClass': TestClass} + node, ctx = self.prepare( + TestClass.test_fn_2, + ns, + namer=converter_testing.FakeNoRenameNamer(), + arg_types={'self': (TestClass.__name__, TestClass)}) + node = call_trees.transform(node, ctx) + + with self.compiled(node, ns) as result: + tc = TestClass() + self.assertEquals(3, result.test_fn_2(tc, 1)) + + def test_py_func_no_retval(self): + + def test_fn(a): + setattr(a, 'foo', 'bar') + + with self.converted(test_fn, call_trees, {'setattr': setattr}) as result: + with self.cached_session() as sess: + + class Dummy(object): + pass + + a = Dummy() + result.test_fn(a) + py_func_op, = sess.graph.get_operations() + self.assertFalse(hasattr(a, 'foo')) + sess.run(py_func_op) + self.assertEquals('bar', a.foo) + + def test_py_func_known_function(self): + + def test_fn(): + return np.random.binomial(2, 0.5) + + with self.converted(test_fn, call_trees, {'np': np}, + dtypes.int64) as result: + with self.cached_session() as sess: + self.assertTrue(isinstance(result.test_fn(), ops.Tensor)) + self.assertIn(sess.run(result.test_fn()), (0, 1, 2)) + + def test_uncompiled_modules(self): + + def test_fn(a): + a = math_ops.multiply(a, constant_op.constant(2)) + a = math_ops.add(a, constant_op.constant(1)) + return a + + ns = {'math_ops': math_ops, 'constant_op': constant_op} + node, ctx = self.prepare( + test_fn, + ns, + arg_types=set(((math_ops.__name__,), (constant_op.__name__,)))) + node = call_trees.transform(node, ctx) + + with self.compiled(node, ns) as result: + with self.cached_session() as sess: + result_tensor = result.test_fn(constant_op.constant(1)) + self.assertEquals(sess.run(result_tensor), 3) + + +if __name__ == '__main__': + test.main() |