diff options
-rw-r--r-- | tensorflow/contrib/BUILD | 7 | ||||
-rw-r--r-- | tensorflow/contrib/__init__.py | 1 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/tf_core_kernels.cmake | 3 | ||||
-rw-r--r-- | tensorflow/contrib/cmake/tf_core_ops.cmake | 1 | ||||
-rwxr-xr-x | tensorflow/contrib/cmake/tf_python.cmake | 7 | ||||
-rw-r--r-- | tensorflow/contrib/nccl/__init__.py | 22 | ||||
-rwxr-xr-x | tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh | 11 |
7 files changed, 32 insertions, 20 deletions
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 595d899738..a726471d0f 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,8 +7,6 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "if_not_windows") - py_library( name = "contrib_py", srcs = glob(["**/*.py"]), @@ -46,6 +44,7 @@ py_library( "//tensorflow/contrib/losses:losses_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nn:nn_py", "//tensorflow/contrib/opt:opt_py", @@ -65,9 +64,7 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", - ] + if_not_windows([ - "//tensorflow/contrib/nccl:nccl_py", - ]), + ], ) cc_library( diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index d4ddd1cf6a..9b703cf090 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -45,6 +45,7 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import quantization diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index 33384eed48..0663ba1637 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -37,6 +37,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/blas_gemm.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/gru_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/lstm_ops.cc" diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index 4e30005629..126ef6c00c 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -58,6 +58,7 @@ GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/i GENERATE_CONTRIB_OP_LIBRARY(layers_bucketization "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/bucketization_op.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_gru "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/gru_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(rnn_lstm "${tensorflow_source_dir}/tensorflow/contrib/rnn/ops/lstm_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(tensor_forest "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/ops/tensor_forest_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 02038da7f8..37bdcec086 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -405,6 +405,11 @@ add_python_module("tensorflow/contrib/ndlstm/python") add_python_module("tensorflow/contrib/nn") add_python_module("tensorflow/contrib/nn/python") add_python_module("tensorflow/contrib/nn/python/ops") +add_python_module("tensorflow/contrib/nccl") +add_python_module("tensorflow/contrib/nccl/kernels") +add_python_module("tensorflow/contrib/nccl/ops") +add_python_module("tensorflow/contrib/nccl/python") +add_python_module("tensorflow/contrib/nccl/python/ops") add_python_module("tensorflow/contrib/opt") add_python_module("tensorflow/contrib/opt/python") add_python_module("tensorflow/contrib/opt/python/training") @@ -599,6 +604,8 @@ GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/memory_stats/ops/gen_memory_stats_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_nccl_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/nccl/ops/gen_nccl_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_gru_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/rnn/ops/gen_gru_ops.py) GENERATE_PYTHON_OP_LIB("contrib_rnn_lstm_ops" diff --git a/tensorflow/contrib/nccl/__init__.py b/tensorflow/contrib/nccl/__init__.py index 0275ed6079..d851c522c0 100644 --- a/tensorflow/contrib/nccl/__init__.py +++ b/tensorflow/contrib/nccl/__init__.py @@ -12,13 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Ops for nccl AllReduce.""" +"""Functions for using NVIDIA nccl collective ops. + +@@all_max +@@all_min +@@all_prod +@@all_sum +@@broadcast + +""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import -# pylint: disable=wildcard-import -from tensorflow.contrib.nccl.python.ops.nccl_ops import * -# pylint: enable=wildcard-import +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_max +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_min +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_prod +from tensorflow.contrib.nccl.python.ops.nccl_ops import all_sum +from tensorflow.contrib.nccl.python.ops.nccl_ops import broadcast + +from tensorflow.python.util.all_util import remove_undocumented +remove_undocumented(__name__) diff --git a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh index b428bebc6f..a08a3c2874 100755 --- a/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh +++ b/tensorflow/tools/ci_build/windows/libtensorflow_cpu.sh @@ -31,15 +31,6 @@ if [ ! -e "WORKSPACE" ]; then exit 1 fi -#### BEGIN HACKS TO BE RESOLVED WITH NEWER BAZEL VERSIONS #### -# Disable nccl. -# This can be removed once we switch to a bazel release that includes -# https://github.com/bazelbuild/bazel/commit/8e0991cb19eadfcb651cd6987255d5f7c0a58e0a -# (the fix for https://github.com/bazelbuild/bazel/issues/2494). -# Most likley bazel 0.4.5 will contain that. -sed -i -e "s/\"@nccl_archive/#\"@nccl_archive/" ./tensorflow/contrib/nccl/BUILD -sed -i -e "s/\"@nccl_archive/#\"@nccl_archive/" ./tensorflow/tools/pip_package/BUILD - # Enable JNI support for Windows in Bazel. # This can be removed once # https://github.com/bazelbuild/bazel/pull/2599 @@ -66,7 +57,7 @@ bazel build -c opt ${BUILD_OPTS} \ tensorflow/tools/lib_package:jnilicenses_generate # Revert the hacks above -git checkout ./tensorflow/contrib/nccl/BUILD ./tensorflow/tools/pip_package/BUILD +git checkout ./tensorflow/tools/pip_package/BUILD git checkout ./tensorflow/java/src/main/native/BUILD rm -f ./tensorflow/java/src/main/native/windows_jni_md.h |