From a65d3dd42122d3a58985d56118d58c5b4224f38f Mon Sep 17 00:00:00 2001 From: Anna R Date: Fri, 7 Sep 2018 12:20:37 -0700 Subject: Add tf_api_version flag. If --define=tf_api_version=2 flag is passed in, then bazel will build TensorFlow API version 2.0. In all other cases, it would build API version 1.*. PiperOrigin-RevId: 212016666 --- tensorflow/BUILD | 50 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) (limited to 'tensorflow/BUILD') diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 2926789953..386e0096ff 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -24,6 +24,11 @@ load( "//tensorflow/python/tools/api/generator:api_gen.bzl", "gen_api_init_files", # @unused ) +load("//tensorflow/python/tools/api/generator:api_gen.bzl", "get_compat_files") +load( + "//tensorflow/python/tools/api/generator:api_init_files.bzl", + "TENSORFLOW_API_INIT_FILES", # @unused +) load( "//tensorflow/python/tools/api/generator:api_init_files_v1.bzl", "TENSORFLOW_API_INIT_FILES_V1", # @unused @@ -33,6 +38,11 @@ load( "if_ngraph", ) +# @unused +TENSORFLOW_API_INIT_FILES_V2 = ( + TENSORFLOW_API_INIT_FILES + get_compat_files(TENSORFLOW_API_INIT_FILES_V1, 1) +) + # Config setting used when building for products # which requires restricted licenses to be avoided. config_setting( @@ -428,6 +438,13 @@ config_setting( visibility = ["//visibility:public"], ) +# This flag specifies whether TensorFlow 2.0 API should be built instead +# of 1.* API. Note that TensorFlow 2.0 API is currently under development. +config_setting( + name = "api_version_2", + define_values = {"tf_api_version": "2"}, +) + package_group( name = "internal", packages = [ @@ -592,13 +609,39 @@ exports_files( ) gen_api_init_files( - name = "tensorflow_python_api_gen", + name = "tf_python_api_gen_v1", srcs = ["api_template.__init__.py"], api_version = 1, + output_dir = "_api/v1/", output_files = TENSORFLOW_API_INIT_FILES_V1, + output_package = "tensorflow._api.v1", + root_init_template = "api_template.__init__.py", +) + +gen_api_init_files( + name = "tf_python_api_gen_v2", + srcs = ["api_template.__init__.py"], + api_version = 2, + compat_api_versions = [1], + output_dir = "_api/v2/", + output_files = TENSORFLOW_API_INIT_FILES_V2, + output_package = "tensorflow._api.v2", root_init_template = "api_template.__init__.py", ) +genrule( + name = "root_init_gen", + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }), + outs = ["__init__.py"], + cmd = select({ + "api_version_2": "cp $(@D)/_api/v2/__init__.py $(OUTS)", + "//conditions:default": "cp $(@D)/_api/v1/__init__.py $(OUTS)", + }), +) + py_library( name = "tensorflow_py", srcs = ["//tensorflow/python/estimator/api:estimator_python_api_gen"], @@ -613,7 +656,10 @@ py_library( py_library( name = "tensorflow_py_no_contrib", - srcs = [":tensorflow_python_api_gen"], + srcs = select({ + "api_version_2": [":tf_python_api_gen_v2"], + "//conditions:default": [":tf_python_api_gen_v1"], + }) + [":root_init_gen"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = ["//tensorflow/python:no_contrib"], -- cgit v1.2.3