aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/api
diff options
context:
space:
mode:
authorGravatar Anna R <annarev@google.com>2018-09-07 12:20:37 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 12:24:57 -0700
commita65d3dd42122d3a58985d56118d58c5b4224f38f (patch)
tree9c7bf3a1a3dfa68d73af747b281a5ae327868332 /tensorflow/tools/api
parent0a375d94b6fd4c3cd0bd5d0a301b3acc65b96d78 (diff)
Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. In all other cases, it would build API version 1.*.
PiperOrigin-RevId: 212016666
Diffstat (limited to 'tensorflow/tools/api')
-rw-r--r--tensorflow/tools/api/tests/BUILD5
-rw-r--r--tensorflow/tools/api/tests/api_compatibility_test.py14
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/tools/api/tests/BUILD b/tensorflow/tools/api/tests/BUILD
index 8764409e4d..4efa4a9651 100644
--- a/tensorflow/tools/api/tests/BUILD
+++ b/tensorflow/tools/api/tests/BUILD
@@ -15,7 +15,10 @@ load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
py_test(
name = "api_compatibility_test",
- srcs = ["api_compatibility_test.py"],
+ srcs = [
+ "api_compatibility_test.py",
+ "//tensorflow:tf_python_api_gen_v2",
+ ],
data = [
"//tensorflow/tools/api/golden:api_golden_v1",
"//tensorflow/tools/api/golden:api_golden_v2",
diff --git a/tensorflow/tools/api/tests/api_compatibility_test.py b/tensorflow/tools/api/tests/api_compatibility_test.py
index 43d19bc99c..99bed5714f 100644
--- a/tensorflow/tools/api/tests/api_compatibility_test.py
+++ b/tensorflow/tools/api/tests/api_compatibility_test.py
@@ -34,6 +34,7 @@ import sys
import unittest
import tensorflow as tf
+from tensorflow._api import v2 as tf_v2
from google.protobuf import message
from google.protobuf import text_format
@@ -232,14 +233,14 @@ class ApiCompatibilityTest(test.TestCase):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v1, visitor)
+ traverse.traverse(tf_v2.compat.v1, visitor)
def testNoSubclassOfMessageV2(self):
if not hasattr(tf.compat, 'v2'):
return
visitor = public_api.PublicAPIVisitor(_VerifyNoSubclassOfMessageVisitor)
visitor.do_not_descend_map['tf'].append('contrib')
- traverse.traverse(tf.compat.v2, visitor)
+ traverse.traverse(tf_v2, visitor)
def _checkBackwardsCompatibility(
self, root, golden_file_pattern, api_version,
@@ -300,27 +301,24 @@ class ApiCompatibilityTest(test.TestCase):
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV1(self):
- if not hasattr(tf.compat, 'v1'):
- return
api_version = 1
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v1, golden_file_pattern, api_version)
+ tf_v2.compat.v1, golden_file_pattern, api_version)
@unittest.skipUnless(
sys.version_info.major == 2,
'API compabitility test goldens are generated using python2.')
def testAPIBackwardsCompatibilityV2(self):
- if not hasattr(tf.compat, 'v2'):
- return
api_version = 2
golden_file_pattern = os.path.join(
resource_loader.get_root_dir_with_all_resources(),
_KeyToFilePath('*', api_version))
self._checkBackwardsCompatibility(
- tf.compat.v2, golden_file_pattern, api_version)
+ tf_v2, golden_file_pattern, api_version,
+ additional_private_map={'tf.compat': ['v1']})
if __name__ == '__main__':