diff options
author | 2017-12-26 18:51:40 -0800 | |
---|---|---|
committer | 2017-12-26 18:51:40 -0800 | |
commit | f5a27328adafacb8d88bb62df835fc34cd7ed46c (patch) | |
tree | afb0b38b14b26351bd76be0d2d12b814c3c9159a /tensorflow/contrib/mpi_collectives | |
parent | 1f94604944ce48036d8a5fd6f8b1b24ca36be953 (diff) |
mpi_collectives: Refactor to fix build issues (#15534)
* mpi_collectives: Refactor to fix build issues
After TF commit 5c7f9e3, the mpi_collectives package would no longer build
ops and kernels. This build issue caused mpi_collectives import to fail in
Python with the following error: "NameError: Could not find operator MPISize
in dynamic library mpi_collectives.so".
To fix this issue, add build targets to ensure both ops and kernels are built
Note, also refactored the build targets and directory structure to more
closely match other contrib packages.
* mpi_collectives: Minor BUILD fix
* Minor: Buildifier fix
* mpi_collectives: Correct preprocessor defines
* mpi_collectives: Clearer defines inclusion
Diffstat (limited to 'tensorflow/contrib/mpi_collectives')
-rw-r--r-- | tensorflow/contrib/mpi_collectives/BUILD | 124 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/__init__.py | 18 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc (renamed from tensorflow/contrib/mpi_collectives/mpi_ops.cc) | 110 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/kernels/ring.cc (renamed from tensorflow/contrib/mpi_collectives/ring.cc) | 6 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc (renamed from tensorflow/contrib/mpi_collectives/ring.cu.cc) | 6 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/kernels/ring.h (renamed from tensorflow/contrib/mpi_collectives/ring.h) | 4 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/mpi_message.proto | 2 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc | 132 | ||||
-rw-r--r-- | tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py (renamed from tensorflow/contrib/mpi_collectives/mpi_ops.py) | 53 |
9 files changed, 257 insertions, 198 deletions
diff --git a/tensorflow/contrib/mpi_collectives/BUILD b/tensorflow/contrib/mpi_collectives/BUILD index 11c5d6e776..9f9802b8fe 100644 --- a/tensorflow/contrib/mpi_collectives/BUILD +++ b/tensorflow/contrib/mpi_collectives/BUILD @@ -6,20 +6,9 @@ package(default_visibility = [ licenses(["notice"]) # Apache 2.0 -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), - visibility = ["//tensorflow:__subpackages__"], -) - load( "//tensorflow/core:platform/default/build_config.bzl", + "tf_additional_mpi_lib_defines", "tf_proto_library_cc", ) @@ -33,26 +22,98 @@ tf_proto_library_cc( ], ) -load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") -load("//tensorflow:tensorflow.bzl", "tf_py_test") +cc_library( + name = "mpi_defines", + defines = tf_additional_mpi_lib_defines(), +) + +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_py_library", + "tf_custom_op_library", + "tf_gen_op_wrapper_py", + "tf_gen_op_libs", + "tf_kernel_library", + "tf_py_test", +) tf_custom_op_library( - name = "mpi_collectives.so", + name = "python/ops/_mpi_ops.so", srcs = [ - "mpi_ops.cc", - "ring.cc", - "ring.h", + "kernels/mpi_ops.cc", + "kernels/ring.cc", + "kernels/ring.h", + "ops/mpi_ops.cc", ], gpu_srcs = [ - "ring.cu.cc", - "ring.h", + "kernels/ring.cu.cc", + "kernels/ring.h", ], deps = [ + ":mpi_defines", ":mpi_message_proto_cc", "//third_party/mpi", ], ) +tf_kernel_library( + name = "mpi_ops_kernels", + srcs = [ + "kernels/mpi_ops.cc", + "kernels/ring.cc", + ], + hdrs = [ + "kernels/ring.h", + ], + gpu_srcs = [ + "kernels/ring.cu.cc", + ], + deps = [ + ":mpi_defines", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:stream_executor", + ], + # TODO: Include? alwayslink = 1, +) + +tf_gen_op_libs( + op_lib_names = ["mpi_ops"], +) + +tf_gen_op_wrapper_py( + name = "mpi_ops", + deps = [":mpi_ops_op_lib"], +) + +tf_custom_op_py_library( + name = "mpi_collectives_py", + srcs = [ + "__init__.py", + "python/ops/mpi_ops.py", + ], + dso = [ + ":python/ops/_mpi_ops.so", + ], + kernels = [ + ":mpi_ops_kernels", + ":mpi_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":mpi_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:device", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + ], +) + tf_py_test( name = "mpi_ops_test", srcs = ["mpi_ops_test.py"], @@ -61,20 +122,19 @@ tf_py_test( "//tensorflow/python:platform", ], data = [ - ":mpi_collectives.so", + ":python/ops/_mpi_ops.so", ], tags = ["manual"], ) -py_library( - name = "mpi_ops_py", - srcs = [ - "__init__.py", - "mpi_ops.py", - ], - data = [ - ":mpi_collectives.so", - ], - srcs_version = "PY2AND3", - visibility = ["//visibility:public"], +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], ) diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py index 9ed16a6f07..52029cbc36 100644 --- a/tensorflow/contrib/mpi_collectives/__init__.py +++ b/tensorflow/contrib/mpi_collectives/__init__.py @@ -37,7 +37,7 @@ for detecting the running MPI configuration. Example: ```python -from tensorflow.contrib import mpi +import tensorflow.contrib.mpi_collectives as mpi # Use `mpi.Session` instead of `tf.Session` with mpi.Session() as session: @@ -48,8 +48,10 @@ with mpi.Session() as session: print("MPI Size:", session.run(mpi.size())) ``` -@@rank +@@init @@size +@@rank +@@local_rank ### Ring Allreduce and Allgather @@ -123,12 +125,12 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.contrib.mpi_collectives.mpi_ops import size -from tensorflow.contrib.mpi_collectives.mpi_ops import rank -from tensorflow.contrib.mpi_collectives.mpi_ops import local_rank -from tensorflow.contrib.mpi_collectives.mpi_ops import allgather -from tensorflow.contrib.mpi_collectives.mpi_ops import _allreduce -from tensorflow.contrib.mpi_collectives.mpi_ops import init +from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import init +from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import size +from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import rank +from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import local_rank +from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import allgather +from tensorflow.contrib.mpi_collectives.python.ops.mpi_ops import _allreduce def allreduce(tensor, average=True): diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc index a051ab0004..2d5b98022c 100644 --- a/tensorflow/contrib/mpi_collectives/mpi_ops.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/mutex.h" @@ -37,7 +36,7 @@ limitations under the License. #define OMPI_SKIP_MPICXX #include "third_party/mpi/mpi.h" #include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h" -#include "tensorflow/contrib/mpi_collectives/ring.h" +#include "tensorflow/contrib/mpi_collectives/kernels/ring.h" /* * MPI Allreduce and Allgather Ops for TensorFlow. @@ -81,7 +80,7 @@ using GPUDevice = Eigen::GpuDevice; namespace tensorflow { namespace contrib { -namespace mpi { +namespace mpi_collectives { // Make sure template specializations are generated in the ring.cu.cc and the // ring.cc file, not in this file. @@ -877,14 +876,6 @@ REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU), MPIInitOp<GPUDevice>); #endif -REGISTER_OP("MPIInit").Doc(R"doc( -Initialize MPI for the current process. - -If this is run on a GPU, then that GPU must be used for all future MPI -operations. If it is run on CPU, then all future MPI operations must also -run on CPU. -)doc"); - // Op to get the current MPI Size. template <typename Device> class MPISizeOp : public OpKernel { @@ -911,21 +902,6 @@ REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"), MPISizeOp<GPUDevice>); #endif -REGISTER_OP("MPISize") - .Output("size: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the number of running MPI processes. - -More precisely, returns the number of MPI processes in the group associated -with the MPI_COMM_WORLD communicator. - -size: Size of the MPI group. -)doc"); - // Op to get the current MPI Rank. template <typename Device> class MPIRankOp : public OpKernel { @@ -952,21 +928,6 @@ REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"), MPIRankOp<GPUDevice>); #endif -REGISTER_OP("MPIRank") - .Output("rank: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the index of the current process in the MPI group. - -More precisely, returns the rank of the calling process in the MPI_COMM_WORLD -communicator. - -rank: Rank of the calling process. -)doc"); - // Op to get the current local MPI Rank. template <typename Device> class MPILocalRankOp : public OpKernel { @@ -994,21 +955,6 @@ REGISTER_KERNEL_BUILDER( MPILocalRankOp<GPUDevice>); #endif -REGISTER_OP("MPILocalRank") - .Output("rank: int32") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->Scalar()); - return Status::OK(); - }) - .Doc(R"doc( -Returns the index of the current process in the node it is on. - -More precisely, returns the rank of the calling process in communicator that -only spans the MPI processes running on that node. - -rank: Rank of the calling process on the node it is on. -)doc"); - template <typename Device> class MPIAllreduceOp : public AsyncOpKernel { public: @@ -1083,28 +1029,6 @@ REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU), MPIAllreduceOp<GPUDevice>); #endif -REGISTER_OP("MPIAllreduce") - .Attr("T: {int32, int64, float32}") - .Input("tensor: T") - .Output("sum: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - c->set_output(0, c->input(0)); - return Status::OK(); - }) - .Doc(R"doc( -Perform an MPI Allreduce on a tensor. All other processes that do a reduction -on a tensor with the same name must have the same dimension for that tensor. -Tensors are reduced with other tensors that have the same node name for the -allreduce. - -Arguments - tensor: A tensor to reduce. - -Output - sum: A tensor with the same shape as `tensor`, summed across all - MPI processes. -)doc"); - template <typename Device> class MPIAllgatherOp : public AsyncOpKernel { public: @@ -1192,34 +1116,6 @@ class MPIAllgatherOp : public AsyncOpKernel { } }; -REGISTER_OP("MPIAllgather") - .Attr("T: {int32, int64, float32}") - .Attr("S: {int64}") - .Input("tensor: T") - .Input("sizes: S") - .Output("gathered: T") - .SetShapeFn([](shape_inference::InferenceContext* c) { - shape_inference::ShapeHandle output; - TF_RETURN_IF_ERROR( - c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); - c->set_output(0, output); - return Status::OK(); - }) - .Doc(R"doc( -Perform an MPI Allgather on a tensor. All other processes that do a gather on a -tensor with the same name must have the same rank for that tensor, and have the -same dimension on all but the first dimension. - -Arguments - tensor: A tensor to gather. - sizes: A tensor containing the first-dimension sizes of tensors to be - gathered from other ranks - -Output - gathered: A tensor with the same shape as `tensor` except for the first - dimension, which is the sum of dimensions in `sizes`. -)doc"); - REGISTER_KERNEL_BUILDER( Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"), MPIAllgatherOp<CPUDevice>); @@ -1229,7 +1125,7 @@ REGISTER_KERNEL_BUILDER( MPIAllgatherOp<GPUDevice>); #endif -} // namespace mpi +} // namespace mpi_collectives } // namespace contrib } // namespace tensorflow diff --git a/tensorflow/contrib/mpi_collectives/ring.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cc index d93233eb21..8970ceb1a2 100644 --- a/tensorflow/contrib/mpi_collectives/ring.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cc @@ -17,11 +17,11 @@ limitations under the License. #define EIGEN_USE_THREADS -#include "tensorflow/contrib/mpi_collectives/ring.h" +#include "tensorflow/contrib/mpi_collectives/kernels/ring.h" namespace tensorflow { namespace contrib { -namespace mpi { +namespace mpi_collectives { using CPUDevice = Eigen::ThreadPoolDevice; @@ -73,7 +73,7 @@ GENERATE_ACCUMULATE(long long); GENERATE_ACCUMULATE(float); #undef GENERATE_ACCUMULATE -} // namespace mpi +} // namespace mpi_collectives } // namespace contrib } // namespace tensorflow diff --git a/tensorflow/contrib/mpi_collectives/ring.cu.cc b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc index 2f3eef366a..b04abde469 100644 --- a/tensorflow/contrib/mpi_collectives/ring.cu.cc +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.cu.cc @@ -19,11 +19,11 @@ limitations under the License. #define EIGEN_USE_GPU -#include "tensorflow/contrib/mpi_collectives/ring.h" +#include "tensorflow/contrib/mpi_collectives/kernels/ring.h" namespace tensorflow { namespace contrib { -namespace mpi { +namespace mpi_collectives { using CPUDevice = Eigen::ThreadPoolDevice; @@ -109,7 +109,7 @@ GENERATE_ACCUMULATE(long long); GENERATE_ACCUMULATE(float); #undef GENERATE_ACCUMULATE -} // namespace mpi +} // namespace mpi_collectives } // namespace contrib } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/mpi_collectives/ring.h b/tensorflow/contrib/mpi_collectives/kernels/ring.h index cae57ce60e..1d56d588bc 100644 --- a/tensorflow/contrib/mpi_collectives/ring.h +++ b/tensorflow/contrib/mpi_collectives/kernels/ring.h @@ -37,7 +37,7 @@ limitations under the License. namespace tensorflow { namespace contrib { -namespace mpi { +namespace mpi_collectives { using CPUDevice = Eigen::ThreadPoolDevice; using GPUDevice = Eigen::GpuDevice; @@ -317,7 +317,7 @@ Status RingAllgather(OpKernelContext* context, const Tensor* input, return Status::OK(); } -} // namespace mpi +} // namespace mpi_collectives } // namespace contrib } // namespace tensorflow diff --git a/tensorflow/contrib/mpi_collectives/mpi_message.proto b/tensorflow/contrib/mpi_collectives/mpi_message.proto index 7fa5e20301..afbce981ae 100644 --- a/tensorflow/contrib/mpi_collectives/mpi_message.proto +++ b/tensorflow/contrib/mpi_collectives/mpi_message.proto @@ -15,7 +15,7 @@ limitations under the License. syntax = "proto3"; -package tensorflow.contrib.mpi; +package tensorflow.contrib.mpi_collectives; import "tensorflow/core/framework/tensor_shape.proto"; import "tensorflow/core/framework/types.proto"; diff --git a/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc new file mode 100644 index 0000000000..18e6bb61cf --- /dev/null +++ b/tensorflow/contrib/mpi_collectives/ops/mpi_ops.cc @@ -0,0 +1,132 @@ +/* Copyright 2016 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifdef TENSORFLOW_USE_MPI + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { +namespace contrib { +namespace mpi_collectives { + +REGISTER_OP("MPIInit").Doc(R"doc( +Initialize MPI for the current process. + +If this is run on a GPU, then that GPU must be used for all future MPI +operations. If it is run on CPU, then all future MPI operations must also +run on CPU. +)doc"); + +REGISTER_OP("MPISize") + .Output("size: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the number of running MPI processes. + +More precisely, returns the number of MPI processes in the group associated +with the MPI_COMM_WORLD communicator. + +size: Size of the MPI group. +)doc"); + +REGISTER_OP("MPIRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the MPI group. + +More precisely, returns the rank of the calling process in the MPI_COMM_WORLD +communicator. + +rank: Rank of the calling process. +)doc"); + +REGISTER_OP("MPILocalRank") + .Output("rank: int32") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + return Status::OK(); + }) + .Doc(R"doc( +Returns the index of the current process in the node it is on. + +More precisely, returns the rank of the calling process in communicator that +only spans the MPI processes running on that node. + +rank: Rank of the calling process on the node it is on. +)doc"); + +REGISTER_OP("MPIAllreduce") + .Attr("T: {int32, int64, float32}") + .Input("tensor: T") + .Output("sum: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allreduce on a tensor. All other processes that do a reduction +on a tensor with the same name must have the same dimension for that tensor. +Tensors are reduced with other tensors that have the same node name for the +allreduce. + +Arguments + tensor: A tensor to reduce. + +Output + sum: A tensor with the same shape as `tensor`, summed across all + MPI processes. +)doc"); + +REGISTER_OP("MPIAllgather") + .Attr("T: {int32, int64, float32}") + .Attr("S: {int64}") + .Input("tensor: T") + .Input("sizes: S") + .Output("gathered: T") + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle output; + TF_RETURN_IF_ERROR( + c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output)); + c->set_output(0, output); + return Status::OK(); + }) + .Doc(R"doc( +Perform an MPI Allgather on a tensor. All other processes that do a gather on a +tensor with the same name must have the same rank for that tensor, and have the +same dimension on all but the first dimension. + +Arguments + tensor: A tensor to gather. + sizes: A tensor containing the first-dimension sizes of tensors to be + gathered from other ranks + +Output + gathered: A tensor with the same shape as `tensor` except for the first + dimension, which is the sum of dimensions in `sizes`. +)doc"); + +} // namespace mpi_collectives +} // namespace contrib +} // namespace tensorflow + +#endif // TENSORFLOW_USE_MPI diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.py b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py index 81567cc688..f0a116239d 100644 --- a/tensorflow/contrib/mpi_collectives/mpi_ops.py +++ b/tensorflow/contrib/mpi_collectives/python/ops/mpi_ops.py @@ -20,44 +20,13 @@ from __future__ import print_function import tensorflow as tf -from tensorflow.python.framework import errors -from tensorflow.python.framework import load_library +from tensorflow.contrib.mpi_collectives.ops import gen_mpi_ops +from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader -from tensorflow.python.platform import tf_logging as logging - - -def _load_library(name, op_list=None): - """Loads a .so file containing the specified operators. - - Args: - name: The name of the .so file to load. - op_list: A list of names of operators that the library should have. If None - then the .so file's contents will not be verified. - - Raises: - NameError if one of the required ops is missing. - """ - try: - filename = resource_loader.get_path_to_datafile(name) - library = load_library.load_op_library(filename) - for expected_op in (op_list or []): - for lib_op in library.OP_LIST.op: - if lib_op.name == expected_op: - break - else: - raise NameError( - 'Could not find operator %s in dynamic library %s' % - (expected_op, name)) - return library - except errors.NotFoundError: - logging.warning('%s file could not be loaded.', name) - - -MPI_LIB = _load_library('mpi_collectives.so', ['MPISize', 'MPIRank', - 'MPILocalRank', 'MPIAllgather', - 'MPIAllreduce']) +_mpi_ops_so = loader.load_op_library( + resource_loader.get_path_to_datafile("_mpi_ops.so")) def size(name=None): """An op which returns the number of MPI processes. @@ -68,7 +37,7 @@ def size(name=None): Returns: An integer scalar containing the number of MPI processes. """ - return MPI_LIB.mpi_size(name=name) + return gen_mpi_ops.mpi_size(name=name) ops.NotDifferentiable('MPISize') @@ -83,7 +52,7 @@ def rank(name=None): Returns: An integer scalar with the MPI rank of the calling process. """ - return MPI_LIB.mpi_rank(name=name) + return gen_mpi_ops.mpi_rank(name=name) ops.NotDifferentiable('MPIRank') @@ -95,7 +64,7 @@ def init(name=None): All future MPI ops must be run on the same device that the `init` op was run on. """ - return MPI_LIB.mpi_init(name=name) + return gen_mpi_ops.mpi_init(name=name) ops.NotDifferentiable('MPIInit') @@ -112,7 +81,7 @@ def local_rank(name=None): Returns: An integer scalar with the local MPI rank of the calling process. """ - return MPI_LIB.mpi_local_rank(name=name) + return gen_mpi_ops.mpi_local_rank(name=name) ops.NotDifferentiable('MPILocalRank') @@ -129,7 +98,7 @@ def _allreduce(tensor, name=None): A tensor of the same shape and type as `tensor`, summed across all processes. """ - return MPI_LIB.mpi_allreduce(tensor, name=name) + return gen_mpi_ops.mpi_allreduce(tensor, name=name) ops.NotDifferentiable('MPIAllreduce') @@ -156,8 +125,8 @@ def allgather(tensor, name=None): if name is None: name = "allgather" sizing_name = "{}_sizing".format(name) - sizes = MPI_LIB.mpi_allgather(my_size, sizes_flag, name=sizing_name) - return MPI_LIB.mpi_allgather(tensor, sizes, name=name) + sizes = gen_mpi_ops.mpi_allgather(my_size, sizes_flag, name=sizing_name) + return gen_mpi_ops.mpi_allgather(tensor, sizes, name=name) ops.NotDifferentiable('MPIAllgather') |