diff options
Diffstat (limited to 'tensorflow/tools/api/tests/api_compatibility_test.py')
-rw-r--r-- | tensorflow/tools/api/tests/api_compatibility_test.py | 42 |
1 files changed, 32 insertions, 10 deletions
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 90375a794f..d1b34fb242 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,13 @@ import sys import unittest import tensorflow as tf +# pylint: disable=g-import-not-at-top +try: + from tensorflow.compat import v1 as tf_v1 + # We import compat.v1 as tf_v1 instead. + del tf.compat.v1 +except ImportError: + tf_v1 = None from google.protobuf import message from google.protobuf import text_format @@ -46,6 +53,7 @@ from tensorflow.tools.api.lib import api_objects_pb2 from tensorflow.tools.api.lib import python_object_to_proto_visitor from tensorflow.tools.common import public_api from tensorflow.tools.common import traverse +# pylint: enable=g-import-not-at-top # FLAGS defined at the bottom: @@ -215,25 +223,19 @@ class ApiCompatibilityTest(test.TestCase): visitor.do_not_descend_map['tf'].append('contrib') traverse.traverse(tf, visitor) - @unittest.skipUnless( - sys.version_info.major == 2, - 'API compabitility test goldens are generated using python2.') - def testAPIBackwardsCompatibility(self): - # Extract all API stuff. + def checkBackwardsCompatibility(self, root, golden_file_pattern): + # Extract all API stuff. visitor = python_object_to_proto_visitor.PythonObjectToProtoVisitor() public_api_visitor = public_api.PublicAPIVisitor(visitor) public_api_visitor.do_not_descend_map['tf'].append('contrib') public_api_visitor.do_not_descend_map['tf.GPUOptions'] = ['Experimental'] - traverse.traverse(tf, public_api_visitor) + traverse.traverse(root, public_api_visitor) proto_dict = visitor.GetProtos() # Read all golden files. - expression = os.path.join( - resource_loader.get_root_dir_with_all_resources(), - _KeyToFilePath('*')) - golden_file_list = file_io.get_matching_files(expression) + golden_file_list = file_io.get_matching_files(golden_file_pattern) def _ReadFileToProto(filename): """Read a filename, create a protobuf from its contents.""" @@ -254,6 +256,26 @@ class ApiCompatibilityTest(test.TestCase): verbose=FLAGS.verbose_diffs, update_goldens=FLAGS.update_goldens) + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testAPIBackwardsCompatibility(self): + golden_file_pattern = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*')) + self.checkBackwardsCompatibility(tf, golden_file_pattern) + + @unittest.skipUnless( + sys.version_info.major == 2, + 'API compabitility test goldens are generated using python2.') + def testAPIBackwardsCompatibilityV1(self): + if not tf_v1: + return + golden_file_pattern = os.path.join( + resource_loader.get_root_dir_with_all_resources(), + _KeyToFilePath('*')) + self.checkBackwardsCompatibility(tf_v1, golden_file_pattern) + if __name__ == '__main__': parser = argparse.ArgumentParser() |