aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tools/pip_package/pip_smoke_test.py
diff options
context:
space:
mode:
authorGravatar Andrew Selle <aselle@google.com>2018-05-02 21:15:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-02 21:17:42 -0700
commitf6000468263c5db7befbf5c320e8b3af7d90b819 (patch)
tree42ce43c08ed205308d3faa15a306ab5aabc7ac0b /tensorflow/tools/pip_package/pip_smoke_test.py
parent71f97c8cd9304a8e1cf8e309e15000d5831b212a (diff)
Expose Interpreter to tensorflow.contrib.lite
PiperOrigin-RevId: 195198645
Diffstat (limited to 'tensorflow/tools/pip_package/pip_smoke_test.py')
-rw-r--r--tensorflow/tools/pip_package/pip_smoke_test.py73
1 files changed, 49 insertions, 24 deletions
diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py
index b23dde2019..401f833dbd 100644
--- a/tensorflow/tools/pip_package/pip_smoke_test.py
+++ b/tensorflow/tools/pip_package/pip_smoke_test.py
@@ -30,15 +30,42 @@ os.chdir(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")))
PIP_PACKAGE_QUERY_EXPRESSION = (
"deps(//tensorflow/tools/pip_package:build_pip_package)")
-# pylint: disable=g-backslash-continuation
-PY_TEST_QUERY_EXPRESSION = 'deps(\
- filter("^((?!benchmark).)*$",\
- kind(py_test,\
- //tensorflow/python/... \
- + //tensorflow/contrib/... \
- - //tensorflow/contrib/tensorboard/... \
- - attr(tags, "manual|no_pip", //tensorflow/...))), 1)'
-# pylint: enable=g-backslash-continuation
+
+def GetBuild(dir_base):
+ """Get the list of BUILD file all targets recursively startind at dir_base."""
+ items = []
+ for root, _, files in os.walk(dir_base):
+ for name in files:
+ if (name == "BUILD" and
+ root.find("tensorflow/contrib/lite/examples/android") == -1):
+ items.append("//" + root + ":all")
+ return items
+
+
+def BuildPyTestDependencies():
+ python_targets = GetBuild("tensorflow/python")
+ contrib_targets = GetBuild("tensorflow/contrib")
+ tensorboard_targets = GetBuild("tensorflow/contrib/tensorboard")
+ tensorflow_targets = GetBuild("tensorflow")
+ # Build list of test targets,
+ # python + contrib - tensorboard - attr(manual|pno_pip)
+ targets = " + ".join(python_targets)
+ for t in contrib_targets:
+ targets += " + " + t
+ for t in tensorboard_targets:
+ targets += " - " + t
+ targets += ' - attr(tags, "manual|no_pip", %s)' % " + ".join(
+ tensorflow_targets)
+ query_kind = "kind(py_test, %s)" % targets
+ # Skip benchmarks etc.
+ query_filter = 'filter("^((?!benchmark).)*$", %s)' % query_kind
+ # Get the dependencies
+ query_deps = "deps(%s, 1)" % query_filter
+
+ return python_targets, query_deps
+
+
+PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies()
# Hard-coded blacklist of files if not included in pip package
# TODO(amitpatankar): Clean up blacklist.
@@ -79,16 +106,6 @@ BLACKLIST = [
]
-def bazel_query(query_target):
- """Run bazel query on target."""
- try:
- output = subprocess.check_output(
- ["bazel", "query", "--keep_going", query_target])
- except subprocess.CalledProcessError as e:
- output = e.output
- return output
-
-
def main():
"""This script runs the pip smoke test.
@@ -103,14 +120,22 @@ def main():
"""
# pip_package_dependencies_list is the list of included files in pip packages
- pip_package_dependencies = bazel_query(PIP_PACKAGE_QUERY_EXPRESSION)
+ pip_package_dependencies = subprocess.check_output(
+ ["bazel", "cquery", PIP_PACKAGE_QUERY_EXPRESSION])
pip_package_dependencies_list = pip_package_dependencies.strip().split("\n")
+ pip_package_dependencies_list = [
+ x.split()[0] for x in pip_package_dependencies_list
+ ]
print("Pip package superset size: %d" % len(pip_package_dependencies_list))
# tf_py_test_dependencies is the list of dependencies for all python
# tests in tensorflow
- tf_py_test_dependencies = bazel_query(PY_TEST_QUERY_EXPRESSION)
+ tf_py_test_dependencies = subprocess.check_output(
+ ["bazel", "cquery", PY_TEST_QUERY_EXPRESSION])
tf_py_test_dependencies_list = tf_py_test_dependencies.strip().split("\n")
+ tf_py_test_dependencies_list = [
+ x.split()[0] for x in tf_py_test_dependencies.strip().split("\n")
+ ]
print("Pytest dependency subset size: %d" % len(tf_py_test_dependencies_list))
missing_dependencies = []
@@ -141,9 +166,9 @@ def main():
for missing_dependency in missing_dependencies:
print("\nMissing dependency: %s " % missing_dependency)
print("Affected Tests:")
- rdep_query = ("rdeps(kind(py_test, //tensorflow/python/...), %s)" %
- missing_dependency)
- affected_tests = bazel_query(rdep_query)
+ rdep_query = ("rdeps(kind(py_test, %s), %s)" %
+ (" + ".join(PYTHON_TARGETS), missing_dependency))
+ affected_tests = subprocess.check_output(["bazel", "cquery", rdep_query])
affected_tests_list = affected_tests.split("\n")[:-2]
print("\n".join(affected_tests_list))