aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/nccl
diff options
context:
space:
mode:
authorGravatar Eugene Brevdo <ebrevdo@google.com>2018-09-11 10:41:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-11 10:51:01 -0700
commit36e1a5ea5ba2dd5eaa7f4cfc84a61f8ce3ea20e1 (patch)
tree4f1671f78f5971b02dc2af66f57eabbf01005112 /tensorflow/contrib/nccl
parent36d7b12357df667dcd427c070e21779ed83f4ec9 (diff)
[TF] Variant improvements.
1. Change Variant Decode to accept VariantTensorData (non-ref). This should allow some optimization in the future. In the meantime it means removing the variant.h include from tensor.h, since variant_encode_decode.h now relies on tensor.h and variant.h now relies on that. It also means we found a bunch of places where tensor.proto.h, variant.h, and mutex.h were being imported through tensor.h (along with a bunch of other crap); so now we directly import them in order to compile. 2. Move Variant registry to use TypeIndex instead of a TypeName string; this should speed up registry lookups. PiperOrigin-RevId: 212478896
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r--tensorflow/contrib/nccl/BUILD24
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_rewrite.cc1
2 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD
index 62996d1fd8..225025e995 100644
--- a/tensorflow/contrib/nccl/BUILD
+++ b/tensorflow/contrib/nccl/BUILD
@@ -25,15 +25,17 @@ tf_custom_op_library(
name = "python/ops/_nccl_ops.so",
srcs = [
"ops/nccl_ops.cc",
- ],
+ ] + if_cuda(["kernels/nccl_rewrite.cc"]),
gpu_srcs = if_not_windows_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
]),
- deps = if_cuda([
+ deps = [] + if_cuda([
"@local_config_nccl//:nccl",
"//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:protos_all_proto_text",
]),
)
@@ -57,32 +59,30 @@ tf_cuda_cc_test(
"notap",
],
deps =
- [
+ if_cuda([
+ "@local_config_nccl//:nccl",
"//tensorflow/core:cuda",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
- "@local_config_nccl//:nccl",
- ],
+ ]),
)
tf_kernel_library(
name = "nccl_kernels",
- srcs = [
+ srcs = if_cuda([
"kernels/nccl_manager.cc",
"kernels/nccl_manager.h",
"kernels/nccl_ops.cc",
- "kernels/nccl_rewrite.cc",
- ],
- deps = [
+ ]),
+ deps = if_cuda([
+ "@local_config_nccl//:nccl",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib",
- "//tensorflow/core:proto_text",
"//tensorflow/core:stream_executor",
- "@local_config_nccl//:nccl",
- ],
+ ]),
alwayslink = 1,
)
diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
index 4676e937e5..06ff86e6d8 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/common_runtime/optimization_registry.h"
+#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/graph/node_builder.h"
namespace tensorflow {