aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/mpi_collectives
diff options
context:
space:
mode:
authorGravatar Joel Hestness <jthestness@gmail.com>2017-12-26 18:51:40 -0800
committerGravatar drpngx <drpngx@users.noreply.github.com>2017-12-26 18:51:40 -0800
commitf5a27328adafacb8d88bb62df835fc34cd7ed46c (patch)
treeafb0b38b14b26351bd76be0d2d12b814c3c9159a /tensorflow/contrib/mpi_collectives
parent1f94604944ce48036d8a5fd6f8b1b24ca36be953 (diff)
mpi_collectives: Refactor to fix build issues (#15534)
* mpi_collectives: Refactor to fix build issues After TF commit 5c7f9e3, the mpi_collectives package would no longer build ops and kernels. This build issue caused mpi_collectives import to fail in Python with the following error: "NameError: Could not find operator MPISize in dynamic library mpi_collectives.so". To fix this issue, add build targets to ensure both ops and kernels are built Note, also refactored the build targets and directory structure to more closely match other contrib packages. * mpi_collectives: Minor BUILD fix * Minor: Buildifier fix * mpi_collectives: Correct preprocessor defines * mpi_collectives: Clearer defines inclusion
Diffstat (limited to 'tensorflow/contrib/mpi_collectives')
-rw-r--r--tensorflow/contrib/mpi_collectives/BUILD124
-rw-r--r--tensorflow/contrib/mpi_collectives/__init__.py18
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc (renamed from tensorflow/contrib/mpi_collectives/mpi_ops.cc)110
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/ring.cc (renamed from tensorflow/contrib/mpi_collectives/ring.cc)6
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc (renamed from tensorflow/contrib/mpi_collectives/ring.cu.cc)6
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/ring.h (renamed from tensorflow/contrib/mpi_collectives/ring.h)4
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_message.proto2
-rw-r--r--tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc132
-rw-r--r--tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py (renamed from tensorflow/contrib/mpi_collectives/mpi_ops.py)53
9 files changed, 257 insertions, 198 deletions
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD
index 11c5d6e776..9f9802b8fe 100644
--- a/tensorflow/contrib/mpi_collectives/BUILD
+++ b/tensorflow/contrib/mpi_collectives/BUILD
@@ -6,20 +6,9 @@ package(default_visibility = [
licenses(["notice"]) # Apache 2.0
-filegroup(
- name = "all_files",
- srcs = glob(
- ["**/*"],
- exclude = [
- "**/METADATA",
- "**/OWNERS",
- ],
- ),
- visibility = ["//tensorflow:__subpackages__"],
-)
-
load(
"//tensorflow/core:platform/default/build_config.bzl",
+ "tf_additional_mpi_lib_defines",
"tf_proto_library_cc",
)
@@ -33,26 +22,98 @@ tf_proto_library_cc(
],
)
-load("//tensorflow:tensorflow.bzl", "tf_custom_op_library")
-load("//tensorflow:tensorflow.bzl", "tf_py_test")
+cc_library(
+ name = "mpi_defines",
+ defines = tf_additional_mpi_lib_defines(),
+)
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_custom_op_py_library",
+ "tf_custom_op_library",
+ "tf_gen_op_wrapper_py",
+ "tf_gen_op_libs",
+ "tf_kernel_library",
+ "tf_py_test",
+)
tf_custom_op_library(
- name = "mpi_collectives.so",
+ name = "python/ops/_mpi_ops.so",
srcs = [
- "mpi_ops.cc",
- "ring.cc",
- "ring.h",
+ "kernels/mpi_ops.cc",
+ "kernels/ring.cc",
+ "kernels/ring.h",
+ "ops/mpi_ops.cc",
],
gpu_srcs = [
- "ring.cu.cc",
- "ring.h",
+ "kernels/ring.cu.cc",
+ "kernels/ring.h",
],
deps = [
+ ":mpi_defines",
":mpi_message_proto_cc",
"//third_party/mpi",
],
)
+tf_kernel_library(
+ name = "mpi_ops_kernels",
+ srcs = [
+ "kernels/mpi_ops.cc",
+ "kernels/ring.cc",
+ ],
+ hdrs = [
+ "kernels/ring.h",
+ ],
+ gpu_srcs = [
+ "kernels/ring.cu.cc",
+ ],
+ deps = [
+ ":mpi_defines",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:gpu_headers_lib",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:proto_text",
+ "//tensorflow/core:stream_executor",
+ ],
+ # TODO: Include? alwayslink = 1,
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["mpi_ops"],
+)
+
+tf_gen_op_wrapper_py(
+ name = "mpi_ops",
+ deps = [":mpi_ops_op_lib"],
+)
+
+tf_custom_op_py_library(
+ name = "mpi_collectives_py",
+ srcs = [
+ "__init__.py",
+ "python/ops/mpi_ops.py",
+ ],
+ dso = [
+ ":python/ops/_mpi_ops.so",
+ ],
+ kernels = [
+ ":mpi_ops_kernels",
+ ":mpi_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":mpi_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:device",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:util",
+ ],
+)
+
tf_py_test(
name = "mpi_ops_test",
srcs = ["mpi_ops_test.py"],
@@ -61,20 +122,19 @@ tf_py_test(
"//tensorflow/python:platform",
],
data = [
- ":mpi_collectives.so",
+ ":python/ops/_mpi_ops.so",
],
tags = ["manual"],
)
-py_library(
- name = "mpi_ops_py",
- srcs = [
- "__init__.py",
- "mpi_ops.py",
- ],
- data = [
- ":mpi_collectives.so",
- ],
- srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
)
diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py
index 9ed16a6f07..52029cbc36 100644
--- a/tensorflow/contrib/mpi_collectives/__init__.py
+++ b/tensorflow/contrib/mpi_collectives/__init__.py
@@ -37,7 +37,7 @@ for detecting the running MPI configuration.
Example:
```python
-from tensorflow.contrib import mpi
+import tensorflow.contrib.mpi_collectives as mpi
# Use `mpi.Session` instead of `tf.Session`
with mpi.Session() as session:
@@ -48,8 +48,10 @@ with mpi.Session() as session:
print("MPI Size:", session.run(mpi.size()))
```
-@@rank
+@@init
@@size
+@@rank
+@@local_rank
### Ring Allreduce and Allgather
@@ -123,12 +125,12 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.contrib.mpi_collectives.mpi_ops import size
-from tensorflow.contrib.mpi_collectives.mpi_ops import rank
-from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank
-from tensorflow.contrib.mpi_collectives.mpi_ops import allgather
-from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce
-from tensorflow.contrib.mpi_collectives.mpi_ops import init
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import init
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import size
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import rank
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import local_rank
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import allgather
+from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import _allreduce
def allreduce(tensor, average=True):
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
index a051ab0004..2d5b98022c 100644
--- a/tensorflow/contrib/mpi_collectives/mpi_ops.cc
+++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
@@ -21,7 +21,6 @@ limitations under the License.
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
-#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/mutex.h"
@@ -37,7 +36,7 @@ limitations under the License.
#define OMPI_SKIP_MPICXX
#include "third_party/mpi/mpi.h"
#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
-#include "tensorflow/contrib/mpi_collectives/ring.h"
+#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
/*
* MPI Allreduce and Allgather Ops for TensorFlow.
@@ -81,7 +80,7 @@ using GPUDevice = Eigen::GpuDevice;
namespace tensorflow {
namespace contrib {
-namespace mpi {
+namespace mpi_collectives {
// Make sure template specializations are generated in the ring.cu.cc and the
// ring.cc file, not in this file.
@@ -877,14 +876,6 @@ REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
MPIInitOp<GPUDevice>);
#endif
-REGISTER_OP("MPIInit").Doc(R"doc(
-Initialize MPI for the current process.
-
-If this is run on a GPU, then that GPU must be used for all future MPI
-operations. If it is run on CPU, then all future MPI operations must also
-run on CPU.
-)doc");
-
// Op to get the current MPI Size.
template <typename Device>
class MPISizeOp : public OpKernel {
@@ -911,21 +902,6 @@ REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
MPISizeOp<GPUDevice>);
#endif
-REGISTER_OP("MPISize")
- .Output("size: int32")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->Scalar());
- return Status::OK();
- })
- .Doc(R"doc(
-Returns the number of running MPI processes.
-
-More precisely, returns the number of MPI processes in the group associated
-with the MPI_COMM_WORLD communicator.
-
-size: Size of the MPI group.
-)doc");
-
// Op to get the current MPI Rank.
template <typename Device>
class MPIRankOp : public OpKernel {
@@ -952,21 +928,6 @@ REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
MPIRankOp<GPUDevice>);
#endif
-REGISTER_OP("MPIRank")
- .Output("rank: int32")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->Scalar());
- return Status::OK();
- })
- .Doc(R"doc(
-Returns the index of the current process in the MPI group.
-
-More precisely, returns the rank of the calling process in the MPI_COMM_WORLD
-communicator.
-
-rank: Rank of the calling process.
-)doc");
-
// Op to get the current local MPI Rank.
template <typename Device>
class MPILocalRankOp : public OpKernel {
@@ -994,21 +955,6 @@ REGISTER_KERNEL_BUILDER(
MPILocalRankOp<GPUDevice>);
#endif
-REGISTER_OP("MPILocalRank")
- .Output("rank: int32")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->Scalar());
- return Status::OK();
- })
- .Doc(R"doc(
-Returns the index of the current process in the node it is on.
-
-More precisely, returns the rank of the calling process in communicator that
-only spans the MPI processes running on that node.
-
-rank: Rank of the calling process on the node it is on.
-)doc");
-
template <typename Device>
class MPIAllreduceOp : public AsyncOpKernel {
public:
@@ -1083,28 +1029,6 @@ REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
MPIAllreduceOp<GPUDevice>);
#endif
-REGISTER_OP("MPIAllreduce")
- .Attr("T: {int32, int64, float32}")
- .Input("tensor: T")
- .Output("sum: T")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- c->set_output(0, c->input(0));
- return Status::OK();
- })
- .Doc(R"doc(
-Perform an MPI Allreduce on a tensor. All other processes that do a reduction
-on a tensor with the same name must have the same dimension for that tensor.
-Tensors are reduced with other tensors that have the same node name for the
-allreduce.
-
-Arguments
- tensor: A tensor to reduce.
-
-Output
- sum: A tensor with the same shape as `tensor`, summed across all
- MPI processes.
-)doc");
-
template <typename Device>
class MPIAllgatherOp : public AsyncOpKernel {
public:
@@ -1192,34 +1116,6 @@ class MPIAllgatherOp : public AsyncOpKernel {
}
};
-REGISTER_OP("MPIAllgather")
- .Attr("T: {int32, int64, float32}")
- .Attr("S: {int64}")
- .Input("tensor: T")
- .Input("sizes: S")
- .Output("gathered: T")
- .SetShapeFn([](shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle output;
- TF_RETURN_IF_ERROR(
- c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
- c->set_output(0, output);
- return Status::OK();
- })
- .Doc(R"doc(
-Perform an MPI Allgather on a tensor. All other processes that do a gather on a
-tensor with the same name must have the same rank for that tensor, and have the
-same dimension on all but the first dimension.
-
-Arguments
- tensor: A tensor to gather.
- sizes: A tensor containing the first-dimension sizes of tensors to be
- gathered from other ranks
-
-Output
- gathered: A tensor with the same shape as `tensor` except for the first
- dimension, which is the sum of dimensions in `sizes`.
-)doc");
-
REGISTER_KERNEL_BUILDER(
Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
MPIAllgatherOp<CPUDevice>);
@@ -1229,7 +1125,7 @@ REGISTER_KERNEL_BUILDER(
MPIAllgatherOp<GPUDevice>);
#endif
-} // namespace mpi
+} // namespace mpi_collectives
} // namespace contrib
} // namespace tensorflow
diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cc
index d93233eb21..8970ceb1a2 100644
--- a/tensorflow/contrib/mpi_collectives/ring.cc
+++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cc
@@ -17,11 +17,11 @@ limitations under the License.
#define EIGEN_USE_THREADS
-#include "tensorflow/contrib/mpi_collectives/ring.h"
+#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
namespace tensorflow {
namespace contrib {
-namespace mpi {
+namespace mpi_collectives {
using CPUDevice = Eigen::ThreadPoolDevice;
@@ -73,7 +73,7 @@ GENERATE_ACCUMULATE(long long);
GENERATE_ACCUMULATE(float);
#undef GENERATE_ACCUMULATE
-} // namespace mpi
+} // namespace mpi_collectives
} // namespace contrib
} // namespace tensorflow
diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc
index 2f3eef366a..b04abde469 100644
--- a/tensorflow/contrib/mpi_collectives/ring.cu.cc
+++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc
@@ -19,11 +19,11 @@ limitations under the License.
#define EIGEN_USE_GPU
-#include "tensorflow/contrib/mpi_collectives/ring.h"
+#include "tensorflow/contrib/mpi_collectives/kernels/ring.h"
namespace tensorflow {
namespace contrib {
-namespace mpi {
+namespace mpi_collectives {
using CPUDevice = Eigen::ThreadPoolDevice;
@@ -109,7 +109,7 @@ GENERATE_ACCUMULATE(long long);
GENERATE_ACCUMULATE(float);
#undef GENERATE_ACCUMULATE
-} // namespace mpi
+} // namespace mpi_collectives
} // namespace contrib
} // namespace tensorflow
#endif // GOOGLE_CUDA
diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/kernels/ring.h
index cae57ce60e..1d56d588bc 100644
--- a/tensorflow/contrib/mpi_collectives/ring.h
+++ b/tensorflow/contrib/mpi_collectives/kernels/ring.h
@@ -37,7 +37,7 @@ limitations under the License.
namespace tensorflow {
namespace contrib {
-namespace mpi {
+namespace mpi_collectives {
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
@@ -317,7 +317,7 @@ Status RingAllgather(OpKernelContext* context, const Tensor* input,
return Status::OK();
}
-} // namespace mpi
+} // namespace mpi_collectives
} // namespace contrib
} // namespace tensorflow
diff --git a/tensorflow/contrib/mpi_collectives/mpi_message.proto b/tensorflow/contrib/mpi_collectives/mpi_message.proto
index 7fa5e20301..afbce981ae 100644
--- a/tensorflow/contrib/mpi_collectives/mpi_message.proto
+++ b/tensorflow/contrib/mpi_collectives/mpi_message.proto
@@ -15,7 +15,7 @@ limitations under the License.
syntax = "proto3";
-package tensorflow.contrib.mpi;
+package tensorflow.contrib.mpi_collectives;
import "tensorflow/core/framework/tensor_shape.proto";
import "tensorflow/core/framework/types.proto";
diff --git a/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc
new file mode 100644
index 0000000000..18e6bb61cf
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc
@@ -0,0 +1,132 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi_collectives {
+
+REGISTER_OP("MPIInit").Doc(R"doc(
+Initialize MPI for the current process.
+
+If this is run on a GPU, then that GPU must be used for all future MPI
+operations. If it is run on CPU, then all future MPI operations must also
+run on CPU.
+)doc");
+
+REGISTER_OP("MPISize")
+ .Output("size: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the number of running MPI processes.
+
+More precisely, returns the number of MPI processes in the group associated
+with the MPI_COMM_WORLD communicator.
+
+size: Size of the MPI group.
+)doc");
+
+REGISTER_OP("MPIRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the MPI group.
+
+More precisely, returns the rank of the calling process in the MPI_COMM_WORLD
+communicator.
+
+rank: Rank of the calling process.
+)doc");
+
+REGISTER_OP("MPILocalRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the node it is on.
+
+More precisely, returns the rank of the calling process in communicator that
+only spans the MPI processes running on that node.
+
+rank: Rank of the calling process on the node it is on.
+)doc");
+
+REGISTER_OP("MPIAllreduce")
+ .Attr("T: {int32, int64, float32}")
+ .Input("tensor: T")
+ .Output("sum: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allreduce on a tensor. All other processes that do a reduction
+on a tensor with the same name must have the same dimension for that tensor.
+Tensors are reduced with other tensors that have the same node name for the
+allreduce.
+
+Arguments
+ tensor: A tensor to reduce.
+
+Output
+ sum: A tensor with the same shape as `tensor`, summed across all
+ MPI processes.
+)doc");
+
+REGISTER_OP("MPIAllgather")
+ .Attr("T: {int32, int64, float32}")
+ .Attr("S: {int64}")
+ .Input("tensor: T")
+ .Input("sizes: S")
+ .Output("gathered: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle output;
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
+ c->set_output(0, output);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allgather on a tensor. All other processes that do a gather on a
+tensor with the same name must have the same rank for that tensor, and have the
+same dimension on all but the first dimension.
+
+Arguments
+ tensor: A tensor to gather.
+ sizes: A tensor containing the first-dimension sizes of tensors to be
+ gathered from other ranks
+
+Output
+ gathered: A tensor with the same shape as `tensor` except for the first
+ dimension, which is the sum of dimensions in `sizes`.
+)doc");
+
+} // namespace mpi_collectives
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py
index 81567cc688..f0a116239d 100644
--- a/tensorflow/contrib/mpi_collectives/mpi_ops.py
+++ b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py
@@ -20,44 +20,13 @@ from __future__ import print_function
import tensorflow as tf
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import load_library
+from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops
+from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
from tensorflow.python.platform import resource_loader
-from tensorflow.python.platform import tf_logging as logging
-
-
-def _load_library(name, op_list=None):
- """Loads a .so file containing the specified operators.
-
- Args:
- name: The name of the .so file to load.
- op_list: A list of names of operators that the library should have. If None
- then the .so file's contents will not be verified.
-
- Raises:
- NameError if one of the required ops is missing.
- """
- try:
- filename = resource_loader.get_path_to_datafile(name)
- library = load_library.load_op_library(filename)
- for expected_op in (op_list or []):
- for lib_op in library.OP_LIST.op:
- if lib_op.name == expected_op:
- break
- else:
- raise NameError(
- 'Could not find operator %s in dynamic library %s' %
- (expected_op, name))
- return library
- except errors.NotFoundError:
- logging.warning('%s file could not be loaded.', name)
-
-
-MPI_LIB = _load_library('mpi_collectives.so', ['MPISize', 'MPIRank',
- 'MPILocalRank', 'MPIAllgather',
- 'MPIAllreduce'])
+_mpi_ops_so = loader.load_op_library(
+ resource_loader.get_path_to_datafile("_mpi_ops.so"))
def size(name=None):
"""An op which returns the number of MPI processes.
@@ -68,7 +37,7 @@ def size(name=None):
Returns:
An integer scalar containing the number of MPI processes.
"""
- return MPI_LIB.mpi_size(name=name)
+ return gen_mpi_ops.mpi_size(name=name)
ops.NotDifferentiable('MPISize')
@@ -83,7 +52,7 @@ def rank(name=None):
Returns:
An integer scalar with the MPI rank of the calling process.
"""
- return MPI_LIB.mpi_rank(name=name)
+ return gen_mpi_ops.mpi_rank(name=name)
ops.NotDifferentiable('MPIRank')
@@ -95,7 +64,7 @@ def init(name=None):
All future MPI ops must be run on the same device that the `init` op was run
on.
"""
- return MPI_LIB.mpi_init(name=name)
+ return gen_mpi_ops.mpi_init(name=name)
ops.NotDifferentiable('MPIInit')
@@ -112,7 +81,7 @@ def local_rank(name=None):
Returns:
An integer scalar with the local MPI rank of the calling process.
"""
- return MPI_LIB.mpi_local_rank(name=name)
+ return gen_mpi_ops.mpi_local_rank(name=name)
ops.NotDifferentiable('MPILocalRank')
@@ -129,7 +98,7 @@ def _allreduce(tensor, name=None):
A tensor of the same shape and type as `tensor`, summed across all
processes.
"""
- return MPI_LIB.mpi_allreduce(tensor, name=name)
+ return gen_mpi_ops.mpi_allreduce(tensor, name=name)
ops.NotDifferentiable('MPIAllreduce')
@@ -156,8 +125,8 @@ def allgather(tensor, name=None):
if name is None:
name = "allgather"
sizing_name = "{}_sizing".format(name)
- sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name)
- return MPI_LIB.mpi_allgather(tensor, sizes, name=name)
+ sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name)
+ return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name)
ops.NotDifferentiable('MPIAllgather')