diff options
author | Anna R <annarev@google.com> | 2018-09-07 12:20:37 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 12:24:57 -0700 |
commit | a65d3dd42122d3a58985d56118d58c5b4224f38f (patch) | |
tree | 9c7bf3a1a3dfa68d73af747b281a5ae327868332 /tensorflow/tools/api | |
parent | 0a375d94b6fd4c3cd0bd5d0a301b3acc65b96d78 (diff) |
Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. In all other cases, it would build API version 1.*.
PiperOrigin-RevId: 212016666
Diffstat (limited to 'tensorflow/tools/api')
-rw-r--r-- | tensorflow/tools/api/tests/BUILD | 5 | ||||
-rw-r--r-- | tensorflow/tools/api/tests/api_compatibility_test.py | 14 |
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD index 8764409e4d..4efa4a9651 100644 --- a/tensorflow/tools/api/tests/BUILD +++ b/tensorflow/tools/api/tests/BUILD @@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary") py_test( name = "api_compatibility_test", - srcs = ["api_compatibility_test.py"], + srcs = [ + "api_compatibility_test.py", + "//tensorflow:tf_python_api_gen_v2", + ], data = [ "//tensorflow/tools/api/golden:api_golden_v1", "//tensorflow/tools/api/golden:api_golden_v2", diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py index 43d19bc99c..99bed5714f 100644 --- a/tensorflow/tools/api/tests/api_compatibility_test.py +++ b/tensorflow/tools/api/tests/api_compatibility_test.py @@ -34,6 +34,7 @@ import sys import unittest import tensorflow as tf +from tensorflow._api import v2 as tf_v2 from google.protobuf import message from google.protobuf import text_format @@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v1, visitor) + traverse.traverse(tf_v2.compat.v1, visitor) def testNoSubclassOfMessageV2(self): if not hasattr(tf.compat, 'v2'): return visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor) visitor.do_not_descend_map['tf'].append('contrib') - traverse.traverse(tf.compat.v2, visitor) + traverse.traverse(tf_v2, visitor) def _checkBackwardsCompatibility( self, root, golden_file_pattern, api_version, @@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase): sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV1(self): - if not hasattr(tf.compat, 'v1'): - return api_version = 1 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v1, golden_file_pattern, api_version) + tf_v2.compat.v1, golden_file_pattern, api_version) @unittest.skipUnless( sys.version_info.major == 2, 'API compabitility test goldens are generated using python2.') def testAPIBackwardsCompatibilityV2(self): - if not hasattr(tf.compat, 'v2'): - return api_version = 2 golden_file_pattern = os.path.join( resource_loader.get_root_dir_with_all_resources(), _KeyToFilePath('*', api_version)) self._checkBackwardsCompatibility( - tf.compat.v2, golden_file_pattern, api_version) + tf_v2, golden_file_pattern, api_version, + additional_private_map={'tf.compat': ['v1']}) if __name__ == '__main__': |