aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/api/tests/api_compatibility_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tools/api/tests/api_compatibility_test.py')
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py14
1 files changed, 6 insertions, 8 deletions
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 43d19bc99c..99bed5714f 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,7 @@ import sys
import unittest
import tensorflow as tf
+from tensorflow._api import v2 as tf_v2
from google.protobuf import message
from google.protobuf import text_format
@@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v1, visitor)
+ traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v2, visitor)
+ traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
self, root, golden_file_pattern, api_version,
@@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase):
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV1(self):
- if not hasattr(tf.compat, 'v1'):
- return
api_version = 1
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v1, golden_file_pattern, api_version)
+ tf_v2.compat.v1, golden_file_pattern, api_version)
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV2(self):
- if not hasattr(tf.compat, 'v2'):
- return
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v2, golden_file_pattern, api_version)
+ tf_v2, golden_file_pattern, api_version,
+ additional_private_map={'tf.compat': ['v1']})
if __name__ == '__main__':