aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/autograph/integration_tests/BUILD53
-rw-r--r--tensorflow/examples/autograph/integration_tests/errors_test.py152
-rw-r--r--tensorflow/examples/autograph/integration_tests/keras_test.py103
-rw-r--r--tensorflow/examples/autograph/integration_tests/list_literals_test.py41
4 files changed, 349 insertions, 0 deletions
diff --git a/tensorflow/examples/autograph/integration_tests/BUILD b/tensorflow/examples/autograph/integration_tests/BUILD
new file mode 100644
index 0000000000..3630b41fc8
--- /dev/null
+++ b/tensorflow/examples/autograph/integration_tests/BUILD
@@ -0,0 +1,53 @@
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_test(
+ name = "errors_test",
+ srcs = [
+ "errors_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "keras_test",
+ srcs = [
+ "keras_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "list_literals_test",
+ srcs = [
+ "list_literals_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = ["no_windows"],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/examples/autograph/integration_tests/errors_test.py b/tensorflow/examples/autograph/integration_tests/errors_test.py
new file mode 100644
index 0000000000..69e5936832
--- /dev/null
+++ b/tensorflow/examples/autograph/integration_tests/errors_test.py
@@ -0,0 +1,152 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Error traceback rewriting integration tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python import autograph as ag
+
+
+class ErrorsTest(tf.test.TestCase):
+
+ def test_graph_construction_error_rewriting_call_tree(self):
+
+ def test_fn():
+ return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
+
+ def inner_caller():
+ return test_fn()
+
+ def caller():
+ return inner_caller()
+
+ with self.assertRaises(ag.GraphConstructionError) as error:
+ graph = ag.to_graph(caller)
+ graph()
+ expected = error.exception
+ custom_traceback = expected.custom_traceback
+ found_correct_filename = False
+ num_test_fn_names = 0
+ num_inner_caller_names = 0
+ num_caller_names = 0
+ for frame in custom_traceback:
+ filename, _, fn_name, _ = frame
+ self.assertFalse('/tmp/' in filename)
+ found_correct_filename |= __file__ in filename
+ self.assertNotEqual('tf__test_fn', fn_name)
+ num_test_fn_names += int('test_fn' == fn_name)
+ self.assertNotEqual('tf__inner_caller', fn_name)
+ num_inner_caller_names += int('inner_caller' == fn_name)
+ self.assertNotEqual('tf__caller', fn_name)
+ num_caller_names += int('caller' == fn_name)
+ self.assertTrue(found_correct_filename)
+ self.assertEqual(num_test_fn_names, 1)
+ self.assertEqual(num_inner_caller_names, 1)
+ self.assertEqual(num_caller_names, 1)
+
+ def test_graph_construction_error_rewriting_class(self):
+
+ class TestClass(object):
+
+ def test_fn(self):
+ return tf.random_normal((2, 3), mean=0.0, dtype=tf.int32)
+
+ def inner_caller(self):
+ return self.test_fn()
+
+ def caller(self):
+ return self.inner_caller()
+
+ # Note we expect a TypeError here because the traceback will not be
+ # rewritten for classes.
+ with self.assertRaises(TypeError):
+ graph = ag.to_graph(TestClass)
+ graph().caller()
+
+ def test_runtime_error_rewriting(self):
+
+ def g(x, s):
+ while tf.reduce_sum(x) > s:
+ x //= 0
+ return x
+
+ def test_fn(x):
+ return g(x, 10)
+
+ compiled_fn = ag.to_graph(test_fn)
+
+ with self.assertRaises(ag.TfRuntimeError) as error:
+ with self.test_session() as sess:
+ x = compiled_fn(tf.constant([4, 8]))
+ with ag.improved_errors(compiled_fn):
+ sess.run(x)
+ expected = error.exception
+ custom_traceback = expected.custom_traceback
+ found_correct_filename = False
+ num_test_fn_frames = 0
+ num_g_frames = 0
+ for frame in custom_traceback:
+ filename, _, fn_name, source_code = frame
+ self.assertFalse('/tmp/' in filename)
+ self.assertFalse('control_flow.py' in filename)
+ self.assertFalse('ag__.' in fn_name)
+ found_correct_filename |= __file__ in filename
+ num_test_fn_frames += int('test_fn' == fn_name and
+ 'return g(x, 10)' in source_code)
+ num_g_frames += int('g' == fn_name and 'x //= 0' in source_code)
+ self.assertTrue(found_correct_filename)
+ self.assertEqual(num_test_fn_frames, 1)
+ self.assertEqual(num_g_frames, 1)
+
+ def test_runtime_error_rewriting_nested(self):
+
+ def test_fn(x):
+
+ def g(y):
+ return y**2 // 0
+
+ s = 0
+ for xi in x:
+ s += g(xi)
+ return s
+
+ compiled_fn = ag.to_graph(test_fn)
+
+ # TODO(b/111408261): Nested functions currently do not rewrite correctly,
+ # when they do we should change this test to check for the same traceback
+ # properties as the other tests. This should throw a runtime error with a
+ # frame with "g" as the function name but because we don't yet add
+ # try/except blocks to inner functions the name is "tf__g".
+ with self.assertRaises(ag.TfRuntimeError) as error:
+ with self.test_session() as sess:
+ x = compiled_fn(tf.constant([4, 8]))
+ with ag.improved_errors(compiled_fn):
+ sess.run(x)
+ expected = error.exception
+ custom_traceback = expected.custom_traceback
+ num_tf_g_frames = 0
+ for frame in custom_traceback:
+ _, _, fn_name, _ = frame
+ self.assertNotEqual('g', fn_name)
+ num_tf_g_frames += int('tf__g' == fn_name)
+ self.assertEqual(num_tf_g_frames, 1)
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/examples/autograph/integration_tests/keras_test.py b/tensorflow/examples/autograph/integration_tests/keras_test.py
new file mode 100644
index 0000000000..dca7c07b47
--- /dev/null
+++ b/tensorflow/examples/autograph/integration_tests/keras_test.py
@@ -0,0 +1,103 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Keras integration tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python import autograph
+
+
+class MinimalKeras(tf.keras.Model):
+
+ def call(self, x):
+ return x * 3
+
+
+class ModelWithStaticConditional(object):
+
+ def __init__(self, initial):
+ self.initial = initial
+ if self.initial:
+ self.h = 15
+
+ @autograph.convert()
+ def call(self):
+ x = 10
+ if self.initial:
+ x += self.h
+ return x
+
+
+class BasicBlock(tf.keras.Model):
+
+ def __init__(self):
+ super(BasicBlock, self).__init__()
+ self.conv1 = tf.keras.layers.Conv2D(8, 3)
+ self.pool = tf.keras.layers.GlobalAveragePooling2D()
+ self.dense = tf.keras.layers.Dense(3)
+
+ def call(self, x):
+ x = self.conv1(x)
+ x = self.pool(x)
+ x = self.dense(x)
+ return x
+
+
+class CompoundModel(tf.keras.Model):
+
+ def __init__(self):
+ super(CompoundModel, self).__init__()
+ self.block = BasicBlock()
+
+ @autograph.convert(recursive=True)
+ def call(self, x):
+ x = self.block(x) # pylint: disable=not-callable
+ return x
+
+
+class KerasTest(tf.test.TestCase):
+
+ def test_basic(self):
+ MinimalKeras()
+
+ def test_conditional_attributes_False(self):
+ model = ModelWithStaticConditional(False)
+ self.assertEqual(model.call(), 10)
+
+ def test_conditional_attributes_True(self):
+ model = ModelWithStaticConditional(True)
+ self.assertEqual(model.call(), 25)
+
+ def test_recursive_true(self):
+ with self.assertRaisesRegexp(NotImplementedError,
+ 'Object conversion is not yet supported.'):
+ with tf.Graph().as_default():
+ model = CompoundModel()
+ model.build(tf.TensorShape((None, 10, 10, 1)))
+ init = tf.global_variables_initializer()
+
+ with tf.Session() as sess:
+ sess.run(init)
+ sample_input = tf.random_uniform((1, 10, 10, 1))
+ output = model(sample_input) # pylint: disable=not-callable
+ self.assertEqual(sess.run(output).shape, (1, 3))
+
+
+if __name__ == '__main__':
+ tf.test.main()
diff --git a/tensorflow/examples/autograph/integration_tests/list_literals_test.py b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
new file mode 100644
index 0000000000..917f5ff9d8
--- /dev/null
+++ b/tensorflow/examples/autograph/integration_tests/list_literals_test.py
@@ -0,0 +1,41 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests of functions that use list literals."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python import autograph as ag
+
+
+def list_used_as_tuple():
+ return tf.constant([1, 2, 3])
+
+
+class ListLiteralsTest(tf.test.TestCase):
+
+ def test_basic(self):
+ converted = ag.to_graph(list_used_as_tuple)
+ result = converted()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(result), [1, 2, 3])
+
+
+if __name__ == '__main__':
+ tf.test.main()