aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Avijit <30507445+avijit-nervana@users.noreply.github.com>2018-07-24 23:35:27 -0700
committerGravatar GitHub <noreply@github.com>2018-07-24 23:35:27 -0700
commit121e0161c5a7273c5a59f1e10a8577428c685796 (patch)
treefb6b99b4af3accd9c68d05442be95f9250b59604 /tensorflow
parent80fb8679ab14ba3d180e8eb22da11509a15b9219 (diff)
nGraph integration with TensorFlow
* Added nGraph bridge as a third_party to be built with TensorFlow based on user selection. * Added a limited set of C++ unit tests to verify the correctness of the computation
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD14
-rw-r--r--tensorflow/core/BUILD5
-rw-r--r--tensorflow/core/common_runtime/threadpool_device_factory.cc1
-rw-r--r--tensorflow/core/platform/default/build_config.bzl2
-rw-r--r--tensorflow/python/BUILD4
-rw-r--r--tensorflow/tensorflow.bzl5
-rw-r--r--tensorflow/workspace.bzl33
7 files changed, 59 insertions, 5 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 51eea94847..6d443eb9f2 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -24,6 +24,8 @@ load(
"gen_api_init_files", # @unused
)
+load("//third_party/ngraph:build_defs.bzl", "if_ngraph")
+
# Config setting for determining if we are building for Android.
config_setting(
name = "android",
@@ -408,6 +410,14 @@ config_setting(
visibility = ["//visibility:public"],
)
+# This flag is set from the configure step when the user selects with nGraph option.
+# By default it should be false
+config_setting(
+ name = "with_ngraph_support",
+ values = {"define": "with_ngraph_support=true"},
+ visibility = ["//visibility:public"],
+)
+
package_group(
name = "internal",
packages = [
@@ -540,7 +550,7 @@ tf_cc_shared_object(
"//tensorflow/c:version_script.lds",
"//tensorflow/c/eager:c_api",
"//tensorflow/core:tensorflow",
- ],
+ ]
)
tf_cc_shared_object(
@@ -568,7 +578,7 @@ tf_cc_shared_object(
"//tensorflow/cc:scope",
"//tensorflow/cc/profiler",
"//tensorflow/core:tensorflow",
- ],
+ ] + if_ngraph(["@ngraph_tf//:ngraph_tf"])
)
exports_files(
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index dbe87a6dbb..19060c5ce7 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -2325,6 +2325,7 @@ tf_generate_proto_text_sources(
":lib_internal",
":protos_all_proto_cc",
],
+ visibility = ["//visibility:public"],
)
cc_library(
@@ -2435,6 +2436,7 @@ cc_header_only_library(
deps = [
":core_cpu_lib",
],
+ visibility = ["//visibility:public"],
)
tf_cuda_library(
@@ -2502,7 +2504,7 @@ tf_cuda_library(
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn",
],
- ),
+ ) ,
alwayslink = 1,
)
@@ -2560,6 +2562,7 @@ tf_cuda_library(
cc_library(
name = "protos_cc",
deps = ["//tensorflow/core/platform/default/build_config:protos_cc"],
+ visibility = ["//visibility:public"],
)
# Library containing all of the graph construction code that is
diff --git a/tensorflow/core/common_runtime/threadpool_device_factory.cc b/tensorflow/core/common_runtime/threadpool_device_factory.cc
index 6a900c02c0..61a62621ba 100644
--- a/tensorflow/core/common_runtime/threadpool_device_factory.cc
+++ b/tensorflow/core/common_runtime/threadpool_device_factory.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 28891320c4..9f6bc70f04 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -707,4 +707,4 @@ def tf_additional_binary_deps():
[
"//third_party/mkl:intel_binary_blob",
],
- )
+ ) \ No newline at end of file
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index d60d37df50..f2ab2f80e6 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -44,6 +44,7 @@ load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_mpi_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_gdr_deps")
load("//tensorflow/core:platform/default/build_config_root.bzl", "if_static")
+load("//third_party/ngraph:build_defs.bzl","if_ngraph")
py_library(
name = "python",
@@ -3669,7 +3670,8 @@ tf_py_wrap_cc(
tf_additional_plugin_deps() +
tf_additional_verbs_deps() +
tf_additional_mpi_deps() +
- tf_additional_gdr_deps()),
+ tf_additional_gdr_deps())+
+ if_ngraph(["@ngraph_tf//:ngraph_tf"])
)
# ** Targets for Windows build (start) **
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index e4241667ad..5884870daa 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -24,6 +24,10 @@ load(
"if_mkl",
"if_mkl_lnx_x64"
)
+load(
+ "//third_party/ngraph:build_defs.bzl",
+ "if_ngraph",
+)
def register_extension_info(**kwargs):
pass
@@ -214,6 +218,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
+ if_cuda(["-DGOOGLE_CUDA=1"])
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"])
+ if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML"])
+ + if_ngraph(["-DINTEL_NGRAPH=1"])
+ if_mkl_lnx_x64(["-fopenmp"])
+ if_android_arm(["-mfpu=neon"])
+ if_linux_x86_64(["-msse3"])
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index b712954d6d..8953edf8a6 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -803,6 +803,39 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
strip_prefix = "rules_android-0.1.1",
)
+ tf_http_archive(
+ name = "ngraph",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ "https://github.com/NervanaSystems/ngraph/archive/v0.5.0.tar.gz",
+ ],
+ sha256 = "cb35d3d98836f615408afd18371fb13e3400711247e0d822ba7f306c45e9bb2c",
+ strip_prefix = "ngraph-0.5.0",
+ build_file = clean_dep("//third_party/ngraph:ngraph.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "nlohmann_json_lib",
+ urls = [
+ "https://mirror.bazel.build/github.com/nlohmann/json/archive/v3.1.1.tar.gz",
+ "https://github.com/nlohmann/json/archive/v3.1.1.tar.gz",
+ ],
+ sha256 = "9f3549824af3ca7e9707a2503959886362801fb4926b869789d6929098a79e47",
+ strip_prefix = "json-3.1.1",
+ build_file = clean_dep("//third_party/ngraph:nlohmann_json.BUILD"),
+ )
+
+ tf_http_archive(
+ name = "ngraph_tf",
+ urls = [
+ "https://mirror.bazel.build/github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc0.tar.gz",
+ "https://github.com/NervanaSystems/ngraph-tf/archive/v0.3.0-rc0.tar.gz"
+ ],
+ sha256 = "c09a35d0a605afeeaf5aad81181a6abc7e9b9e39312e8fdfbae20cbd8eb58523",
+ strip_prefix = "ngraph-tf-0.3.0-rc0",
+ build_file = clean_dep("//third_party/ngraph:ngraph_tf.BUILD"),
+ )
+
##############################################################################
# BIND DEFINITIONS
#