aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/api/tests/api_compatibility_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/api/tests/api_compatibility_test.py')
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py42
1 files changed, 32 insertions, 10 deletions
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 90375a794f..d1b34fb242 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,13 @@ import sys
import unittest
import tensorflow as tf
+# pylint: disable=g-import-not-at-top
+try:
+ from tensorflow.compat import v1 as tf_v1
+ # We import compat.v1 as tf_v1 instead.
+ del tf.compat.v1
+except ImportError:
+ tf_v1 = None
from google.protobuf import message
from google.protobuf import text_format
@@ -46,6 +53,7 @@ from tensorflow.tools.api.lib import api_objects_pb2
from tensorflow.tools.api.lib import python_object_to_proto_visitor
from tensorflow.tools.common import public_api
from tensorflow.tools.common import traverse
+# pylint: enable=g-import-not-at-top
# FLAGS defined at the bottom:
@@ -215,25 +223,19 @@ class ApiCompatibilityTest(test.TestCase):
visitor.do_not_descend_map['tf'].append('contrib')
traverse.traverse(tf, visitor)
- @unittest.skipUnless(
- sys.version_info.major == 2,
- 'API compabitility test goldens are generated using python2.')
- def testAPIBackwardsCompatibility(self):
- # Extract all API stuff.
+ def checkBackwardsCompatibility(self, root, golden_file_pattern):
+ # Extract all API stuff.
visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor()
public_api_visitor = public_api.PublicAPIVisitor(visitor)
public_api_visitor.do_not_descend_map['tf'].append('contrib')
public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental']
- traverse.traverse(tf, public_api_visitor)
+ traverse.traverse(root, public_api_visitor)
proto_dict = visitor.GetProtos()
# Read all golden files.
- expression = os.path.join(
- resource_loader.get_root_dir_with_all_resources(),
- _KeyToFilePath('*'))
- golden_file_list = file_io.get_matching_files(expression)
+ golden_file_list = file_io.get_matching_files(golden_file_pattern)
def _ReadFileToProto(filename):
"""Read a filename, create a protobuf from its contents."""
@@ -254,6 +256,26 @@ class ApiCompatibilityTest(test.TestCase):
verbose=FLAGS.verbose_diffs,
update_goldens=FLAGS.update_goldens)
+ @unittest.skipUnless(
+ sys.version_info.major == 2,
+ 'API compabitility test goldens are generated using python2.')
+ def testAPIBackwardsCompatibility(self):
+ golden_file_pattern = os.path.join(
+ resource_loader.get_root_dir_with_all_resources(),
+ _KeyToFilePath('*'))
+ self.checkBackwardsCompatibility(tf, golden_file_pattern)
+
+ @unittest.skipUnless(
+ sys.version_info.major == 2,
+ 'API compabitility test goldens are generated using python2.')
+ def testAPIBackwardsCompatibilityV1(self):
+ if not tf_v1:
+ return
+ golden_file_pattern = os.path.join(
+ resource_loader.get_root_dir_with_all_resources(),
+ _KeyToFilePath('*'))
+ self.checkBackwardsCompatibility(tf_v1, golden_file_pattern)
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()