diff options
author | Eugene Brevdo <ebrevdo@google.com> | 2018-09-11 10:41:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-11 10:51:01 -0700 |
commit | 36e1a5ea5ba2dd5eaa7f4cfc84a61f8ce3ea20e1 (patch) | |
tree | 4f1671f78f5971b02dc2af66f57eabbf01005112 /tensorflow/contrib/nccl | |
parent | 36d7b12357df667dcd427c070e21779ed83f4ec9 (diff) |
[TF] Variant improvements.
1. Change Variant Decode to accept VariantTensorData (non-ref).
This should allow some optimization in the future.
In the meantime it means removing the variant.h include from tensor.h, since
variant_encode_decode.h now relies on tensor.h and variant.h now relies on that.
It also means we found a bunch of places where tensor.proto.h, variant.h, and
mutex.h were being imported through tensor.h (along with a bunch of other crap);
so now we directly import them in order to compile.
2. Move Variant registry to use TypeIndex instead of a TypeName string; this should
speed up registry lookups.
PiperOrigin-RevId: 212478896
Diffstat (limited to 'tensorflow/contrib/nccl')
-rw-r--r-- | tensorflow/contrib/nccl/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/kernels/nccl_rewrite.cc | 1 |
2 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index 62996d1fd8..225025e995 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -25,15 +25,17 @@ tf_custom_op_library( name = "python/ops/_nccl_ops.so", srcs = [ "ops/nccl_ops.cc", - ], + ] + if_cuda(["kernels/nccl_rewrite.cc"]), gpu_srcs = if_not_windows_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", ]), - deps = if_cuda([ + deps = [] + if_cuda([ "@local_config_nccl//:nccl", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:protos_all_proto_text", ]), ) @@ -57,32 +59,30 @@ tf_cuda_cc_test( "notap", ], deps = - [ + if_cuda([ + "@local_config_nccl//:nccl", "//tensorflow/core:cuda", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "@local_config_nccl//:nccl", - ], + ]), ) tf_kernel_library( name = "nccl_kernels", - srcs = [ + srcs = if_cuda([ "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", - "kernels/nccl_rewrite.cc", - ], - deps = [ + ]), + deps = if_cuda([ + "@local_config_nccl//:nccl", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", - "//tensorflow/core:proto_text", "//tensorflow/core:stream_executor", - "@local_config_nccl//:nccl", - ], + ]), alwayslink = 1, ) diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc index 4676e937e5..06ff86e6d8 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -20,6 +20,7 @@ limitations under the License. #include <vector> #include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/graph/node_builder.h" namespace tensorflow { |