aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/autograph/impl/api_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/autograph/impl/api_test.py')
-rw-r--r--tensorflow/python/autograph/impl/api_test.py329
1 files changed, 329 insertions, 0 deletions
diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py
new file mode 100644
index 0000000000..54e12f0223
--- /dev/null
+++ b/tensorflow/python/autograph/impl/api_test.py
@@ -0,0 +1,329 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for api module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.autograph import utils
+from tensorflow.python.autograph.core import config
+from tensorflow.python.autograph.impl import api
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.autograph.utils import py_func
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+from tensorflow.python.util import tf_inspect
+
+
+tf = utils.fake_tf()
+
+
+class ApiTest(test.TestCase):
+
+ def setUp(self):
+ config.COMPILED_IMPORT_STATEMENTS = (
+ 'from __future__ import print_function',
+ )
+
+ def test_decorator_recurses(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_does_not_recurse(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ return tf.negative(a)
+
+ @api.convert(recursive=False)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_unconverted_graph(self):
+
+ class TestClass(object):
+
+ @api.do_not_convert(api.RunMode.GRAPH)
+ def called_member(self, a):
+ return tf.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_unconverted_py_func(self):
+
+ class TestClass(object):
+
+ @api.do_not_convert(
+ api.RunMode.PY_FUNC, return_dtypes=py_func.MatchDType(1))
+ def called_member(self, a):
+ return np.negative(a)
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ y = self.called_member(a)
+ # set_shape works around while_loop's limitations.
+ # TODO(mdan): Allow specifying shapes (or ShapeLike) instead.
+ y.set_shape(a.shape)
+ x //= y
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_calls_decorated(self):
+
+ class TestClass(object):
+
+ @api.convert()
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= self.called_member(a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_decorator_preserves_argspec(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ called_member_converted = api.convert()(called_member)
+
+ tc = TestClass()
+ self.assertListEqual(
+ list(tf_inspect.getfullargspec(tc.called_member)),
+ list(tf_inspect.getfullargspec(tc.called_member_converted)))
+
+ def test_convert_call_site_decorator(self):
+
+ class TestClass(object):
+
+ def called_member(self, a):
+ if a < 0:
+ a = -a
+ return a
+
+ @api.convert(recursive=True)
+ def test_method(self, x, s, a):
+ while tf.reduce_sum(x) > s:
+ x //= api.converted_call(self.called_member, False, False, False, {},
+ self, a)
+ return x
+
+ tc = TestClass()
+ with self.test_session() as sess:
+ x = tc.test_method(
+ constant_op.constant([2, 4]), constant_op.constant(1),
+ constant_op.constant(-2))
+ self.assertListEqual([0, 1], sess.run(x).tolist())
+
+ def test_converted_call_builtin(self):
+ x = api.converted_call(range, False, False, False, {}, 3)
+ self.assertEqual((0, 1, 2), tuple(x))
+
+ def test_converted_call_function(self):
+
+ def test_fn(x):
+ if x < 0:
+ return -x
+ return x
+
+ with self.test_session() as sess:
+ x = api.converted_call(test_fn, False, False, False, {},
+ constant_op.constant(-1))
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_method(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc.test_method, False, False, False, {}, tc)
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_method_by_class(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(TestClass.test_method, False, False, False, {}, tc)
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_callable_object(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def __call__(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = TestClass(constant_op.constant(-1))
+ x = api.converted_call(tc, False, False, False, {})
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_constructor(self):
+
+ class TestClass(object):
+
+ def __init__(self, x):
+ self.x = x
+
+ def test_method(self):
+ if self.x < 0:
+ return -self.x
+ return self.x
+
+ with self.test_session() as sess:
+ tc = api.converted_call(TestClass, False, False, False, {},
+ constant_op.constant(-1))
+ # tc is now a converted object.
+ x = tc.test_method()
+ self.assertEqual(1, sess.run(x))
+
+ def test_converted_call_already_converted(self):
+
+ def f(x):
+ return x == 0
+
+ with self.test_session() as sess:
+ x = api.converted_call(f, False, False, False, {},
+ constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
+ converted_f = api.to_graph(f)
+ x = api.converted_call(converted_f, False, False, False, {},
+ constant_op.constant(0))
+ self.assertTrue(sess.run(x))
+
+ def test_to_graph_basic(self):
+
+ def test_fn(x, s):
+ while tf.reduce_sum(x) > s:
+ x //= 2
+ return x
+
+ compiled_fn = api.to_graph(test_fn)
+
+ with self.test_session() as sess:
+ x = compiled_fn(constant_op.constant([4, 8]), 4)
+ self.assertListEqual([1, 2], sess.run(x).tolist())
+
+ def test_to_code_basic(self):
+
+ def test_fn(x, s):
+ while tf.reduce_sum(x) > s:
+ x /= 2
+ return x
+
+ compiled_code = api.to_code(test_fn)
+
+ # Just check that it is parseable Python code.
+ self.assertIsNotNone(parser.parse_str(compiled_code))
+
+ def test_source_map_attribute_present(self):
+
+ def test_fn(y):
+ return y**2
+
+ self.assertTrue(hasattr(api.to_graph(test_fn), 'ag_source_map'))
+
+
+if __name__ == '__main__':
+ test.main()