diff options
Diffstat (limited to 'tensorflow/tools/api/tests/api_compatibility_test.py')
-rw-r--r-- | tensorflow/tools/api/tests/api_compatibility_test.py | 14 |
1 files changed, 6 insertions, 8 deletions
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 43d19bc99c..99bed5714f 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,7 @@ import sys import unittest import tensorflow as tf +from tensorflow._api import v2 as tf_v2 from google.protobuf import message from google.protobuf import text_format @@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v1, visitor) + traverse.traverse(tf_v2.compat.v1, visitor) def testNoSubclassOfMessageV2(self): if not hasattr(tf.compat, 'v2'): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v2, visitor) + traverse.traverse(tf_v2, visitor) def _checkBackwardsCompatibility( self, root, golden_file_pattern, api_version, @@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase): sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV1(self): - if not hasattr(tf.compat, 'v1'): - return api_version = 1 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v1, golden_file_pattern, api_version) + tf_v2.compat.v1, golden_file_pattern, api_version) @unittest.skipUnless( sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV2(self): - if not hasattr(tf.compat, 'v2'): - return api_version = 2 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v2, golden_file_pattern, api_version) + tf_v2, golden_file_pattern, api_version, + additional_private_map={'tf.compat': ['v1']}) if __name__ == '__main__': |